From 20e60ca5c79788effa2c501f7a770c4b4477a258 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Fri, 9 Aug 2024 19:08:48 +0200 Subject: [PATCH 1/2] Add protocol for authenticating a unix socket --- Cargo.lock | 89 +++++- Cargo.toml | 9 +- src/authenticated_unix_socket.rs | 454 +++++++++++++++++++++++++++++++ src/main.rs | 1 + 4 files changed, 548 insertions(+), 5 deletions(-) create mode 100644 src/authenticated_unix_socket.rs diff --git a/Cargo.lock b/Cargo.lock index 58fa798..02ca9cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,6 +99,21 @@ version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +[[package]] +name = "async-bincode" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21849a990d47109757e820904d7c0b569a8013f6595bf14d911884634d58795f" +dependencies = [ + "bincode", + "byteorder", + "bytes", + "futures-core", + "futures-sink", + "serde", + "tokio", +] + [[package]] name = "atoi" version = "2.0.0" @@ -141,6 +156,15 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -555,6 +579,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -599,6 +638,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -617,8 +667,10 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -906,20 +958,27 @@ name = "mysqladm-rs" version = "0.1.0" dependencies = [ "anyhow", + "async-bincode", + "bincode", "clap", "dialoguer", "env_logger", + "futures", "indoc", "itertools", "log", "nix", "prettytable", + "rand", "ratatui", "serde", "serde_json", "sqlx", + "thiserror", "tokio", + "tokio-util", "toml", + "uuid", ] [[package]] @@ -1812,18 +1871,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", @@ -1883,6 +1942,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.12" @@ -2023,6 +2095,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "uuid" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", +] + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 9242ae5..003066a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,20 +5,27 @@ edition = "2021" [dependencies] anyhow = "1.0.82" +async-bincode = "0.7.2" +bincode = "1.3.3" clap = { version = "4.5.4", features = ["derive"] } dialoguer = "0.11.0" env_logger = "0.11.3" +futures = "0.3.30" indoc = "2.0.5" itertools = "0.12.1" log = "0.4.21" -nix = { version = "0.28.0", features = ["user"] } +nix = { version = "0.28.0", features = ["fs", "user"] } prettytable = "0.10.0" +rand = "0.8.5" ratatui = { version = "0.26.2", optional = true } serde = "1.0.198" serde_json = { version = "1.0.116", features = ["preserve_order"] } sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] } +thiserror = "1.0.63" tokio = { version = "1.37.0", features = ["rt", "macros"] } +tokio-util = "0.7.11" toml = "0.8.12" +uuid = { version = "1.10.0", features = ["v4"] } [features] default = ["mysql-admutils-compatibility"] diff --git a/src/authenticated_unix_socket.rs b/src/authenticated_unix_socket.rs new file mode 100644 index 0000000..0092b5f --- /dev/null +++ b/src/authenticated_unix_socket.rs @@ -0,0 +1,454 @@ +//! This module provides a way to authenticate a client uid to a server over a Unix socket. +//! This is needed so that the server can trust the client's uid, which it depends on to +//! make modifications for that user in the database. It is crucial that the server can trust +//! that the client is the unix user it claims to be. +//! +//! This works by having the client respond to a challenge on a socket that is verifiably owned +//! by the client. In more detailed steps, the following should happen: +//! +//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644 +//! and owned by the uid of the current user. +//! 2. The client opens a request to the server on the "normal" socket where the server is listening, +//! In this request, the client should include the following: +//! - The address of it's authentication socket +//! - The uid of the user currently using the client +//! - A challenge string that has been randomly generated +//! 3. The server validates the following: +//! - The address of the auth socket is valid +//! - The owner of the auth socket address is the same as the uid +//! 4. Server connects to the auth socket address and receives another challenge string. +//! The server should close the connection after receiving the challenge string. +//! 5. Server verifies that the challenge is the same as the one it originally received. +//! It responds to the client with an "Authenticated" message if the challenge matches, +//! or an error message if it does not. +//! 6. Client closes the authentication socket. Normal socket is used for communication. +//! +//! Note that the server can at any point in the process send an error message to the client +//! over it's initial connection, and the client should respond by closing the authentication +//! socket, it's connection to the normal socket, and reporting the error to the user. +//! +//! Also note that it is essential that the client does not send any sensitive information +//! over it's authentication socket, since it is readable by any user on the system. + +use std::os::unix::io::AsRawFd; +use std::path::PathBuf; + +use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination}; +use futures::{SinkExt, StreamExt}; +use nix::{sys::stat, unistd::Uid}; +use rand::distributions::Alphanumeric; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use tokio::net::{UnixListener, UnixStream}; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ClientRequest { + Initialize { + uid: u32, + challenge: u64, + auth_socket: String, + }, + Cancel, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ServerResponse { + Authenticated, + ChallengeDidNotMatch, + ServerError(ServerError), +} + +// TODO: wrap more data into the errors + +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +pub enum ServerError { + InvalidRequest, + UnableToReadPermissionsFromAuthSocket, + CouldNotConnectToAuthSocket, + AuthSocketClosedEarly, + UidMismatch, + ChallengeMismatch, + InvalidChallenge, +} + +#[derive(Debug, PartialEq)] +pub enum ClientError { + UnableToConnectToServer, + UnableToOpenAuthSocket, + UnableToConfigureAuthSocket, + AuthSocketClosedEarly, + UnableToCloseAuthSocket, + AuthenticationError, + InvalidServerResponse(ServerResponse), + UnableToParseServerResponse, + NoServerResponse, + ServerError(ServerError), +} + +async fn create_auth_socket(socket_addr: &str) -> Result { + let auth_socket = + UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?; + + stat::fchmod( + auth_socket.as_raw_fd(), + stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP, + ) + .map_err(|_err| ClientError::UnableToConfigureAuthSocket)?; + + Ok(auth_socket) +} + +type ClientToServerStream<'a> = + AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>; +type ServerToClientStream<'a> = + AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>; + +// TODO: make the challenge constant size and use socket directly, this is overkill +type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>; + +// TODO: add timeout + +const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock"; +pub async fn client_authenticate( + normal_socket: &mut UnixStream, + #[cfg(not(test))] auth_socket_dir: Option, + #[cfg(test)] auth_socket_file: Option, +) -> Result<(), ClientError> { + let random_prefix: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(16) + .map(char::from) + .collect(); + + let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME); + + #[cfg(not(test))] + let auth_socket_address = match auth_socket_dir { + Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(), + None => std::env::temp_dir() + .join(socket_name) + .to_str() + .unwrap() + .to_string(), + }; + + #[cfg(test)] + let auth_socket_address = match auth_socket_file { + Some(file) => file.to_str().unwrap().to_string(), + None => std::env::temp_dir() + .join(socket_name) + .to_str() + .unwrap() + .to_string(), + }; + + client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await +} + +async fn client_authenticate_with_auth_socket_address( + normal_socket: &mut UnixStream, + auth_socket_address: &str, +) -> Result<(), ClientError> { + let auth_socket = create_auth_socket(auth_socket_address).await?; + + let result = + client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await; + + std::fs::remove_file(auth_socket_address) + .map_err(|_err| ClientError::UnableToCloseAuthSocket)?; + + result +} + +async fn client_authenticate_with_auth_socket( + normal_socket: &mut UnixStream, + auth_socket: UnixListener, + auth_socket_address: &str, +) -> Result<(), ClientError> { + let challenge = rand::random::(); + let uid = nix::unistd::getuid(); + + let mut normal_socket: ClientToServerStream = + AsyncBincodeStream::from(normal_socket).for_async(); + + let challenge_replier_cancellation_token = CancellationToken::new(); + let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone(); + let challenge_replier_handle = tokio::spawn(async move { + loop { + tokio::select! { + socket = auth_socket.accept() => + { + match socket { + Ok((mut conn, _addr)) => { + let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async(); + stream.send(challenge).await.ok(); + stream.close().await.ok(); + } + Err(_err) => return Err(ClientError::AuthSocketClosedEarly), + } + } + + _ = challenge_replier_cancellation_token_clone.cancelled() => { + break Ok(()); + } + } + } + }); + + let client_hello = ClientRequest::Initialize { + uid: uid.into(), + challenge, + auth_socket: auth_socket_address.to_string(), + }; + + normal_socket + .send(client_hello) + .await + .map_err(|err| match *err { + bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer, + _ => ClientError::NoServerResponse, + })?; + + match normal_socket.next().await { + Some(Ok(ServerResponse::Authenticated)) => {} + Some(Ok(ServerResponse::ChallengeDidNotMatch)) => { + return Err(ClientError::AuthenticationError) + } + Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)), + Some(Err(err)) => match *err { + bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse), + _ => return Err(ClientError::UnableToParseServerResponse), + }, + None => return Err(ClientError::NoServerResponse), + } + + challenge_replier_cancellation_token.cancel(); + challenge_replier_handle.await.ok(); + + Ok(()) +} + +macro_rules! report_server_error_and_return { + ($normal_socket:expr, $err:expr) => {{ + $normal_socket + .send(ServerResponse::ServerError($err)) + .await + .ok(); + return Err($err); + }}; +} + +async fn server_authenticate( + normal_socket: &mut UnixStream, + #[cfg(test)] unix_user_uid: Option, +) -> Result { + let mut normal_socket: ServerToClientStream = + AsyncBincodeStream::from(normal_socket).for_async(); + + let (uid, challenge, auth_socket) = match normal_socket.next().await { + Some(Ok(ClientRequest::Initialize { + uid, + challenge, + auth_socket, + })) => (uid, challenge, auth_socket), + // TODO: more granular errros + _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest), + }; + + #[cfg(test)] + let auth_socket_uid = match unix_user_uid { + Some(uid) => uid, + None => report_server_error_and_return!( + normal_socket, + ServerError::UnableToReadPermissionsFromAuthSocket + ), + }; + + #[cfg(not(test))] + let auth_socket_uid = match stat::stat(auth_socket.as_str()) { + Ok(stat) => stat.st_uid, + Err(_err) => report_server_error_and_return!( + normal_socket, + ServerError::UnableToReadPermissionsFromAuthSocket + ), + }; + + if uid != auth_socket_uid { + report_server_error_and_return!(normal_socket, ServerError::UidMismatch); + } + + let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await { + Ok(socket) => socket, + Err(_err) => { + report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket) + } + }; + let mut authenticated_unix_socket: AuthStream = + AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async(); + + let challenge_2 = match authenticated_unix_socket.next().await { + Some(Ok(challenge)) => challenge, + Some(Err(_)) => { + report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge) + } + None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly), + }; + + authenticated_unix_socket.close().await.ok(); + + if challenge != challenge_2 { + normal_socket + .send(ServerResponse::ChallengeDidNotMatch) + .await + .ok(); + return Err(ServerError::ChallengeMismatch); + } + + normal_socket.send(ServerResponse::Authenticated).await.ok(); + + Ok(Uid::from_raw(uid)) +} + +#[cfg(test)] +mod test { + use super::*; + + use std::time::Duration; + use tokio::time::sleep; + + #[tokio::test] + async fn test_valid_authentication() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let client_handle = + tokio::spawn(async move { client_authenticate(&mut client, None).await }); + + let server_handle = tokio::spawn(async move { + let uid = nix::unistd::getuid().into(); + server_authenticate(&mut server, Some(uid)).await + }); + + client_handle.await.unwrap().unwrap(); + server_handle.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_ensure_auth_socket_does_not_exist() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let client_handle = tokio::spawn(async move { + client_authenticate_with_auth_socket_address( + &mut client, + "/tmp/test_auth_socket_does_not_exist.sock", + ) + .await + }); + + let server_handle = tokio::spawn(async move { + let uid = nix::unistd::getuid().into(); + server_authenticate(&mut server, Some(uid)).await + }); + + client_handle.await.unwrap().unwrap(); + server_handle.await.unwrap().unwrap(); + } + + #[tokio::test] + async fn test_uid_mismatch() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let client_handle = tokio::spawn(async move { + let err = client_authenticate(&mut client, None).await; + assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch))); + }); + + let server_handle = tokio::spawn(async move { + let uid: u32 = nix::unistd::getuid().into(); + let err = server_authenticate(&mut server, Some(uid + 1)).await; + assert_eq!(err, Err(ServerError::UidMismatch)); + }); + + client_handle.await.unwrap(); + server_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_snooping_connection() { + let (mut client, mut server) = UnixStream::pair().unwrap(); + + let socket_path = std::env::temp_dir().join("socket_to_snoop.sock"); + let socket_path_clone = socket_path.clone(); + let client_handle = + tokio::spawn( + async move { client_authenticate(&mut client, Some(socket_path_clone)).await }, + ); + + while !socket_path.exists() { + sleep(std::time::Duration::from_millis(10)).await; + } + + let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); + let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async(); + let message = snooper.next().await.unwrap().unwrap(); + + let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); + let mut other_snooper: AuthStream = + AsyncBincodeStream::from(&mut other_snooper).for_async(); + let other_message = other_snooper.next().await.unwrap().unwrap(); + + assert_eq!(message, other_message); + + let third_snooper_handle = tokio::spawn(async move { + let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); + let mut third_snooper: AuthStream = + AsyncBincodeStream::from(&mut third_snooper).for_async(); + // NOTE: Should hang + third_snooper.send(1234).await.unwrap() + }); + + sleep(Duration::from_millis(10)).await; + + let server_handle = tokio::spawn(async move { + let uid: u32 = nix::unistd::getuid().into(); + server_authenticate(&mut server, Some(uid)).await + }); + + client_handle.await.unwrap().unwrap(); + server_handle.await.unwrap().unwrap(); + + third_snooper_handle.abort(); + } + + #[tokio::test] + async fn test_dead_server() { + let (mut client, server) = UnixStream::pair().unwrap(); + std::mem::drop(server); + + let client_handle = tokio::spawn(async move { + let err = client_authenticate(&mut client, None).await; + assert_eq!(err, Err(ClientError::UnableToConnectToServer)); + }); + + client_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_no_response_from_server() { + let (mut client, server) = UnixStream::pair().unwrap(); + + let client_handle = tokio::spawn(async move { + let err = client_authenticate(&mut client, None).await; + assert_eq!(err, Err(ClientError::NoServerResponse)); + }); + + sleep(Duration::from_millis(200)).await; + + std::mem::drop(server); + + client_handle.await.unwrap(); + } + + // TODO: Test challenge mismatch + // TODO: Test invoking server with junk data +} diff --git a/src/main.rs b/src/main.rs index d46d92f..2622c50 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; use clap::Parser; +mod authenticated_unix_socket; mod cli; mod core; -- 2.44.2 From af86893acf94f49d40cc2b42ff15987cae21e16f Mon Sep 17 00:00:00 2001 From: h7x4 Date: Sat, 10 Aug 2024 02:16:38 +0200 Subject: [PATCH 2/2] Rewrite entire codebase to split into client and server --- Cargo.lock | 99 ++- Cargo.toml | 9 +- src/cli.rs | 5 +- src/cli/common.rs | 20 + src/cli/database_command.rs | 295 +++++---- src/cli/mysql_admutils_compatibility.rs | 1 + .../mysql_admutils_compatibility/common.rs | 59 +- .../error_messages.rs | 176 +++++ .../mysql_dbadm.rs | 253 ++++++-- .../mysql_useradm.rs | 283 +++++--- src/cli/user_command.rs | 322 ++++----- src/core.rs | 7 +- src/core/bootstrap.rs | 177 +++++ .../bootstrap}/authenticated_unix_socket.rs | 117 ++-- src/core/common.rs | 343 +--------- src/core/database_operations.rs | 120 ---- ...e_operations.rs => database_privileges.rs} | 366 +++-------- src/core/protocol.rs | 5 + src/core/protocol/request_response.rs | 79 +++ src/core/protocol/server_responses.rs | 611 ++++++++++++++++++ src/core/user_operations.rs | 249 ------- src/main.rs | 147 +++-- src/server.rs | 6 + src/server/command.rs | 77 +++ src/server/common.rs | 11 + src/{core => server}/config.rs | 88 ++- src/server/input_sanitization.rs | 158 +++++ src/server/server_loop.rs | 229 +++++++ src/server/sql.rs | 3 + src/server/sql/database_operations.rs | 165 +++++ .../sql/database_privilege_operations.rs | 452 +++++++++++++ src/server/sql/user_operations.rs | 375 +++++++++++ 32 files changed, 3708 insertions(+), 1599 deletions(-) create mode 100644 src/cli/common.rs create mode 100644 src/cli/mysql_admutils_compatibility/error_messages.rs create mode 100644 src/core/bootstrap.rs rename src/{ => core/bootstrap}/authenticated_unix_socket.rs (83%) delete mode 100644 src/core/database_operations.rs rename src/core/{database_privilege_operations.rs => database_privileges.rs} (66%) create mode 100644 src/core/protocol.rs create mode 100644 src/core/protocol/request_response.rs create mode 100644 src/core/protocol/server_responses.rs delete mode 100644 src/core/user_operations.rs create mode 100644 src/server.rs create mode 100644 src/server/command.rs create mode 100644 src/server/common.rs rename src/{core => server}/config.rs (66%) create mode 100644 src/server/input_sanitization.rs create mode 100644 src/server/server_loop.rs create mode 100644 src/server/sql.rs create mode 100644 src/server/sql/database_operations.rs create mode 100644 src/server/sql/database_privilege_operations.rs create mode 100644 src/server/sql/user_operations.rs diff --git a/Cargo.lock b/Cargo.lock index 02ca9cf..deecfbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -418,6 +418,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", + "unicode-xid", +] + [[package]] name = "dialoguer" version = "0.11.0" @@ -470,6 +491,18 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "educe" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4bd92664bf78c4d3dba9b7cdafce6fa15b13ed3ed16175218196942e99168a8" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "either" version = "1.11.0" @@ -491,6 +524,26 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "env_filter" version = "0.1.0" @@ -961,9 +1014,11 @@ dependencies = [ "async-bincode", "bincode", "clap", + "derive_more", "dialoguer", "env_logger", "futures", + "futures-util", "indoc", "itertools", "log", @@ -974,8 +1029,9 @@ dependencies = [ "serde", "serde_json", "sqlx", - "thiserror", "tokio", + "tokio-serde", + "tokio-stream", "tokio-util", "toml", "uuid", @@ -1109,6 +1165,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -1931,6 +2007,21 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "tokio-serde" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf600e7036b17782571dd44fa0a5cea3c82f60db5137f774a325a76a0d6852b" +dependencies = [ + "bincode", + "bytes", + "educe", + "futures-core", + "futures-sink", + "pin-project", + "serde", +] + [[package]] name = "tokio-stream" version = "0.1.15" @@ -2060,6 +2151,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "unicode_categories" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 003066a..4fc032e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,22 +8,25 @@ anyhow = "1.0.82" async-bincode = "0.7.2" bincode = "1.3.3" clap = { version = "4.5.4", features = ["derive"] } +derive_more = { version = "1.0.0", features = ["display", "error"] } dialoguer = "0.11.0" env_logger = "0.11.3" futures = "0.3.30" +futures-util = "0.3.30" indoc = "2.0.5" itertools = "0.12.1" log = "0.4.21" -nix = { version = "0.28.0", features = ["fs", "user"] } +nix = { version = "0.28.0", features = ["fs", "process", "user"] } prettytable = "0.10.0" rand = "0.8.5" ratatui = { version = "0.26.2", optional = true } serde = "1.0.198" serde_json = { version = "1.0.116", features = ["preserve_order"] } sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] } -thiserror = "1.0.63" tokio = { version = "1.37.0", features = ["rt", "macros"] } -tokio-util = "0.7.11" +tokio-serde = { version = "0.9.0", features = ["bincode"] } +tokio-stream = "0.1.15" +tokio-util = { version = "0.7.11", features = ["codec"] } toml = "0.8.12" uuid = { version = "1.10.0", features = ["v4"] } diff --git a/src/cli.rs b/src/cli.rs index 7c80cfc..1b29138 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,3 +1,6 @@ +mod common; pub mod database_command; -pub mod mysql_admutils_compatibility; pub mod user_command; + +#[cfg(feature = "mysql-admutils-compatibility")] +pub mod mysql_admutils_compatibility; diff --git a/src/cli/common.rs b/src/cli/common.rs new file mode 100644 index 0000000..02087b1 --- /dev/null +++ b/src/cli/common.rs @@ -0,0 +1,20 @@ +use crate::core::protocol::Response; + +pub fn erroneous_server_response( + response: Option>, +) -> anyhow::Result<()> { + match response { + Some(Ok(Response::Error(e))) => { + anyhow::bail!("Server returned error: {}", e); + } + Some(Err(e)) => { + anyhow::bail!(e); + } + Some(response) => { + anyhow::bail!("Unexpected response from server: {:?}", response); + } + None => { + anyhow::bail!("No response from server"); + } + } +} diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs index b41ed8c..1d1eaa4 100644 --- a/src/cli/database_command.rs +++ b/src/cli/database_command.rs @@ -1,17 +1,29 @@ use anyhow::Context; use clap::Parser; use dialoguer::{Confirm, Editor}; +use futures_util::{SinkExt, StreamExt}; +use nix::unistd::{getuid, User}; use prettytable::{Cell, Row, Table}; -use sqlx::{Connection, MySqlConnection}; -use crate::core::{ - common::{close_database_connection, get_current_unix_user, yn, CommandStatus}, - database_operations::*, - database_privilege_operations::*, - user_operations::user_exists, +use crate::{ + cli::common::erroneous_server_response, + core::{ + common::yn, + database_privileges::{ + db_priv_field_human_readable_name, diff_privileges, display_privilege_diffs, + generate_editor_content_from_privilege_data, parse_privilege_data_from_editor_content, + parse_privilege_table_cli_arg, + }, + protocol::{ + print_create_databases_output_status, print_drop_databases_output_status, + print_modify_database_privileges_output_status, ClientToServerMessageStream, Request, + Response, + }, + }, + server::sql::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS}, }; -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] // #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))] pub enum DatabaseCommand { /// Create one or more databases @@ -86,28 +98,28 @@ pub enum DatabaseCommand { EditDbPrivs(DatabaseEditPrivsArgs), } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseCreateArgs { /// The name of the database(s) to create. #[arg(num_args = 1..)] name: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseDropArgs { /// The name of the database(s) to drop. #[arg(num_args = 1..)] name: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseListArgs { /// Whether to output the information in JSON format. #[arg(short, long)] json: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseShowPrivsArgs { /// The name of the database(s) to show. #[arg(num_args = 0..)] @@ -118,7 +130,7 @@ pub struct DatabaseShowPrivsArgs { json: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseEditPrivsArgs { /// The name of the database to edit privileges for. pub name: Option, @@ -141,125 +153,143 @@ pub struct DatabaseEditPrivsArgs { pub async fn handle_command( command: DatabaseCommand, - mut connection: MySqlConnection, -) -> anyhow::Result { - let result = connection - .transaction(|txn| { - Box::pin(async move { - match command { - DatabaseCommand::CreateDb(args) => create_databases(args, txn).await, - DatabaseCommand::DropDb(args) => drop_databases(args, txn).await, - DatabaseCommand::ListDb(args) => list_databases(args, txn).await, - DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await, - DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await, - } - }) - }) - .await; - - close_database_connection(connection).await; - - result + server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + match command { + DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await, + DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await, + DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await, + DatabaseCommand::ShowDbPrivs(args) => { + show_database_privileges(args, server_connection).await + } + DatabaseCommand::EditDbPrivs(args) => { + edit_database_privileges(args, server_connection).await + } + } } async fn create_databases( args: DatabaseCreateArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.name.is_empty() { anyhow::bail!("No database names provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::CreateDatabases(args.name.clone()); + server_connection.send(message).await?; - for name in args.name { - // TODO: This can be optimized by fetching all the database privileges in one query. - if let Err(e) = create_database(&name, connection).await { - eprintln!("Failed to create database '{}': {}", name, e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("Database '{}' created.", name); - } - } + let result = match server_connection.next().await { + Some(Ok(Response::CreateDatabases(result))) => result, + response => return erroneous_server_response(response), + }; - Ok(result) + server_connection.send(Request::Exit).await?; + + print_create_databases_output_status(&result); + + Ok(()) } async fn drop_databases( args: DatabaseDropArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.name.is_empty() { anyhow::bail!("No database names provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::DropDatabases(args.name.clone()); + server_connection.send(message).await?; - for name in args.name { - // TODO: This can be optimized by fetching all the database privileges in one query. - if let Err(e) = drop_database(&name, connection).await { - eprintln!("Failed to drop database '{}': {}", name, e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("Database '{}' dropped.", name); - } - } + let result = match server_connection.next().await { + Some(Ok(Response::DropDatabases(result))) => result, + response => return erroneous_server_response(response), + }; - Ok(result) + server_connection.send(Request::Exit).await?; + + print_drop_databases_output_status(&result); + + Ok(()) } async fn list_databases( args: DatabaseListArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let databases = get_database_list(connection).await?; + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = Request::ListDatabases; + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::ListAllDatabases(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + let database_list = match result { + Ok(list) => list, + Err(err) => { + return Err(anyhow::anyhow!(err.to_error_message()).context("Failed to list databases")) + } + }; if args.json { - println!("{}", serde_json::to_string_pretty(&databases)?); - return Ok(CommandStatus::NoModificationsIntended); - } - - if databases.is_empty() { + println!("{}", serde_json::to_string_pretty(&database_list)?); + } else if database_list.is_empty() { println!("No databases to show."); } else { - for db in databases { + for db in database_list { println!("{}", db); } } - Ok(CommandStatus::NoModificationsIntended) + Ok(()) } async fn show_database_privileges( args: DatabaseShowPrivsArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let database_users_to_show = if args.name.is_empty() { - get_all_database_privileges(connection).await? + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = if args.name.is_empty() { + Request::ListPrivileges(None) } else { - // TODO: This can be optimized by fetching all the database privileges in one query. - let mut result = Vec::with_capacity(args.name.len()); - for name in args.name { - match get_database_privileges(&name, connection).await { - Ok(db) => result.extend(db), - Err(e) => { - eprintln!("Failed to show database '{}': {}", name, e); + Request::ListPrivileges(Some(args.name.clone())) + }; + server_connection.send(message).await?; + + let privilege_data = match server_connection.next().await { + Some(Ok(Response::ListPrivileges(databases))) => databases + .into_iter() + .filter_map(|(database_name, result)| match result { + Ok(privileges) => Some(privileges), + Err(err) => { + eprintln!("{}", err.to_error_message(&database_name)); eprintln!("Skipping..."); + println!(); + None } + }) + .flatten() + .collect::>(), + Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { + Ok(list) => list, + Err(err) => { + server_connection.send(Request::Exit).await?; + return Err(anyhow::anyhow!(err.to_error_message()) + .context("Failed to list database privileges")); } - } - result + }, + response => return erroneous_server_response(response), }; - if args.json { - println!("{}", serde_json::to_string_pretty(&database_users_to_show)?); - return Ok(CommandStatus::NoModificationsIntended); - } + server_connection.send(Request::Exit).await?; - if database_users_to_show.is_empty() { - println!("No database users to show."); + if args.json { + println!("{}", serde_json::to_string_pretty(&privilege_data)?); + } else if privilege_data.is_empty() { + println!("No database privileges to show."); } else { let mut table = Table::new(); table.add_row(Row::new( @@ -270,7 +300,7 @@ async fn show_database_privileges( .collect(), )); - for row in database_users_to_show { + for row in privilege_data { table.add_row(row![ row.db, row.user, @@ -290,17 +320,40 @@ async fn show_database_privileges( table.printstd(); } - Ok(CommandStatus::NoModificationsIntended) + Ok(()) } -pub async fn edit_privileges( +pub async fn edit_database_privileges( args: DatabaseEditPrivsArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let privilege_data = if let Some(name) = &args.name { - get_database_privileges(name, connection).await? - } else { - get_all_database_privileges(connection).await? + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = Request::ListPrivileges(args.name.clone().map(|name| vec![name])); + + server_connection.send(message).await?; + + let privilege_data = match server_connection.next().await { + Some(Ok(Response::ListPrivileges(databases))) => databases + .into_iter() + .filter_map(|(database_name, result)| match result { + Ok(privileges) => Some(privileges), + Err(err) => { + eprintln!("{}", err.to_error_message(&database_name)); + eprintln!("Skipping..."); + println!(); + None + } + }) + .flatten() + .collect::>(), + Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { + Ok(list) => list, + Err(err) => { + server_connection.send(Request::Exit).await?; + return Err(anyhow::anyhow!(err.to_error_message()) + .context("Failed to list database privileges")); + } + }, + response => return erroneous_server_response(response), }; // TODO: The data from args should not be absolute. @@ -316,22 +369,16 @@ pub async fn edit_privileges( edit_privileges_with_editor(&privilege_data)? }; - for row in privileges_to_change.iter() { - if !user_exists(&row.user, connection).await? { - // TODO: allow user to return and correct their mistake - anyhow::bail!("User {} does not exist", row.user); - } - } - let diffs = diff_privileges(&privilege_data, &privileges_to_change); if diffs.is_empty() { println!("No changes to make."); - return Ok(CommandStatus::NoModificationsNeeded); + return Ok(()); } println!("The following changes will be made:\n"); println!("{}", display_privilege_diffs(&diffs)); + if !args.yes && !Confirm::new() .with_prompt("Do you want to apply these changes?") @@ -339,15 +386,27 @@ pub async fn edit_privileges( .show_default(true) .interact()? { - return Ok(CommandStatus::Cancelled); + server_connection.send(Request::Exit).await?; + return Ok(()); } - apply_privilege_diffs(diffs, connection).await?; + let message = Request::ModifyPrivileges(diffs); + server_connection.send(message).await?; - Ok(CommandStatus::SuccessfullyModified) + let result = match server_connection.next().await { + Some(Ok(Response::ModifyPrivileges(result))) => result, + response => return erroneous_server_response(response), + }; + + // TODO: allow user to return and correct their mistake + print_modify_database_privileges_output_status(&result); + + server_connection.send(Request::Exit).await?; + + Ok(()) } -pub fn parse_privilege_tables_from_args( +fn parse_privilege_tables_from_args( args: &DatabaseEditPrivsArgs, ) -> anyhow::Result> { debug_assert!(!args.privs.is_empty()); @@ -371,20 +430,22 @@ pub fn parse_privilege_tables_from_args( Ok(result) } -pub fn edit_privileges_with_editor( +fn edit_privileges_with_editor( privilege_data: &[DatabasePrivilegeRow], ) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; + let unix_user = User::from_uid(getuid()) + .context("Failed to look up your UNIX username") + .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))?; let editor_content = generate_editor_content_from_privilege_data(privilege_data, &unix_user.name); // TODO: handle errors better here - let result = Editor::new() - .extension("tsv") - .edit(&editor_content)? - .unwrap(); + let result = Editor::new().extension("tsv").edit(&editor_content)?; - parse_privilege_data_from_editor_content(result) - .context("Could not parse privilege data from editor") + match result { + None => Ok(privilege_data.to_vec()), + Some(result) => parse_privilege_data_from_editor_content(result) + .context("Could not parse privilege data from editor"), + } } diff --git a/src/cli/mysql_admutils_compatibility.rs b/src/cli/mysql_admutils_compatibility.rs index df4bf14..561b340 100644 --- a/src/cli/mysql_admutils_compatibility.rs +++ b/src/cli/mysql_admutils_compatibility.rs @@ -1,3 +1,4 @@ pub mod common; +mod error_messages; pub mod mysql_dbadm; pub mod mysql_useradm; diff --git a/src/cli/mysql_admutils_compatibility/common.rs b/src/cli/mysql_admutils_compatibility/common.rs index 506b9c0..f2c7f19 100644 --- a/src/cli/mysql_admutils_compatibility/common.rs +++ b/src/cli/mysql_admutils_compatibility/common.rs @@ -1,57 +1,4 @@ -use crate::core::common::{ - get_current_unix_user, validate_name_or_error, validate_ownership_or_error, DbOrUser, -}; - -/// In contrast to the new implementation which reports errors on any invalid name -/// for any reason, mysql-admutils would only log the error and skip that particular -/// name. This function replicates that behavior. -pub fn filter_db_or_user_names( - names: Vec, - db_or_user: DbOrUser, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user { - DbOrUser::Database => "mysql-dbadm".to_string(), - DbOrUser::User => "mysql-useradm".to_string(), - }); - - let filtered_names = names - .into_iter() - // NOTE: The original implementation would only copy the first 32 characters - // of the argument into it's internal buffer. We replicate that behavior - // here. - .map(|name| name.chars().take(32).collect::()) - .filter(|name| { - if let Err(_err) = validate_ownership_or_error(name, &unix_user, db_or_user) { - println!( - "You are not in charge of mysql-{}: '{}'. Skipping.", - db_or_user.lowercased(), - name - ); - return false; - } - true - }) - .filter(|name| { - // NOTE: while this also checks for the length of the name, - // the name is already truncated to 32 characters. So - // if there is an error, it's guaranteed to be due to - // invalid characters. - if let Err(_err) = validate_name_or_error(name, db_or_user) { - println!( - concat!( - "{}: {} name '{}' contains invalid characters.\n", - "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", - ), - argv0, - db_or_user.capitalized(), - name - ); - return false; - } - true - }) - .collect(); - - Ok(filtered_names) +#[inline] +pub fn trim_to_32_chars(name: &str) -> String { + name.chars().take(32).collect() } diff --git a/src/cli/mysql_admutils_compatibility/error_messages.rs b/src/cli/mysql_admutils_compatibility/error_messages.rs new file mode 100644 index 0000000..76fdc44 --- /dev/null +++ b/src/cli/mysql_admutils_compatibility/error_messages.rs @@ -0,0 +1,176 @@ +use crate::core::protocol::{ + CreateDatabaseError, CreateUserError, DbOrUser, DropDatabaseError, DropUserError, + GetDatabasesPrivilegeDataError, ListUsersError, +}; + +pub fn name_validation_error_to_error_message(name: &str, db_or_user: DbOrUser) -> String { + let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user { + DbOrUser::Database => "mysql-dbadm".to_string(), + DbOrUser::User => "mysql-useradm".to_string(), + }); + + format!( + concat!( + "{}: {} name '{}' contains invalid characters.\n", + "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", + ), + argv0, + db_or_user.capitalized(), + name, + ) +} + +pub fn owner_validation_error_message(name: &str, db_or_user: DbOrUser) -> String { + format!( + "You are not in charge of mysql-{}: '{}'. Skipping.", + db_or_user.lowercased(), + name + ) +} + +pub fn handle_create_user_error(error: CreateUserError, name: &str) { + let argv0 = std::env::args() + .next() + .unwrap_or_else(|| "mysql-useradm".to_string()); + match error { + CreateUserError::SanitizationError(_) => { + eprintln!( + "{}", + name_validation_error_to_error_message(name, DbOrUser::User) + ); + } + CreateUserError::OwnershipError(_) => { + eprintln!("{}", owner_validation_error_message(name, DbOrUser::User)); + } + CreateUserError::MySqlError(_) | CreateUserError::UserAlreadyExists => { + eprintln!("{}: Failed to create user '{}'.", argv0, name); + } + } +} + +pub fn handle_drop_user_error(error: DropUserError, name: &str) { + let argv0 = std::env::args() + .next() + .unwrap_or_else(|| "mysql-useradm".to_string()); + match error { + DropUserError::SanitizationError(_) => { + eprintln!( + "{}", + name_validation_error_to_error_message(name, DbOrUser::User) + ); + } + DropUserError::OwnershipError(_) => { + eprintln!("{}", owner_validation_error_message(name, DbOrUser::User)); + } + DropUserError::MySqlError(_) | DropUserError::UserDoesNotExist => { + eprintln!("{}: Failed to delete user '{}'.", argv0, name); + } + } +} + +pub fn handle_list_users_error(error: ListUsersError, name: &str) { + let argv0 = std::env::args() + .next() + .unwrap_or_else(|| "mysql-useradm".to_string()); + match error { + ListUsersError::SanitizationError(_) => { + eprintln!( + "{}", + name_validation_error_to_error_message(name, DbOrUser::User) + ); + } + ListUsersError::OwnershipError(_) => { + eprintln!("{}", owner_validation_error_message(name, DbOrUser::User)); + } + ListUsersError::UserDoesNotExist => { + eprintln!( + "{}: User '{}' does not exist. You must create it first.", + argv0, name, + ); + } + ListUsersError::MySqlError(_) => { + eprintln!("{}: Failed to look up password for user '{}'", argv0, name); + } + } +} + +// ---------------------------------------------------------------------------- + +pub fn handle_create_database_error(error: CreateDatabaseError, name: &str) { + let argv0 = std::env::args() + .next() + .unwrap_or_else(|| "mysql-dbadm".to_string()); + match error { + CreateDatabaseError::SanitizationError(_) => { + eprintln!( + "{}", + name_validation_error_to_error_message(name, DbOrUser::Database) + ); + } + CreateDatabaseError::OwnershipError(_) => { + eprintln!( + "{}", + owner_validation_error_message(name, DbOrUser::Database) + ); + } + CreateDatabaseError::MySqlError(_) => { + eprintln!("{}: Cannot create database '{}'.", argv0, name); + } + CreateDatabaseError::DatabaseAlreadyExists => { + eprintln!("{}: Database '{}' already exists.", argv0, name); + } + } +} + +pub fn handle_drop_database_error(error: DropDatabaseError, name: &str) { + let argv0 = std::env::args() + .next() + .unwrap_or_else(|| "mysql-dbadm".to_string()); + match error { + DropDatabaseError::SanitizationError(_) => { + eprintln!( + "{}", + name_validation_error_to_error_message(name, DbOrUser::Database) + ); + } + DropDatabaseError::OwnershipError(_) => { + eprintln!( + "{}", + owner_validation_error_message(name, DbOrUser::Database) + ); + } + DropDatabaseError::MySqlError(_) => { + eprintln!("{}: Cannot drop database '{}'.", argv0, name); + } + DropDatabaseError::DatabaseDoesNotExist => { + eprintln!("{}: Database '{}' doesn't exist.", argv0, name); + } + } +} + +pub fn format_show_database_error_message( + error: GetDatabasesPrivilegeDataError, + name: &str, +) -> String { + let argv0 = std::env::args() + .next() + .unwrap_or_else(|| "mysql-dbadm".to_string()); + + match error { + GetDatabasesPrivilegeDataError::SanitizationError(_) => { + name_validation_error_to_error_message(name, DbOrUser::Database) + } + GetDatabasesPrivilegeDataError::OwnershipError(_) => { + owner_validation_error_message(name, DbOrUser::Database) + } + GetDatabasesPrivilegeDataError::MySqlError(err) => { + format!( + "{}: Failed to look up privileges for database '{}': {}", + argv0, name, err + ) + } + GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => { + format!("{}: Database '{}' doesn't exist.", argv0, name) + } + } +} diff --git a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs index b556914..63d6d3a 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs @@ -1,14 +1,29 @@ use clap::Parser; -use sqlx::MySqlConnection; +use futures_util::{SinkExt, StreamExt}; +use std::os::unix::net::UnixStream as StdUnixStream; +use std::path::PathBuf; +use tokio::net::UnixStream as TokioUnixStream; use crate::{ - cli::{database_command, mysql_admutils_compatibility::common::filter_db_or_user_names}, - core::{ - common::{yn, DbOrUser}, - config::{create_mysql_connection_from_config, get_config, GlobalConfigArgs}, - database_operations::{create_database, drop_database, get_database_list}, - database_privilege_operations, + cli::{ + common::erroneous_server_response, + database_command, + mysql_admutils_compatibility::{ + common::trim_to_32_chars, + error_messages::{ + format_show_database_error_message, handle_create_database_error, + handle_drop_database_error, + }, + }, }, + core::{ + bootstrap::bootstrap_server_connection_and_drop_privileges, + protocol::{ + create_client_to_server_message_stream, ClientToServerMessageStream, + GetDatabasesPrivilegeDataError, Request, Response, + }, + }, + server::sql::database_privilege_operations::DatabasePrivilegeRow, }; const HELP_DB_PERM: &str = r#" @@ -39,8 +54,25 @@ pub struct Args { #[command(subcommand)] pub command: Option, - #[command(flatten)] - config_overrides: GlobalConfigArgs, + /// Path to the socket of the server, if it already exists. + #[arg( + short, + long, + value_name = "PATH", + global = true, + hide_short_help = true + )] + server_socket_path: Option, + + /// Config file to use for the server. + #[arg( + short, + long, + value_name = "PATH", + global = true, + hide_short_help = true + )] + config: Option, /// Print help for the 'editperm' subcommand. #[arg(long, global = true)] @@ -76,7 +108,7 @@ pub enum Command { /// to make changes to the permission table. /// Run 'mysql-dbadm --help-editperm' for more /// information. - EditPerm(EditPermArgs), + Editperm(EditPermArgs), } #[derive(Parser)] @@ -106,7 +138,7 @@ pub struct EditPermArgs { pub database: String, } -pub async fn main() -> anyhow::Result<()> { +pub fn main() -> anyhow::Result<()> { let args: Args = Args::parse(); if args.help_editperm { @@ -114,6 +146,9 @@ pub async fn main() -> anyhow::Result<()> { return Ok(()); } + let server_connection = + bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; + let command = match args.command { Some(command) => command, None => { @@ -125,64 +160,164 @@ pub async fn main() -> anyhow::Result<()> { } }; - let config = get_config(args.config_overrides)?; - let mut connection = create_mysql_connection_from_config(config.mysql).await?; + tokio_run_command(command, server_connection)?; - match command { - Command::Create(args) => { - let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; - for name in filtered_names { - create_database(&name, &mut connection).await?; - println!("Database {} created.", name); - } - } - Command::Drop(args) => { - let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; - for name in filtered_names { - drop_database(&name, &mut connection).await?; - println!("Database {} dropped.", name); - } - } - Command::Show(args) => { - let names = if args.name.is_empty() { - get_database_list(&mut connection).await? - } else { - filter_db_or_user_names(args.name, DbOrUser::Database)? - }; + Ok(()) +} - for name in names { - show_db(&name, &mut connection).await?; - } - } - Command::EditPerm(args) => { - // TODO: This does not accurately replicate the behavior of the old implementation. - // Hopefully, not many people rely on this in an automated fashion, as it - // is made to be interactive in nature. However, we should still try to - // replicate the old behavior as closely as possible. - let edit_privileges_args = database_command::DatabaseEditPrivsArgs { - name: Some(args.database), - privs: vec![], - json: false, - editor: None, - yes: false, - }; +fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let tokio_socket = TokioUnixStream::from_std(server_connection)?; + let message_stream = create_client_to_server_message_stream(tokio_socket); + match command { + Command::Create(args) => create_databases(args, message_stream).await, + Command::Drop(args) => drop_databases(args, message_stream).await, + Command::Show(args) => show_databases(args, message_stream).await, + Command::Editperm(args) => { + let edit_privileges_args = database_command::DatabaseEditPrivsArgs { + name: Some(args.database), + privs: vec![], + json: false, + // TODO: use this to mimic the old editor-finding logic + editor: None, + yes: false, + }; - database_command::edit_privileges(edit_privileges_args, &mut connection).await?; + database_command::edit_database_privileges(edit_privileges_args, message_stream) + .await + } + } + }) +} + +async fn create_databases( + args: CreateArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let database_names = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = Request::CreateDatabases(database_names); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::CreateDatabases(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (name, result) in result { + match result { + Ok(()) => println!("Database {} created.", name), + Err(err) => handle_create_database_error(err, &name), } } Ok(()) } -async fn show_db(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { +async fn drop_databases( + args: DatabaseDropArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let database_names = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = Request::DropDatabases(database_names); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::DropDatabases(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (name, result) in result { + match result { + Ok(()) => println!("Database {} dropped.", name), + Err(err) => handle_drop_database_error(err, &name), + } + } + + Ok(()) +} + +async fn show_databases( + args: DatabaseShowArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let database_names: Vec = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = if database_names.is_empty() { + let message = Request::ListDatabases; + server_connection.send(message).await?; + let response = server_connection.next().await; + let databases = match response { + Some(Ok(Response::ListAllDatabases(databases))) => databases.unwrap_or(vec![]), + response => return erroneous_server_response(response), + }; + + Request::ListPrivileges(Some(databases)) + } else { + Request::ListPrivileges(Some(database_names)) + }; + server_connection.send(message).await?; + + let response = server_connection.next().await; + + server_connection.send(Request::Exit).await?; + // NOTE: mysql-dbadm show has a quirk where valid database names // for non-existent databases will report with no users. - // This function should *not* check for db existence, only - // validate the names. - let privileges = database_privilege_operations::get_database_privileges(name, connection) - .await - .unwrap_or(vec![]); + let results: Vec), String>> = match response { + Some(Ok(Response::ListPrivileges(result))) => result + .into_iter() + .map(|(name, rows)| match rows.map(|rows| (name.clone(), rows)) { + Ok(rows) => Ok(rows), + Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])), + Err(err) => Err(format_show_database_error_message(err, &name)), + }) + .collect(), + response => return erroneous_server_response(response), + }; + results.into_iter().try_for_each(|result| match result { + Ok((name, rows)) => print_db_privs(&name, rows), + Err(err) => { + eprintln!("{}", err); + Ok(()) + } + })?; + + Ok(()) +} + +#[inline] +fn yn(value: bool) -> &'static str { + if value { + "Y" + } else { + "N" + } +} + +fn print_db_privs(name: &str, rows: Vec) -> anyhow::Result<()> { println!( concat!( "Database '{}':\n", @@ -191,10 +326,10 @@ async fn show_db(name: &str, connection: &mut MySqlConnection) -> anyhow::Result ), name, ); - if privileges.is_empty() { + if rows.is_empty() { println!("# (no permissions currently granted to any users)"); } else { - for privilege in privileges { + for privilege in rows { println!( " {:<16} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {}", privilege.user, diff --git a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs index cce2f44..5f17494 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs @@ -1,13 +1,28 @@ use clap::Parser; -use sqlx::MySqlConnection; +use futures_util::{SinkExt, StreamExt}; +use std::path::PathBuf; + +use std::os::unix::net::UnixStream as StdUnixStream; +use tokio::net::UnixStream as TokioUnixStream; use crate::{ - cli::{mysql_admutils_compatibility::common::filter_db_or_user_names, user_command}, - core::{ - common::{close_database_connection, get_current_unix_user, DbOrUser}, - config::{create_mysql_connection_from_config, get_config, GlobalConfigArgs}, - user_operations::*, + cli::{ + common::erroneous_server_response, + mysql_admutils_compatibility::{ + common::trim_to_32_chars, + error_messages::{ + handle_create_user_error, handle_drop_user_error, handle_list_users_error, + }, + }, + user_command::read_password_from_stdin_with_double_check, }, + core::{ + bootstrap::bootstrap_server_connection_and_drop_privileges, + protocol::{ + create_client_to_server_message_stream, ClientToServerMessageStream, Request, Response, + }, + }, + server::sql::user_operations::DatabaseUser, }; #[derive(Parser)] @@ -15,8 +30,25 @@ pub struct Args { #[command(subcommand)] pub command: Option, - #[command(flatten)] - config_overrides: GlobalConfigArgs, + /// Path to the socket of the server, if it already exists. + #[arg( + short, + long, + value_name = "PATH", + global = true, + hide_short_help = true + )] + server_socket_path: Option, + + /// Config file to use for the server. + #[arg( + short, + long, + value_name = "PATH", + global = true, + hide_short_help = true + )] + config: Option, } /// Create, delete or change password for the USER(s), @@ -69,7 +101,7 @@ pub struct ShowArgs { name: Vec, } -pub async fn main() -> anyhow::Result<()> { +pub fn main() -> anyhow::Result<()> { let args: Args = Args::parse(); let command = match args.command { @@ -85,78 +117,185 @@ pub async fn main() -> anyhow::Result<()> { } }; - let config = get_config(args.config_overrides)?; - let mut connection = create_mysql_connection_from_config(config.mysql).await?; + let server_connection = + bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; - match command { - Command::Create(args) => { - let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?; - for name in filtered_names { - create_database_user(&name, &mut connection).await?; - } - } - Command::Delete(args) => { - let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?; - for name in filtered_names { - delete_database_user(&name, &mut connection).await?; - } - } - Command::Passwd(args) => passwd(args, &mut connection).await?, - Command::Show(args) => show(args, &mut connection).await?, - } - - close_database_connection(connection).await; + tokio_run_command(command, server_connection)?; Ok(()) } -async fn passwd(args: PasswdArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> { - let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?; - - // NOTE: this gets doubly checked during the call to `set_password_for_database_user`. - // This is moving the check before asking the user for the password, - // to avoid having them figure out that the user does not exist after they - // have entered the password twice. - let mut better_filtered_names = Vec::with_capacity(filtered_names.len()); - for name in filtered_names.into_iter() { - if !user_exists(&name, connection).await? { - println!( - "{}: User '{}' does not exist. You must create it first.", - std::env::args() - .next() - .unwrap_or("mysql-useradm".to_string()), - name, - ); - } else { - better_filtered_names.push(name); - } - } - - for name in better_filtered_names { - let password = user_command::read_password_from_stdin_with_double_check(&name)?; - set_password_for_database_user(&name, &password, connection).await?; - println!("Password updated for user '{}'.", name); - } - - Ok(()) +fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let tokio_socket = TokioUnixStream::from_std(server_connection)?; + let message_stream = create_client_to_server_message_stream(tokio_socket); + match command { + Command::Create(args) => create_user(args, message_stream).await, + Command::Delete(args) => drop_users(args, message_stream).await, + Command::Passwd(args) => passwd_users(args, message_stream).await, + Command::Show(args) => show_users(args, message_stream).await, + } + }) } -async fn show(args: ShowArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> { - let users = if args.name.is_empty() { - let unix_user = get_current_unix_user()?; - get_all_database_users_for_unix_user(&unix_user, connection).await? - } else { - let filtered_usernames = filter_db_or_user_names(args.name, DbOrUser::User)?; - let mut result = Vec::with_capacity(filtered_usernames.len()); - for username in filtered_usernames.iter() { - // TODO: fetch all users in one query - if let Some(user) = get_database_user_for_user(username, connection).await? { - result.push(user) - } - } - result +async fn create_user( + args: CreateArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let usernames = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = Request::CreateUsers(usernames); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::CreateUsers(result))) => result, + response => return erroneous_server_response(response), }; + server_connection.send(Request::Exit).await?; + + for (name, result) in result { + match result { + Ok(()) => println!("User '{}' created.", name), + Err(err) => handle_create_user_error(err, &name), + } + } + + Ok(()) +} + +async fn drop_users( + args: DeleteArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let usernames = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = Request::DropUsers(usernames); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::DropUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (name, result) in result { + match result { + Ok(()) => println!("User '{}' deleted.", name), + Err(err) => handle_drop_user_error(err, &name), + } + } + + Ok(()) +} + +async fn passwd_users( + args: PasswdArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let usernames = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = Request::ListUsers(Some(usernames)); + server_connection.send(message).await?; + + let response = match server_connection.next().await { + Some(Ok(Response::ListUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + let argv0 = std::env::args() + .next() + .unwrap_or("mysql-useradm".to_string()); + + let users = response + .into_iter() + .filter_map(|(name, result)| match result { + Ok(user) => Some(user), + Err(err) => { + handle_list_users_error(err, &name); + None + } + }) + .collect::>(); + + for user in users { + let password = read_password_from_stdin_with_double_check(&user.user)?; + let message = Request::PasswdUser(user.user.clone(), password); + server_connection.send(message).await?; + match server_connection.next().await { + Some(Ok(Response::PasswdUser(result))) => match result { + Ok(()) => println!("Password updated for user '{}'.", user.user), + Err(_) => eprintln!( + "{}: Failed to update password for user '{}'.", + argv0, user.user, + ), + }, + response => return erroneous_server_response(response), + } + } + + server_connection.send(Request::Exit).await?; + + Ok(()) +} + +async fn show_users( + args: ShowArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let usernames: Vec<_> = args + .name + .iter() + .map(|name| trim_to_32_chars(name)) + .collect(); + + let message = if usernames.is_empty() { + Request::ListUsers(None) + } else { + Request::ListUsers(Some(usernames)) + }; + server_connection.send(message).await?; + + let users: Vec = match server_connection.next().await { + Some(Ok(Response::ListAllUsers(result))) => match result { + Ok(users) => users, + Err(err) => { + println!("Failed to list users: {:?}", err); + return Ok(()); + } + }, + Some(Ok(Response::ListUsers(result))) => result + .into_iter() + .filter_map(|(name, result)| match result { + Ok(user) => Some(user), + Err(err) => { + handle_list_users_error(err, &name); + None + } + }) + .collect(), + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + for user in users { if user.has_password { println!("User '{}': password set.", user.user); diff --git a/src/cli/user_command.rs b/src/cli/user_command.rs index fd79acc..2b841e8 100644 --- a/src/cli/user_command.rs +++ b/src/cli/user_command.rs @@ -1,27 +1,24 @@ -use std::collections::BTreeMap; -use std::vec; - use anyhow::Context; use clap::Parser; use dialoguer::{Confirm, Password}; -use prettytable::Table; -use serde_json::json; -use sqlx::{Connection, MySqlConnection}; +use futures_util::{SinkExt, StreamExt}; -use crate::core::{ - common::{close_database_connection, get_current_unix_user, CommandStatus}, - database_operations::*, - user_operations::*, +use crate::core::protocol::{ + print_create_users_output_status, print_drop_users_output_status, + print_lock_users_output_status, print_set_password_output_status, + print_unlock_users_output_status, ClientToServerMessageStream, Request, Response, }; -#[derive(Parser)] +use super::common::erroneous_server_response; + +#[derive(Parser, Debug, Clone)] pub struct UserArgs { #[clap(subcommand)] subcmd: UserCommand, } #[allow(clippy::enum_variant_names)] -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub enum UserCommand { /// Create one or more users #[command()] @@ -50,7 +47,7 @@ pub enum UserCommand { UnlockUser(UserUnlockArgs), } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserCreateArgs { #[arg(num_args = 1..)] username: Vec, @@ -60,13 +57,13 @@ pub struct UserCreateArgs { no_password: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserDeleteArgs { #[arg(num_args = 1..)] username: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserPasswdArgs { username: String, @@ -74,7 +71,7 @@ pub struct UserPasswdArgs { password_file: Option, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserShowArgs { #[arg(num_args = 0..)] username: Vec, @@ -83,13 +80,13 @@ pub struct UserShowArgs { json: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserLockArgs { #[arg(num_args = 1..)] username: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserUnlockArgs { #[arg(num_args = 1..)] username: Vec, @@ -97,48 +94,45 @@ pub struct UserUnlockArgs { pub async fn handle_command( command: UserCommand, - mut connection: MySqlConnection, -) -> anyhow::Result { - let result = connection - .transaction(|txn| { - Box::pin(async move { - match command { - UserCommand::CreateUser(args) => create_users(args, txn).await, - UserCommand::DropUser(args) => drop_users(args, txn).await, - UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await, - UserCommand::ShowUser(args) => show_users(args, txn).await, - UserCommand::LockUser(args) => lock_users(args, txn).await, - UserCommand::UnlockUser(args) => unlock_users(args, txn).await, - } - }) - }) - .await; - - close_database_connection(connection).await; - - result + server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + match command { + UserCommand::CreateUser(args) => create_users(args, server_connection).await, + UserCommand::DropUser(args) => drop_users(args, server_connection).await, + UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await, + UserCommand::ShowUser(args) => show_users(args, server_connection).await, + UserCommand::LockUser(args) => lock_users(args, server_connection).await, + UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await, + } } async fn create_users( args: UserCreateArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::CreateUsers(args.username.clone()); + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); + } - for username in args.username { - if let Err(e) = create_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping...\n"); - result = CommandStatus::PartiallySuccessfullyModified; - continue; - } else { - println!("User '{}' created.", username); - } + let result = match server_connection.next().await { + Some(Ok(Response::CreateUsers(result))) => result, + response => return erroneous_server_response(response), + }; + print_create_users_output_status(&result); + + let successfully_created_users = result + .iter() + .filter_map(|(username, result)| result.as_ref().ok().map(|_| username)) + .collect::>(); + + for username in successfully_created_users { if !args.no_password && Confirm::new() .with_prompt(format!( @@ -147,41 +141,55 @@ async fn create_users( )) .interact()? { - change_password_for_user( - UserPasswdArgs { - username, - password_file: None, - }, - connection, - ) - .await?; + let password = read_password_from_stdin_with_double_check(username)?; + let message = Request::PasswdUser(username.clone(), password); + + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + match server_connection.next().await { + Some(Ok(Response::PasswdUser(result))) => { + print_set_password_output_status(&result, username) + } + response => return erroneous_server_response(response), + } + + println!(); } - println!(); } - Ok(result) + + server_connection.send(Request::Exit).await?; + + Ok(()) } async fn drop_users( args: UserDeleteArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::DropUsers(args.username.clone()); - for username in args.username { - if let Err(e) = delete_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("User '{}' dropped.", username); - } + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); } - Ok(result) + let result = match server_connection.next().await { + Some(Ok(Response::DropUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + print_drop_users_output_status(&result); + + Ok(()) } pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result { @@ -195,15 +203,10 @@ pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Res .map_err(Into::into) } -async fn change_password_for_user( +async fn passwd_user( args: UserPasswdArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - // NOTE: although this also is checked in `set_password_for_database_user`, we check it here - // to provide a more natural order of error messages. - let unix_user = get_current_unix_user()?; - validate_user_name(&args.username, &unix_user)?; - + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { let password = if let Some(password_file) = args.password_file { std::fs::read_to_string(password_file) .context("Failed to read password file")? @@ -213,129 +216,146 @@ async fn change_password_for_user( read_password_from_stdin_with_double_check(&args.username)? }; - set_password_for_database_user(&args.username, &password, connection).await?; + let message = Request::PasswdUser(args.username.clone(), password); - Ok(CommandStatus::SuccessfullyModified) + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + let result = match server_connection.next().await { + Some(Ok(Response::PasswdUser(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + print_set_password_output_status(&result, &args.username); + + Ok(()) } async fn show_users( args: UserShowArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let unix_user = get_current_unix_user()?; - - let users = if args.username.is_empty() { - get_all_database_users_for_unix_user(&unix_user, connection).await? + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = if args.username.is_empty() { + Request::ListUsers(None) } else { - let mut result = vec![]; - for username in args.username { - if let Err(e) = validate_user_name(&username, &unix_user) { - eprintln!("{}", e); - eprintln!("Skipping..."); - continue; - } - - let user = get_database_user_for_user(&username, connection).await?; - if let Some(user) = user { - result.push(user); - } else { - eprintln!("User not found: {}", username); - } - } - result + Request::ListUsers(Some(args.username.clone())) }; - let mut user_databases: BTreeMap> = BTreeMap::new(); - for user in users.iter() { - user_databases.insert( - user.user.clone(), - get_databases_where_user_has_privileges(&user.user, connection).await?, - ); + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); } - if args.json { - let users_json = users + let users = match server_connection.next().await { + Some(Ok(Response::ListUsers(users))) => users .into_iter() - .map(|user| { - json!({ - "user": user.user, - "has_password": user.has_password, - "is_locked": user.is_locked, - "databases": user_databases.get(&user.user).unwrap_or(&vec![]), - }) + .filter_map(|(username, result)| match result { + Ok(user) => Some(user), + Err(err) => { + eprintln!("{}", err.to_error_message(&username)); + eprintln!("Skipping..."); + None + } }) - .collect::(); + .collect::>(), + Some(Ok(Response::ListAllUsers(users))) => match users { + Ok(users) => users, + Err(err) => { + server_connection.send(Request::Exit).await?; + return Err( + anyhow::anyhow!(err.to_error_message()).context("Failed to list all users") + ); + } + }, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + // TODO: print databases where user has privileges + if args.json { println!( "{}", - serde_json::to_string_pretty(&users_json) - .context("Failed to serialize users to JSON")? + serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")? ); } else if users.is_empty() { - println!("No users found."); + println!("No users to show."); } else { - let mut table = Table::new(); + let mut table = prettytable::Table::new(); table.add_row(row![ "User", "Password is set", "Locked", - "Databases where user has privileges" + // "Databases where user has privileges" ]); for user in users { table.add_row(row![ user.user, user.has_password, user.is_locked, - user_databases.get(&user.user).unwrap_or(&vec![]).join("\n") + // user.databases.join("\n") ]); } table.printstd(); } - Ok(CommandStatus::NoModificationsIntended) + Ok(()) } async fn lock_users( args: UserLockArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::LockUsers(args.username.clone()); - for username in args.username { - if let Err(e) = lock_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("User '{}' locked.", username); - } + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); } - Ok(result) + let result = match server_connection.next().await { + Some(Ok(Response::LockUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + print_lock_users_output_status(&result); + + Ok(()) } async fn unlock_users( args: UserUnlockArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::UnlockUsers(args.username.clone()); - for username in args.username { - if let Err(e) = unlock_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("User '{}' unlocked.", username); - } + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); } - Ok(result) + let result = match server_connection.next().await { + Some(Ok(Response::UnlockUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + print_unlock_users_output_status(&result); + + Ok(()) } diff --git a/src/core.rs b/src/core.rs index aa51dca..565d038 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,5 +1,4 @@ +pub mod bootstrap; pub mod common; -pub mod config; -pub mod database_operations; -pub mod database_privilege_operations; -pub mod user_operations; +pub mod database_privileges; +pub mod protocol; diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs new file mode 100644 index 0000000..fdb6408 --- /dev/null +++ b/src/core/bootstrap.rs @@ -0,0 +1,177 @@ +use std::{fs, path::PathBuf}; + +use anyhow::Context; +use nix::libc::{exit, EXIT_SUCCESS}; +use std::os::unix::net::UnixStream as StdUnixStream; +use tokio::net::UnixStream as TokioUnixStream; + +use crate::{ + core::{ + bootstrap::authenticated_unix_socket::client_authenticate, + common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, + }, + server::{config::read_config_form_path, server_loop::handle_requests_for_single_session}, +}; + +pub mod authenticated_unix_socket; + +// TODO: this function is security critical, it should be integration tested +// in isolation. +/// Drop privileges to the real user and group of the process. +/// If the process is not running with elevated privileges, this function +/// is a no-op. +pub fn drop_privs() -> anyhow::Result<()> { + log::debug!("Dropping privileges"); + let real_uid = nix::unistd::getuid(); + let real_gid = nix::unistd::getgid(); + + nix::unistd::setuid(real_uid).context("Failed to drop privileges")?; + nix::unistd::setgid(real_gid).context("Failed to drop privileges")?; + + debug_assert_eq!(nix::unistd::getuid(), real_uid); + debug_assert_eq!(nix::unistd::getgid(), real_gid); + + log::debug!("Privileges dropped successfully"); + Ok(()) +} + +/// This function is used to bootstrap the connection to the server. +/// This can happen in two ways: +/// 1. If a socket path is provided, or exists in the default location, +/// the function will connect to the socket and authenticate with the +/// server to ensure that the server knows the uid of the client. +/// 2. If a config path is provided, or exists in the default location, +/// and the config is readable, the function will assume it is either +/// setuid or setgid, and will fork a child process to run the server +/// with the provided config. The server will exit silently by itself +/// when it is done, and this function will only return for the client +/// with the socket for the server. +/// If neither of these options are available, the function will fail. +pub fn bootstrap_server_connection_and_drop_privileges( + server_socket_path: Option, + config_path: Option, +) -> anyhow::Result { + if server_socket_path.is_some() && config_path.is_some() { + anyhow::bail!("Cannot provide both a socket path and a config path"); + } + + log::debug!("Starting the server connection bootstrap process"); + + let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?; + + drop_privs()?; + + let result: anyhow::Result = if do_authenticate { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let mut socket = TokioUnixStream::from_std(socket)?; + client_authenticate(&mut socket, None).await?; + Ok(socket.into_std()?) + }) + } else { + Ok(socket) + }; + + result +} + +/// Inner function for [`bootstrap_server_connection_and_drop_privileges`]. +/// See that function for more information. +fn bootstrap_server_connection( + socket_path: Option, + config_path: Option, +) -> anyhow::Result<(StdUnixStream, bool)> { + // TODO: ensure this is both readable and writable + if let Some(socket_path) = socket_path { + log::debug!("Connecting to socket at {:?}", socket_path); + return match StdUnixStream::connect(socket_path) { + Ok(socket) => Ok((socket, true)), + Err(e) => match e.kind() { + std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), + std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), + _ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)), + }, + }; + } + if let Some(config_path) = config_path { + // ensure config exists and is readable + if fs::metadata(&config_path).is_err() { + return Err(anyhow::anyhow!("Config file not found or not readable")); + } + + log::debug!("Starting server with config at {:?}", config_path); + return invoke_server_with_config(config_path).map(|socket| (socket, false)); + } + + if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { + return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) { + Ok(socket) => Ok((socket, true)), + Err(e) => match e.kind() { + std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), + std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), + _ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)), + }, + }; + } + + let config_path = PathBuf::from(DEFAULT_CONFIG_PATH); + if fs::metadata(&config_path).is_ok() { + return invoke_server_with_config(config_path).map(|socket| (socket, false)); + } + + anyhow::bail!("No socket path or config path provided, and no default socket or config found"); +} + +// TODO: we should somehow ensure that the forked process is killed on completion, +// just in case the client does not behave properly. +/// Fork a child process to run the server with the provided config. +/// The server will exit silently by itself when it is done, and this function +/// will only return for the client with the socket for the server. +fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result { + let (server_socket, client_socket) = StdUnixStream::pair()?; + let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?; + + match (unsafe { nix::unistd::fork() }).context("Failed to fork")? { + nix::unistd::ForkResult::Parent { child } => { + log::debug!("Forked child process with PID {}", child); + Ok(client_socket) + } + nix::unistd::ForkResult::Child => { + log::debug!("Running server in child process"); + + match run_forked_server(config_path, server_socket, unix_user) { + Err(e) => Err(e), + Ok(_) => unreachable!(), + } + } + } +} + +/// Run the server in the forked child process. +/// This function will not return, but will exit the process with a success code. +fn run_forked_server( + config_path: PathBuf, + server_socket: StdUnixStream, + unix_user: UnixUser, +) -> anyhow::Result<()> { + let config = read_config_form_path(Some(config_path))?; + + let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let socket = TokioUnixStream::from_std(server_socket)?; + handle_requests_for_single_session(socket, &unix_user, &config).await?; + Ok(()) + }); + + result?; + + unsafe { + exit(EXIT_SUCCESS); + } +} diff --git a/src/authenticated_unix_socket.rs b/src/core/bootstrap/authenticated_unix_socket.rs similarity index 83% rename from src/authenticated_unix_socket.rs rename to src/core/bootstrap/authenticated_unix_socket.rs index 0092b5f..cabfcd8 100644 --- a/src/authenticated_unix_socket.rs +++ b/src/core/bootstrap/authenticated_unix_socket.rs @@ -30,10 +30,13 @@ //! Also note that it is essential that the client does not send any sensitive information //! over it's authentication socket, since it is readable by any user on the system. +// TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream + use std::os::unix::io::AsRawFd; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination}; +use derive_more::derive::{Display, Error}; use futures::{SinkExt, StreamExt}; use nix::{sys::stat, unistd::Uid}; use rand::distributions::Alphanumeric; @@ -52,7 +55,7 @@ pub enum ClientRequest { Cancel, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)] pub enum ServerResponse { Authenticated, ChallengeDidNotMatch, @@ -61,7 +64,7 @@ pub enum ServerResponse { // TODO: wrap more data into the errors -#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)] pub enum ServerError { InvalidRequest, UnableToReadPermissionsFromAuthSocket, @@ -72,7 +75,7 @@ pub enum ServerError { InvalidChallenge, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Display, Error)] pub enum ClientError { UnableToConnectToServer, UnableToOpenAuthSocket, @@ -80,13 +83,12 @@ pub enum ClientError { AuthSocketClosedEarly, UnableToCloseAuthSocket, AuthenticationError, - InvalidServerResponse(ServerResponse), UnableToParseServerResponse, NoServerResponse, ServerError(ServerError), } -async fn create_auth_socket(socket_addr: &str) -> Result { +async fn create_auth_socket(socket_addr: &PathBuf) -> Result { let auth_socket = UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?; @@ -109,11 +111,13 @@ type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDest // TODO: add timeout +// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR + const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock"; + pub async fn client_authenticate( normal_socket: &mut UnixStream, - #[cfg(not(test))] auth_socket_dir: Option, - #[cfg(test)] auth_socket_file: Option, + auth_socket_dir: Option, ) -> Result<(), ClientError> { let random_prefix: String = rand::thread_rng() .sample_iter(&Alphanumeric) @@ -123,32 +127,16 @@ pub async fn client_authenticate( let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME); - #[cfg(not(test))] - let auth_socket_address = match auth_socket_dir { - Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(), - None => std::env::temp_dir() - .join(socket_name) - .to_str() - .unwrap() - .to_string(), - }; - - #[cfg(test)] - let auth_socket_address = match auth_socket_file { - Some(file) => file.to_str().unwrap().to_string(), - None => std::env::temp_dir() - .join(socket_name) - .to_str() - .unwrap() - .to_string(), - }; + let auth_socket_address = auth_socket_dir + .unwrap_or(std::env::temp_dir()) + .join(socket_name); client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await } async fn client_authenticate_with_auth_socket_address( normal_socket: &mut UnixStream, - auth_socket_address: &str, + auth_socket_address: &PathBuf, ) -> Result<(), ClientError> { let auth_socket = create_auth_socket(auth_socket_address).await?; @@ -164,7 +152,7 @@ async fn client_authenticate_with_auth_socket_address( async fn client_authenticate_with_auth_socket( normal_socket: &mut UnixStream, auth_socket: UnixListener, - auth_socket_address: &str, + auth_socket_address: &Path, ) -> Result<(), ClientError> { let challenge = rand::random::(); let uid = nix::unistd::getuid(); @@ -199,7 +187,10 @@ async fn client_authenticate_with_auth_socket( let client_hello = ClientRequest::Initialize { uid: uid.into(), challenge, - auth_socket: auth_socket_address.to_string(), + auth_socket: auth_socket_address + .to_str() + .ok_or(ClientError::UnableToConfigureAuthSocket)? + .to_owned(), }; normal_socket @@ -239,9 +230,13 @@ macro_rules! report_server_error_and_return { }}; } -async fn server_authenticate( +pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result { + _server_authenticate(normal_socket, None).await +} + +pub async fn _server_authenticate( normal_socket: &mut UnixStream, - #[cfg(test)] unix_user_uid: Option, + unix_user_uid: Option, ) -> Result { let mut normal_socket: ServerToClientStream = AsyncBincodeStream::from(normal_socket).for_async(); @@ -256,22 +251,15 @@ async fn server_authenticate( _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest), }; - #[cfg(test)] let auth_socket_uid = match unix_user_uid { Some(uid) => uid, - None => report_server_error_and_return!( - normal_socket, - ServerError::UnableToReadPermissionsFromAuthSocket - ), - }; - - #[cfg(not(test))] - let auth_socket_uid = match stat::stat(auth_socket.as_str()) { - Ok(stat) => stat.st_uid, - Err(_err) => report_server_error_and_return!( - normal_socket, - ServerError::UnableToReadPermissionsFromAuthSocket - ), + None => match stat::stat(auth_socket.as_str()) { + Ok(stat) => stat.st_uid, + Err(_err) => report_server_error_and_return!( + normal_socket, + ServerError::UnableToReadPermissionsFromAuthSocket + ), + }, }; if uid != auth_socket_uid { @@ -324,10 +312,7 @@ mod test { let client_handle = tokio::spawn(async move { client_authenticate(&mut client, None).await }); - let server_handle = tokio::spawn(async move { - let uid = nix::unistd::getuid().into(); - server_authenticate(&mut server, Some(uid)).await - }); + let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); client_handle.await.unwrap().unwrap(); server_handle.await.unwrap().unwrap(); @@ -340,15 +325,12 @@ mod test { let client_handle = tokio::spawn(async move { client_authenticate_with_auth_socket_address( &mut client, - "/tmp/test_auth_socket_does_not_exist.sock", + &PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"), ) .await }); - let server_handle = tokio::spawn(async move { - let uid = nix::unistd::getuid().into(); - server_authenticate(&mut server, Some(uid)).await - }); + let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); client_handle.await.unwrap().unwrap(); server_handle.await.unwrap().unwrap(); @@ -365,7 +347,7 @@ mod test { let server_handle = tokio::spawn(async move { let uid: u32 = nix::unistd::getuid().into(); - let err = server_authenticate(&mut server, Some(uid + 1)).await; + let err = _server_authenticate(&mut server, Some(uid + 1)).await; assert_eq!(err, Err(ServerError::UidMismatch)); }); @@ -379,13 +361,19 @@ mod test { let socket_path = std::env::temp_dir().join("socket_to_snoop.sock"); let socket_path_clone = socket_path.clone(); - let client_handle = - tokio::spawn( - async move { client_authenticate(&mut client, Some(socket_path_clone)).await }, - ); + let client_handle = tokio::spawn(async move { + client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await + }); - while !socket_path.exists() { - sleep(std::time::Duration::from_millis(10)).await; + for i in 0..100 { + if socket_path.exists() { + break; + } + sleep(Duration::from_millis(10)).await; + + if i == 99 { + panic!("Socket not created after 1 second, assuming test failure"); + } } let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); @@ -409,10 +397,7 @@ mod test { sleep(Duration::from_millis(10)).await; - let server_handle = tokio::spawn(async move { - let uid: u32 = nix::unistd::getuid().into(); - server_authenticate(&mut server, Some(uid)).await - }); + let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); client_handle.await.unwrap().unwrap(); server_handle.await.unwrap().unwrap(); diff --git a/src/core/common.rs b/src/core/common.rs index 9beb916..2d4045c 100644 --- a/src/core/common.rs +++ b/src/core/common.rs @@ -1,56 +1,32 @@ use anyhow::Context; -use indoc::indoc; -use itertools::Itertools; -use nix::unistd::{getuid, Group, User}; -use sqlx::{Connection, MySqlConnection}; +use nix::unistd::{Group as LibcGroup, User as LibcUser}; #[cfg(not(target_os = "macos"))] use std::ffi::CString; -/// Report the result status of a command. -/// This is used to display a status message to the user. -pub enum CommandStatus { - /// The command was successful, - /// and made modification to the database. - SuccessfullyModified, +pub const DEFAULT_CONFIG_PATH: &str = "/etc/mysqladm/config.toml"; +pub const DEFAULT_SOCKET_PATH: &str = "/run/mysqladm/mysqladm.sock"; - /// The command was mostly successful, - /// and modifications have been made to the database. - /// However, some of the requested modifications failed. - PartiallySuccessfullyModified, - - /// The command was successful, - /// but no modifications were needed. - NoModificationsNeeded, - - /// The command was successful, - /// and made no modification to the database. - NoModificationsIntended, - - /// The command was cancelled, either through a dialog or a signal. - /// No modifications have been made to the database. - Cancelled, +pub struct UnixUser { + pub username: String, + pub groups: Vec, } -pub fn get_current_unix_user() -> anyhow::Result { - User::from_uid(getuid()) - .context("Failed to look up your UNIX username") - .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))) -} +// TODO: these functions are somewhat critical, and should have integration tests #[cfg(target_os = "macos")] -pub fn get_unix_groups(_user: &User) -> anyhow::Result> { +fn get_unix_groups(_user: &LibcUser) -> anyhow::Result> { // Return an empty list on macOS since there is no `getgrouplist` function Ok(vec![]) } #[cfg(not(target_os = "macos"))] -pub fn get_unix_groups(user: &User) -> anyhow::Result> { +fn get_unix_groups(user: &LibcUser) -> anyhow::Result> { let user_cstr = CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?; let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)? .iter() - .filter_map(|gid| match Group::from_gid(*gid) { + .filter_map(|gid| match LibcGroup::from_gid(*gid) { Ok(Some(group)) => Some(group), Ok(None) => None, Err(e) => { @@ -62,211 +38,32 @@ pub fn get_unix_groups(user: &User) -> anyhow::Result> { None } }) - .collect::>(); + .collect::>(); Ok(groups) } -/// This function creates a regex that matches items (users, databases) -/// that belong to the user or any of the user's groups. -pub fn create_user_group_matching_regex(user: &User) -> String { - let groups = get_unix_groups(user).unwrap_or_default(); +impl UnixUser { + pub fn from_uid(uid: u32) -> anyhow::Result { + let libc_uid = nix::unistd::Uid::from_raw(uid); + let libc_user = LibcUser::from_uid(libc_uid) + .context("Failed to look up your UNIX username")? + .ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))?; - if groups.is_empty() { - format!("{}(_.+)?", user.name) - } else { - format!( - "({}|{})(_.+)?", - user.name, - groups - .iter() - .map(|g| g.name.as_str()) - .collect::>() - .join("|") - ) - } -} + let groups = get_unix_groups(&libc_user)?; -/// This enum is used to differentiate between database and user operations. -/// Their output are very similar, but there are slight differences in the words used. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum DbOrUser { - Database, - User, -} - -impl DbOrUser { - pub fn lowercased(&self) -> String { - match self { - DbOrUser::Database => "database".to_string(), - DbOrUser::User => "user".to_string(), - } + Ok(UnixUser { + username: libc_user.name, + groups: groups.iter().map(|g| g.name.clone()).collect(), + }) } - pub fn capitalized(&self) -> String { - match self { - DbOrUser::Database => "Database".to_string(), - DbOrUser::User => "User".to_string(), - } + pub fn from_enviroment() -> anyhow::Result { + let libc_uid = nix::unistd::getuid(); + UnixUser::from_uid(libc_uid.as_raw()) } } -#[derive(Debug, PartialEq, Eq)] -pub enum NameValidationResult { - Valid, - EmptyString, - InvalidCharacters, - TooLong, -} - -pub fn validate_name(name: &str) -> NameValidationResult { - if name.is_empty() { - NameValidationResult::EmptyString - } else if name.len() > 64 { - NameValidationResult::TooLong - } else if !name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') - { - NameValidationResult::InvalidCharacters - } else { - NameValidationResult::Valid - } -} - -pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> { - match validate_name(name) { - NameValidationResult::Valid => Ok(()), - NameValidationResult::EmptyString => { - anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized()) - } - NameValidationResult::TooLong => anyhow::bail!( - "{} is too long. Maximum length is 64 characters.", - db_or_user.capitalized() - ), - NameValidationResult::InvalidCharacters => anyhow::bail!( - indoc! {r#" - Invalid characters in {} name: '{}' - - Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. - "#}, - db_or_user.lowercased(), - name - ), - } -} - -#[derive(Debug, PartialEq, Eq)] -pub enum OwnerValidationResult { - // The name is valid and matches one of the given prefixes - Match, - - // The name is valid, but none of the given prefixes matched the name - NoMatch, - - // The name is empty, which is invalid - StringEmpty, - - // The name is in the format "_", which is invalid - MissingPrefix, - - // The name is in the format "_", which is invalid - MissingPostfix, -} - -/// Core logic for validating the ownership of a database name. -/// This function checks if the given name matches any of the given prefixes. -/// These prefixes will in most cases be the user's unix username and any -/// unix groups the user is a member of. -pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult { - if name.is_empty() { - return OwnerValidationResult::StringEmpty; - } - - if name.starts_with('_') { - return OwnerValidationResult::MissingPrefix; - } - - let (prefix, _) = match name.split_once('_') { - Some(pair) => pair, - None => return OwnerValidationResult::MissingPostfix, - }; - - if prefixes.iter().any(|g| g == prefix) { - OwnerValidationResult::Match - } else { - OwnerValidationResult::NoMatch - } -} - -/// Validate the ownership of a database name or database user name. -/// This function takes the name of a database or user and a unix user, -/// for which it fetches the user's groups. It then checks if the name -/// is prefixed with the user's username or any of the user's groups. -pub fn validate_ownership_or_error<'a>( - name: &'a str, - user: &User, - db_or_user: DbOrUser, -) -> anyhow::Result<&'a str> { - let user_groups = get_unix_groups(user)?; - let prefixes = std::iter::once(user.name.clone()) - .chain(user_groups.iter().map(|g| g.name.clone())) - .collect::>(); - - match validate_ownership_by_prefixes(name, &prefixes) { - OwnerValidationResult::Match => Ok(name), - OwnerValidationResult::NoMatch => { - anyhow::bail!( - indoc! {r#" - Invalid {} name prefix: '{}' does not match your username or any of your groups. - Are you sure you are allowed to create {} names with this prefix? - - Allowed prefixes: - - {} - {} - "#}, - db_or_user.lowercased(), - name, - db_or_user.lowercased(), - user.name, - user_groups - .iter() - .filter(|g| g.name != user.name) - .map(|g| format!(" - {}", g.name)) - .sorted() - .join("\n"), - ); - } - _ => anyhow::bail!( - "'{}' is not a valid {} name.", - name, - db_or_user.lowercased() - ), - } -} - -/// Gracefully close a MySQL connection. -pub async fn close_database_connection(connection: MySqlConnection) { - if let Err(e) = connection - .close() - .await - .context("Failed to close connection properly") - { - eprintln!("{}", e); - eprintln!("Ignoring..."); - } -} - -#[inline] -pub fn quote_literal(s: &str) -> String { - format!("'{}'", s.replace('\'', r"\'")) -} - -#[inline] -pub fn quote_identifier(s: &str) -> String { - format!("`{}`", s.replace('`', r"\`")) -} - #[inline] pub(crate) fn yn(b: bool) -> &'static str { if b { @@ -303,94 +100,4 @@ mod test { assert_eq!(rev_yn("n"), Some(false)); assert_eq!(rev_yn("X"), None); } - - #[test] - fn test_quote_literal() { - let payload = "' OR 1=1 --"; - assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#); - } - - #[test] - fn test_quote_identifier() { - let payload = "` OR 1=1 --"; - assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#); - } - - #[test] - fn test_validate_name() { - assert_eq!(validate_name(""), NameValidationResult::EmptyString); - assert_eq!( - validate_name("abcdefghijklmnopqrstuvwxyz"), - NameValidationResult::Valid - ); - assert_eq!( - validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), - NameValidationResult::Valid - ); - assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid); - - for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() { - assert_eq!( - validate_name(&c.to_string()), - NameValidationResult::InvalidCharacters - ); - } - - assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid); - - assert_eq!( - validate_name(&"a".repeat(65)), - NameValidationResult::TooLong - ); - } - - #[test] - fn test_validate_owner_by_prefixes() { - let prefixes = vec!["user".to_string(), "group".to_string()]; - - assert_eq!( - validate_ownership_by_prefixes("", &prefixes), - OwnerValidationResult::StringEmpty - ); - - assert_eq!( - validate_ownership_by_prefixes("user", &prefixes), - OwnerValidationResult::MissingPostfix - ); - assert_eq!( - validate_ownership_by_prefixes("something", &prefixes), - OwnerValidationResult::MissingPostfix - ); - assert_eq!( - validate_ownership_by_prefixes("user-testdb", &prefixes), - OwnerValidationResult::MissingPostfix - ); - - assert_eq!( - validate_ownership_by_prefixes("_testdb", &prefixes), - OwnerValidationResult::MissingPrefix - ); - - assert_eq!( - validate_ownership_by_prefixes("user_testdb", &prefixes), - OwnerValidationResult::Match - ); - assert_eq!( - validate_ownership_by_prefixes("group_testdb", &prefixes), - OwnerValidationResult::Match - ); - assert_eq!( - validate_ownership_by_prefixes("group_test_db", &prefixes), - OwnerValidationResult::Match - ); - assert_eq!( - validate_ownership_by_prefixes("group_test-db", &prefixes), - OwnerValidationResult::Match - ); - - assert_eq!( - validate_ownership_by_prefixes("nonexistent_testdb", &prefixes), - OwnerValidationResult::NoMatch - ); - } } diff --git a/src/core/database_operations.rs b/src/core/database_operations.rs deleted file mode 100644 index 9cbc3c6..0000000 --- a/src/core/database_operations.rs +++ /dev/null @@ -1,120 +0,0 @@ -use anyhow::Context; -use indoc::formatdoc; -use itertools::Itertools; -use nix::unistd::User; -use sqlx::{prelude::*, MySqlConnection}; - -use crate::core::{ - common::{ - create_user_group_matching_regex, get_current_unix_user, quote_identifier, - validate_name_or_error, validate_ownership_or_error, DbOrUser, - }, - database_privilege_operations::DATABASE_PRIVILEGE_FIELDS, -}; - -pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { - let user = get_current_unix_user()?; - validate_database_name(name, &user)?; - - // NOTE: see the note about SQL injections in `validate_owner_of_database_name` - sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name))) - .execute(connection) - .await - .map_err(|e| { - if e.to_string().contains("database exists") { - anyhow::anyhow!("Database '{}' already exists", name) - } else { - e.into() - } - })?; - - Ok(()) -} - -pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { - let user = get_current_unix_user()?; - validate_database_name(name, &user)?; - - // NOTE: see the note about SQL injections in `validate_owner_of_database_name` - sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name))) - .execute(connection) - .await - .map_err(|e| { - if e.to_string().contains("doesn't exist") { - anyhow::anyhow!("Database '{}' does not exist", name) - } else { - e.into() - } - })?; - - Ok(()) -} - -pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - - let databases: Vec = sqlx::query( - r#" - SELECT `SCHEMA_NAME` AS `database` - FROM `information_schema`.`SCHEMATA` - WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') - AND `SCHEMA_NAME` REGEXP ? - "#, - ) - .bind(create_user_group_matching_regex(&unix_user)) - .fetch_all(connection) - .await - .and_then(|row| { - row.into_iter() - .map(|row| row.try_get::("database")) - .collect::>() - }) - .context(format!( - "Failed to get databases for user '{}'", - unix_user.name - ))?; - - Ok(databases) -} - -pub async fn get_databases_where_user_has_privileges( - username: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let result = sqlx::query( - formatdoc!( - r#" - SELECT `db` AS `database` - FROM `db` - WHERE `user` = ? - AND ({}) - "#, - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{}` = 'Y'", field)) - .join(" OR "), - ) - .as_str(), - ) - .bind(username) - .fetch_all(connection) - .await? - .into_iter() - .map(|databases| databases.try_get::("database").unwrap()) - .collect(); - - Ok(result) -} - -/// NOTE: It is very critical that this function validates the database name -/// properly. MySQL does not seem to allow for prepared statements, binding -/// the database name as a parameter to the query. This means that we have -/// to validate the database name ourselves to prevent SQL injection. -pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> { - validate_name_or_error(name, DbOrUser::Database) - .context(format!("Invalid database name: '{}'", name))?; - validate_ownership_or_error(name, user, DbOrUser::Database) - .context(format!("Invalid database name: '{}'", name))?; - - Ok(()) -} diff --git a/src/core/database_privilege_operations.rs b/src/core/database_privileges.rs similarity index 66% rename from src/core/database_privilege_operations.rs rename to src/core/database_privileges.rs index d5cf850..f8df3d4 100644 --- a/src/core/database_privilege_operations.rs +++ b/src/core/database_privileges.rs @@ -1,52 +1,16 @@ -//! Database privilege operations -//! -//! This module contains functions for querying, modifying, -//! displaying and comparing database privileges. -//! -//! A lot of the complexity comes from two core components: -//! -//! - The privilege editor that needs to be able to print -//! an editable table of privileges and reparse the content -//! after the user has made manual changes. -//! -//! - The comparison functionality that tells the user what -//! changes will be made when applying a set of changes -//! to the list of database privileges. - -use std::collections::{BTreeSet, HashMap}; - use anyhow::{anyhow, Context}; -use indoc::indoc; use itertools::Itertools; use prettytable::Table; use serde::{Deserialize, Serialize}; -use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; - -use crate::core::{ - common::{ - create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, - }, - database_operations::validate_database_name, +use std::{ + cmp::max, + collections::{BTreeSet, HashMap}, }; -/// This is the list of fields that are used to fetch the db + user + privileges -/// from the `db` table in the database. If you need to add or remove privilege -/// fields, this is a good place to start. -pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ - "db", - "user", - "select_priv", - "insert_priv", - "update_priv", - "delete_priv", - "create_priv", - "drop_priv", - "alter_priv", - "index_priv", - "create_tmp_table_priv", - "lock_tables_priv", - "references_priv", -]; +use super::common::{rev_yn, yn}; +use crate::server::sql::database_privilege_operations::{ + DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS, +}; pub fn db_priv_field_human_readable_name(name: &str) -> String { match name { @@ -67,162 +31,24 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String { } } -/// This struct represents the set of privileges for a single user on a single database. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] -pub struct DatabasePrivilegeRow { - pub db: String, - pub user: String, - pub select_priv: bool, - pub insert_priv: bool, - pub update_priv: bool, - pub delete_priv: bool, - pub create_priv: bool, - pub drop_priv: bool, - pub alter_priv: bool, - pub index_priv: bool, - pub create_tmp_table_priv: bool, - pub lock_tables_priv: bool, - pub references_priv: bool, -} +pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { + debug_assert!(row1.db == row2.db && row1.user == row2.user); -impl DatabasePrivilegeRow { - pub fn empty(db: &str, user: &str) -> Self { - Self { - db: db.to_owned(), - user: user.to_owned(), - select_priv: false, - insert_priv: false, - update_priv: false, - delete_priv: false, - create_priv: false, - drop_priv: false, - alter_priv: false, - index_priv: false, - create_tmp_table_priv: false, - lock_tables_priv: false, - references_priv: false, - } + DatabasePrivilegeRowDiff { + db: row1.db.clone(), + user: row1.user.clone(), + diff: DATABASE_PRIVILEGE_FIELDS + .into_iter() + .skip(2) + .filter_map(|field| { + DatabasePrivilegeChange::new( + row1.get_privilege_by_name(field), + row2.get_privilege_by_name(field), + field, + ) + }) + .collect(), } - - pub fn get_privilege_by_name(&self, name: &str) -> bool { - match name { - "select_priv" => self.select_priv, - "insert_priv" => self.insert_priv, - "update_priv" => self.update_priv, - "delete_priv" => self.delete_priv, - "create_priv" => self.create_priv, - "drop_priv" => self.drop_priv, - "alter_priv" => self.alter_priv, - "index_priv" => self.index_priv, - "create_tmp_table_priv" => self.create_tmp_table_priv, - "lock_tables_priv" => self.lock_tables_priv, - "references_priv" => self.references_priv, - _ => false, - } - } - - pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { - debug_assert!(self.db == other.db && self.user == other.user); - - DatabasePrivilegeRowDiff { - db: self.db.clone(), - user: self.user.clone(), - diff: DATABASE_PRIVILEGE_FIELDS - .into_iter() - .skip(2) - .filter_map(|field| { - DatabasePrivilegeChange::new( - self.get_privilege_by_name(field), - other.get_privilege_by_name(field), - field, - ) - }) - .collect(), - } - } -} - -#[inline] -fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result { - let field = DATABASE_PRIVILEGE_FIELDS[position]; - let value = row.try_get(position)?; - match rev_yn(value) { - Some(val) => Ok(val), - _ => { - log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value); - Ok(false) - } - } -} - -impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { - fn from_row(row: &MySqlRow) -> Result { - Ok(Self { - db: row.try_get("db")?, - user: row.try_get("user")?, - select_priv: get_mysql_row_priv_field(row, 2)?, - insert_priv: get_mysql_row_priv_field(row, 3)?, - update_priv: get_mysql_row_priv_field(row, 4)?, - delete_priv: get_mysql_row_priv_field(row, 5)?, - create_priv: get_mysql_row_priv_field(row, 6)?, - drop_priv: get_mysql_row_priv_field(row, 7)?, - alter_priv: get_mysql_row_priv_field(row, 8)?, - index_priv: get_mysql_row_priv_field(row, 9)?, - create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?, - lock_tables_priv: get_mysql_row_priv_field(row, 11)?, - references_priv: get_mysql_row_priv_field(row, 12)?, - }) - } -} - -/// Get all users + privileges for a single database. -pub async fn get_database_privileges( - database_name: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - validate_database_name(database_name, &unix_user)?; - - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( - "SELECT {} FROM `db` WHERE `db` = ?", - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| quote_identifier(field)) - .join(","), - )) - .bind(database_name) - .fetch_all(connection) - .await - .context("Failed to show database")?; - - Ok(result) -} - -/// Get all database + user + privileges pairs that are owned by the current user. -pub async fn get_all_database_privileges( - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( - indoc! {r#" - SELECT {} FROM `db` WHERE `db` IN - (SELECT DISTINCT `SCHEMA_NAME` AS `database` - FROM `information_schema`.`SCHEMATA` - WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') - AND `SCHEMA_NAME` REGEXP ?) - "#}, - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","), - )) - .bind(create_user_group_matching_regex(&unix_user)) - .fetch_all(connection) - .await - .context("Failed to show databases")?; - - Ok(result) } /*************************/ @@ -340,17 +166,23 @@ pub fn generate_editor_content_from_privilege_data( // editor will be the example user and example db name. // Hence, it's put as the fallback value, despite not really // being a "fallback" in the normal sense. - let longest_username = privilege_data - .iter() - .map(|p| p.user.len()) - .max() - .unwrap_or(example_user.len()); + let longest_username = max( + privilege_data + .iter() + .map(|p| p.user.len()) + .max() + .unwrap_or(example_user.len()), + "User".len(), + ); - let longest_database_name = privilege_data - .iter() - .map(|p| p.db.len()) - .max() - .unwrap_or(example_db.len()); + let longest_database_name = max( + privilege_data + .iter() + .map(|p| p.db.len()) + .max() + .unwrap_or(example_db.len()), + "Database".len(), + ); let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS .into_iter() @@ -578,7 +410,7 @@ pub fn parse_privilege_data_from_editor_content( /// instances of privilege sets for a single user on a single database. /// /// The `User` and `Database` are the same for both instances. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct DatabasePrivilegeRowDiff { pub db: String, pub user: String, @@ -586,7 +418,7 @@ pub struct DatabasePrivilegeRowDiff { } /// This enum represents a change for a single privilege. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub enum DatabasePrivilegeChange { YesToNo(String), NoToYes(String), @@ -603,13 +435,31 @@ impl DatabasePrivilegeChange { } /// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub enum DatabasePrivilegesDiff { New(DatabasePrivilegeRow), Modified(DatabasePrivilegeRowDiff), Deleted(DatabasePrivilegeRow), } +impl DatabasePrivilegesDiff { + pub fn get_database_name(&self) -> &str { + match self { + DatabasePrivilegesDiff::New(p) => &p.db, + DatabasePrivilegesDiff::Modified(p) => &p.db, + DatabasePrivilegesDiff::Deleted(p) => &p.db, + } + } + + pub fn get_user_name(&self) -> &str { + match self { + DatabasePrivilegesDiff::New(p) => &p.user, + DatabasePrivilegesDiff::Modified(p) => &p.user, + DatabasePrivilegesDiff::Deleted(p) => &p.user, + } + } +} + /// This function calculates the differences between two sets of database privileges. /// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or /// apply a set of privilege modifications to the database. @@ -633,7 +483,7 @@ pub fn diff_privileges( for p in to { if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { - let diff = old_p.diff(p); + let diff = diff(old_p, p); if !diff.diff.is_empty() { result.insert(DatabasePrivilegesDiff::Modified(diff)); } @@ -651,72 +501,6 @@ pub fn diff_privileges( result } -/// Uses the result of [`diff_privileges`] to modify privileges in the database. -pub async fn apply_privilege_diffs( - diffs: BTreeSet, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - for diff in diffs { - match diff { - DatabasePrivilegesDiff::New(p) => { - let tables = DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","); - - let question_marks = std::iter::repeat("?") - .take(DATABASE_PRIVILEGE_FIELDS.len()) - .join(","); - - sqlx::query( - format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), - ) - .bind(p.db) - .bind(p.user) - .bind(yn(p.select_priv)) - .bind(yn(p.insert_priv)) - .bind(yn(p.update_priv)) - .bind(yn(p.delete_priv)) - .bind(yn(p.create_priv)) - .bind(yn(p.drop_priv)) - .bind(yn(p.alter_priv)) - .bind(yn(p.index_priv)) - .bind(yn(p.create_tmp_table_priv)) - .bind(yn(p.lock_tables_priv)) - .bind(yn(p.references_priv)) - .execute(&mut *connection) - .await?; - } - DatabasePrivilegesDiff::Modified(p) => { - let tables = p - .diff - .iter() - .map(|diff| match diff { - DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name), - DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name), - }) - .join(","); - - sqlx::query( - format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), - ) - .bind(p.db) - .bind(p.user) - .execute(&mut *connection) - .await?; - } - DatabasePrivilegesDiff::Deleted(p) => { - sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") - .bind(p.db) - .bind(p.user) - .execute(&mut *connection) - .await?; - } - } - } - Ok(()) -} - fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { diff.diff .iter() @@ -731,6 +515,20 @@ fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { .join("\n") } +fn display_new_privileges_list(row: &DatabasePrivilegeRow) -> String { + DATABASE_PRIVILEGE_FIELDS + .into_iter() + .skip(2) + .map(|field| { + if row.get_privilege_by_name(field) { + format!("{}: Y", db_priv_field_human_readable_name(field)) + } else { + format!("{}: N", db_priv_field_human_readable_name(field)) + } + }) + .join("\n") +} + /// Displays the difference between two sets of database privileges. pub fn display_privilege_diffs(diffs: &BTreeSet) -> String { let mut table = Table::new(); @@ -741,24 +539,14 @@ pub fn display_privilege_diffs(diffs: &BTreeSet) -> Stri table.add_row(row![ p.db, p.user, - "(New user)\n".to_string() - + &display_privilege_cell( - &DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p) - ) + "(New user)\n".to_string() + &display_new_privileges_list(p) ]); } DatabasePrivilegesDiff::Modified(p) => { table.add_row(row![p.db, p.user, display_privilege_cell(p),]); } DatabasePrivilegesDiff::Deleted(p) => { - table.add_row(row![ - p.db, - p.user, - "(All privileges removed)\n".to_string() - + &display_privilege_cell( - &p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user)) - ) - ]); + table.add_row(row![p.db, p.user, "Removed".to_string()]); } } } diff --git a/src/core/protocol.rs b/src/core/protocol.rs new file mode 100644 index 0000000..1048569 --- /dev/null +++ b/src/core/protocol.rs @@ -0,0 +1,5 @@ +pub mod request_response; +pub mod server_responses; + +pub use request_response::*; +pub use server_responses::*; diff --git a/src/core/protocol/request_response.rs b/src/core/protocol/request_response.rs new file mode 100644 index 0000000..13bc013 --- /dev/null +++ b/src/core/protocol/request_response.rs @@ -0,0 +1,79 @@ +use std::collections::BTreeSet; + +use serde::{Deserialize, Serialize}; +use tokio::net::UnixStream; +use tokio_serde::{formats::Bincode, Framed as SerdeFramed}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; + +use crate::core::{database_privileges::DatabasePrivilegesDiff, protocol::*}; + +pub type ServerToClientMessageStream = SerdeFramed< + Framed, + Request, + Response, + Bincode, +>; + +pub type ClientToServerMessageStream = SerdeFramed< + Framed, + Response, + Request, + Bincode, +>; + +pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream { + let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); + tokio_serde::Framed::new(length_delimited, Bincode::default()) +} + +pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream { + let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); + tokio_serde::Framed::new(length_delimited, Bincode::default()) +} + +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Request { + CreateDatabases(Vec), + DropDatabases(Vec), + ListDatabases, + ListPrivileges(Option>), + ModifyPrivileges(BTreeSet), + + CreateUsers(Vec), + DropUsers(Vec), + PasswdUser(String, String), + ListUsers(Option>), + LockUsers(Vec), + UnlockUsers(Vec), + + // Commit, + Exit, +} + +// TODO: include a generic "message" that will display a message to the user? + +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Response { + // Specific data for specific commands + CreateDatabases(CreateDatabasesOutput), + DropDatabases(DropDatabasesOutput), + ListAllDatabases(ListAllDatabasesOutput), + ListPrivileges(GetDatabasesPrivilegeData), + ListAllPrivileges(GetAllDatabasesPrivilegeData), + ModifyPrivileges(ModifyDatabasePrivilegesOutput), + + CreateUsers(CreateUsersOutput), + DropUsers(DropUsersOutput), + PasswdUser(SetPasswordOutput), + ListUsers(ListUsersOutput), + ListAllUsers(ListAllUsersOutput), + LockUsers(LockUsersOutput), + UnlockUsers(UnlockUsersOutput), + + // Generic responses + OperationAborted, + Error(String), + Exit, +} diff --git a/src/core/protocol/server_responses.rs b/src/core/protocol/server_responses.rs new file mode 100644 index 0000000..23b89c2 --- /dev/null +++ b/src/core/protocol/server_responses.rs @@ -0,0 +1,611 @@ +use std::collections::BTreeMap; + +use indoc::indoc; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +use crate::{ + core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff}, + server::sql::{ + database_privilege_operations::DatabasePrivilegeRow, user_operations::DatabaseUser, + }, +}; + +/// This enum is used to differentiate between database and user operations. +/// Their output are very similar, but there are slight differences in the words used. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum DbOrUser { + Database, + User, +} + +impl DbOrUser { + pub fn lowercased(&self) -> String { + match self { + DbOrUser::Database => "database".to_string(), + DbOrUser::User => "user".to_string(), + } + } + + pub fn capitalized(&self) -> String { + match self { + DbOrUser::Database => "Database".to_string(), + DbOrUser::User => "User".to_string(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum NameValidationError { + EmptyString, + InvalidCharacters, + TooLong, +} + +impl NameValidationError { + pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String { + match self { + NameValidationError::EmptyString => { + format!("{} name cannot be empty.", db_or_user.capitalized()).to_owned() + } + NameValidationError::TooLong => format!( + "{} is too long. Maximum length is 64 characters.", + db_or_user.capitalized() + ) + .to_owned(), + NameValidationError::InvalidCharacters => format!( + indoc! {r#" + Invalid characters in {} name: '{}' + + Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. + "#}, + db_or_user.lowercased(), + name + ) + .to_owned(), + } + } +} + +impl OwnerValidationError { + pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String { + let user = UnixUser::from_enviroment(); + + match self { + OwnerValidationError::NoMatch => format!( + indoc! {r#" + Invalid {} name prefix: '{}' does not match your username or any of your groups. + Are you sure you are allowed to create {} names with this prefix? + + Allowed prefixes: + - {} + {} + "#}, + db_or_user.lowercased(), + name, + db_or_user.lowercased(), + user.as_ref() + .map(|u| u.username.clone()) + .unwrap_or("???".to_string()), + user.map(|u| u.groups) + .unwrap_or_default() + .iter() + .map(|g| format!(" - {}", g)) + .sorted() + .join("\n"), + ) + .to_owned(), + + _ => format!( + "'{}' is not a valid {} name.", + name, + db_or_user.lowercased() + ) + .to_string(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum OwnerValidationError { + // The name is valid, but none of the given prefixes matched the name + NoMatch, + + // The name is empty, which is invalid + StringEmpty, + + // The name is in the format "_", which is invalid + MissingPrefix, + + // The name is in the format "_", which is invalid + MissingPostfix, +} + +pub type CreateDatabasesOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum CreateDatabaseError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseAlreadyExists, + MySqlError(String), +} + +pub fn print_create_databases_output_status(output: &CreateDatabasesOutput) { + for (database_name, result) in output { + match result { + Ok(()) => { + println!("Database '{}' created successfully.", database_name); + } + Err(err) => { + println!("{}", err.to_error_message(database_name)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl CreateDatabaseError { + pub fn to_error_message(&self, database_name: &str) -> String { + match self { + CreateDatabaseError::SanitizationError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + CreateDatabaseError::OwnershipError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + CreateDatabaseError::DatabaseAlreadyExists => { + format!("Database {} already exists.", database_name) + } + CreateDatabaseError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type DropDatabasesOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DropDatabaseError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseDoesNotExist, + MySqlError(String), +} + +pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) { + for (database_name, result) in output { + match result { + Ok(()) => { + println!("Database '{}' dropped successfully.", database_name); + } + Err(err) => { + println!("{}", err.to_error_message(database_name)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl DropDatabaseError { + pub fn to_error_message(&self, database_name: &str) -> String { + match self { + DropDatabaseError::SanitizationError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + DropDatabaseError::OwnershipError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + DropDatabaseError::DatabaseDoesNotExist => { + format!("Database {} does not exist.", database_name) + } + DropDatabaseError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type ListAllDatabasesOutput = Result, ListDatabasesError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListDatabasesError { + MySqlError(String), +} + +impl ListDatabasesError { + pub fn to_error_message(&self) -> String { + match self { + ListDatabasesError::MySqlError(err) => format!("MySQL error: {}", err), + } + } +} + +// TODO: merge all rows into a single collection. +// they already contain which database they belong to. +// no need to index by database name. + +pub type GetDatabasesPrivilegeData = + BTreeMap, GetDatabasesPrivilegeDataError>>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum GetDatabasesPrivilegeDataError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseDoesNotExist, + MySqlError(String), +} + +impl GetDatabasesPrivilegeDataError { + pub fn to_error_message(&self, database_name: &str) -> String { + match self { + GetDatabasesPrivilegeDataError::SanitizationError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + GetDatabasesPrivilegeDataError::OwnershipError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => { + format!("Database '{}' does not exist.", database_name) + } + GetDatabasesPrivilegeDataError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type GetAllDatabasesPrivilegeData = + Result, GetAllDatabasesPrivilegeDataError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum GetAllDatabasesPrivilegeDataError { + MySqlError(String), +} + +impl GetAllDatabasesPrivilegeDataError { + pub fn to_error_message(&self) -> String { + match self { + GetAllDatabasesPrivilegeDataError::MySqlError(err) => format!("MySQL error: {}", err), + } + } +} + +pub type ModifyDatabasePrivilegesOutput = + BTreeMap<(String, String), Result<(), ModifyDatabasePrivilegesError>>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ModifyDatabasePrivilegesError { + DatabaseSanitizationError(NameValidationError), + DatabaseOwnershipError(OwnerValidationError), + UserSanitizationError(NameValidationError), + UserOwnershipError(OwnerValidationError), + DatabaseDoesNotExist, + DiffDoesNotApply(DiffDoesNotApplyError), + MySqlError(String), +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DiffDoesNotApplyError { + RowAlreadyExists(String, String), + RowDoesNotExist(String, String), + RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow), +} + +pub fn print_modify_database_privileges_output_status(output: &ModifyDatabasePrivilegesOutput) { + for ((database_name, username), result) in output { + match result { + Ok(()) => { + println!( + "Privileges for user '{}' on database '{}' modified successfully.", + username, database_name + ); + } + Err(err) => { + println!("{}", err.to_error_message(database_name, username)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl ModifyDatabasePrivilegesError { + pub fn to_error_message(&self, database_name: &str, username: &str) -> String { + match self { + ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + ModifyDatabasePrivilegesError::DatabaseOwnershipError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + ModifyDatabasePrivilegesError::UserSanitizationError(err) => { + err.to_error_message(username, DbOrUser::User) + } + ModifyDatabasePrivilegesError::UserOwnershipError(err) => { + err.to_error_message(username, DbOrUser::User) + } + ModifyDatabasePrivilegesError::DatabaseDoesNotExist => { + format!("Database '{}' does not exist.", database_name) + } + ModifyDatabasePrivilegesError::DiffDoesNotApply(diff) => { + format!( + "Could not apply privilege change:\n{}", + diff.to_error_message() + ) + } + ModifyDatabasePrivilegesError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +impl DiffDoesNotApplyError { + pub fn to_error_message(&self) -> String { + match self { + DiffDoesNotApplyError::RowAlreadyExists(database_name, username) => { + format!( + "Privileges for user '{}' on database '{}' already exist.", + username, database_name + ) + } + DiffDoesNotApplyError::RowDoesNotExist(database_name, username) => { + format!( + "Privileges for user '{}' on database '{}' do not exist.", + username, database_name + ) + } + DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(diff, row) => { + format!( + "Could not apply privilege change {:?} to row {:?}", + diff, row + ) + } + } + } +} + +pub type CreateUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum CreateUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserAlreadyExists, + MySqlError(String), +} + +pub fn print_create_users_output_status(output: &CreateUsersOutput) { + for (username, result) in output { + match result { + Ok(()) => { + println!("User '{}' created successfully.", username); + } + Err(err) => { + println!("{}", err.to_error_message(username)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl CreateUserError { + pub fn to_error_message(&self, username: &str) -> String { + match self { + CreateUserError::SanitizationError(err) => { + err.to_error_message(username, DbOrUser::User) + } + CreateUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), + CreateUserError::UserAlreadyExists => { + format!("User '{}' already exists.", username) + } + CreateUserError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type DropUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DropUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + MySqlError(String), +} + +pub fn print_drop_users_output_status(output: &DropUsersOutput) { + for (username, result) in output { + match result { + Ok(()) => { + println!("User '{}' dropped successfully.", username); + } + Err(err) => { + println!("{}", err.to_error_message(username)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl DropUserError { + pub fn to_error_message(&self, username: &str) -> String { + match self { + DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), + DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), + DropUserError::UserDoesNotExist => { + format!("User '{}' does not exist.", username) + } + DropUserError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type SetPasswordOutput = Result<(), SetPasswordError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SetPasswordError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + MySqlError(String), +} + +pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &str) { + match output { + Ok(()) => { + println!("Password for user '{}' set successfully.", username); + } + Err(err) => { + println!("{}", err.to_error_message(username)); + println!("Skipping..."); + } + } +} + +impl SetPasswordError { + pub fn to_error_message(&self, username: &str) -> String { + match self { + SetPasswordError::SanitizationError(err) => { + err.to_error_message(username, DbOrUser::User) + } + SetPasswordError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), + SetPasswordError::UserDoesNotExist => { + format!("User '{}' does not exist.", username) + } + SetPasswordError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type LockUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum LockUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + UserIsAlreadyLocked, + MySqlError(String), +} + +pub fn print_lock_users_output_status(output: &LockUsersOutput) { + for (username, result) in output { + match result { + Ok(()) => { + println!("User '{}' locked successfully.", username); + } + Err(err) => { + println!("{}", err.to_error_message(username)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl LockUserError { + pub fn to_error_message(&self, username: &str) -> String { + match self { + LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), + LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), + LockUserError::UserDoesNotExist => { + format!("User '{}' does not exist.", username) + } + LockUserError::UserIsAlreadyLocked => { + format!("User '{}' is already locked.", username) + } + LockUserError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type UnlockUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum UnlockUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + UserIsAlreadyUnlocked, + MySqlError(String), +} + +pub fn print_unlock_users_output_status(output: &UnlockUsersOutput) { + for (username, result) in output { + match result { + Ok(()) => { + println!("User '{}' unlocked successfully.", username); + } + Err(err) => { + println!("{}", err.to_error_message(username)); + println!("Skipping..."); + } + } + println!(); + } +} + +impl UnlockUserError { + pub fn to_error_message(&self, username: &str) -> String { + match self { + UnlockUserError::SanitizationError(err) => { + err.to_error_message(username, DbOrUser::User) + } + UnlockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), + UnlockUserError::UserDoesNotExist => { + format!("User '{}' does not exist.", username) + } + UnlockUserError::UserIsAlreadyUnlocked => { + format!("User '{}' is already unlocked.", username) + } + UnlockUserError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type ListUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListUsersError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + MySqlError(String), +} + +impl ListUsersError { + pub fn to_error_message(&self, username: &str) -> String { + match self { + ListUsersError::SanitizationError(err) => { + err.to_error_message(username, DbOrUser::User) + } + ListUsersError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), + ListUsersError::UserDoesNotExist => { + format!("User '{}' does not exist.", username) + } + ListUsersError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type ListAllUsersOutput = Result, ListAllUsersError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListAllUsersError { + MySqlError(String), +} + +impl ListAllUsersError { + pub fn to_error_message(&self) -> String { + match self { + ListAllUsersError::MySqlError(err) => format!("MySQL error: {}", err), + } + } +} diff --git a/src/core/user_operations.rs b/src/core/user_operations.rs deleted file mode 100644 index 5d3ac8a..0000000 --- a/src/core/user_operations.rs +++ /dev/null @@ -1,249 +0,0 @@ -use anyhow::Context; -use nix::unistd::User; -use serde::{Deserialize, Serialize}; -use sqlx::{prelude::*, MySqlConnection}; - -use crate::core::common::{ - create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error, - validate_ownership_or_error, DbOrUser, -}; - -pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - let user_exists = sqlx::query( - r#" - SELECT EXISTS( - SELECT 1 - FROM `mysql`.`user` - WHERE `User` = ? - ) - "#, - ) - .bind(db_user) - .fetch_one(connection) - .await? - .get::(0); - - Ok(user_exists) -} - -pub async fn create_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' already exists", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -pub async fn delete_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -pub async fn set_password_for_database_user( - db_user: &str, - password: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = crate::core::common::get_current_unix_user()?; - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query( - format!( - "ALTER USER {}@'%' IDENTIFIED BY {}", - quote_literal(db_user), - quote_literal(password).as_str() - ) - .as_str(), - ) - .execute(connection) - .await?; - - Ok(()) -} - -async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - let is_locked = sqlx::query( - r#" - SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true' - FROM `mysql`.`global_priv` - WHERE `User` = ? - AND `Host` = '%' - "#, - ) - .bind(db_user) - .fetch_one(connection) - .await? - .get::(0); - - Ok(is_locked) -} - -pub async fn lock_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - if user_is_locked(db_user, connection).await? { - anyhow::bail!("User '{}' is already locked", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -pub async fn unlock_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - if !user_is_locked(db_user, connection).await? { - anyhow::bail!("User '{}' is already unlocked", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -/// This struct contains information about a database user. -/// This can be extended if we need more information in the future. -#[derive(Debug, Clone, FromRow, Serialize, Deserialize)] -pub struct DatabaseUser { - #[sqlx(rename = "User")] - pub user: String, - - #[allow(dead_code)] - #[serde(skip)] - #[sqlx(rename = "Host")] - pub host: String, - - #[sqlx(rename = "has_password")] - pub has_password: bool, - - #[sqlx(rename = "is_locked")] - pub is_locked: bool, -} - -const DB_USER_SELECT_STATEMENT: &str = r#" -SELECT - `mysql`.`user`.`User`, - `mysql`.`user`.`Host`, - `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, - COALESCE( - JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), - 'false' - ) != 'false' AS `is_locked` -FROM `mysql`.`user` -JOIN `mysql`.`global_priv` ON - `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` - AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` -"#; - -/// This function fetches all database users that have a prefix matching the -/// unix username and group names of the given unix user. -pub async fn get_all_database_users_for_unix_user( - unix_user: &User, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let users = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), - ) - .bind(create_user_group_matching_regex(unix_user)) - .fetch_all(connection) - .await?; - - Ok(users) -} - -/// This function fetches a database user if it exists. -pub async fn get_database_user_for_user( - username: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let user = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), - ) - .bind(username) - .fetch_optional(connection) - .await?; - - Ok(user) -} - -/// NOTE: It is very critical that this function validates the database name -/// properly. MySQL does not seem to allow for prepared statements, binding -/// the database name as a parameter to the query. This means that we have -/// to validate the database name ourselves to prevent SQL injection. -pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> { - validate_name_or_error(name, DbOrUser::User) - .context(format!("Invalid username: '{}'", name))?; - validate_ownership_or_error(name, user, DbOrUser::User) - .context(format!("Invalid username: '{}'", name))?; - - Ok(()) -} diff --git a/src/main.rs b/src/main.rs index 2622c50..0ca5515 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,42 +1,69 @@ #[macro_use] extern crate prettytable; -use core::common::CommandStatus; -#[cfg(feature = "mysql-admutils-compatibility")] +use clap::Parser; + use std::path::PathBuf; +use std::os::unix::net::UnixStream as StdUnixStream; +use tokio::net::UnixStream as TokioUnixStream; + +use crate::{ + core::{ + bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs}, + protocol::create_client_to_server_message_stream, + }, + server::command::ServerArgs, +}; + #[cfg(feature = "mysql-admutils-compatibility")] use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; -use clap::Parser; +mod server; -mod authenticated_unix_socket; mod cli; mod core; #[cfg(feature = "tui")] mod tui; -#[derive(Parser)] +#[derive(Parser, Debug)] struct Args { #[command(subcommand)] command: Command, - #[command(flatten)] - config_overrides: core::config::GlobalConfigArgs, + /// Path to the socket of the server, if it already exists. + #[arg( + short, + long, + value_name = "PATH", + global = true, + hide_short_help = true + )] + server_socket_path: Option, + + /// Config file to use for the server. + #[arg( + short, + long, + value_name = "PATH", + global = true, + hide_short_help = true + )] + config: Option, #[cfg(feature = "tui")] #[arg(short, long, alias = "tui", global = true)] interactive: bool, } -/// Database administration tool for non-admin users to manage their own MySQL databases and users. -/// -/// This tool allows you to manage users and databases in MySQL. -/// -/// You are only allowed to manage databases and users that are prefixed with -/// either your username, or a group that you are a member of. -#[derive(Parser)] +// Database administration tool for non-admin users to manage their own MySQL databases and users. +// +// This tool allows you to manage users and databases in MySQL. +// +// You are only allowed to manage databases and users that are prefixed with +// either your username, or a group that you are a member of. +#[derive(Parser, Debug, Clone)] #[command(version, about, disable_help_subcommand = true)] enum Command { #[command(flatten)] @@ -44,10 +71,18 @@ enum Command { #[command(flatten)] User(cli::user_command::UserCommand), + + #[command(hide = true)] + Server(server::command::ServerArgs), } -#[tokio::main(flavor = "current_thread")] -async fn main() -> anyhow::Result<()> { +// TODO: tag all functions that are run with elevated privileges with +// comments emphasizing the need for caution. + +fn main() -> anyhow::Result<()> { + // TODO: find out if there are any security risks of running + // env_logger and clap with elevated privileges. + env_logger::init(); #[cfg(feature = "mysql-admutils-compatibility")] @@ -59,42 +94,60 @@ async fn main() -> anyhow::Result<()> { }); match argv0.as_deref() { - Some("mysql-dbadm") => return mysql_dbadm::main().await, - Some("mysql-useradm") => return mysql_useradm::main().await, + Some("mysql-dbadm") => return mysql_dbadm::main(), + Some("mysql-useradm") => return mysql_useradm::main(), _ => { /* fall through */ } } } let args: Args = Args::parse(); - let config = core::config::get_config(args.config_overrides)?; - let connection = core::config::create_mysql_connection_from_config(config.mysql).await?; - - let result = match args.command { - Command::Db(command) => cli::database_command::handle_command(command, connection).await, - Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await, - }; - - match result { - Ok(CommandStatus::SuccessfullyModified) => { - println!("Modifications committed successfully"); - Ok(()) + match args.command { + Command::Server(ref command) => { + drop_privs()?; + tokio_start_server(args.server_socket_path, args.config, command.clone())?; + return Ok(()); } - Ok(CommandStatus::PartiallySuccessfullyModified) => { - println!("Some modifications committed successfully"); - Ok(()) - } - Ok(CommandStatus::NoModificationsNeeded) => { - println!("No modifications made"); - Ok(()) - } - Ok(CommandStatus::NoModificationsIntended) => { - /* Don't report anything */ - Ok(()) - } - Ok(CommandStatus::Cancelled) => { - println!("Command cancelled successfully"); - Ok(()) - } - Err(e) => Err(e), + _ => { /* fall through */ } } + + let server_connection = + bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; + + tokio_run_command(args.command, server_connection)?; + + Ok(()) +} + +fn tokio_start_server( + server_socket_path: Option, + config_path: Option, + args: ServerArgs, +) -> anyhow::Result<()> { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + server::command::handle_command(server_socket_path, config_path, args).await + }) +} + +fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let tokio_socket = TokioUnixStream::from_std(server_connection)?; + let message_stream = create_client_to_server_message_stream(tokio_socket); + match command { + Command::User(user_args) => { + cli::user_command::handle_command(user_args, message_stream).await + } + Command::Db(db_args) => { + cli::database_command::handle_command(db_args, message_stream).await + } + Command::Server(_) => unreachable!(), + } + }) } diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..a285bbb --- /dev/null +++ b/src/server.rs @@ -0,0 +1,6 @@ +pub mod command; +mod common; +pub mod config; +pub mod input_sanitization; +pub mod server_loop; +pub mod sql; diff --git a/src/server/command.rs b/src/server/command.rs new file mode 100644 index 0000000..7a77122 --- /dev/null +++ b/src/server/command.rs @@ -0,0 +1,77 @@ +use std::os::fd::FromRawFd; +use std::path::PathBuf; + +use anyhow::Context; +use clap::Parser; + +use std::os::unix::net::UnixStream as StdUnixStream; +use tokio::net::UnixStream as TokioUnixStream; + +use crate::core::bootstrap::authenticated_unix_socket; +use crate::core::common::UnixUser; +use crate::server::config::read_config_from_path_with_arg_overrides; +use crate::server::server_loop::listen_for_incoming_connections; +use crate::server::{ + config::{ServerConfig, ServerConfigArgs}, + server_loop::handle_requests_for_single_session, +}; + +#[derive(Parser, Debug, Clone)] +pub struct ServerArgs { + #[command(subcommand)] + subcmd: ServerCommand, + + #[command(flatten)] + config_overrides: ServerConfigArgs, +} + +#[derive(Parser, Debug, Clone)] +pub enum ServerCommand { + #[command()] + Listen, + + #[command()] + SocketActivate, +} + +pub async fn handle_command( + socket_path: Option, + config_path: Option, + args: ServerArgs, +) -> anyhow::Result<()> { + let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; + + // if let Err(e) = &result { + // eprintln!("{}", e); + // } + + match args.subcmd { + ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, + ServerCommand::SocketActivate => socket_activate(config).await, + } +} + +async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { + // TODO: allow getting socket path from other socket activation sources + let mut conn = get_socket_from_systemd().await?; + let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?; + let unix_user = UnixUser::from_uid(uid.into())?; + handle_requests_for_single_session(conn, &unix_user, &config).await?; + + Ok(()) +} + +async fn get_socket_from_systemd() -> anyhow::Result { + let fd = std::env::var("LISTEN_FDS") + .context("LISTEN_FDS not set, not running under systemd?")? + .parse::() + .context("Failed to parse LISTEN_FDS")?; + + if fd != 1 { + return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd)); + } + + let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; + let socket = TokioUnixStream::from_std(std_unix_stream)?; + Ok(socket) +} diff --git a/src/server/common.rs b/src/server/common.rs new file mode 100644 index 0000000..db540ce --- /dev/null +++ b/src/server/common.rs @@ -0,0 +1,11 @@ +use crate::core::common::UnixUser; + +/// This function creates a regex that matches items (users, databases) +/// that belong to the user or any of the user's groups. +pub fn create_user_group_matching_regex(user: &UnixUser) -> String { + if user.groups.is_empty() { + format!("{}(_.+)?", user.username) + } else { + format!("({}|{})(_.+)?", user.username, user.groups.join("|")) + } +} diff --git a/src/core/config.rs b/src/server/config.rs similarity index 66% rename from src/core/config.rs rename to src/server/config.rs index 83ada90..3d0ca99 100644 --- a/src/core/config.rs +++ b/src/server/config.rs @@ -5,11 +5,16 @@ use clap::Parser; use serde::{Deserialize, Serialize}; use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection}; +use crate::core::common::DEFAULT_CONFIG_PATH; + +pub const DEFAULT_PORT: u16 = 3306; +pub const DEFAULT_TIMEOUT: u64 = 2; + // NOTE: this might look empty now, and the extra wrapping for the mysql // config seems unnecessary, but it will be useful later when we // add more configuration options. #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Config { +pub struct ServerConfig { pub mysql: MysqlConfig, } @@ -23,58 +28,36 @@ pub struct MysqlConfig { pub timeout: Option, } -const DEFAULT_PORT: u16 = 3306; -const DEFAULT_TIMEOUT: u64 = 2; - -#[derive(Parser)] -pub struct GlobalConfigArgs { - /// Path to the configuration file. - #[arg( - short, - long, - value_name = "PATH", - global = true, - hide_short_help = true, - default_value = "/etc/mysqladm/config.toml" - )] - config_file: String, - +#[derive(Parser, Debug, Clone)] +pub struct ServerConfigArgs { /// Hostname of the MySQL server. - #[arg(long, value_name = "HOST", global = true, hide_short_help = true)] + #[arg(long, value_name = "HOST", global = true)] mysql_host: Option, /// Port of the MySQL server. - #[arg(long, value_name = "PORT", global = true, hide_short_help = true)] + #[arg(long, value_name = "PORT", global = true)] mysql_port: Option, /// Username to use for the MySQL connection. - #[arg(long, value_name = "USER", global = true, hide_short_help = true)] + #[arg(long, value_name = "USER", global = true)] mysql_user: Option, /// Path to a file containing the MySQL password. - #[arg(long, value_name = "PATH", global = true, hide_short_help = true)] + #[arg(long, value_name = "PATH", global = true)] mysql_password_file: Option, /// Seconds to wait for the MySQL connection to be established. - #[arg(long, value_name = "SECONDS", global = true, hide_short_help = true)] + #[arg(long, value_name = "SECONDS", global = true)] mysql_connect_timeout: Option, } /// Use the arguments and whichever configuration file which might or might not /// be found and default values to determine the configuration for the program. -pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result { - let config_path = PathBuf::from(args.config_file); - - let config: Config = fs::read_to_string(&config_path) - .context(format!( - "Failed to read config file from {:?}", - &config_path - )) - .and_then(|c| toml::from_str(&c).context("Failed to parse config file")) - .context(format!( - "Failed to parse config file from {:?}", - &config_path - ))?; +pub fn read_config_from_path_with_arg_overrides( + config_path: Option, + args: ServerConfigArgs, +) -> anyhow::Result { + let config = read_config_form_path(config_path)?; let mysql = &config.mysql; @@ -86,22 +69,35 @@ pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result { mysql.password.to_owned() }; - let mysql_config = MysqlConfig { - host: args.mysql_host.unwrap_or(mysql.host.to_owned()), - port: args.mysql_port.or(mysql.port), - username: args.mysql_user.unwrap_or(mysql.username.to_owned()), - password, - timeout: args.mysql_connect_timeout.or(mysql.timeout), - }; - - Ok(Config { - mysql: mysql_config, + Ok(ServerConfig { + mysql: MysqlConfig { + host: args.mysql_host.unwrap_or(mysql.host.to_owned()), + port: args.mysql_port.or(mysql.port), + username: args.mysql_user.unwrap_or(mysql.username.to_owned()), + password, + timeout: args.mysql_connect_timeout.or(mysql.timeout), + }, }) } +pub fn read_config_form_path(config_path: Option) -> anyhow::Result { + let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); + + fs::read_to_string(&config_path) + .context(format!( + "Failed to read config file from {:?}", + &config_path + )) + .and_then(|c| toml::from_str(&c).context("Failed to parse config file")) + .context(format!( + "Failed to parse config file from {:?}", + &config_path + )) +} + /// Use the provided configuration to establish a connection to a MySQL server. pub async fn create_mysql_connection_from_config( - config: MysqlConfig, + config: &MysqlConfig, ) -> anyhow::Result { match tokio::time::timeout( Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)), diff --git a/src/server/input_sanitization.rs b/src/server/input_sanitization.rs new file mode 100644 index 0000000..6ce5201 --- /dev/null +++ b/src/server/input_sanitization.rs @@ -0,0 +1,158 @@ +use crate::core::{ + common::UnixUser, + protocol::server_responses::{NameValidationError, OwnerValidationError}, +}; + +const MAX_NAME_LENGTH: usize = 64; + +pub fn validate_name(name: &str) -> Result<(), NameValidationError> { + if name.is_empty() { + Err(NameValidationError::EmptyString) + } else if name.len() > MAX_NAME_LENGTH { + Err(NameValidationError::TooLong) + } else if !name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + Err(NameValidationError::InvalidCharacters) + } else { + Ok(()) + } +} + +pub fn validate_ownership_by_unix_user( + name: &str, + user: &UnixUser, +) -> Result<(), OwnerValidationError> { + let prefixes = std::iter::once(user.username.clone()) + .chain(user.groups.iter().cloned()) + .collect::>(); + + validate_ownership_by_prefixes(name, &prefixes) +} + +/// Core logic for validating the ownership of a database name. +/// This function checks if the given name matches any of the given prefixes. +/// These prefixes will in most cases be the user's unix username and any +/// unix groups the user is a member of. +pub fn validate_ownership_by_prefixes( + name: &str, + prefixes: &[String], +) -> Result<(), OwnerValidationError> { + if name.is_empty() { + return Err(OwnerValidationError::StringEmpty); + } + + if name.starts_with('_') { + return Err(OwnerValidationError::MissingPrefix); + } + + let (prefix, _) = match name.split_once('_') { + Some(pair) => pair, + None => return Err(OwnerValidationError::MissingPostfix), + }; + + if !prefixes.iter().any(|g| g == prefix) { + return Err(OwnerValidationError::NoMatch); + } + + Ok(()) +} + +#[inline] +pub fn quote_literal(s: &str) -> String { + format!("'{}'", s.replace('\'', r"\'")) +} + +#[inline] +pub fn quote_identifier(s: &str) -> String { + format!("`{}`", s.replace('`', r"\`")) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_quote_literal() { + let payload = "' OR 1=1 --"; + assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#); + } + + #[test] + fn test_quote_identifier() { + let payload = "` OR 1=1 --"; + assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#); + } + + #[test] + fn test_validate_name() { + assert_eq!(validate_name(""), Err(NameValidationError::EmptyString)); + assert_eq!(validate_name("abcdefghijklmnopqrstuvwxyz"), Ok(())); + assert_eq!(validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), Ok(())); + assert_eq!(validate_name("0123456789_-"), Ok(())); + + for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() { + assert_eq!( + validate_name(&c.to_string()), + Err(NameValidationError::InvalidCharacters) + ); + } + + assert_eq!(validate_name(&"a".repeat(MAX_NAME_LENGTH)), Ok(())); + + assert_eq!( + validate_name(&"a".repeat(MAX_NAME_LENGTH + 1)), + Err(NameValidationError::TooLong) + ); + } + + #[test] + fn test_validate_owner_by_prefixes() { + let prefixes = vec!["user".to_string(), "group".to_string()]; + + assert_eq!( + validate_ownership_by_prefixes("", &prefixes), + Err(OwnerValidationError::StringEmpty) + ); + + assert_eq!( + validate_ownership_by_prefixes("user", &prefixes), + Err(OwnerValidationError::MissingPostfix) + ); + assert_eq!( + validate_ownership_by_prefixes("something", &prefixes), + Err(OwnerValidationError::MissingPostfix) + ); + assert_eq!( + validate_ownership_by_prefixes("user-testdb", &prefixes), + Err(OwnerValidationError::MissingPostfix) + ); + + assert_eq!( + validate_ownership_by_prefixes("_testdb", &prefixes), + Err(OwnerValidationError::MissingPrefix) + ); + + assert_eq!( + validate_ownership_by_prefixes("user_testdb", &prefixes), + Ok(()) + ); + assert_eq!( + validate_ownership_by_prefixes("group_testdb", &prefixes), + Ok(()) + ); + assert_eq!( + validate_ownership_by_prefixes("group_test_db", &prefixes), + Ok(()) + ); + assert_eq!( + validate_ownership_by_prefixes("group_test-db", &prefixes), + Ok(()) + ); + + assert_eq!( + validate_ownership_by_prefixes("nonexistent_testdb", &prefixes), + Err(OwnerValidationError::NoMatch) + ); + } +} diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs new file mode 100644 index 0000000..f0b8b27 --- /dev/null +++ b/src/server/server_loop.rs @@ -0,0 +1,229 @@ +use std::{collections::BTreeSet, fs, path::PathBuf}; + +use anyhow::Context; + +use futures_util::{SinkExt, StreamExt}; +use tokio::io::AsyncWriteExt; +use tokio::net::{UnixListener, UnixStream}; + +use sqlx::prelude::*; +use sqlx::MySqlConnection; + +use crate::{ + core::{ + bootstrap::authenticated_unix_socket, + common::{UnixUser, DEFAULT_SOCKET_PATH}, + protocol::request_response::{ + create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream, + }, + }, + server::{ + config::{create_mysql_connection_from_config, ServerConfig}, + sql::{ + database_operations::{create_databases, drop_databases, list_databases_for_user}, + database_privilege_operations::{ + apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, + }, + user_operations::{ + create_database_users, drop_database_users, list_all_database_users_for_unix_user, + list_database_users, lock_database_users, set_password_for_database_user, + unlock_database_users, + }, + }, + }, +}; + +// TODO: consider using a connection pool + +// TODO: use tracing for login, so we can scope the log messages per incoming connection + +pub async fn listen_for_incoming_connections( + socket_path: Option, + config: ServerConfig, + // db_connection: &mut MySqlConnection, +) -> anyhow::Result<()> { + let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH)); + + let parent_directory = socket_path.parent().unwrap(); + if !parent_directory.exists() { + println!("Creating directory {:?}", parent_directory); + fs::create_dir_all(parent_directory)?; + } + + println!("Listening on {:?}", socket_path); + match fs::remove_file(socket_path.as_path()) { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} + Err(e) => return Err(e.into()), + } + + let listener = UnixListener::bind(socket_path)?; + + while let Ok((mut conn, _addr)) = listener.accept().await { + let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await { + Ok(uid) => uid, + Err(e) => { + eprintln!("Failed to authenticate client: {}", e); + conn.shutdown().await?; + continue; + } + }; + let unix_user = match UnixUser::from_uid(uid.into()) { + Ok(user) => user, + Err(e) => { + eprintln!("Failed to get UnixUser from uid: {}", e); + conn.shutdown().await?; + continue; + } + }; + match handle_requests_for_single_session(conn, &unix_user, &config).await { + Ok(_) => {} + Err(e) => { + eprintln!("Failed to run server: {}", e); + } + } + } + + Ok(()) +} + +pub async fn handle_requests_for_single_session( + socket: UnixStream, + unix_user: &UnixUser, + config: &ServerConfig, +) -> anyhow::Result<()> { + let message_stream = create_server_to_client_message_stream(socket); + let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; + + let result = handle_requests_for_single_session_with_db_connection( + message_stream, + unix_user, + &mut db_connection, + ) + .await; + + if let Err(e) = db_connection + .close() + .await + .context("Failed to close connection properly") + { + eprintln!("{}", e); + eprintln!("Ignoring..."); + } + + result +} + +// TODO: ensure proper db_connection hygiene for functions that invoke +// this function + +pub async fn handle_requests_for_single_session_with_db_connection( + mut stream: ServerToClientMessageStream, + unix_user: &UnixUser, + db_connection: &mut MySqlConnection, +) -> anyhow::Result<()> { + loop { + // TODO: better error handling + let request = match stream.next().await { + Some(Ok(request)) => request, + Some(Err(e)) => return Err(e.into()), + None => { + log::warn!("Client disconnected without sending an exit message"); + break; + } + }; + + match request { + Request::CreateDatabases(databases_names) => { + let result = create_databases(databases_names, unix_user, db_connection).await; + stream.send(Response::CreateDatabases(result)).await?; + stream.flush().await?; + } + Request::DropDatabases(databases_names) => { + let result = drop_databases(databases_names, unix_user, db_connection).await; + stream.send(Response::DropDatabases(result)).await?; + stream.flush().await?; + } + Request::ListDatabases => { + let result = list_databases_for_user(unix_user, db_connection).await; + stream.send(Response::ListAllDatabases(result)).await?; + stream.flush().await?; + } + Request::ListPrivileges(database_names) => { + let response = match database_names { + Some(database_names) => { + let privilege_data = + get_databases_privilege_data(database_names, unix_user, db_connection) + .await; + Response::ListPrivileges(privilege_data) + } + None => { + let privilege_data = + get_all_database_privileges(unix_user, db_connection).await; + Response::ListAllPrivileges(privilege_data) + } + }; + + stream.send(response).await?; + stream.flush().await?; + } + Request::ModifyPrivileges(database_privilege_diffs) => { + let result = apply_privilege_diffs( + BTreeSet::from_iter(database_privilege_diffs), + unix_user, + db_connection, + ) + .await; + stream.send(Response::ModifyPrivileges(result)).await?; + stream.flush().await?; + } + Request::CreateUsers(db_users) => { + let result = create_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::CreateUsers(result)).await?; + stream.flush().await?; + } + Request::DropUsers(db_users) => { + let result = drop_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::DropUsers(result)).await?; + stream.flush().await?; + } + Request::PasswdUser(db_user, password) => { + let result = + set_password_for_database_user(&db_user, &password, unix_user, db_connection) + .await; + stream.send(Response::PasswdUser(result)).await?; + stream.flush().await?; + } + Request::ListUsers(db_users) => { + let response = match db_users { + Some(db_users) => { + let result = list_database_users(db_users, unix_user, db_connection).await; + Response::ListUsers(result) + } + None => { + let result = + list_all_database_users_for_unix_user(unix_user, db_connection).await; + Response::ListAllUsers(result) + } + }; + stream.send(response).await?; + stream.flush().await?; + } + Request::LockUsers(db_users) => { + let result = lock_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::LockUsers(result)).await?; + stream.flush().await?; + } + Request::UnlockUsers(db_users) => { + let result = unlock_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::UnlockUsers(result)).await?; + stream.flush().await?; + } + Request::Exit => { + break; + } + } + } + + Ok(()) +} diff --git a/src/server/sql.rs b/src/server/sql.rs new file mode 100644 index 0000000..8db4e8b --- /dev/null +++ b/src/server/sql.rs @@ -0,0 +1,3 @@ +pub mod database_operations; +pub mod database_privilege_operations; +pub mod user_operations; diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs new file mode 100644 index 0000000..9ddddda --- /dev/null +++ b/src/server/sql/database_operations.rs @@ -0,0 +1,165 @@ +use crate::{ + core::{ + common::UnixUser, + protocol::{ + CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput, + ListDatabasesError, + }, + }, + server::{ + common::create_user_group_matching_regex, + input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user}, + }, +}; + +use sqlx::prelude::*; + +use sqlx::MySqlConnection; +use std::collections::BTreeMap; + +// NOTE: this function is unsafe because it does no input validation. +pub(super) async fn unsafe_database_exists( + database_name: &str, + connection: &mut MySqlConnection, +) -> Result { + let result = + sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") + .bind(database_name) + .fetch_optional(connection) + .await?; + + Ok(result.is_some()) +} + +pub async fn create_databases( + database_names: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> CreateDatabasesOutput { + let mut results = BTreeMap::new(); + + for database_name in database_names { + if let Err(err) = validate_name(&database_name) { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::OwnershipError(err)), + ); + continue; + } + + match unsafe_database_exists(&database_name, &mut *connection).await { + Ok(true) => { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::DatabaseAlreadyExists), + ); + continue; + } + Err(err) => { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::MySqlError(err.to_string())), + ); + continue; + } + _ => {} + } + + let result = + sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); + + results.insert(database_name, result); + } + + results +} + +pub async fn drop_databases( + database_names: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> DropDatabasesOutput { + let mut results = BTreeMap::new(); + + for database_name in database_names { + if let Err(err) = validate_name(&database_name) { + results.insert( + database_name.clone(), + Err(DropDatabaseError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { + results.insert( + database_name.clone(), + Err(DropDatabaseError::OwnershipError(err)), + ); + continue; + } + + match unsafe_database_exists(&database_name, &mut *connection).await { + Ok(false) => { + results.insert( + database_name.clone(), + Err(DropDatabaseError::DatabaseDoesNotExist), + ); + continue; + } + Err(err) => { + results.insert( + database_name.clone(), + Err(DropDatabaseError::MySqlError(err.to_string())), + ); + continue; + } + _ => {} + } + + let result = + sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); + + results.insert(database_name, result); + } + + results +} + +pub async fn list_databases_for_user( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> Result, ListDatabasesError> { + sqlx::query( + r#" + SELECT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') + AND `SCHEMA_NAME` REGEXP ? + "#, + ) + .bind(create_user_group_matching_regex(unix_user)) + .fetch_all(connection) + .await + .and_then(|rows| { + rows.into_iter() + .map(|row| row.try_get::("database")) + .collect::, sqlx::Error>>() + }) + .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) +} diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs new file mode 100644 index 0000000..b5ba9a0 --- /dev/null +++ b/src/server/sql/database_privilege_operations.rs @@ -0,0 +1,452 @@ +// TODO: fix comment +//! Database privilege operations +//! +//! This module contains functions for querying, modifying, +//! displaying and comparing database privileges. +//! +//! A lot of the complexity comes from two core components: +//! +//! - The privilege editor that needs to be able to print +//! an editable table of privileges and reparse the content +//! after the user has made manual changes. +//! +//! - The comparison functionality that tells the user what +//! changes will be made when applying a set of changes +//! to the list of database privileges. + +use std::collections::{BTreeMap, BTreeSet}; + +use indoc::indoc; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; + +use crate::{ + core::{ + common::{rev_yn, yn, UnixUser}, + database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff}, + protocol::{ + DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError, + GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError, + ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput, + }, + }, + server::{ + common::create_user_group_matching_regex, + input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user}, + sql::database_operations::unsafe_database_exists, + }, +}; + +/// This is the list of fields that are used to fetch the db + user + privileges +/// from the `db` table in the database. If you need to add or remove privilege +/// fields, this is a good place to start. +pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ + "db", + "user", + "select_priv", + "insert_priv", + "update_priv", + "delete_priv", + "create_priv", + "drop_priv", + "alter_priv", + "index_priv", + "create_tmp_table_priv", + "lock_tables_priv", + "references_priv", +]; + +// NOTE: ord is needed for BTreeSet to accept the type, but it +// doesn't have any natural implementation semantics. + +/// This struct represents the set of privileges for a single user on a single database. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] +pub struct DatabasePrivilegeRow { + pub db: String, + pub user: String, + pub select_priv: bool, + pub insert_priv: bool, + pub update_priv: bool, + pub delete_priv: bool, + pub create_priv: bool, + pub drop_priv: bool, + pub alter_priv: bool, + pub index_priv: bool, + pub create_tmp_table_priv: bool, + pub lock_tables_priv: bool, + pub references_priv: bool, +} + +impl DatabasePrivilegeRow { + pub fn get_privilege_by_name(&self, name: &str) -> bool { + match name { + "select_priv" => self.select_priv, + "insert_priv" => self.insert_priv, + "update_priv" => self.update_priv, + "delete_priv" => self.delete_priv, + "create_priv" => self.create_priv, + "drop_priv" => self.drop_priv, + "alter_priv" => self.alter_priv, + "index_priv" => self.index_priv, + "create_tmp_table_priv" => self.create_tmp_table_priv, + "lock_tables_priv" => self.lock_tables_priv, + "references_priv" => self.references_priv, + _ => false, + } + } +} + +#[inline] +fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result { + let field = DATABASE_PRIVILEGE_FIELDS[position]; + let value = row.try_get(position)?; + match rev_yn(value) { + Some(val) => Ok(val), + _ => { + log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value); + Ok(false) + } + } +} + +impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { + fn from_row(row: &MySqlRow) -> Result { + Ok(Self { + db: row.try_get("db")?, + user: row.try_get("user")?, + select_priv: get_mysql_row_priv_field(row, 2)?, + insert_priv: get_mysql_row_priv_field(row, 3)?, + update_priv: get_mysql_row_priv_field(row, 4)?, + delete_priv: get_mysql_row_priv_field(row, 5)?, + create_priv: get_mysql_row_priv_field(row, 6)?, + drop_priv: get_mysql_row_priv_field(row, 7)?, + alter_priv: get_mysql_row_priv_field(row, 8)?, + index_priv: get_mysql_row_priv_field(row, 9)?, + create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?, + lock_tables_priv: get_mysql_row_priv_field(row, 11)?, + references_priv: get_mysql_row_priv_field(row, 12)?, + }) + } +} + +// NOTE: this function is unsafe because it does no input validation. +/// Get all users + privileges for a single database. +async fn unsafe_get_database_privileges( + database_name: &str, + connection: &mut MySqlConnection, +) -> Result, sqlx::Error> { + sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + "SELECT {} FROM `db` WHERE `db` = ?", + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| quote_identifier(field)) + .join(","), + )) + .bind(database_name) + .fetch_all(connection) + .await +} + +// NOTE: this function is unsafe because it does no input validation. +/// Get all users + privileges for a single database-user pair. +pub async fn unsafe_get_database_privileges_for_db_user_pair( + database_name: &str, + user_name: &str, + connection: &mut MySqlConnection, +) -> Result, sqlx::Error> { + sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + "SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?", + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| quote_identifier(field)) + .join(","), + )) + .bind(database_name) + .bind(user_name) + .fetch_optional(connection) + .await +} + +pub async fn get_databases_privilege_data( + database_names: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> GetDatabasesPrivilegeData { + let mut results = BTreeMap::new(); + + for database_name in database_names.iter() { + if let Err(err) = validate_name(database_name) { + results.insert( + database_name.clone(), + Err(GetDatabasesPrivilegeDataError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) { + results.insert( + database_name.clone(), + Err(GetDatabasesPrivilegeDataError::OwnershipError(err)), + ); + continue; + } + + if !unsafe_database_exists(database_name, connection) + .await + .unwrap() + { + results.insert( + database_name.clone(), + Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist), + ); + continue; + } + + let result = unsafe_get_database_privileges(database_name, connection) + .await + .map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string())); + + results.insert(database_name.clone(), result); + } + + debug_assert!(database_names.len() == results.len()); + + results +} + +/// Get all database + user + privileges pairs that are owned by the current user. +pub async fn get_all_database_privileges( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> GetAllDatabasesPrivilegeData { + sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + indoc! {r#" + SELECT {} FROM `db` WHERE `db` IN + (SELECT DISTINCT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') + AND `SCHEMA_NAME` REGEXP ?) + "#}, + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| quote_identifier(field)) + .join(","), + )) + .bind(create_user_group_matching_regex(unix_user)) + .fetch_all(connection) + .await + .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())) +} + +async fn unsafe_apply_privilege_diff( + database_privilege_diff: &DatabasePrivilegesDiff, + connection: &mut MySqlConnection, +) -> Result<(), sqlx::Error> { + match database_privilege_diff { + DatabasePrivilegesDiff::New(p) => { + let tables = DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| quote_identifier(field)) + .join(","); + + let question_marks = std::iter::repeat("?") + .take(DATABASE_PRIVILEGE_FIELDS.len()) + .join(","); + + sqlx::query( + format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), + ) + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .bind(yn(p.select_priv)) + .bind(yn(p.insert_priv)) + .bind(yn(p.update_priv)) + .bind(yn(p.delete_priv)) + .bind(yn(p.create_priv)) + .bind(yn(p.drop_priv)) + .bind(yn(p.alter_priv)) + .bind(yn(p.index_priv)) + .bind(yn(p.create_tmp_table_priv)) + .bind(yn(p.lock_tables_priv)) + .bind(yn(p.references_priv)) + .execute(connection) + .await + .map(|_| ()) + } + DatabasePrivilegesDiff::Modified(p) => { + let changes = p + .diff + .iter() + .map(|diff| match diff { + DatabasePrivilegeChange::YesToNo(name) => { + format!("{} = 'N'", quote_identifier(name)) + } + DatabasePrivilegeChange::NoToYes(name) => { + format!("{} = 'Y'", quote_identifier(name)) + } + }) + .join(","); + + sqlx::query( + format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(), + ) + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .execute(connection) + .await + .map(|_| ()) + } + DatabasePrivilegesDiff::Deleted(p) => { + sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .execute(connection) + .await + .map(|_| ()) + } + } +} + +async fn validate_diff( + diff: &DatabasePrivilegesDiff, + connection: &mut MySqlConnection, +) -> Result<(), ModifyDatabasePrivilegesError> { + let privilege_row = unsafe_get_database_privileges_for_db_user_pair( + diff.get_database_name(), + diff.get_user_name(), + connection, + ) + .await; + + let privilege_row = match privilege_row { + Ok(privilege_row) => privilege_row, + Err(e) => return Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())), + }; + + let result = match diff { + DatabasePrivilegesDiff::New(_) => { + if privilege_row.is_some() { + Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( + DiffDoesNotApplyError::RowAlreadyExists( + diff.get_user_name().to_string(), + diff.get_database_name().to_string(), + ), + )) + } else { + Ok(()) + } + } + DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => { + Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( + DiffDoesNotApplyError::RowDoesNotExist( + diff.get_user_name().to_string(), + diff.get_database_name().to_string(), + ), + )) + } + DatabasePrivilegesDiff::Modified(row_diff) => { + let row = privilege_row.unwrap(); + + let error_exists = row_diff.diff.iter().any(|change| match change { + DatabasePrivilegeChange::YesToNo(name) => !row.get_privilege_by_name(name), + DatabasePrivilegeChange::NoToYes(name) => row.get_privilege_by_name(name), + }); + + if error_exists { + Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( + DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.clone(), row), + )) + } else { + Ok(()) + } + } + DatabasePrivilegesDiff::Deleted(_) => { + if privilege_row.is_none() { + Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( + DiffDoesNotApplyError::RowDoesNotExist( + diff.get_user_name().to_string(), + diff.get_database_name().to_string(), + ), + )) + } else { + Ok(()) + } + } + }; + + result +} + +/// Uses the result of [`diff_privileges`] to modify privileges in the database. +pub async fn apply_privilege_diffs( + database_privilege_diffs: BTreeSet, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> ModifyDatabasePrivilegesOutput { + let mut results: BTreeMap<(String, String), _> = BTreeMap::new(); + + for diff in database_privilege_diffs { + let key = ( + diff.get_database_name().to_string(), + diff.get_user_name().to_string(), + ); + if let Err(err) = validate_name(diff.get_database_name()) { + results.insert( + key, + Err(ModifyDatabasePrivilegesError::DatabaseSanitizationError( + err, + )), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) { + results.insert( + key, + Err(ModifyDatabasePrivilegesError::DatabaseOwnershipError(err)), + ); + continue; + } + + if let Err(err) = validate_name(diff.get_user_name()) { + results.insert( + key, + Err(ModifyDatabasePrivilegesError::UserSanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) { + results.insert( + key, + Err(ModifyDatabasePrivilegesError::UserOwnershipError(err)), + ); + continue; + } + + if !unsafe_database_exists(diff.get_database_name(), connection) + .await + .unwrap() + { + results.insert( + key, + Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist), + ); + continue; + } + + if let Err(err) = validate_diff(&diff, connection).await { + results.insert(key, Err(err)); + continue; + } + + let result = unsafe_apply_privilege_diff(&diff, connection) + .await + .map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string())); + + results.insert(key, result); + } + + results +} diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs new file mode 100644 index 0000000..4c83692 --- /dev/null +++ b/src/server/sql/user_operations.rs @@ -0,0 +1,375 @@ +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; + +use sqlx::prelude::*; +use sqlx::MySqlConnection; + +use crate::{ + core::{ + common::UnixUser, + protocol::{ + CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError, + ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput, + SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput, + }, + }, + server::{ + common::create_user_group_matching_regex, + input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user}, + }, +}; + +// NOTE: this function is unsafe because it does no input validation. +async fn unsafe_user_exists( + db_user: &str, + connection: &mut MySqlConnection, +) -> Result { + sqlx::query( + r#" + SELECT EXISTS( + SELECT 1 + FROM `mysql`.`user` + WHERE `User` = ? + ) + "#, + ) + .bind(db_user) + .fetch_one(connection) + .await + .map(|row| row.get::(0)) +} + +pub async fn create_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> CreateUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(CreateUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(CreateUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(true) => { + results.insert(db_user, Err(CreateUserError::UserAlreadyExists)); + continue; + } + Err(err) => { + results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| CreateUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +pub async fn drop_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> DropUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(DropUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(DropUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(DropUserError::UserDoesNotExist)); + continue; + } + Err(err) => { + results.insert(db_user, Err(DropUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| DropUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +pub async fn set_password_for_database_user( + db_user: &str, + password: &str, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> SetPasswordOutput { + if let Err(err) = validate_name(db_user) { + return Err(SetPasswordError::SanitizationError(err)); + } + + if let Err(err) = validate_ownership_by_unix_user(db_user, unix_user) { + return Err(SetPasswordError::OwnershipError(err)); + } + + match unsafe_user_exists(db_user, &mut *connection).await { + Ok(false) => return Err(SetPasswordError::UserDoesNotExist), + Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())), + _ => {} + } + + sqlx::query( + format!( + "ALTER USER {}@'%' IDENTIFIED BY {}", + quote_literal(db_user), + quote_literal(password).as_str() + ) + .as_str(), + ) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| SetPasswordError::MySqlError(err.to_string())) +} + +// NOTE: this function is unsafe because it does no input validation. +async fn database_user_is_locked_unsafe( + db_user: &str, + connection: &mut MySqlConnection, +) -> Result { + sqlx::query( + r#" + SELECT COALESCE( + JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), + 'false' + ) != 'false' + FROM `mysql`.`global_priv` + WHERE `User` = ? + AND `Host` = '%' + "#, + ) + .bind(db_user) + .fetch_one(connection) + .await + .map(|row| row.get::(0)) +} + +pub async fn lock_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> LockUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(LockUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(LockUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(true) => {} + Ok(false) => { + results.insert(db_user, Err(LockUserError::UserDoesNotExist)); + continue; + } + Err(err) => { + results.insert(db_user, Err(LockUserError::MySqlError(err.to_string()))); + continue; + } + } + + match database_user_is_locked_unsafe(&db_user, &mut *connection).await { + Ok(false) => {} + Ok(true) => { + results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked)); + continue; + } + Err(err) => { + results.insert(db_user, Err(LockUserError::MySqlError(err.to_string()))); + continue; + } + } + + let result = sqlx::query( + format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(), + ) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| LockUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +pub async fn unlock_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> UnlockUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(UnlockUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(UnlockUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(UnlockUserError::UserDoesNotExist)); + continue; + } + Err(err) => { + results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + match database_user_is_locked_unsafe(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked)); + continue; + } + Err(err) => { + results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query( + format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(), + ) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| UnlockUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +/// This struct contains information about a database user. +/// This can be extended if we need more information in the future. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct DatabaseUser { + #[sqlx(rename = "User")] + pub user: String, + + #[allow(dead_code)] + #[serde(skip)] + #[sqlx(rename = "Host")] + pub host: String, + + #[sqlx(rename = "has_password")] + pub has_password: bool, + + #[sqlx(rename = "is_locked")] + pub is_locked: bool, +} + +const DB_USER_SELECT_STATEMENT: &str = r#" +SELECT + `mysql`.`user`.`User`, + `mysql`.`user`.`Host`, + `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, + COALESCE( + JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), + 'false' + ) != 'false' AS `is_locked` +FROM `mysql`.`user` +JOIN `mysql`.`global_priv` ON + `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` + AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` +"#; + +pub async fn list_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> ListUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(ListUsersError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(ListUsersError::OwnershipError(err))); + continue; + } + + let result = sqlx::query_as::<_, DatabaseUser>( + &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), + ) + .bind(&db_user) + .fetch_optional(&mut *connection) + .await; + + match result { + Ok(Some(user)) => results.insert(db_user, Ok(user)), + Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)), + Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))), + }; + } + + results +} + +pub async fn list_all_database_users_for_unix_user( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> ListAllUsersOutput { + sqlx::query_as::<_, DatabaseUser>( + &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), + ) + .bind(create_user_group_matching_regex(unix_user)) + .fetch_all(connection) + .await + .map_err(|err| ListAllUsersError::MySqlError(err.to_string())) +} -- 2.44.2