diff --git a/Cargo.lock b/Cargo.lock index 58fa798..02ca9cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,6 +99,21 @@ version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +[[package]] +name = "async-bincode" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21849a990d47109757e820904d7c0b569a8013f6595bf14d911884634d58795f" +dependencies = [ + "bincode", + "byteorder", + "bytes", + "futures-core", + "futures-sink", + "serde", + "tokio", +] + [[package]] name = "atoi" version = "2.0.0" @@ -141,6 +156,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -555,6 +579,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -599,6 +638,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -617,8 +667,10 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -906,20 +958,27 @@ name = "mysqladm-rs" version = "0.1.0" dependencies = [ "anyhow", + "async-bincode", + "bincode", "clap", "dialoguer", "env_logger", + "futures", "indoc", "itertools", "log", "nix", "prettytable", + "rand", "ratatui", "serde", "serde_json", "sqlx", + "thiserror", "tokio", + "tokio-util", "toml", + "uuid", ] [[package]] @@ -1812,18 +1871,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", @@ -1883,6 +1942,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.12" @@ -2023,6 +2095,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", +] + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 9242ae5..003066a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,20 +5,27 @@ edition = "2021" [dependencies] anyhow = "1.0.82" +async-bincode = "0.7.2" +bincode = "1.3.3" clap = { version = "4.5.4", features = ["derive"] } dialoguer = "0.11.0" env_logger = "0.11.3" +futures = "0.3.30" indoc = "2.0.5" itertools = "0.12.1" log = "0.4.21" -nix = { version = "0.28.0", features = ["user"] } +nix = { version = "0.28.0", features = ["fs", "user"] } prettytable = "0.10.0" +rand = "0.8.5" ratatui = { version = "0.26.2", optional = true } serde = "1.0.198" serde_json = { version = "1.0.116", features = ["preserve_order"] } sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] } +thiserror = "1.0.63" tokio = { version = "1.37.0", features = ["rt", "macros"] } +tokio-util = "0.7.11" toml = "0.8.12" +uuid = { version = "1.10.0", features = ["v4"] } [features] default = ["mysql-admutils-compatibility"] diff --git a/src/authenticated_unix_socket.rs b/src/authenticated_unix_socket.rs new file mode 100644 index 0000000..0092b5f --- /dev/null +++ b/src/authenticated_unix_socket.rs @@ -0,0 +1,454 @@ +//! This module provides a way to authenticate a client uid to a server over a Unix socket. +//! This is needed so that the server can trust the client's uid, which it depends on to +//! make modifications for that user in the database. It is crucial that the server can trust +//! that the client is the unix user it claims to be. +//! +//! This works by having the client respond to a challenge on a socket that is verifiably owned +//! by the client. In more detailed steps, the following should happen: +//! +//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644 +//! and owned by the uid of the current user. +//! 2. The client opens a request to the server on the "normal" socket where the server is listening, +//! In this request, the client should include the following: +//! - The address of it's authentication socket +//! - The uid of the user currently using the client +//! - A challenge string that has been randomly generated +//! 3. The server validates the following: +//! - The address of the auth socket is valid +//! - The owner of the auth socket address is the same as the uid +//! 4. Server connects to the auth socket address and receives another challenge string. +//! The server should close the connection after receiving the challenge string. +//! 5. Server verifies that the challenge is the same as the one it originally received. +//! It responds to the client with an "Authenticated" message if the challenge matches, +//! or an error message if it does not. +//! 6. Client closes the authentication socket. Normal socket is used for communication. +//! +//! Note that the server can at any point in the process send an error message to the client +//! over it's initial connection, and the client should respond by closing the authentication +//! socket, it's connection to the normal socket, and reporting the error to the user. +//! +//! Also note that it is essential that the client does not send any sensitive information +//! over it's authentication socket, since it is readable by any user on the system. + +use std::os::unix::io::AsRawFd; +use std::path::PathBuf; + +use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination}; +use futures::{SinkExt, StreamExt}; +use nix::{sys::stat, unistd::Uid}; +use rand::distributions::Alphanumeric; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use tokio::net::{UnixListener, UnixStream}; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ClientRequest { + Initialize { + uid: u32, + challenge: u64, + auth_socket: String, + }, + Cancel, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ServerResponse { + Authenticated, + ChallengeDidNotMatch, + ServerError(ServerError), +} + +// TODO: wrap more data into the errors + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub enum ServerError { + InvalidRequest, + UnableToReadPermissionsFromAuthSocket, + CouldNotConnectToAuthSocket, + AuthSocketClosedEarly, + UidMismatch, + ChallengeMismatch, + InvalidChallenge, +} + +#[derive(Debug, PartialEq)] +pub enum ClientError { + UnableToConnectToServer, + UnableToOpenAuthSocket, + UnableToConfigureAuthSocket, + AuthSocketClosedEarly, + UnableToCloseAuthSocket, + AuthenticationError, + InvalidServerResponse(ServerResponse), + UnableToParseServerResponse, + NoServerResponse, + ServerError(ServerError), +} + +async fn create_auth_socket(socket_addr: &str) -> Result { + let auth_socket = + UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?; + + stat::fchmod( + auth_socket.as_raw_fd(), + stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP, + ) + .map_err(|_err| ClientError::UnableToConfigureAuthSocket)?; + + Ok(auth_socket) +} + +type ClientToServerStream<'a> = + AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>; +type ServerToClientStream<'a> = + AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>; + +// TODO: make the challenge constant size and use socket directly, this is overkill +type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>; + +// TODO: add timeout + +const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock"; +pub async fn client_authenticate( + normal_socket: &mut UnixStream, + #[cfg(not(test))] auth_socket_dir: Option, + #[cfg(test)] auth_socket_file: Option, +) -> Result<(), ClientError> { + let random_prefix: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(16) + .map(char::from) + .collect(); + + let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME); + + #[cfg(not(test))] + let auth_socket_address = match auth_socket_dir { + Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(), + None => std::env::temp_dir() + .join(socket_name) + .to_str() + .unwrap() + .to_string(), + }; + + #[cfg(test)] + let auth_socket_address = match auth_socket_file { + Some(file) => file.to_str().unwrap().to_string(), + None => std::env::temp_dir() + .join(socket_name) + .to_str() + .unwrap() + .to_string(), + }; + + client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await +} + +async fn client_authenticate_with_auth_socket_address( + normal_socket: &mut UnixStream, + auth_socket_address: &str, +) -> Result<(), ClientError> { + let auth_socket = create_auth_socket(auth_socket_address).await?; + + let result = + client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await; + + std::fs::remove_file(auth_socket_address) + .map_err(|_err| ClientError::UnableToCloseAuthSocket)?; + + result +} + +async fn client_authenticate_with_auth_socket( + normal_socket: &mut UnixStream, + auth_socket: UnixListener, + auth_socket_address: &str, +) -> Result<(), ClientError> { + let challenge = rand::random::(); + let uid = nix::unistd::getuid(); + + let mut normal_socket: ClientToServerStream = + AsyncBincodeStream::from(normal_socket).for_async(); + + let challenge_replier_cancellation_token = CancellationToken::new(); + let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone(); + let challenge_replier_handle = tokio::spawn(async move { + loop { + tokio::select! { + socket = auth_socket.accept() => + { + match socket { + Ok((mut conn, _addr)) => { + let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async(); + stream.send(challenge).await.ok(); + stream.close().await.ok(); + } + Err(_err) => return Err(ClientError::AuthSocketClosedEarly), + } + } + + _ = challenge_replier_cancellation_token_clone.cancelled() => { + break Ok(()); + } + } + } + }); + + let client_hello = ClientRequest::Initialize { + uid: uid.into(), + challenge, + auth_socket: auth_socket_address.to_string(), + }; + + normal_socket + .send(client_hello) + .await + .map_err(|err| match *err { + bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer, + _ => ClientError::NoServerResponse, + })?; + + match normal_socket.next().await { + Some(Ok(ServerResponse::Authenticated)) => {} + Some(Ok(ServerResponse::ChallengeDidNotMatch)) => { + return Err(ClientError::AuthenticationError) + } + Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)), + Some(Err(err)) => match *err { + bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse), + _ => return Err(ClientError::UnableToParseServerResponse), + }, + None => return Err(ClientError::NoServerResponse), + } + + challenge_replier_cancellation_token.cancel(); + challenge_replier_handle.await.ok(); + + Ok(()) +} + +macro_rules! report_server_error_and_return { + ($normal_socket:expr, $err:expr) => {{ + $normal_socket + .send(ServerResponse::ServerError($err)) + .await + .ok(); + return Err($err); + }}; +} + +async fn server_authenticate( + normal_socket: &mut UnixStream, + #[cfg(test)] unix_user_uid: Option, +) -> Result { + let mut normal_socket: ServerToClientStream = + AsyncBincodeStream::from(normal_socket).for_async(); + + let (uid, challenge, auth_socket) = match normal_socket.next().await { + Some(Ok(ClientRequest::Initialize { + uid, + challenge, + auth_socket, + })) => (uid, challenge, auth_socket), + // TODO: more granular errros + _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest), + }; + + #[cfg(test)] + let auth_socket_uid = match unix_user_uid { + Some(uid) => uid, + None => report_server_error_and_return!( + normal_socket, + ServerError::UnableToReadPermissionsFromAuthSocket + ), + }; + + #[cfg(not(test))] + let auth_socket_uid = match stat::stat(auth_socket.as_str()) { + Ok(stat) => stat.st_uid, + Err(_err) => report_server_error_and_return!( + normal_socket, + ServerError::UnableToReadPermissionsFromAuthSocket + ), + }; + + if uid != auth_socket_uid { + report_server_error_and_return!(normal_socket, ServerError::UidMismatch); + } + + let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await { + Ok(socket) => socket, + Err(_err) => { + report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket) + } + }; + let mut authenticated_unix_socket: AuthStream = + AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async(); + + let challenge_2 = match authenticated_unix_socket.next().await { + Some(Ok(challenge)) => challenge, + Some(Err(_)) => { + report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge) + } + None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly), + }; + + authenticated_unix_socket.close().await.ok(); + + if challenge != challenge_2 { + normal_socket + .send(ServerResponse::ChallengeDidNotMatch) + .await + .ok(); + return Err(ServerError::ChallengeMismatch); + } + + normal_socket.send(ServerResponse::Authenticated).await.ok(); + + Ok(Uid::from_raw(uid)) +} + +#[cfg(test)] +mod test { + use super::*; + + use std::time::Duration; + use tokio::time::sleep; + + #[tokio::test] + async fn test_valid_authentication() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let client_handle = + tokio::spawn(async move { client_authenticate(&mut client, None).await }); + + let server_handle = tokio::spawn(async move { + let uid = nix::unistd::getuid().into(); + server_authenticate(&mut server, Some(uid)).await + }); + + client_handle.await.unwrap().unwrap(); + server_handle.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_ensure_auth_socket_does_not_exist() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let client_handle = tokio::spawn(async move { + client_authenticate_with_auth_socket_address( + &mut client, + "/tmp/test_auth_socket_does_not_exist.sock", + ) + .await + }); + + let server_handle = tokio::spawn(async move { + let uid = nix::unistd::getuid().into(); + server_authenticate(&mut server, Some(uid)).await + }); + + client_handle.await.unwrap().unwrap(); + server_handle.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_uid_mismatch() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let client_handle = tokio::spawn(async move { + let err = client_authenticate(&mut client, None).await; + assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch))); + }); + + let server_handle = tokio::spawn(async move { + let uid: u32 = nix::unistd::getuid().into(); + let err = server_authenticate(&mut server, Some(uid + 1)).await; + assert_eq!(err, Err(ServerError::UidMismatch)); + }); + + client_handle.await.unwrap(); + server_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_snooping_connection() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let socket_path = std::env::temp_dir().join("socket_to_snoop.sock"); + let socket_path_clone = socket_path.clone(); + let client_handle = + tokio::spawn( + async move { client_authenticate(&mut client, Some(socket_path_clone)).await }, + ); + + while !socket_path.exists() { + sleep(std::time::Duration::from_millis(10)).await; + } + + let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); + let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async(); + let message = snooper.next().await.unwrap().unwrap(); + + let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); + let mut other_snooper: AuthStream = + AsyncBincodeStream::from(&mut other_snooper).for_async(); + let other_message = other_snooper.next().await.unwrap().unwrap(); + + assert_eq!(message, other_message); + + let third_snooper_handle = tokio::spawn(async move { + let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); + let mut third_snooper: AuthStream = + AsyncBincodeStream::from(&mut third_snooper).for_async(); + // NOTE: Should hang + third_snooper.send(1234).await.unwrap() + }); + + sleep(Duration::from_millis(10)).await; + + let server_handle = tokio::spawn(async move { + let uid: u32 = nix::unistd::getuid().into(); + server_authenticate(&mut server, Some(uid)).await + }); + + client_handle.await.unwrap().unwrap(); + server_handle.await.unwrap().unwrap(); + + third_snooper_handle.abort(); + } + + #[tokio::test] + async fn test_dead_server() { + let (mut client, server) = UnixStream::pair().unwrap(); + std::mem::drop(server); + + let client_handle = tokio::spawn(async move { + let err = client_authenticate(&mut client, None).await; + assert_eq!(err, Err(ClientError::UnableToConnectToServer)); + }); + + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_no_response_from_server() { + let (mut client, server) = UnixStream::pair().unwrap(); + + let client_handle = tokio::spawn(async move { + let err = client_authenticate(&mut client, None).await; + assert_eq!(err, Err(ClientError::NoServerResponse)); + }); + + sleep(Duration::from_millis(200)).await; + + std::mem::drop(server); + + client_handle.await.unwrap(); + } + + // TODO: Test challenge mismatch + // TODO: Test invoking server with junk data +} diff --git a/src/main.rs b/src/main.rs index d46d92f..2622c50 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; use clap::Parser; +mod authenticated_unix_socket; mod cli; mod core;