diff --git a/Cargo.lock b/Cargo.lock index ad147a3..dee32b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1010,6 +1010,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -1058,6 +1067,7 @@ dependencies = [ "prettytable", "rand", "ratatui", + "sd-notify", "serde", "serde_json", "sqlx", @@ -1079,6 +1089,7 @@ dependencies = [ "cfg-if", "cfg_aliases", "libc", + "memoffset", ] [[package]] @@ -1524,6 +1535,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sd-notify" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4646d6f919800cd25c50edb49438a1381e2cd4833c027e75e8897981c50b8b5e" + [[package]] name = "serde" version = "1.0.208" diff --git a/Cargo.toml b/Cargo.toml index dc81b61..34877bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,10 +16,11 @@ futures-util = "0.3.30" indoc = "2.0.5" itertools = "0.13.0" log = "0.4.22" -nix = { version = "0.29.0", features = ["fs", "process", "user"] } +nix = { version = "0.29.0", features = ["fs", "process", "socket", "user"] } prettytable = "0.10.0" rand = "0.8.5" ratatui = { version = "0.28.0", optional = true } +sd-notify = "0.4.2" serde = "1.0.208" serde_json = { version = "1.0.125", features = ["preserve_order"] } sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] } diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index fdb6408..f64d89e 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -6,15 +6,10 @@ use std::os::unix::net::UnixStream as StdUnixStream; use tokio::net::UnixStream as TokioUnixStream; use crate::{ - core::{ - bootstrap::authenticated_unix_socket::client_authenticate, - common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, - }, + core::common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, server::{config::read_config_form_path, server_loop::handle_requests_for_single_session}, }; -pub mod authenticated_unix_socket; - // TODO: this function is security critical, it should be integration tested // in isolation. /// Drop privileges to the real user and group of the process. @@ -57,25 +52,11 @@ pub fn bootstrap_server_connection_and_drop_privileges( log::debug!("Starting the server connection bootstrap process"); - let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?; + let socket = bootstrap_server_connection(server_socket_path, config_path)?; drop_privs()?; - let result: anyhow::Result = if do_authenticate { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap() - .block_on(async { - let mut socket = TokioUnixStream::from_std(socket)?; - client_authenticate(&mut socket, None).await?; - Ok(socket.into_std()?) - }) - } else { - Ok(socket) - }; - - result + Ok(socket) } /// Inner function for [`bootstrap_server_connection_and_drop_privileges`]. @@ -83,12 +64,12 @@ pub fn bootstrap_server_connection_and_drop_privileges( fn bootstrap_server_connection( socket_path: Option, config_path: Option, -) -> anyhow::Result<(StdUnixStream, bool)> { +) -> anyhow::Result { // TODO: ensure this is both readable and writable if let Some(socket_path) = socket_path { log::debug!("Connecting to socket at {:?}", socket_path); return match StdUnixStream::connect(socket_path) { - Ok(socket) => Ok((socket, true)), + Ok(socket) => Ok(socket), Err(e) => match e.kind() { std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), @@ -103,12 +84,13 @@ fn bootstrap_server_connection( } log::debug!("Starting server with config at {:?}", config_path); - return invoke_server_with_config(config_path).map(|socket| (socket, false)); + return invoke_server_with_config(config_path); } if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { + log::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH); return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) { - Ok(socket) => Ok((socket, true)), + Ok(socket) => Ok(socket), Err(e) => match e.kind() { std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), @@ -119,7 +101,8 @@ fn bootstrap_server_connection( let config_path = PathBuf::from(DEFAULT_CONFIG_PATH); if fs::metadata(&config_path).is_ok() { - return invoke_server_with_config(config_path).map(|socket| (socket, false)); + log::debug!("Starting server with default config at {:?}", config_path); + return invoke_server_with_config(config_path); } anyhow::bail!("No socket path or config path provided, and no default socket or config found"); diff --git a/src/core/bootstrap/authenticated_unix_socket.rs b/src/core/bootstrap/authenticated_unix_socket.rs deleted file mode 100644 index cabfcd8..0000000 --- a/src/core/bootstrap/authenticated_unix_socket.rs +++ /dev/null @@ -1,439 +0,0 @@ -//! This module provides a way to authenticate a client uid to a server over a Unix socket. -//! This is needed so that the server can trust the client's uid, which it depends on to -//! make modifications for that user in the database. It is crucial that the server can trust -//! that the client is the unix user it claims to be. -//! -//! This works by having the client respond to a challenge on a socket that is verifiably owned -//! by the client. In more detailed steps, the following should happen: -//! -//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644 -//! and owned by the uid of the current user. -//! 2. The client opens a request to the server on the "normal" socket where the server is listening, -//! In this request, the client should include the following: -//! - The address of it's authentication socket -//! - The uid of the user currently using the client -//! - A challenge string that has been randomly generated -//! 3. The server validates the following: -//! - The address of the auth socket is valid -//! - The owner of the auth socket address is the same as the uid -//! 4. Server connects to the auth socket address and receives another challenge string. -//! The server should close the connection after receiving the challenge string. -//! 5. Server verifies that the challenge is the same as the one it originally received. -//! It responds to the client with an "Authenticated" message if the challenge matches, -//! or an error message if it does not. -//! 6. Client closes the authentication socket. Normal socket is used for communication. -//! -//! Note that the server can at any point in the process send an error message to the client -//! over it's initial connection, and the client should respond by closing the authentication -//! socket, it's connection to the normal socket, and reporting the error to the user. -//! -//! Also note that it is essential that the client does not send any sensitive information -//! over it's authentication socket, since it is readable by any user on the system. - -// TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream - -use std::os::unix::io::AsRawFd; -use std::path::{Path, PathBuf}; - -use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination}; -use derive_more::derive::{Display, Error}; -use futures::{SinkExt, StreamExt}; -use nix::{sys::stat, unistd::Uid}; -use rand::distributions::Alphanumeric; -use rand::Rng; -use serde::{Deserialize, Serialize}; -use tokio::net::{UnixListener, UnixStream}; -use tokio_util::sync::CancellationToken; - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum ClientRequest { - Initialize { - uid: u32, - challenge: u64, - auth_socket: String, - }, - Cancel, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)] -pub enum ServerResponse { - Authenticated, - ChallengeDidNotMatch, - ServerError(ServerError), -} - -// TODO: wrap more data into the errors - -#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)] -pub enum ServerError { - InvalidRequest, - UnableToReadPermissionsFromAuthSocket, - CouldNotConnectToAuthSocket, - AuthSocketClosedEarly, - UidMismatch, - ChallengeMismatch, - InvalidChallenge, -} - -#[derive(Debug, PartialEq, Display, Error)] -pub enum ClientError { - UnableToConnectToServer, - UnableToOpenAuthSocket, - UnableToConfigureAuthSocket, - AuthSocketClosedEarly, - UnableToCloseAuthSocket, - AuthenticationError, - UnableToParseServerResponse, - NoServerResponse, - ServerError(ServerError), -} - -async fn create_auth_socket(socket_addr: &PathBuf) -> Result { - let auth_socket = - UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?; - - stat::fchmod( - auth_socket.as_raw_fd(), - stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP, - ) - .map_err(|_err| ClientError::UnableToConfigureAuthSocket)?; - - Ok(auth_socket) -} - -type ClientToServerStream<'a> = - AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>; -type ServerToClientStream<'a> = - AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>; - -// TODO: make the challenge constant size and use socket directly, this is overkill -type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>; - -// TODO: add timeout - -// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR - -const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock"; - -pub async fn client_authenticate( - normal_socket: &mut UnixStream, - auth_socket_dir: Option, -) -> Result<(), ClientError> { - let random_prefix: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(16) - .map(char::from) - .collect(); - - let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME); - - let auth_socket_address = auth_socket_dir - .unwrap_or(std::env::temp_dir()) - .join(socket_name); - - client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await -} - -async fn client_authenticate_with_auth_socket_address( - normal_socket: &mut UnixStream, - auth_socket_address: &PathBuf, -) -> Result<(), ClientError> { - let auth_socket = create_auth_socket(auth_socket_address).await?; - - let result = - client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await; - - std::fs::remove_file(auth_socket_address) - .map_err(|_err| ClientError::UnableToCloseAuthSocket)?; - - result -} - -async fn client_authenticate_with_auth_socket( - normal_socket: &mut UnixStream, - auth_socket: UnixListener, - auth_socket_address: &Path, -) -> Result<(), ClientError> { - let challenge = rand::random::(); - let uid = nix::unistd::getuid(); - - let mut normal_socket: ClientToServerStream = - AsyncBincodeStream::from(normal_socket).for_async(); - - let challenge_replier_cancellation_token = CancellationToken::new(); - let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone(); - let challenge_replier_handle = tokio::spawn(async move { - loop { - tokio::select! { - socket = auth_socket.accept() => - { - match socket { - Ok((mut conn, _addr)) => { - let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async(); - stream.send(challenge).await.ok(); - stream.close().await.ok(); - } - Err(_err) => return Err(ClientError::AuthSocketClosedEarly), - } - } - - _ = challenge_replier_cancellation_token_clone.cancelled() => { - break Ok(()); - } - } - } - }); - - let client_hello = ClientRequest::Initialize { - uid: uid.into(), - challenge, - auth_socket: auth_socket_address - .to_str() - .ok_or(ClientError::UnableToConfigureAuthSocket)? - .to_owned(), - }; - - normal_socket - .send(client_hello) - .await - .map_err(|err| match *err { - bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer, - _ => ClientError::NoServerResponse, - })?; - - match normal_socket.next().await { - Some(Ok(ServerResponse::Authenticated)) => {} - Some(Ok(ServerResponse::ChallengeDidNotMatch)) => { - return Err(ClientError::AuthenticationError) - } - Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)), - Some(Err(err)) => match *err { - bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse), - _ => return Err(ClientError::UnableToParseServerResponse), - }, - None => return Err(ClientError::NoServerResponse), - } - - challenge_replier_cancellation_token.cancel(); - challenge_replier_handle.await.ok(); - - Ok(()) -} - -macro_rules! report_server_error_and_return { - ($normal_socket:expr, $err:expr) => {{ - $normal_socket - .send(ServerResponse::ServerError($err)) - .await - .ok(); - return Err($err); - }}; -} - -pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result { - _server_authenticate(normal_socket, None).await -} - -pub async fn _server_authenticate( - normal_socket: &mut UnixStream, - unix_user_uid: Option, -) -> Result { - let mut normal_socket: ServerToClientStream = - AsyncBincodeStream::from(normal_socket).for_async(); - - let (uid, challenge, auth_socket) = match normal_socket.next().await { - Some(Ok(ClientRequest::Initialize { - uid, - challenge, - auth_socket, - })) => (uid, challenge, auth_socket), - // TODO: more granular errros - _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest), - }; - - let auth_socket_uid = match unix_user_uid { - Some(uid) => uid, - None => match stat::stat(auth_socket.as_str()) { - Ok(stat) => stat.st_uid, - Err(_err) => report_server_error_and_return!( - normal_socket, - ServerError::UnableToReadPermissionsFromAuthSocket - ), - }, - }; - - if uid != auth_socket_uid { - report_server_error_and_return!(normal_socket, ServerError::UidMismatch); - } - - let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await { - Ok(socket) => socket, - Err(_err) => { - report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket) - } - }; - let mut authenticated_unix_socket: AuthStream = - AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async(); - - let challenge_2 = match authenticated_unix_socket.next().await { - Some(Ok(challenge)) => challenge, - Some(Err(_)) => { - report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge) - } - None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly), - }; - - authenticated_unix_socket.close().await.ok(); - - if challenge != challenge_2 { - normal_socket - .send(ServerResponse::ChallengeDidNotMatch) - .await - .ok(); - return Err(ServerError::ChallengeMismatch); - } - - normal_socket.send(ServerResponse::Authenticated).await.ok(); - - Ok(Uid::from_raw(uid)) -} - -#[cfg(test)] -mod test { - use super::*; - - use std::time::Duration; - use tokio::time::sleep; - - #[tokio::test] - async fn test_valid_authentication() { - let (mut client, mut server) = UnixStream::pair().unwrap(); - - let client_handle = - tokio::spawn(async move { client_authenticate(&mut client, None).await }); - - let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); - - client_handle.await.unwrap().unwrap(); - server_handle.await.unwrap().unwrap(); - } - - #[tokio::test] - async fn test_ensure_auth_socket_does_not_exist() { - let (mut client, mut server) = UnixStream::pair().unwrap(); - - let client_handle = tokio::spawn(async move { - client_authenticate_with_auth_socket_address( - &mut client, - &PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"), - ) - .await - }); - - let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); - - client_handle.await.unwrap().unwrap(); - server_handle.await.unwrap().unwrap(); - } - - #[tokio::test] - async fn test_uid_mismatch() { - let (mut client, mut server) = UnixStream::pair().unwrap(); - - let client_handle = tokio::spawn(async move { - let err = client_authenticate(&mut client, None).await; - assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch))); - }); - - let server_handle = tokio::spawn(async move { - let uid: u32 = nix::unistd::getuid().into(); - let err = _server_authenticate(&mut server, Some(uid + 1)).await; - assert_eq!(err, Err(ServerError::UidMismatch)); - }); - - client_handle.await.unwrap(); - server_handle.await.unwrap(); - } - - #[tokio::test] - async fn test_snooping_connection() { - let (mut client, mut server) = UnixStream::pair().unwrap(); - - let socket_path = std::env::temp_dir().join("socket_to_snoop.sock"); - let socket_path_clone = socket_path.clone(); - let client_handle = tokio::spawn(async move { - client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await - }); - - for i in 0..100 { - if socket_path.exists() { - break; - } - sleep(Duration::from_millis(10)).await; - - if i == 99 { - panic!("Socket not created after 1 second, assuming test failure"); - } - } - - let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); - let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async(); - let message = snooper.next().await.unwrap().unwrap(); - - let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); - let mut other_snooper: AuthStream = - AsyncBincodeStream::from(&mut other_snooper).for_async(); - let other_message = other_snooper.next().await.unwrap().unwrap(); - - assert_eq!(message, other_message); - - let third_snooper_handle = tokio::spawn(async move { - let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); - let mut third_snooper: AuthStream = - AsyncBincodeStream::from(&mut third_snooper).for_async(); - // NOTE: Should hang - third_snooper.send(1234).await.unwrap() - }); - - sleep(Duration::from_millis(10)).await; - - let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); - - client_handle.await.unwrap().unwrap(); - server_handle.await.unwrap().unwrap(); - - third_snooper_handle.abort(); - } - - #[tokio::test] - async fn test_dead_server() { - let (mut client, server) = UnixStream::pair().unwrap(); - std::mem::drop(server); - - let client_handle = tokio::spawn(async move { - let err = client_authenticate(&mut client, None).await; - assert_eq!(err, Err(ClientError::UnableToConnectToServer)); - }); - - client_handle.await.unwrap(); - } - - #[tokio::test] - async fn test_no_response_from_server() { - let (mut client, server) = UnixStream::pair().unwrap(); - - let client_handle = tokio::spawn(async move { - let err = client_authenticate(&mut client, None).await; - assert_eq!(err, Err(ClientError::NoServerResponse)); - }); - - sleep(Duration::from_millis(200)).await; - - std::mem::drop(server); - - client_handle.await.unwrap(); - } - - // TODO: Test challenge mismatch - // TODO: Test invoking server with junk data -} diff --git a/src/server/command.rs b/src/server/command.rs index 7a77122..3ad7dbe 100644 --- a/src/server/command.rs +++ b/src/server/command.rs @@ -7,7 +7,6 @@ use clap::Parser; use std::os::unix::net::UnixStream as StdUnixStream; use tokio::net::UnixStream as TokioUnixStream; -use crate::core::bootstrap::authenticated_unix_socket; use crate::core::common::UnixUser; use crate::server::config::read_config_from_path_with_arg_overrides; use crate::server::server_loop::listen_for_incoming_connections; @@ -41,10 +40,6 @@ pub async fn handle_command( ) -> anyhow::Result<()> { let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; - // if let Err(e) = &result { - // eprintln!("{}", e); - // } - match args.subcmd { ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, ServerCommand::SocketActivate => socket_activate(config).await, @@ -53,23 +48,26 @@ pub async fn handle_command( async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { // TODO: allow getting socket path from other socket activation sources - let mut conn = get_socket_from_systemd().await?; - let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?; - let unix_user = UnixUser::from_uid(uid.into())?; + let conn = get_socket_from_systemd().await?; + let uid = conn.peer_cred()?.uid(); + let unix_user = UnixUser::from_uid(uid)?; + + log::info!("Accepted connection from {}", unix_user.username); + + sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); + handle_requests_for_single_session(conn, &unix_user, &config).await?; Ok(()) } async fn get_socket_from_systemd() -> anyhow::Result { - let fd = std::env::var("LISTEN_FDS") - .context("LISTEN_FDS not set, not running under systemd?")? - .parse::() - .context("Failed to parse LISTEN_FDS")?; + let fd = sd_notify::listen_fds() + .context("Failed to get file descriptors from systemd")? + .next() + .context("No file descriptors received from systemd")?; - if fd != 1 { - return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd)); - } + log::debug!("Received file descriptor from systemd: {}", fd); let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; let socket = TokioUnixStream::from_std(std_unix_stream)?; diff --git a/src/server/config.rs b/src/server/config.rs index 3d0ca99..e3639ca 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -83,6 +83,8 @@ pub fn read_config_from_path_with_arg_overrides( pub fn read_config_form_path(config_path: Option) -> anyhow::Result { let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); + log::debug!("Reading config from {:?}", &config_path); + fs::read_to_string(&config_path) .context(format!( "Failed to read config file from {:?}", @@ -99,6 +101,13 @@ pub fn read_config_form_path(config_path: Option) -> anyhow::Result anyhow::Result { + let mut display_config = config.clone(); + "".clone_into(&mut display_config.password); + log::debug!( + "Connecting to MySQL server with parameters: {:#?}", + display_config + ); + match tokio::time::timeout( Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)), MySqlConnectOptions::new() diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs index 86d9411..03d9b4d 100644 --- a/src/server/server_loop.rs +++ b/src/server/server_loop.rs @@ -9,7 +9,6 @@ use sqlx::MySqlConnection; use crate::{ core::{ - bootstrap::authenticated_unix_socket, common::{UnixUser, DEFAULT_SOCKET_PATH}, protocol::request_response::{ create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream, @@ -44,11 +43,12 @@ pub async fn listen_for_incoming_connections( let parent_directory = socket_path.parent().unwrap(); if !parent_directory.exists() { - println!("Creating directory {:?}", parent_directory); + log::debug!("Creating directory {:?}", parent_directory); fs::create_dir_all(parent_directory)?; } - println!("Listening on {:?}", socket_path); + log::info!("Listening on socket {:?}", socket_path); + match fs::remove_file(socket_path.as_path()) { Ok(_) => {} Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} @@ -57,16 +57,13 @@ pub async fn listen_for_incoming_connections( let listener = UnixListener::bind(socket_path)?; + sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); + while let Ok((mut conn, _addr)) = listener.accept().await { - let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await { - Ok(uid) => uid, - Err(e) => { - eprintln!("Failed to authenticate client: {}", e); - conn.shutdown().await?; - continue; - } - }; - let unix_user = match UnixUser::from_uid(uid.into()) { + let uid = conn.peer_cred()?.uid(); + log::trace!("Accepted connection from uid {}", uid); + + let unix_user = match UnixUser::from_uid(uid) { Ok(user) => user, Err(e) => { eprintln!("Failed to get UnixUser from uid: {}", e); @@ -74,6 +71,9 @@ pub async fn listen_for_incoming_connections( continue; } }; + + log::info!("Accepted connection from {}", unix_user.username); + match handle_requests_for_single_session(conn, &unix_user, &config).await { Ok(_) => {} Err(e) => { @@ -92,6 +92,7 @@ pub async fn handle_requests_for_single_session( ) -> anyhow::Result<()> { let message_stream = create_server_to_client_message_stream(socket); let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; + log::debug!("Successfully connected to database"); let result = handle_requests_for_single_session_with_db_connection( message_stream, @@ -128,6 +129,8 @@ pub async fn handle_requests_for_single_session_with_db_connection( } }; + log::trace!("Received request: {:?}", request); + match request { Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index 9ddddda..2af32f5 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -26,9 +26,17 @@ pub(super) async fn unsafe_database_exists( sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") .bind(database_name) .fetch_optional(connection) - .await?; + .await; - Ok(result.is_some()) + if let Err(err) = &result { + log::error!( + "Failed to check if database '{}' exists: {:?}", + &database_name, + err + ); + } + + Ok(result?.is_some()) } pub async fn create_databases( @@ -80,6 +88,10 @@ pub async fn create_databases( .map(|_| ()) .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to create database '{}': {:?}", &database_name, err); + } + results.insert(database_name, result); } @@ -135,6 +147,10 @@ pub async fn drop_databases( .map(|_| ()) .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to drop database '{}': {:?}", &database_name, err); + } + results.insert(database_name, result); } @@ -145,7 +161,7 @@ pub async fn list_databases_for_user( unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> Result, ListDatabasesError> { - sqlx::query( + let result = sqlx::query( r#" SELECT `SCHEMA_NAME` AS `database` FROM `information_schema`.`SCHEMATA` @@ -161,5 +177,15 @@ pub async fn list_databases_for_user( .map(|row| row.try_get::("database")) .collect::, sqlx::Error>>() }) - .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) + .map_err(|err| ListDatabasesError::MySqlError(err.to_string())); + + if let Err(err) = &result { + log::error!( + "Failed to list databases for user '{}': {:?}", + unix_user.username, + err + ); + } + + result } diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index b5ba9a0..ec6b776 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -136,7 +136,7 @@ async fn unsafe_get_database_privileges( database_name: &str, connection: &mut MySqlConnection, ) -> Result, sqlx::Error> { - sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( "SELECT {} FROM `db` WHERE `db` = ?", DATABASE_PRIVILEGE_FIELDS .iter() @@ -145,7 +145,17 @@ async fn unsafe_get_database_privileges( )) .bind(database_name) .fetch_all(connection) - .await + .await; + + if let Err(e) = &result { + log::error!( + "Failed to get database privileges for '{}': {}", + &database_name, + e + ); + } + + result } // NOTE: this function is unsafe because it does no input validation. @@ -155,7 +165,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair( user_name: &str, connection: &mut MySqlConnection, ) -> Result, sqlx::Error> { - sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( "SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?", DATABASE_PRIVILEGE_FIELDS .iter() @@ -165,7 +175,18 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair( .bind(database_name) .bind(user_name) .fetch_optional(connection) - .await + .await; + + if let Err(e) = &result { + log::error!( + "Failed to get database privileges for '{}.{}': {}", + &database_name, + &user_name, + e + ); + } + + result } pub async fn get_databases_privilege_data( @@ -220,7 +241,7 @@ pub async fn get_all_database_privileges( unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> GetAllDatabasesPrivilegeData { - sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( indoc! {r#" SELECT {} FROM `db` WHERE `db` IN (SELECT DISTINCT `SCHEMA_NAME` AS `database` @@ -236,14 +257,20 @@ pub async fn get_all_database_privileges( .bind(create_user_group_matching_regex(unix_user)) .fetch_all(connection) .await - .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())) + .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())); + + if let Err(e) = &result { + log::error!("Failed to get all database privileges: {:?}", e); + } + + result } async fn unsafe_apply_privilege_diff( database_privilege_diff: &DatabasePrivilegesDiff, connection: &mut MySqlConnection, ) -> Result<(), sqlx::Error> { - match database_privilege_diff { + let result = match database_privilege_diff { DatabasePrivilegesDiff::New(p) => { let tables = DATABASE_PRIVILEGE_FIELDS .iter() @@ -305,7 +332,13 @@ async fn unsafe_apply_privilege_diff( .await .map(|_| ()) } + }; + + if let Err(e) = &result { + log::error!("Failed to apply database privilege diff: {}", e); } + + result } async fn validate_diff( diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index c73333e..06da219 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -1,6 +1,6 @@ +use indoc::formatdoc; use itertools::Itertools; use std::collections::BTreeMap; -use indoc::formatdoc; use serde::{Deserialize, Serialize}; @@ -29,7 +29,7 @@ async fn unsafe_user_exists( db_user: &str, connection: &mut MySqlConnection, ) -> Result { - sqlx::query( + let result = sqlx::query( r#" SELECT EXISTS( SELECT 1 @@ -41,7 +41,13 @@ async fn unsafe_user_exists( .bind(db_user) .fetch_one(connection) .await - .map(|row| row.get::(0)) + .map(|row| row.get::(0)); + + if let Err(err) = &result { + log::error!("Failed to check if database user exists: {:?}", err); + } + + result } pub async fn create_database_users( @@ -80,6 +86,10 @@ pub async fn create_database_users( .map(|_| ()) .map_err(|err| CreateUserError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to create database user '{}': {:?}", &db_user, err); + } + results.insert(db_user, result); } @@ -122,6 +132,10 @@ pub async fn drop_database_users( .map(|_| ()) .map_err(|err| DropUserError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to drop database user '{}': {:?}", &db_user, err); + } + results.insert(db_user, result); } @@ -148,7 +162,7 @@ pub async fn set_password_for_database_user( _ => {} } - sqlx::query( + let result = sqlx::query( format!( "ALTER USER {}@'%' IDENTIFIED BY {}", quote_literal(db_user), @@ -159,7 +173,17 @@ pub async fn set_password_for_database_user( .execute(&mut *connection) .await .map(|_| ()) - .map_err(|err| SetPasswordError::MySqlError(err.to_string())) + .map_err(|err| SetPasswordError::MySqlError(err.to_string())); + + if let Err(err) = &result { + log::error!( + "Failed to set password for database user '{}': {:?}", + &db_user, + err + ); + } + + result } // NOTE: this function is unsafe because it does no input validation. @@ -167,7 +191,7 @@ async fn database_user_is_locked_unsafe( db_user: &str, connection: &mut MySqlConnection, ) -> Result { - sqlx::query( + let result = sqlx::query( r#" SELECT COALESCE( JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), @@ -181,7 +205,17 @@ async fn database_user_is_locked_unsafe( .bind(db_user) .fetch_one(connection) .await - .map(|row| row.get::(0)) + .map(|row| row.get::(0)); + + if let Err(err) = &result { + log::error!( + "Failed to check if database user is locked '{}': {:?}", + &db_user, + err + ); + } + + result } pub async fn lock_database_users( @@ -234,6 +268,10 @@ pub async fn lock_database_users( .map(|_| ()) .map_err(|err| LockUserError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to lock database user '{}': {:?}", &db_user, err); + } + results.insert(db_user, result); } @@ -290,6 +328,10 @@ pub async fn unlock_database_users( .map(|_| ()) .map_err(|err| UnlockUserError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to unlock database user '{}': {:?}", &db_user, err); + } + results.insert(db_user, result); } @@ -298,39 +340,55 @@ pub async fn unlock_database_users( /// This struct contains information about a database user. /// This can be extended if we need more information in the future. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct DatabaseUser { - #[sqlx(rename = "User")] pub user: String, - - #[allow(dead_code)] #[serde(skip)] - #[sqlx(rename = "Host")] pub host: String, - - #[sqlx(rename = "has_password")] pub has_password: bool, - - #[sqlx(rename = "is_locked")] pub is_locked: bool, - - #[sqlx(skip)] pub databases: Vec, } +/// Some mysql versions with some collations mark some columns as binary fields, +/// which in the current version of sqlx is not parsable as string. +/// See: https://github.com/launchbadge/sqlx/issues/3387 +#[inline] +fn try_get_with_binary_fallback( + row: &sqlx::mysql::MySqlRow, + column: &str, +) -> Result { + row.try_get(column).or_else(|_| { + row.try_get::, _>(column) + .map(|v| String::from_utf8_lossy(&v).to_string()) + }) +} + +impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser { + fn from_row(row: &sqlx::mysql::MySqlRow) -> Result { + Ok(Self { + user: try_get_with_binary_fallback(row, "User")?, + host: try_get_with_binary_fallback(row, "Host")?, + has_password: row.try_get("has_password")?, + is_locked: row.try_get("is_locked")?, + databases: Vec::new(), + }) + } +} + const DB_USER_SELECT_STATEMENT: &str = r#" SELECT - `mysql`.`user`.`User`, - `mysql`.`user`.`Host`, - `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, + `user`.`User`, + `user`.`Host`, + `user`.`Password` != '' OR `user`.`authentication_string` != '' AS `has_password`, COALESCE( - JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), + JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"), 'false' ) != 'false' AS `is_locked` -FROM `mysql`.`user` -JOIN `mysql`.`global_priv` ON - `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` - AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` +FROM `user` +JOIN `global_priv` ON + `user`.`User` = `global_priv`.`User` + AND `user`.`Host` = `global_priv`.`Host` "#; pub async fn list_database_users( @@ -358,6 +416,10 @@ pub async fn list_database_users( .fetch_optional(&mut *connection) .await; + if let Err(err) = &result { + log::error!("Failed to list database user '{}': {:?}", &db_user, err); + } + if let Ok(Some(user)) = result.as_mut() { append_databases_where_user_has_privileges(user, &mut *connection).await; } @@ -377,13 +439,17 @@ pub async fn list_all_database_users_for_unix_user( connection: &mut MySqlConnection, ) -> ListAllUsersOutput { let mut result = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), + &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `user`.`User` REGEXP ?"), ) .bind(create_user_group_matching_regex(unix_user)) .fetch_all(&mut *connection) .await .map_err(|err| ListAllUsersError::MySqlError(err.to_string())); + if let Err(err) = &result { + log::error!("Failed to list all database users: {:?}", err); + } + if let Ok(users) = result.as_mut() { for user in users { append_databases_where_user_has_privileges(user, &mut *connection).await; @@ -394,15 +460,15 @@ pub async fn list_all_database_users_for_unix_user( } pub async fn append_databases_where_user_has_privileges( - database_user: &mut DatabaseUser, + db_user: &mut DatabaseUser, connection: &mut MySqlConnection, ) { let database_list = sqlx::query( formatdoc!( r#" - SELECT `db` AS `database` + SELECT `Db` AS `database` FROM `db` - WHERE `user` = ? AND ({}) + WHERE `User` = ? AND ({}) "#, DATABASE_PRIVILEGE_FIELDS .iter() @@ -411,14 +477,22 @@ pub async fn append_databases_where_user_has_privileges( ) .as_str(), ) - .bind(database_user.user.clone()) + .bind(db_user.user.clone()) .fetch_all(&mut *connection) .await; - database_user.databases = database_list + if let Err(err) = &database_list { + log::error!( + "Failed to list databases for user '{}': {:?}", + &db_user.user, + err + ); + } + + db_user.databases = database_list .map(|rows| { rows.into_iter() - .map(|row| row.get::("database")) + .map(|row| try_get_with_binary_fallback(&row, "database").unwrap()) .collect() }) .unwrap_or_default();