Add protocol for authenticating a unix socket
This commit is contained in:
parent
dc29dd274a
commit
20e60ca5c7
|
@ -99,6 +99,21 @@ version = "1.0.82"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
|
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-bincode"
|
||||||
|
version = "0.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "21849a990d47109757e820904d7c0b569a8013f6595bf14d911884634d58795f"
|
||||||
|
dependencies = [
|
||||||
|
"bincode",
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"serde",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atoi"
|
name = "atoi"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
|
@ -141,6 +156,15 @@ version = "1.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bincode"
|
||||||
|
version = "1.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bitflags"
|
name = "bitflags"
|
||||||
version = "1.3.2"
|
version = "1.3.2"
|
||||||
|
@ -555,6 +579,21 @@ dependencies = [
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.30"
|
version = "0.3.30"
|
||||||
|
@ -599,6 +638,17 @@ version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.30"
|
version = "0.3.30"
|
||||||
|
@ -617,8 +667,10 @@ version = "0.3.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-io",
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"memchr",
|
"memchr",
|
||||||
|
@ -906,20 +958,27 @@ name = "mysqladm-rs"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-bincode",
|
||||||
|
"bincode",
|
||||||
"clap",
|
"clap",
|
||||||
"dialoguer",
|
"dialoguer",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"futures",
|
||||||
"indoc",
|
"indoc",
|
||||||
"itertools",
|
"itertools",
|
||||||
"log",
|
"log",
|
||||||
"nix",
|
"nix",
|
||||||
"prettytable",
|
"prettytable",
|
||||||
|
"rand",
|
||||||
"ratatui",
|
"ratatui",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
"toml",
|
"toml",
|
||||||
|
"uuid",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1812,18 +1871,18 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.58"
|
version = "1.0.63"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.58"
|
version = "1.0.63"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
|
@ -1883,6 +1942,19 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-util"
|
||||||
|
version = "0.7.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml"
|
name = "toml"
|
||||||
version = "0.8.12"
|
version = "0.8.12"
|
||||||
|
@ -2023,6 +2095,15 @@ version = "0.2.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uuid"
|
||||||
|
version = "1.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314"
|
||||||
|
dependencies = [
|
||||||
|
"getrandom",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "vcpkg"
|
name = "vcpkg"
|
||||||
version = "0.2.15"
|
version = "0.2.15"
|
||||||
|
|
|
@ -5,20 +5,27 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.82"
|
anyhow = "1.0.82"
|
||||||
|
async-bincode = "0.7.2"
|
||||||
|
bincode = "1.3.3"
|
||||||
clap = { version = "4.5.4", features = ["derive"] }
|
clap = { version = "4.5.4", features = ["derive"] }
|
||||||
dialoguer = "0.11.0"
|
dialoguer = "0.11.0"
|
||||||
env_logger = "0.11.3"
|
env_logger = "0.11.3"
|
||||||
|
futures = "0.3.30"
|
||||||
indoc = "2.0.5"
|
indoc = "2.0.5"
|
||||||
itertools = "0.12.1"
|
itertools = "0.12.1"
|
||||||
log = "0.4.21"
|
log = "0.4.21"
|
||||||
nix = { version = "0.28.0", features = ["user"] }
|
nix = { version = "0.28.0", features = ["fs", "user"] }
|
||||||
prettytable = "0.10.0"
|
prettytable = "0.10.0"
|
||||||
|
rand = "0.8.5"
|
||||||
ratatui = { version = "0.26.2", optional = true }
|
ratatui = { version = "0.26.2", optional = true }
|
||||||
serde = "1.0.198"
|
serde = "1.0.198"
|
||||||
serde_json = { version = "1.0.116", features = ["preserve_order"] }
|
serde_json = { version = "1.0.116", features = ["preserve_order"] }
|
||||||
sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] }
|
sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] }
|
||||||
|
thiserror = "1.0.63"
|
||||||
tokio = { version = "1.37.0", features = ["rt", "macros"] }
|
tokio = { version = "1.37.0", features = ["rt", "macros"] }
|
||||||
|
tokio-util = "0.7.11"
|
||||||
toml = "0.8.12"
|
toml = "0.8.12"
|
||||||
|
uuid = { version = "1.10.0", features = ["v4"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["mysql-admutils-compatibility"]
|
default = ["mysql-admutils-compatibility"]
|
||||||
|
|
|
@ -0,0 +1,454 @@
|
||||||
|
//! This module provides a way to authenticate a client uid to a server over a Unix socket.
|
||||||
|
//! This is needed so that the server can trust the client's uid, which it depends on to
|
||||||
|
//! make modifications for that user in the database. It is crucial that the server can trust
|
||||||
|
//! that the client is the unix user it claims to be.
|
||||||
|
//!
|
||||||
|
//! This works by having the client respond to a challenge on a socket that is verifiably owned
|
||||||
|
//! by the client. In more detailed steps, the following should happen:
|
||||||
|
//!
|
||||||
|
//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644
|
||||||
|
//! and owned by the uid of the current user.
|
||||||
|
//! 2. The client opens a request to the server on the "normal" socket where the server is listening,
|
||||||
|
//! In this request, the client should include the following:
|
||||||
|
//! - The address of it's authentication socket
|
||||||
|
//! - The uid of the user currently using the client
|
||||||
|
//! - A challenge string that has been randomly generated
|
||||||
|
//! 3. The server validates the following:
|
||||||
|
//! - The address of the auth socket is valid
|
||||||
|
//! - The owner of the auth socket address is the same as the uid
|
||||||
|
//! 4. Server connects to the auth socket address and receives another challenge string.
|
||||||
|
//! The server should close the connection after receiving the challenge string.
|
||||||
|
//! 5. Server verifies that the challenge is the same as the one it originally received.
|
||||||
|
//! It responds to the client with an "Authenticated" message if the challenge matches,
|
||||||
|
//! or an error message if it does not.
|
||||||
|
//! 6. Client closes the authentication socket. Normal socket is used for communication.
|
||||||
|
//!
|
||||||
|
//! Note that the server can at any point in the process send an error message to the client
|
||||||
|
//! over it's initial connection, and the client should respond by closing the authentication
|
||||||
|
//! socket, it's connection to the normal socket, and reporting the error to the user.
|
||||||
|
//!
|
||||||
|
//! Also note that it is essential that the client does not send any sensitive information
|
||||||
|
//! over it's authentication socket, since it is readable by any user on the system.
|
||||||
|
|
||||||
|
use std::os::unix::io::AsRawFd;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use nix::{sys::stat, unistd::Uid};
|
||||||
|
use rand::distributions::Alphanumeric;
|
||||||
|
use rand::Rng;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum ClientRequest {
|
||||||
|
Initialize {
|
||||||
|
uid: u32,
|
||||||
|
challenge: u64,
|
||||||
|
auth_socket: String,
|
||||||
|
},
|
||||||
|
Cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum ServerResponse {
|
||||||
|
Authenticated,
|
||||||
|
ChallengeDidNotMatch,
|
||||||
|
ServerError(ServerError),
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: wrap more data into the errors
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
|
||||||
|
pub enum ServerError {
|
||||||
|
InvalidRequest,
|
||||||
|
UnableToReadPermissionsFromAuthSocket,
|
||||||
|
CouldNotConnectToAuthSocket,
|
||||||
|
AuthSocketClosedEarly,
|
||||||
|
UidMismatch,
|
||||||
|
ChallengeMismatch,
|
||||||
|
InvalidChallenge,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub enum ClientError {
|
||||||
|
UnableToConnectToServer,
|
||||||
|
UnableToOpenAuthSocket,
|
||||||
|
UnableToConfigureAuthSocket,
|
||||||
|
AuthSocketClosedEarly,
|
||||||
|
UnableToCloseAuthSocket,
|
||||||
|
AuthenticationError,
|
||||||
|
InvalidServerResponse(ServerResponse),
|
||||||
|
UnableToParseServerResponse,
|
||||||
|
NoServerResponse,
|
||||||
|
ServerError(ServerError),
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_auth_socket(socket_addr: &str) -> Result<UnixListener, ClientError> {
|
||||||
|
let auth_socket =
|
||||||
|
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
|
||||||
|
|
||||||
|
stat::fchmod(
|
||||||
|
auth_socket.as_raw_fd(),
|
||||||
|
stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP,
|
||||||
|
)
|
||||||
|
.map_err(|_err| ClientError::UnableToConfigureAuthSocket)?;
|
||||||
|
|
||||||
|
Ok(auth_socket)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientToServerStream<'a> =
|
||||||
|
AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>;
|
||||||
|
type ServerToClientStream<'a> =
|
||||||
|
AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>;
|
||||||
|
|
||||||
|
// TODO: make the challenge constant size and use socket directly, this is overkill
|
||||||
|
type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>;
|
||||||
|
|
||||||
|
// TODO: add timeout
|
||||||
|
|
||||||
|
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
|
||||||
|
pub async fn client_authenticate(
|
||||||
|
normal_socket: &mut UnixStream,
|
||||||
|
#[cfg(not(test))] auth_socket_dir: Option<PathBuf>,
|
||||||
|
#[cfg(test)] auth_socket_file: Option<PathBuf>,
|
||||||
|
) -> Result<(), ClientError> {
|
||||||
|
let random_prefix: String = rand::thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(16)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
|
||||||
|
|
||||||
|
#[cfg(not(test))]
|
||||||
|
let auth_socket_address = match auth_socket_dir {
|
||||||
|
Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(),
|
||||||
|
None => std::env::temp_dir()
|
||||||
|
.join(socket_name)
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
let auth_socket_address = match auth_socket_file {
|
||||||
|
Some(file) => file.to_str().unwrap().to_string(),
|
||||||
|
None => std::env::temp_dir()
|
||||||
|
.join(socket_name)
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn client_authenticate_with_auth_socket_address(
|
||||||
|
normal_socket: &mut UnixStream,
|
||||||
|
auth_socket_address: &str,
|
||||||
|
) -> Result<(), ClientError> {
|
||||||
|
let auth_socket = create_auth_socket(auth_socket_address).await?;
|
||||||
|
|
||||||
|
let result =
|
||||||
|
client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await;
|
||||||
|
|
||||||
|
std::fs::remove_file(auth_socket_address)
|
||||||
|
.map_err(|_err| ClientError::UnableToCloseAuthSocket)?;
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn client_authenticate_with_auth_socket(
|
||||||
|
normal_socket: &mut UnixStream,
|
||||||
|
auth_socket: UnixListener,
|
||||||
|
auth_socket_address: &str,
|
||||||
|
) -> Result<(), ClientError> {
|
||||||
|
let challenge = rand::random::<u64>();
|
||||||
|
let uid = nix::unistd::getuid();
|
||||||
|
|
||||||
|
let mut normal_socket: ClientToServerStream =
|
||||||
|
AsyncBincodeStream::from(normal_socket).for_async();
|
||||||
|
|
||||||
|
let challenge_replier_cancellation_token = CancellationToken::new();
|
||||||
|
let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone();
|
||||||
|
let challenge_replier_handle = tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
socket = auth_socket.accept() =>
|
||||||
|
{
|
||||||
|
match socket {
|
||||||
|
Ok((mut conn, _addr)) => {
|
||||||
|
let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async();
|
||||||
|
stream.send(challenge).await.ok();
|
||||||
|
stream.close().await.ok();
|
||||||
|
}
|
||||||
|
Err(_err) => return Err(ClientError::AuthSocketClosedEarly),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = challenge_replier_cancellation_token_clone.cancelled() => {
|
||||||
|
break Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let client_hello = ClientRequest::Initialize {
|
||||||
|
uid: uid.into(),
|
||||||
|
challenge,
|
||||||
|
auth_socket: auth_socket_address.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
normal_socket
|
||||||
|
.send(client_hello)
|
||||||
|
.await
|
||||||
|
.map_err(|err| match *err {
|
||||||
|
bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer,
|
||||||
|
_ => ClientError::NoServerResponse,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match normal_socket.next().await {
|
||||||
|
Some(Ok(ServerResponse::Authenticated)) => {}
|
||||||
|
Some(Ok(ServerResponse::ChallengeDidNotMatch)) => {
|
||||||
|
return Err(ClientError::AuthenticationError)
|
||||||
|
}
|
||||||
|
Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)),
|
||||||
|
Some(Err(err)) => match *err {
|
||||||
|
bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse),
|
||||||
|
_ => return Err(ClientError::UnableToParseServerResponse),
|
||||||
|
},
|
||||||
|
None => return Err(ClientError::NoServerResponse),
|
||||||
|
}
|
||||||
|
|
||||||
|
challenge_replier_cancellation_token.cancel();
|
||||||
|
challenge_replier_handle.await.ok();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! report_server_error_and_return {
|
||||||
|
($normal_socket:expr, $err:expr) => {{
|
||||||
|
$normal_socket
|
||||||
|
.send(ServerResponse::ServerError($err))
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
return Err($err);
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn server_authenticate(
|
||||||
|
normal_socket: &mut UnixStream,
|
||||||
|
#[cfg(test)] unix_user_uid: Option<u32>,
|
||||||
|
) -> Result<Uid, ServerError> {
|
||||||
|
let mut normal_socket: ServerToClientStream =
|
||||||
|
AsyncBincodeStream::from(normal_socket).for_async();
|
||||||
|
|
||||||
|
let (uid, challenge, auth_socket) = match normal_socket.next().await {
|
||||||
|
Some(Ok(ClientRequest::Initialize {
|
||||||
|
uid,
|
||||||
|
challenge,
|
||||||
|
auth_socket,
|
||||||
|
})) => (uid, challenge, auth_socket),
|
||||||
|
// TODO: more granular errros
|
||||||
|
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
let auth_socket_uid = match unix_user_uid {
|
||||||
|
Some(uid) => uid,
|
||||||
|
None => report_server_error_and_return!(
|
||||||
|
normal_socket,
|
||||||
|
ServerError::UnableToReadPermissionsFromAuthSocket
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(test))]
|
||||||
|
let auth_socket_uid = match stat::stat(auth_socket.as_str()) {
|
||||||
|
Ok(stat) => stat.st_uid,
|
||||||
|
Err(_err) => report_server_error_and_return!(
|
||||||
|
normal_socket,
|
||||||
|
ServerError::UnableToReadPermissionsFromAuthSocket
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
if uid != auth_socket_uid {
|
||||||
|
report_server_error_and_return!(normal_socket, ServerError::UidMismatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await {
|
||||||
|
Ok(socket) => socket,
|
||||||
|
Err(_err) => {
|
||||||
|
report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut authenticated_unix_socket: AuthStream =
|
||||||
|
AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async();
|
||||||
|
|
||||||
|
let challenge_2 = match authenticated_unix_socket.next().await {
|
||||||
|
Some(Ok(challenge)) => challenge,
|
||||||
|
Some(Err(_)) => {
|
||||||
|
report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge)
|
||||||
|
}
|
||||||
|
None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly),
|
||||||
|
};
|
||||||
|
|
||||||
|
authenticated_unix_socket.close().await.ok();
|
||||||
|
|
||||||
|
if challenge != challenge_2 {
|
||||||
|
normal_socket
|
||||||
|
.send(ServerResponse::ChallengeDidNotMatch)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
return Err(ServerError::ChallengeMismatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
normal_socket.send(ServerResponse::Authenticated).await.ok();
|
||||||
|
|
||||||
|
Ok(Uid::from_raw(uid))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_valid_authentication() {
|
||||||
|
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||||
|
|
||||||
|
let client_handle =
|
||||||
|
tokio::spawn(async move { client_authenticate(&mut client, None).await });
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
let uid = nix::unistd::getuid().into();
|
||||||
|
server_authenticate(&mut server, Some(uid)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
client_handle.await.unwrap().unwrap();
|
||||||
|
server_handle.await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_ensure_auth_socket_does_not_exist() {
|
||||||
|
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||||
|
|
||||||
|
let client_handle = tokio::spawn(async move {
|
||||||
|
client_authenticate_with_auth_socket_address(
|
||||||
|
&mut client,
|
||||||
|
"/tmp/test_auth_socket_does_not_exist.sock",
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
let uid = nix::unistd::getuid().into();
|
||||||
|
server_authenticate(&mut server, Some(uid)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
client_handle.await.unwrap().unwrap();
|
||||||
|
server_handle.await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_uid_mismatch() {
|
||||||
|
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||||
|
|
||||||
|
let client_handle = tokio::spawn(async move {
|
||||||
|
let err = client_authenticate(&mut client, None).await;
|
||||||
|
assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch)));
|
||||||
|
});
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
let uid: u32 = nix::unistd::getuid().into();
|
||||||
|
let err = server_authenticate(&mut server, Some(uid + 1)).await;
|
||||||
|
assert_eq!(err, Err(ServerError::UidMismatch));
|
||||||
|
});
|
||||||
|
|
||||||
|
client_handle.await.unwrap();
|
||||||
|
server_handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_snooping_connection() {
|
||||||
|
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||||
|
|
||||||
|
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
|
||||||
|
let socket_path_clone = socket_path.clone();
|
||||||
|
let client_handle =
|
||||||
|
tokio::spawn(
|
||||||
|
async move { client_authenticate(&mut client, Some(socket_path_clone)).await },
|
||||||
|
);
|
||||||
|
|
||||||
|
while !socket_path.exists() {
|
||||||
|
sleep(std::time::Duration::from_millis(10)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||||
|
let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async();
|
||||||
|
let message = snooper.next().await.unwrap().unwrap();
|
||||||
|
|
||||||
|
let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||||
|
let mut other_snooper: AuthStream =
|
||||||
|
AsyncBincodeStream::from(&mut other_snooper).for_async();
|
||||||
|
let other_message = other_snooper.next().await.unwrap().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(message, other_message);
|
||||||
|
|
||||||
|
let third_snooper_handle = tokio::spawn(async move {
|
||||||
|
let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||||
|
let mut third_snooper: AuthStream =
|
||||||
|
AsyncBincodeStream::from(&mut third_snooper).for_async();
|
||||||
|
// NOTE: Should hang
|
||||||
|
third_snooper.send(1234).await.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
sleep(Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let server_handle = tokio::spawn(async move {
|
||||||
|
let uid: u32 = nix::unistd::getuid().into();
|
||||||
|
server_authenticate(&mut server, Some(uid)).await
|
||||||
|
});
|
||||||
|
|
||||||
|
client_handle.await.unwrap().unwrap();
|
||||||
|
server_handle.await.unwrap().unwrap();
|
||||||
|
|
||||||
|
third_snooper_handle.abort();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_dead_server() {
|
||||||
|
let (mut client, server) = UnixStream::pair().unwrap();
|
||||||
|
std::mem::drop(server);
|
||||||
|
|
||||||
|
let client_handle = tokio::spawn(async move {
|
||||||
|
let err = client_authenticate(&mut client, None).await;
|
||||||
|
assert_eq!(err, Err(ClientError::UnableToConnectToServer));
|
||||||
|
});
|
||||||
|
|
||||||
|
client_handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_no_response_from_server() {
|
||||||
|
let (mut client, server) = UnixStream::pair().unwrap();
|
||||||
|
|
||||||
|
let client_handle = tokio::spawn(async move {
|
||||||
|
let err = client_authenticate(&mut client, None).await;
|
||||||
|
assert_eq!(err, Err(ClientError::NoServerResponse));
|
||||||
|
});
|
||||||
|
|
||||||
|
sleep(Duration::from_millis(200)).await;
|
||||||
|
|
||||||
|
std::mem::drop(server);
|
||||||
|
|
||||||
|
client_handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Test challenge mismatch
|
||||||
|
// TODO: Test invoking server with junk data
|
||||||
|
}
|
|
@ -10,6 +10,7 @@ use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
|
mod authenticated_unix_socket;
|
||||||
mod cli;
|
mod cli;
|
||||||
mod core;
|
mod core;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue