From 20e60ca5c79788effa2c501f7a770c4b4477a258 Mon Sep 17 00:00:00 2001
From: h7x4 <h7x4@nani.wtf>
Date: Fri, 9 Aug 2024 19:08:48 +0200
Subject: [PATCH] Add protocol for authenticating a unix socket

---
 Cargo.lock                       |  89 +++++-
 Cargo.toml                       |   9 +-
 src/authenticated_unix_socket.rs | 454 +++++++++++++++++++++++++++++++
 src/main.rs                      |   1 +
 4 files changed, 548 insertions(+), 5 deletions(-)
 create mode 100644 src/authenticated_unix_socket.rs

diff --git a/Cargo.lock b/Cargo.lock
index 58fa798..02ca9cf 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -99,6 +99,21 @@ version = "1.0.82"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
 
+[[package]]
+name = "async-bincode"
+version = "0.7.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "21849a990d47109757e820904d7c0b569a8013f6595bf14d911884634d58795f"
+dependencies = [
+ "bincode",
+ "byteorder",
+ "bytes",
+ "futures-core",
+ "futures-sink",
+ "serde",
+ "tokio",
+]
+
 [[package]]
 name = "atoi"
 version = "2.0.0"
@@ -141,6 +156,15 @@ version = "1.6.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
 
+[[package]]
+name = "bincode"
+version = "1.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
+dependencies = [
+ "serde",
+]
+
 [[package]]
 name = "bitflags"
 version = "1.3.2"
@@ -555,6 +579,21 @@ dependencies = [
  "percent-encoding",
 ]
 
+[[package]]
+name = "futures"
+version = "0.3.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
+dependencies = [
+ "futures-channel",
+ "futures-core",
+ "futures-executor",
+ "futures-io",
+ "futures-sink",
+ "futures-task",
+ "futures-util",
+]
+
 [[package]]
 name = "futures-channel"
 version = "0.3.30"
@@ -599,6 +638,17 @@ version = "0.3.30"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
 
+[[package]]
+name = "futures-macro"
+version = "0.3.30"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.60",
+]
+
 [[package]]
 name = "futures-sink"
 version = "0.3.30"
@@ -617,8 +667,10 @@ version = "0.3.30"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
 dependencies = [
+ "futures-channel",
  "futures-core",
  "futures-io",
+ "futures-macro",
  "futures-sink",
  "futures-task",
  "memchr",
@@ -906,20 +958,27 @@ name = "mysqladm-rs"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-bincode",
+ "bincode",
  "clap",
  "dialoguer",
  "env_logger",
+ "futures",
  "indoc",
  "itertools",
  "log",
  "nix",
  "prettytable",
+ "rand",
  "ratatui",
  "serde",
  "serde_json",
  "sqlx",
+ "thiserror",
  "tokio",
+ "tokio-util",
  "toml",
+ "uuid",
 ]
 
 [[package]]
@@ -1812,18 +1871,18 @@ dependencies = [
 
 [[package]]
 name = "thiserror"
-version = "1.0.58"
+version = "1.0.63"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
+checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
 dependencies = [
  "thiserror-impl",
 ]
 
 [[package]]
 name = "thiserror-impl"
-version = "1.0.58"
+version = "1.0.63"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
+checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -1883,6 +1942,19 @@ dependencies = [
  "tokio",
 ]
 
+[[package]]
+name = "tokio-util"
+version = "0.7.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
+dependencies = [
+ "bytes",
+ "futures-core",
+ "futures-sink",
+ "pin-project-lite",
+ "tokio",
+]
+
 [[package]]
 name = "toml"
 version = "0.8.12"
@@ -2023,6 +2095,15 @@ version = "0.2.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
 
+[[package]]
+name = "uuid"
+version = "1.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314"
+dependencies = [
+ "getrandom",
+]
+
 [[package]]
 name = "vcpkg"
 version = "0.2.15"
diff --git a/Cargo.toml b/Cargo.toml
index 9242ae5..003066a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -5,20 +5,27 @@ edition = "2021"
 
 [dependencies]
 anyhow = "1.0.82"
+async-bincode = "0.7.2"
+bincode = "1.3.3"
 clap = { version = "4.5.4", features = ["derive"] }
 dialoguer = "0.11.0"
 env_logger = "0.11.3"
+futures = "0.3.30"
 indoc = "2.0.5"
 itertools = "0.12.1"
 log = "0.4.21"
-nix = { version = "0.28.0", features = ["user"] }
+nix = { version = "0.28.0", features = ["fs", "user"] }
 prettytable = "0.10.0"
+rand = "0.8.5"
 ratatui = { version = "0.26.2", optional = true }
 serde = "1.0.198"
 serde_json = { version = "1.0.116", features = ["preserve_order"] }
 sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] }
+thiserror = "1.0.63"
 tokio = { version = "1.37.0", features = ["rt", "macros"] }
+tokio-util = "0.7.11"
 toml = "0.8.12"
+uuid = { version = "1.10.0", features = ["v4"] }
 
 [features]
 default = ["mysql-admutils-compatibility"]
diff --git a/src/authenticated_unix_socket.rs b/src/authenticated_unix_socket.rs
new file mode 100644
index 0000000..0092b5f
--- /dev/null
+++ b/src/authenticated_unix_socket.rs
@@ -0,0 +1,454 @@
+//! This module provides a way to authenticate a client uid to a server over a Unix socket.
+//! This is needed so that the server can trust the client's uid, which it depends on to
+//! make modifications for that user in the database. It is crucial that the server can trust
+//! that the client is the unix user it claims to be.
+//!
+//! This works by having the client respond to a challenge on a socket that is verifiably owned
+//! by the client. In more detailed steps, the following should happen:
+//!
+//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644
+//!    and owned by the uid of the current user.
+//! 2. The client opens a request to the server on the "normal" socket where the server is listening,
+//!    In this request, the client should include the following:
+//!      - The address of it's authentication socket
+//!      - The uid of the user currently using the client
+//!      - A challenge string that has been randomly generated
+//! 3. The server validates the following:
+//!      - The address of the auth socket is valid
+//!      - The owner of the auth socket address is the same as the uid
+//! 4. Server connects to the auth socket address and receives another challenge string.
+//!    The server should close the connection after receiving the challenge string.
+//! 5. Server verifies that the challenge is the same as the one it originally received.
+//!    It responds to the client with an "Authenticated" message if the challenge matches,
+//!    or an error message if it does not.
+//! 6. Client closes the authentication socket. Normal socket is used for communication.
+//!
+//! Note that the server can at any point in the process send an error message to the client
+//! over it's initial connection, and the client should respond by closing the authentication
+//! socket, it's connection to the normal socket, and reporting the error to the user.
+//!
+//! Also note that it is essential that the client does not send any sensitive information
+//! over it's authentication socket, since it is readable by any user on the system.
+
+use std::os::unix::io::AsRawFd;
+use std::path::PathBuf;
+
+use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
+use futures::{SinkExt, StreamExt};
+use nix::{sys::stat, unistd::Uid};
+use rand::distributions::Alphanumeric;
+use rand::Rng;
+use serde::{Deserialize, Serialize};
+use tokio::net::{UnixListener, UnixStream};
+use tokio_util::sync::CancellationToken;
+
+#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub enum ClientRequest {
+    Initialize {
+        uid: u32,
+        challenge: u64,
+        auth_socket: String,
+    },
+    Cancel,
+}
+
+#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
+pub enum ServerResponse {
+    Authenticated,
+    ChallengeDidNotMatch,
+    ServerError(ServerError),
+}
+
+// TODO: wrap more data into the errors
+
+#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
+pub enum ServerError {
+    InvalidRequest,
+    UnableToReadPermissionsFromAuthSocket,
+    CouldNotConnectToAuthSocket,
+    AuthSocketClosedEarly,
+    UidMismatch,
+    ChallengeMismatch,
+    InvalidChallenge,
+}
+
+#[derive(Debug, PartialEq)]
+pub enum ClientError {
+    UnableToConnectToServer,
+    UnableToOpenAuthSocket,
+    UnableToConfigureAuthSocket,
+    AuthSocketClosedEarly,
+    UnableToCloseAuthSocket,
+    AuthenticationError,
+    InvalidServerResponse(ServerResponse),
+    UnableToParseServerResponse,
+    NoServerResponse,
+    ServerError(ServerError),
+}
+
+async fn create_auth_socket(socket_addr: &str) -> Result<UnixListener, ClientError> {
+    let auth_socket =
+        UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
+
+    stat::fchmod(
+        auth_socket.as_raw_fd(),
+        stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP,
+    )
+    .map_err(|_err| ClientError::UnableToConfigureAuthSocket)?;
+
+    Ok(auth_socket)
+}
+
+type ClientToServerStream<'a> =
+    AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>;
+type ServerToClientStream<'a> =
+    AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>;
+
+// TODO: make the challenge constant size and use socket directly, this is overkill
+type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>;
+
+// TODO: add timeout
+
+const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
+pub async fn client_authenticate(
+    normal_socket: &mut UnixStream,
+    #[cfg(not(test))] auth_socket_dir: Option<PathBuf>,
+    #[cfg(test)] auth_socket_file: Option<PathBuf>,
+) -> Result<(), ClientError> {
+    let random_prefix: String = rand::thread_rng()
+        .sample_iter(&Alphanumeric)
+        .take(16)
+        .map(char::from)
+        .collect();
+
+    let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
+
+    #[cfg(not(test))]
+    let auth_socket_address = match auth_socket_dir {
+        Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(),
+        None => std::env::temp_dir()
+            .join(socket_name)
+            .to_str()
+            .unwrap()
+            .to_string(),
+    };
+
+    #[cfg(test)]
+    let auth_socket_address = match auth_socket_file {
+        Some(file) => file.to_str().unwrap().to_string(),
+        None => std::env::temp_dir()
+            .join(socket_name)
+            .to_str()
+            .unwrap()
+            .to_string(),
+    };
+
+    client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
+}
+
+async fn client_authenticate_with_auth_socket_address(
+    normal_socket: &mut UnixStream,
+    auth_socket_address: &str,
+) -> Result<(), ClientError> {
+    let auth_socket = create_auth_socket(auth_socket_address).await?;
+
+    let result =
+        client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await;
+
+    std::fs::remove_file(auth_socket_address)
+        .map_err(|_err| ClientError::UnableToCloseAuthSocket)?;
+
+    result
+}
+
+async fn client_authenticate_with_auth_socket(
+    normal_socket: &mut UnixStream,
+    auth_socket: UnixListener,
+    auth_socket_address: &str,
+) -> Result<(), ClientError> {
+    let challenge = rand::random::<u64>();
+    let uid = nix::unistd::getuid();
+
+    let mut normal_socket: ClientToServerStream =
+        AsyncBincodeStream::from(normal_socket).for_async();
+
+    let challenge_replier_cancellation_token = CancellationToken::new();
+    let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone();
+    let challenge_replier_handle = tokio::spawn(async move {
+        loop {
+            tokio::select! {
+                socket = auth_socket.accept() =>
+                  {
+                  match socket {
+                      Ok((mut conn, _addr)) => {
+                          let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async();
+                          stream.send(challenge).await.ok();
+                          stream.close().await.ok();
+                      }
+                      Err(_err) => return Err(ClientError::AuthSocketClosedEarly),
+                  }
+              }
+
+                _ = challenge_replier_cancellation_token_clone.cancelled() => {
+                    break Ok(());
+                }
+            }
+        }
+    });
+
+    let client_hello = ClientRequest::Initialize {
+        uid: uid.into(),
+        challenge,
+        auth_socket: auth_socket_address.to_string(),
+    };
+
+    normal_socket
+        .send(client_hello)
+        .await
+        .map_err(|err| match *err {
+            bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer,
+            _ => ClientError::NoServerResponse,
+        })?;
+
+    match normal_socket.next().await {
+        Some(Ok(ServerResponse::Authenticated)) => {}
+        Some(Ok(ServerResponse::ChallengeDidNotMatch)) => {
+            return Err(ClientError::AuthenticationError)
+        }
+        Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)),
+        Some(Err(err)) => match *err {
+            bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse),
+            _ => return Err(ClientError::UnableToParseServerResponse),
+        },
+        None => return Err(ClientError::NoServerResponse),
+    }
+
+    challenge_replier_cancellation_token.cancel();
+    challenge_replier_handle.await.ok();
+
+    Ok(())
+}
+
+macro_rules! report_server_error_and_return {
+    ($normal_socket:expr, $err:expr) => {{
+        $normal_socket
+            .send(ServerResponse::ServerError($err))
+            .await
+            .ok();
+        return Err($err);
+    }};
+}
+
+async fn server_authenticate(
+    normal_socket: &mut UnixStream,
+    #[cfg(test)] unix_user_uid: Option<u32>,
+) -> Result<Uid, ServerError> {
+    let mut normal_socket: ServerToClientStream =
+        AsyncBincodeStream::from(normal_socket).for_async();
+
+    let (uid, challenge, auth_socket) = match normal_socket.next().await {
+        Some(Ok(ClientRequest::Initialize {
+            uid,
+            challenge,
+            auth_socket,
+        })) => (uid, challenge, auth_socket),
+        // TODO: more granular errros
+        _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
+    };
+
+    #[cfg(test)]
+    let auth_socket_uid = match unix_user_uid {
+        Some(uid) => uid,
+        None => report_server_error_and_return!(
+            normal_socket,
+            ServerError::UnableToReadPermissionsFromAuthSocket
+        ),
+    };
+
+    #[cfg(not(test))]
+    let auth_socket_uid = match stat::stat(auth_socket.as_str()) {
+        Ok(stat) => stat.st_uid,
+        Err(_err) => report_server_error_and_return!(
+            normal_socket,
+            ServerError::UnableToReadPermissionsFromAuthSocket
+        ),
+    };
+
+    if uid != auth_socket_uid {
+        report_server_error_and_return!(normal_socket, ServerError::UidMismatch);
+    }
+
+    let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await {
+        Ok(socket) => socket,
+        Err(_err) => {
+            report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket)
+        }
+    };
+    let mut authenticated_unix_socket: AuthStream =
+        AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async();
+
+    let challenge_2 = match authenticated_unix_socket.next().await {
+        Some(Ok(challenge)) => challenge,
+        Some(Err(_)) => {
+            report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge)
+        }
+        None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly),
+    };
+
+    authenticated_unix_socket.close().await.ok();
+
+    if challenge != challenge_2 {
+        normal_socket
+            .send(ServerResponse::ChallengeDidNotMatch)
+            .await
+            .ok();
+        return Err(ServerError::ChallengeMismatch);
+    }
+
+    normal_socket.send(ServerResponse::Authenticated).await.ok();
+
+    Ok(Uid::from_raw(uid))
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    use std::time::Duration;
+    use tokio::time::sleep;
+
+    #[tokio::test]
+    async fn test_valid_authentication() {
+        let (mut client, mut server) = UnixStream::pair().unwrap();
+
+        let client_handle =
+            tokio::spawn(async move { client_authenticate(&mut client, None).await });
+
+        let server_handle = tokio::spawn(async move {
+            let uid = nix::unistd::getuid().into();
+            server_authenticate(&mut server, Some(uid)).await
+        });
+
+        client_handle.await.unwrap().unwrap();
+        server_handle.await.unwrap().unwrap();
+    }
+
+    #[tokio::test]
+    async fn test_ensure_auth_socket_does_not_exist() {
+        let (mut client, mut server) = UnixStream::pair().unwrap();
+
+        let client_handle = tokio::spawn(async move {
+            client_authenticate_with_auth_socket_address(
+                &mut client,
+                "/tmp/test_auth_socket_does_not_exist.sock",
+            )
+            .await
+        });
+
+        let server_handle = tokio::spawn(async move {
+            let uid = nix::unistd::getuid().into();
+            server_authenticate(&mut server, Some(uid)).await
+        });
+
+        client_handle.await.unwrap().unwrap();
+        server_handle.await.unwrap().unwrap();
+    }
+
+    #[tokio::test]
+    async fn test_uid_mismatch() {
+        let (mut client, mut server) = UnixStream::pair().unwrap();
+
+        let client_handle = tokio::spawn(async move {
+            let err = client_authenticate(&mut client, None).await;
+            assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch)));
+        });
+
+        let server_handle = tokio::spawn(async move {
+            let uid: u32 = nix::unistd::getuid().into();
+            let err = server_authenticate(&mut server, Some(uid + 1)).await;
+            assert_eq!(err, Err(ServerError::UidMismatch));
+        });
+
+        client_handle.await.unwrap();
+        server_handle.await.unwrap();
+    }
+
+    #[tokio::test]
+    async fn test_snooping_connection() {
+        let (mut client, mut server) = UnixStream::pair().unwrap();
+
+        let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
+        let socket_path_clone = socket_path.clone();
+        let client_handle =
+            tokio::spawn(
+                async move { client_authenticate(&mut client, Some(socket_path_clone)).await },
+            );
+
+        while !socket_path.exists() {
+            sleep(std::time::Duration::from_millis(10)).await;
+        }
+
+        let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
+        let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async();
+        let message = snooper.next().await.unwrap().unwrap();
+
+        let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
+        let mut other_snooper: AuthStream =
+            AsyncBincodeStream::from(&mut other_snooper).for_async();
+        let other_message = other_snooper.next().await.unwrap().unwrap();
+
+        assert_eq!(message, other_message);
+
+        let third_snooper_handle = tokio::spawn(async move {
+            let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
+            let mut third_snooper: AuthStream =
+                AsyncBincodeStream::from(&mut third_snooper).for_async();
+            // NOTE: Should hang
+            third_snooper.send(1234).await.unwrap()
+        });
+
+        sleep(Duration::from_millis(10)).await;
+
+        let server_handle = tokio::spawn(async move {
+            let uid: u32 = nix::unistd::getuid().into();
+            server_authenticate(&mut server, Some(uid)).await
+        });
+
+        client_handle.await.unwrap().unwrap();
+        server_handle.await.unwrap().unwrap();
+
+        third_snooper_handle.abort();
+    }
+
+    #[tokio::test]
+    async fn test_dead_server() {
+        let (mut client, server) = UnixStream::pair().unwrap();
+        std::mem::drop(server);
+
+        let client_handle = tokio::spawn(async move {
+            let err = client_authenticate(&mut client, None).await;
+            assert_eq!(err, Err(ClientError::UnableToConnectToServer));
+        });
+
+        client_handle.await.unwrap();
+    }
+
+    #[tokio::test]
+    async fn test_no_response_from_server() {
+        let (mut client, server) = UnixStream::pair().unwrap();
+
+        let client_handle = tokio::spawn(async move {
+            let err = client_authenticate(&mut client, None).await;
+            assert_eq!(err, Err(ClientError::NoServerResponse));
+        });
+
+        sleep(Duration::from_millis(200)).await;
+
+        std::mem::drop(server);
+
+        client_handle.await.unwrap();
+    }
+
+    // TODO: Test challenge mismatch
+    // TODO: Test invoking server with junk data
+}
diff --git a/src/main.rs b/src/main.rs
index d46d92f..2622c50 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,6 +10,7 @@ use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
 
 use clap::Parser;
 
+mod authenticated_unix_socket;
 mod cli;
 mod core;