diff --git a/Cargo.lock b/Cargo.lock index d5e3f10..d30af81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,16 @@ dependencies = [ "clap_derive", ] +[[package]] +name = "clap-verbosity-flag" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63d19864d6b68464c59f7162c9914a0b569ddc2926b4a2d71afe62a9738eff53" +dependencies = [ + "clap", + "log", +] + [[package]] name = "clap_builder" version = "4.5.15" @@ -993,6 +1003,9 @@ name = "log" version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +dependencies = [ + "value-bag", +] [[package]] name = "lru" @@ -1064,6 +1077,7 @@ dependencies = [ "async-bincode", "bincode", "clap", + "clap-verbosity-flag", "clap_complete", "derive_more", "dialoguer", @@ -1082,6 +1096,7 @@ dependencies = [ "serde", "serde_json", "sqlx", + "systemd-journal-logger", "tokio", "tokio-serde", "tokio-stream", @@ -1996,6 +2011,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "systemd-journal-logger" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5f3848dd723f2a54ac1d96da793b32923b52de8dfcced8722516dac312a5b2a" +dependencies = [ + "log", + "rustix", +] + [[package]] name = "tempfile" version = "3.12.0" @@ -2287,6 +2312,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "value-bag" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index faa78e2..e94df3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ anyhow = "1.0.86" async-bincode = "0.7.2" bincode = "1.3.3" clap = { version = "4.5.16", features = ["derive"] } +clap-verbosity-flag = "2.2.1" clap_complete = "4.5.18" derive_more = { version = "1.0.0", features = ["display", "error"] } dialoguer = "0.11.0" @@ -25,6 +26,7 @@ sd-notify = "0.4.2" serde = "1.0.208" serde_json = { version = "1.0.125", features = ["preserve_order"] } sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] } +systemd-journal-logger = "2.1.1" tokio = { version = "1.39.3", features = ["rt", "macros"] } tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-stream = "0.1.15" diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..77ed81e --- /dev/null +++ b/deny.toml @@ -0,0 +1,78 @@ +[graph] +targets = [ + "x86_64-unknown-linux-gnu", + "aarch64-unknown-linux-gnu", + "armv7-unknown-linux-gnueabihf", + + "x86_64-unknown-freebsd", + "aarch64-unknown-freebsd", + "armv7-unknown-freebsd", + + "x86_64-apple-darwin", + "aarch64-apple-darwin", +] + +all-features = false +no-default-features = false + +#features = [] + +[output] +feature-depth = 1 + +[advisories] +#db-path = "$CARGO_HOME/advisory-dbs" +#db-urls = ["https://github.com/rustsec/advisory-db"] +ignore = [] + +[licenses] +allow = [ + "GPL-2.0", + "MIT", + "Apache-2.0", + "ISC", + "MPL-2.0", + "Unicode-DFS-2016", + "BSD-3-Clause", + "OpenSSL", +] +confidence-threshold = 0.8 +exceptions = [] + +[[licenses.clarify]] +crate = "ring" +expression = "MIT AND ISC AND OpenSSL" +license-files = [ + { path = "LICENSE", hash = 0xbd0eed23 } +] + +[licenses.private] +ignore = false +registries = [] + +[bans] +multiple-versions = "allow" +wildcards = "allow" +highlight = "all" +workspace-default-features = "allow" +external-default-features = "allow" +allow = [] +deny = [] + +#[[bans.features]] +#crate = "reqwest" +#deny = ["json"] +#allow = [] +#exact = true + +skip = [] +skip-tree = [] + +[sources] +unknown-registry = "warn" +unknown-git = "warn" +allow-registry = ["https://github.com/rust-lang/crates.io-index"] +allow-git = [] + +[sources.allow-org] + diff --git a/flake.nix b/flake.nix index 60ff337..fe92eee 100644 --- a/flake.nix +++ b/flake.nix @@ -47,6 +47,7 @@ toolchain mysql-client cargo-nextest + cargo-deny ]; RUST_SRC_PATH = "${toolchain}/lib/rustlib/src/rust/library"; diff --git a/nix/module.nix b/nix/module.nix index dd39f4f..514753c 100644 --- a/nix/module.nix +++ b/nix/module.nix @@ -15,6 +15,20 @@ in description = "Create a local database user for mysqladm-rs"; }; + logLevel = lib.mkOption { + type = lib.types.enum [ "quiet" "error" "warn" "info" "debug" "trace" ]; + default = "debug"; + description = "Log level for mysqladm-rs"; + apply = level: { + "quiet" = "-q"; + "error" = ""; + "warn" = "-v"; + "info" = "-vv"; + "debug" = "-vvv"; + "trace" = "-vvvv"; + }.${level}; + }; + settings = lib.mkOption { default = { }; type = lib.types.submodule { @@ -76,7 +90,7 @@ in name = cfg.settings.mysql.username; ensurePermissions = { "mysql.*" = "SELECT, INSERT, UPDATE, DELETE"; - "*.*" = "CREATE USER, GRANT OPTION"; + "*.*" = "GRANT OPTION, CREATE, DROP"; }; } ]; @@ -86,7 +100,9 @@ in environment.RUST_LOG = "debug"; serviceConfig = { Type = "notify"; - ExecStart = "${lib.getExe cfg.package} server socket-activate --config ${configFile}"; + ExecStart = "${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate --config ${configFile}"; + + WatchdogSec = 15; User = "mysqladm"; Group = "mysqladm"; @@ -95,7 +111,18 @@ in # This is required to read unix user/group details. PrivateUsers = false; - CapabilityBoundingSet = ""; + # Needed to communicate with MySQL. + PrivateNetwork = false; + + IPAddressDeny = + lib.optionals (lib.elem cfg.settings.mysql.host [ null "localhost" "127.0.0.1" ]) [ "any" ]; + + RestrictAddressFamilies = [ "AF_UNIX" ] + ++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]); + + AmbientCapabilities = [ "" ]; + CapabilityBoundingSet = [ "" ]; + DeviceAllow = [ "" ]; LockPersonality = true; MemoryDenyWriteExecute = true; NoNewPrivileges = true; @@ -113,12 +140,12 @@ in ProtectProc = "invisible"; ProtectSystem = "strict"; RemoveIPC = true; - UMask = "0000"; - RestrictAddressFamilies = [ "AF_UNIX" "AF_INET" "AF_INET6" ]; + UMask = "0777"; RestrictNamespaces = true; RestrictRealtime = true; RestrictSUIDSGID = true; SystemCallArchitectures = "native"; + SocketBindDeny = [ "any" ]; SystemCallFilter = [ "@system-service" "~@privileged" diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs index 3d2f9ab..868497b 100644 --- a/src/cli/database_command.rs +++ b/src/cli/database_command.rs @@ -15,9 +15,10 @@ use crate::{ parse_privilege_table_cli_arg, }, protocol::{ - print_create_databases_output_status, print_drop_databases_output_status, - print_modify_database_privileges_output_status, ClientToServerMessageStream, Request, - Response, + print_create_databases_output_status, print_create_databases_output_status_json, + print_drop_databases_output_status, print_drop_databases_output_status_json, + print_modify_database_privileges_output_status, ClientToServerMessageStream, + MySQLDatabase, Request, Response, }, }, server::sql::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS}, @@ -102,49 +103,57 @@ pub enum DatabaseCommand { #[derive(Parser, Debug, Clone)] pub struct DatabaseCreateArgs { - /// The name of the database(s) to create. + /// The name of the database(s) to create #[arg(num_args = 1..)] - name: Vec, + name: Vec, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, } #[derive(Parser, Debug, Clone)] pub struct DatabaseDropArgs { - /// The name of the database(s) to drop. + /// The name of the database(s) to drop #[arg(num_args = 1..)] - name: Vec, + name: Vec, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, } #[derive(Parser, Debug, Clone)] pub struct DatabaseShowArgs { - /// The name of the database(s) to show. + /// The name of the database(s) to show #[arg(num_args = 0..)] - name: Vec, + name: Vec, - /// Whether to output the information in JSON format. + /// Print the information as JSON #[arg(short, long)] json: bool, } #[derive(Parser, Debug, Clone)] pub struct DatabaseShowPrivsArgs { - /// The name of the database(s) to show. + /// The name of the database(s) to show #[arg(num_args = 0..)] - name: Vec, + name: Vec, - /// Whether to output the information in JSON format. + /// Print the information as JSON #[arg(short, long)] json: bool, } #[derive(Parser, Debug, Clone)] pub struct DatabaseEditPrivsArgs { - /// The name of the database to edit privileges for. - pub name: Option, + /// The name of the database to edit privileges for + pub name: Option, #[arg(short, long, value_name = "[DATABASE:]USER:PRIVILEGES", num_args = 0..)] pub privs: Vec, - /// Whether to output the information in JSON format. + /// Print the information as JSON #[arg(short, long)] pub json: bool, @@ -152,7 +161,7 @@ pub struct DatabaseEditPrivsArgs { #[arg(short, long)] pub editor: Option, - /// Disable interactive confirmation before saving changes. + /// Disable interactive confirmation before saving changes #[arg(short, long)] pub yes: bool, } @@ -182,7 +191,7 @@ async fn create_databases( anyhow::bail!("No database names provided"); } - let message = Request::CreateDatabases(args.name.clone()); + let message = Request::CreateDatabases(args.name.to_owned()); server_connection.send(message).await?; let result = match server_connection.next().await { @@ -192,7 +201,11 @@ async fn create_databases( server_connection.send(Request::Exit).await?; - print_create_databases_output_status(&result); + if args.json { + print_create_databases_output_status_json(&result); + } else { + print_create_databases_output_status(&result); + } Ok(()) } @@ -205,7 +218,7 @@ async fn drop_databases( anyhow::bail!("No database names provided"); } - let message = Request::DropDatabases(args.name.clone()); + let message = Request::DropDatabases(args.name.to_owned()); server_connection.send(message).await?; let result = match server_connection.next().await { @@ -215,7 +228,11 @@ async fn drop_databases( server_connection.send(Request::Exit).await?; - print_drop_databases_output_status(&result); + if args.json { + print_drop_databases_output_status_json(&result); + } else { + print_drop_databases_output_status(&result); + }; Ok(()) } @@ -227,11 +244,13 @@ async fn show_databases( let message = if args.name.is_empty() { Request::ListDatabases(None) } else { - Request::ListDatabases(Some(args.name.clone())) + Request::ListDatabases(Some(args.name.to_owned())) }; server_connection.send(message).await?; + // TODO: collect errors for json output. + let database_list = match server_connection.next().await { Some(Ok(Response::ListDatabases(databases))) => databases .into_iter() @@ -282,7 +301,7 @@ async fn show_database_privileges( let message = if args.name.is_empty() { Request::ListPrivileges(None) } else { - Request::ListPrivileges(Some(args.name.clone())) + Request::ListPrivileges(Some(args.name.to_owned())) }; server_connection.send(message).await?; @@ -354,7 +373,7 @@ pub async fn edit_database_privileges( args: DatabaseEditPrivsArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let message = Request::ListPrivileges(args.name.clone().map(|name| vec![name])); + let message = Request::ListPrivileges(args.name.to_owned().map(|name| vec![name])); server_connection.send(message).await?; @@ -386,13 +405,14 @@ pub async fn edit_database_privileges( let privileges_to_change = if !args.privs.is_empty() { parse_privilege_tables_from_args(&args)? } else { - edit_privileges_with_editor(&privilege_data, args.name.as_deref())? + edit_privileges_with_editor(&privilege_data, args.name.as_ref())? }; let diffs = diff_privileges(&privilege_data, &privileges_to_change); if diffs.is_empty() { println!("No changes to make."); + server_connection.send(Request::Exit).await?; return Ok(()); } @@ -451,7 +471,7 @@ fn parse_privilege_tables_from_args( fn edit_privileges_with_editor( privilege_data: &[DatabasePrivilegeRow], - database_name: Option<&str>, + database_name: Option<&MySQLDatabase>, ) -> anyhow::Result> { let unix_user = User::from_uid(getuid()) .context("Failed to look up your UNIX username") diff --git a/src/cli/mysql_admutils_compatibility/common.rs b/src/cli/mysql_admutils_compatibility/common.rs index f2c7f19..d3d829d 100644 --- a/src/cli/mysql_admutils_compatibility/common.rs +++ b/src/cli/mysql_admutils_compatibility/common.rs @@ -1,4 +1,11 @@ +use crate::core::protocol::{MySQLDatabase, MySQLUser}; + #[inline] -pub fn trim_to_32_chars(name: &str) -> String { - name.chars().take(32).collect() +pub fn trim_db_name_to_32_chars(db_name: &MySQLDatabase) -> MySQLDatabase { + db_name.chars().take(32).collect::().into() +} + +#[inline] +pub fn trim_user_name_to_32_chars(user_name: &MySQLUser) -> MySQLUser { + user_name.chars().take(32).collect::().into() } diff --git a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs index c028791..9c63fc2 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs @@ -9,7 +9,7 @@ use crate::{ common::erroneous_server_response, database_command, mysql_admutils_compatibility::{ - common::trim_to_32_chars, + common::trim_db_name_to_32_chars, error_messages::{ format_show_database_error_message, handle_create_database_error, handle_drop_database_error, @@ -20,7 +20,7 @@ use crate::{ bootstrap::bootstrap_server_connection_and_drop_privileges, protocol::{ create_client_to_server_message_stream, ClientToServerMessageStream, - GetDatabasesPrivilegeDataError, Request, Response, + GetDatabasesPrivilegeDataError, MySQLDatabase, Request, Response, }, }, server::sql::database_privilege_operations::DatabasePrivilegeRow, @@ -120,27 +120,27 @@ pub enum Command { pub struct CreateArgs { /// The name of the DATABASE(s) to create. #[arg(num_args = 1..)] - name: Vec, + name: Vec, } #[derive(Parser)] pub struct DatabaseDropArgs { /// The name of the DATABASE(s) to drop. #[arg(num_args = 1..)] - name: Vec, + name: Vec, } #[derive(Parser)] pub struct DatabaseShowArgs { /// The name of the DATABASE(s) to show. #[arg(num_args = 0..)] - name: Vec, + name: Vec, } #[derive(Parser)] pub struct EditPermArgs { /// The name of the DATABASE to edit permissions for. - pub database: String, + pub database: MySQLDatabase, } pub fn main() -> anyhow::Result<()> { @@ -202,11 +202,7 @@ async fn create_databases( args: CreateArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let database_names = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let database_names = args.name.iter().map(trim_db_name_to_32_chars).collect(); let message = Request::CreateDatabases(database_names); server_connection.send(message).await?; @@ -232,11 +228,7 @@ async fn drop_databases( args: DatabaseDropArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let database_names = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let database_names = args.name.iter().map(trim_db_name_to_32_chars).collect(); let message = Request::DropDatabases(database_names); server_connection.send(message).await?; @@ -262,11 +254,8 @@ async fn show_databases( args: DatabaseShowArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let database_names: Vec = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let database_names: Vec = + args.name.iter().map(trim_db_name_to_32_chars).collect(); let message = if database_names.is_empty() { let message = Request::ListDatabases(None); @@ -291,14 +280,16 @@ async fn show_databases( // NOTE: mysql-dbadm show has a quirk where valid database names // for non-existent databases will report with no users. - let results: Vec), String>> = match response { + let results: Vec), String>> = match response { Some(Ok(Response::ListPrivileges(result))) => result .into_iter() - .map(|(name, rows)| match rows.map(|rows| (name.clone(), rows)) { - Ok(rows) => Ok(rows), - Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])), - Err(err) => Err(format_show_database_error_message(err, &name)), - }) + .map( + |(name, rows)| match rows.map(|rows| (name.to_owned(), rows)) { + Ok(rows) => Ok(rows), + Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])), + Err(err) => Err(format_show_database_error_message(err, &name)), + }, + ) .collect(), response => return erroneous_server_response(response), }; diff --git a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs index cdc7254..fc1dabb 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs @@ -9,7 +9,7 @@ use crate::{ cli::{ common::erroneous_server_response, mysql_admutils_compatibility::{ - common::trim_to_32_chars, + common::trim_user_name_to_32_chars, error_messages::{ handle_create_user_error, handle_drop_user_error, handle_list_users_error, }, @@ -19,7 +19,8 @@ use crate::{ core::{ bootstrap::bootstrap_server_connection_and_drop_privileges, protocol::{ - create_client_to_server_message_stream, ClientToServerMessageStream, Request, Response, + create_client_to_server_message_stream, ClientToServerMessageStream, MySQLUser, + Request, Response, }, }, server::sql::user_operations::DatabaseUser, @@ -83,28 +84,28 @@ pub enum Command { pub struct CreateArgs { /// The name of the USER(s) to create. #[arg(num_args = 1..)] - name: Vec, + name: Vec, } #[derive(Parser)] pub struct DeleteArgs { /// The name of the USER(s) to delete. #[arg(num_args = 1..)] - name: Vec, + name: Vec, } #[derive(Parser)] pub struct PasswdArgs { /// The name of the USER(s) to change the password for. #[arg(num_args = 1..)] - name: Vec, + name: Vec, } #[derive(Parser)] pub struct ShowArgs { /// The name of the USER(s) to show. #[arg(num_args = 0..)] - name: Vec, + name: Vec, } pub fn main() -> anyhow::Result<()> { @@ -152,13 +153,9 @@ async fn create_user( args: CreateArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let usernames = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect(); - let message = Request::CreateUsers(usernames); + let message = Request::CreateUsers(db_users); server_connection.send(message).await?; let result = match server_connection.next().await { @@ -182,13 +179,9 @@ async fn drop_users( args: DeleteArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let usernames = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect(); - let message = Request::DropUsers(usernames); + let message = Request::DropUsers(db_users); server_connection.send(message).await?; let result = match server_connection.next().await { @@ -212,13 +205,9 @@ async fn passwd_users( args: PasswdArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let usernames = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect(); - let message = Request::ListUsers(Some(usernames)); + let message = Request::ListUsers(Some(db_users)); server_connection.send(message).await?; let response = match server_connection.next().await { @@ -243,11 +232,11 @@ async fn passwd_users( for user in users { let password = read_password_from_stdin_with_double_check(&user.user)?; - let message = Request::PasswdUser(user.user.clone(), password); + let message = Request::PasswdUser(user.user.to_owned(), password); server_connection.send(message).await?; match server_connection.next().await { Some(Ok(Response::PasswdUser(result))) => match result { - Ok(()) => println!("Password updated for user '{}'.", user.user), + Ok(()) => println!("Password updated for user '{}'.", &user.user), Err(_) => eprintln!( "{}: Failed to update password for user '{}'.", argv0, user.user, @@ -266,16 +255,12 @@ async fn show_users( args: ShowArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let usernames: Vec<_> = args - .name - .iter() - .map(|name| trim_to_32_chars(name)) - .collect(); + let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect(); - let message = if usernames.is_empty() { + let message = if db_users.is_empty() { Request::ListUsers(None) } else { - Request::ListUsers(Some(usernames)) + Request::ListUsers(Some(db_users)) }; server_connection.send(message).await?; diff --git a/src/cli/user_command.rs b/src/cli/user_command.rs index a99e596..d92ca34 100644 --- a/src/cli/user_command.rs +++ b/src/cli/user_command.rs @@ -4,10 +4,12 @@ use dialoguer::{Confirm, Password}; use futures_util::{SinkExt, StreamExt}; use crate::core::protocol::{ - print_create_users_output_status, print_drop_users_output_status, - print_lock_users_output_status, print_set_password_output_status, - print_unlock_users_output_status, ClientToServerMessageStream, ListUsersError, Request, - Response, + print_create_users_output_status, print_create_users_output_status_json, + print_drop_users_output_status, print_drop_users_output_status_json, + print_lock_users_output_status, print_lock_users_output_status_json, + print_set_password_output_status, print_unlock_users_output_status, + print_unlock_users_output_status_json, ClientToServerMessageStream, ListUsersError, MySQLUser, + Request, Response, }; use super::common::erroneous_server_response; @@ -51,46 +53,69 @@ pub enum UserCommand { #[derive(Parser, Debug, Clone)] pub struct UserCreateArgs { #[arg(num_args = 1..)] - username: Vec, + username: Vec, /// Do not ask for a password, leave it unset #[clap(long)] no_password: bool, + + /// Print the information as JSON + /// + /// Note that this implies `--no-password`, since the command will become non-interactive. + #[arg(short, long)] + json: bool, } #[derive(Parser, Debug, Clone)] pub struct UserDeleteArgs { #[arg(num_args = 1..)] - username: Vec, + username: Vec, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, } #[derive(Parser, Debug, Clone)] pub struct UserPasswdArgs { - username: String, + username: MySQLUser, #[clap(short, long)] password_file: Option, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, } #[derive(Parser, Debug, Clone)] pub struct UserShowArgs { #[arg(num_args = 0..)] - username: Vec, + username: Vec, - #[clap(short, long)] + /// Print the information as JSON + #[arg(short, long)] json: bool, } #[derive(Parser, Debug, Clone)] pub struct UserLockArgs { #[arg(num_args = 1..)] - username: Vec, + username: Vec, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, } #[derive(Parser, Debug, Clone)] pub struct UserUnlockArgs { #[arg(num_args = 1..)] - username: Vec, + username: Vec, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, } pub async fn handle_command( @@ -115,7 +140,7 @@ async fn create_users( anyhow::bail!("No usernames provided"); } - let message = Request::CreateUsers(args.username.clone()); + let message = Request::CreateUsers(args.username.to_owned()); if let Err(err) = server_connection.send(message).await { server_connection.close().await.ok(); anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); @@ -126,39 +151,43 @@ async fn create_users( response => return erroneous_server_response(response), }; - print_create_users_output_status(&result); + if args.json { + print_create_users_output_status_json(&result); + } else { + print_create_users_output_status(&result); - let successfully_created_users = result - .iter() - .filter_map(|(username, result)| result.as_ref().ok().map(|_| username)) - .collect::>(); + let successfully_created_users = result + .iter() + .filter_map(|(username, result)| result.as_ref().ok().map(|_| username)) + .collect::>(); - for username in successfully_created_users { - if !args.no_password - && Confirm::new() - .with_prompt(format!( - "Do you want to set a password for user '{}'?", - username - )) - .default(false) - .interact()? - { - let password = read_password_from_stdin_with_double_check(username)?; - let message = Request::PasswdUser(username.clone(), password); + for username in successfully_created_users { + if !args.no_password + && Confirm::new() + .with_prompt(format!( + "Do you want to set a password for user '{}'?", + username + )) + .default(false) + .interact()? + { + let password = read_password_from_stdin_with_double_check(username)?; + let message = Request::PasswdUser(username.to_owned(), password); - if let Err(err) = server_connection.send(message).await { - server_connection.close().await.ok(); - anyhow::bail!(err); - } - - match server_connection.next().await { - Some(Ok(Response::PasswdUser(result))) => { - print_set_password_output_status(&result, username) + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); } - response => return erroneous_server_response(response), - } - println!(); + match server_connection.next().await { + Some(Ok(Response::PasswdUser(result))) => { + print_set_password_output_status(&result, username) + } + response => return erroneous_server_response(response), + } + + println!(); + } } } @@ -175,7 +204,7 @@ async fn drop_users( anyhow::bail!("No usernames provided"); } - let message = Request::DropUsers(args.username.clone()); + let message = Request::DropUsers(args.username.to_owned()); if let Err(err) = server_connection.send(message).await { server_connection.close().await.ok(); @@ -189,12 +218,16 @@ async fn drop_users( server_connection.send(Request::Exit).await?; - print_drop_users_output_status(&result); + if args.json { + print_drop_users_output_status_json(&result); + } else { + print_drop_users_output_status(&result); + } Ok(()) } -pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result { +pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result { Password::new() .with_prompt(format!("New MySQL password for user '{}'", username)) .with_confirmation( @@ -210,7 +243,7 @@ async fn passwd_user( mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { // TODO: create a "user" exists check" command - let message = Request::ListUsers(Some(vec![args.username.clone()])); + let message = Request::ListUsers(Some(vec![args.username.to_owned()])); if let Err(err) = server_connection.send(message).await { server_connection.close().await.ok(); anyhow::bail!(err); @@ -240,7 +273,7 @@ async fn passwd_user( read_password_from_stdin_with_double_check(&args.username)? }; - let message = Request::PasswdUser(args.username.clone(), password); + let message = Request::PasswdUser(args.username.to_owned(), password); if let Err(err) = server_connection.send(message).await { server_connection.close().await.ok(); @@ -266,7 +299,7 @@ async fn show_users( let message = if args.username.is_empty() { Request::ListUsers(None) } else { - Request::ListUsers(Some(args.username.clone())) + Request::ListUsers(Some(args.username.to_owned())) }; if let Err(err) = server_connection.send(message).await { @@ -337,7 +370,7 @@ async fn lock_users( anyhow::bail!("No usernames provided"); } - let message = Request::LockUsers(args.username.clone()); + let message = Request::LockUsers(args.username.to_owned()); if let Err(err) = server_connection.send(message).await { server_connection.close().await.ok(); @@ -351,7 +384,11 @@ async fn lock_users( server_connection.send(Request::Exit).await?; - print_lock_users_output_status(&result); + if args.json { + print_lock_users_output_status_json(&result); + } else { + print_lock_users_output_status(&result); + } Ok(()) } @@ -364,7 +401,7 @@ async fn unlock_users( anyhow::bail!("No usernames provided"); } - let message = Request::UnlockUsers(args.username.clone()); + let message = Request::UnlockUsers(args.username.to_owned()); if let Err(err) = server_connection.send(message).await { server_connection.close().await.ok(); @@ -378,7 +415,11 @@ async fn unlock_users( server_connection.send(Request::Exit).await?; - print_unlock_users_output_status(&result); + if args.json { + print_unlock_users_output_status_json(&result); + } else { + print_unlock_users_output_status(&result); + } Ok(()) } diff --git a/src/core/common.rs b/src/core/common.rs index 2d4045c..5b30e53 100644 --- a/src/core/common.rs +++ b/src/core/common.rs @@ -54,7 +54,7 @@ impl UnixUser { Ok(UnixUser { username: libc_user.name, - groups: groups.iter().map(|g| g.name.clone()).collect(), + groups: groups.iter().map(|g| g.name.to_owned()).collect(), }) } diff --git a/src/core/database_privileges.rs b/src/core/database_privileges.rs index a92f40a..a731452 100644 --- a/src/core/database_privileges.rs +++ b/src/core/database_privileges.rs @@ -7,7 +7,10 @@ use std::{ collections::{BTreeSet, HashMap}, }; -use super::common::{rev_yn, yn}; +use super::{ + common::{rev_yn, yn}, + protocol::{MySQLDatabase, MySQLUser}, +}; use crate::server::sql::database_privilege_operations::{ DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS, }; @@ -35,8 +38,8 @@ pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> Databas debug_assert!(row1.db == row2.db && row1.user == row2.user); DatabasePrivilegeRowDiff { - db: row1.db.clone(), - user: row1.user.clone(), + db: row1.db.to_owned(), + user: row1.user.to_owned(), diff: DATABASE_PRIVILEGE_FIELDS .into_iter() .skip(2) @@ -70,8 +73,8 @@ pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result, + database_name: Option<&MySQLDatabase>, ) -> String { let example_user = format!("{}_user", unix_user); let example_db = database_name - .unwrap_or(&format!("{}_db", unix_user)) + .unwrap_or(&format!("{}_db", unix_user).into()) .to_string(); // NOTE: `.max()`` fails when the iterator is empty. @@ -206,8 +209,8 @@ pub fn generate_editor_content_from_privilege_data( let example_line = format_privileges_line_for_editor( &DatabasePrivilegeRow { - db: example_db, - user: example_user, + db: example_db.into(), + user: example_user.into(), select_priv: true, insert_priv: true, update_priv: true, @@ -298,8 +301,8 @@ fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult { } let row = DatabasePrivilegeRow { - db: (*parts.first().unwrap()).to_owned(), - user: (*parts.get(1).unwrap()).to_owned(), + db: (*parts.first().unwrap()).into(), + user: (*parts.get(1).unwrap()).into(), select_priv: match parse_privilege_cell_from_editor( parts.get(2).unwrap(), DATABASE_PRIVILEGE_FIELDS[2], @@ -423,8 +426,8 @@ pub fn parse_privilege_data_from_editor_content( /// The `User` and `Database` are the same for both instances. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct DatabasePrivilegeRowDiff { - pub db: String, - pub user: String, + pub db: MySQLDatabase, + pub user: MySQLUser, pub diff: BTreeSet, } @@ -454,7 +457,7 @@ pub enum DatabasePrivilegesDiff { } impl DatabasePrivilegesDiff { - pub fn get_database_name(&self) -> &str { + pub fn get_database_name(&self) -> &MySQLDatabase { match self { DatabasePrivilegesDiff::New(p) => &p.db, DatabasePrivilegesDiff::Modified(p) => &p.db, @@ -462,7 +465,7 @@ impl DatabasePrivilegesDiff { } } - pub fn get_user_name(&self) -> &str { + pub fn get_user_name(&self) -> &MySQLUser { match self { DatabasePrivilegesDiff::New(p) => &p.user, DatabasePrivilegesDiff::Modified(p) => &p.user, @@ -478,34 +481,36 @@ pub fn diff_privileges( from: &[DatabasePrivilegeRow], to: &[DatabasePrivilegeRow], ) -> BTreeSet { - let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( - from.iter() - .cloned() - .map(|p| ((p.db.clone(), p.user.clone()), p)), - ); + let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> = + HashMap::from_iter( + from.iter() + .cloned() + .map(|p| ((p.db.to_owned(), p.user.to_owned()), p)), + ); - let to_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( - to.iter() - .cloned() - .map(|p| ((p.db.clone(), p.user.clone()), p)), - ); + let to_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> = + HashMap::from_iter( + to.iter() + .cloned() + .map(|p| ((p.db.to_owned(), p.user.to_owned()), p)), + ); let mut result = BTreeSet::new(); for p in to { - if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { + if let Some(old_p) = from_lookup_table.get(&(p.db.to_owned(), p.user.to_owned())) { let diff = diff(old_p, p); if !diff.diff.is_empty() { result.insert(DatabasePrivilegesDiff::Modified(diff)); } } else { - result.insert(DatabasePrivilegesDiff::New(p.clone())); + result.insert(DatabasePrivilegesDiff::New(p.to_owned())); } } for p in from { - if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { - result.insert(DatabasePrivilegesDiff::Deleted(p.clone())); + if !to_lookup_table.contains_key(&(p.db.to_owned(), p.user.to_owned())) { + result.insert(DatabasePrivilegesDiff::Deleted(p.to_owned())); } } @@ -593,8 +598,8 @@ mod tests { assert_eq!( result.ok(), Some(DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), + db: "db".into(), + user: "user".into(), select_priv: true, insert_priv: true, update_priv: true, @@ -613,8 +618,8 @@ mod tests { assert_eq!( result.ok(), Some(DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), + db: "db".into(), + user: "user".into(), select_priv: false, insert_priv: false, update_priv: false, @@ -633,8 +638,8 @@ mod tests { assert_eq!( result.ok(), Some(DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), + db: "db".into(), + user: "user".into(), select_priv: true, insert_priv: true, update_priv: true, @@ -668,8 +673,8 @@ mod tests { #[test] fn test_diff_privileges() { let row_to_be_modified = DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), + db: "db".into(), + user: "user".into(), select_priv: true, insert_priv: true, update_priv: true, @@ -683,20 +688,20 @@ mod tests { references_priv: false, }; - let mut row_to_be_deleted = row_to_be_modified.clone(); + let mut row_to_be_deleted = row_to_be_modified.to_owned(); "user2".clone_into(&mut row_to_be_deleted.user); - let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()]; + let from = vec![row_to_be_modified.to_owned(), row_to_be_deleted.to_owned()]; - let mut modified_row = row_to_be_modified.clone(); + let mut modified_row = row_to_be_modified.to_owned(); modified_row.select_priv = false; modified_row.insert_priv = false; modified_row.index_priv = true; - let mut new_row = row_to_be_modified.clone(); + let mut new_row = row_to_be_modified.to_owned(); "user3".clone_into(&mut new_row.user); - let to = vec![modified_row.clone(), new_row.clone()]; + let to = vec![modified_row.to_owned(), new_row.to_owned()]; let diffs = diff_privileges(&from, &to); @@ -705,8 +710,8 @@ mod tests { BTreeSet::from_iter(vec![ DatabasePrivilegesDiff::Deleted(row_to_be_deleted), DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff { - db: "db".to_owned(), - user: "user".to_owned(), + db: "db".into(), + user: "user".into(), diff: BTreeSet::from_iter(vec![ DatabasePrivilegeChange::YesToNo("select_priv".to_owned()), DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()), @@ -722,8 +727,8 @@ mod tests { fn ensure_generated_and_parsed_editor_content_is_equal() { let permissions = vec![ DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), + db: "db".into(), + user: "user".into(), select_priv: true, insert_priv: true, update_priv: true, @@ -737,8 +742,8 @@ mod tests { references_priv: true, }, DatabasePrivilegeRow { - db: "db2".to_owned(), - user: "user2".to_owned(), + db: "db".into(), + user: "user".into(), select_priv: false, insert_priv: false, update_priv: false, diff --git a/src/core/protocol/request_response.rs b/src/core/protocol/request_response.rs index 2f194af..ef0d029 100644 --- a/src/core/protocol/request_response.rs +++ b/src/core/protocol/request_response.rs @@ -1,4 +1,9 @@ -use std::collections::BTreeSet; +use std::{ + collections::BTreeSet, + fmt::{Display, Formatter}, + ops::{Deref, DerefMut}, + str::FromStr, +}; use serde::{Deserialize, Serialize}; use tokio::net::UnixStream; @@ -31,21 +36,107 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer tokio_serde::Framed::new(length_delimited, Bincode::default()) } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct MySQLUser(String); + +impl FromStr for MySQLUser { + type Err = String; + + fn from_str(s: &str) -> Result { + Ok(MySQLUser(s.to_string())) + } +} + +impl Deref for MySQLUser { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for MySQLUser { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Display for MySQLUser { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for MySQLUser { + fn from(s: &str) -> Self { + MySQLUser(s.to_string()) + } +} + +impl From for MySQLUser { + fn from(s: String) -> Self { + MySQLUser(s) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct MySQLDatabase(String); + +impl FromStr for MySQLDatabase { + type Err = String; + + fn from_str(s: &str) -> Result { + Ok(MySQLDatabase(s.to_string())) + } +} + +impl Deref for MySQLDatabase { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for MySQLDatabase { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Display for MySQLDatabase { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for MySQLDatabase { + fn from(s: &str) -> Self { + MySQLDatabase(s.to_string()) + } +} + +impl From for MySQLDatabase { + fn from(s: String) -> Self { + MySQLDatabase(s) + } +} + #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Request { - CreateDatabases(Vec), - DropDatabases(Vec), - ListDatabases(Option>), - ListPrivileges(Option>), + CreateDatabases(Vec), + DropDatabases(Vec), + ListDatabases(Option>), + ListPrivileges(Option>), ModifyPrivileges(BTreeSet), - CreateUsers(Vec), - DropUsers(Vec), - PasswdUser(String, String), - ListUsers(Option>), - LockUsers(Vec), - UnlockUsers(Vec), + CreateUsers(Vec), + DropUsers(Vec), + PasswdUser(MySQLUser, String), + ListUsers(Option>), + LockUsers(Vec), + UnlockUsers(Vec), // Commit, Exit, diff --git a/src/core/protocol/server_responses.rs b/src/core/protocol/server_responses.rs index 4bb857d..c3fc061 100644 --- a/src/core/protocol/server_responses.rs +++ b/src/core/protocol/server_responses.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use indoc::indoc; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use serde_json::json; use crate::{ core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff}, @@ -12,6 +13,8 @@ use crate::{ }, }; +use super::{MySQLDatabase, MySQLUser}; + /// This enum is used to differentiate between database and user operations. /// Their output are very similar, but there are slight differences in the words used. #[derive(Debug, PartialEq, Eq, Clone, Copy)] @@ -21,17 +24,17 @@ pub enum DbOrUser { } impl DbOrUser { - pub fn lowercased(&self) -> String { + pub fn lowercased(&self) -> &'static str { match self { - DbOrUser::Database => "database".to_string(), - DbOrUser::User => "user".to_string(), + DbOrUser::Database => "database", + DbOrUser::User => "user", } } - pub fn capitalized(&self) -> String { + pub fn capitalized(&self) -> &'static str { match self { - DbOrUser::Database => "Database".to_string(), - DbOrUser::User => "User".to_string(), + DbOrUser::Database => "Database", + DbOrUser::User => "User", } } } @@ -72,6 +75,11 @@ impl OwnerValidationError { pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String { let user = UnixUser::from_enviroment(); + let UnixUser { username, groups } = user.unwrap_or(UnixUser { + username: "???".to_string(), + groups: vec![], + }); + match self { OwnerValidationError::NoMatch => format!( indoc! {r#" @@ -87,11 +95,8 @@ impl OwnerValidationError { name, db_or_user.lowercased(), db_or_user.lowercased(), - user.as_ref() - .map(|u| u.username.clone()) - .unwrap_or("???".to_string()), - user.map(|u| u.groups) - .unwrap_or_default() + username, + groups .iter() .map(|g| format!(" - {}", g)) .sorted() @@ -117,7 +122,7 @@ pub enum OwnerValidationError { StringEmpty, } -pub type CreateDatabasesOutput = BTreeMap>; +pub type CreateDatabasesOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum CreateDatabaseError { SanitizationError(NameValidationError), @@ -141,8 +146,29 @@ pub fn print_create_databases_output_status(output: &CreateDatabasesOutput) { } } +pub fn print_create_databases_output_status_json(output: &CreateDatabasesOutput) { + let value = output + .iter() + .map(|(name, result)| match result { + Ok(()) => (name.to_string(), json!({ "status": "success" })), + Err(err) => ( + name.to_string(), + json!({ + "status": "error", + "error": err.to_error_message(name), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + impl CreateDatabaseError { - pub fn to_error_message(&self, database_name: &str) -> String { + pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String { match self { CreateDatabaseError::SanitizationError(err) => { err.to_error_message(database_name, DbOrUser::Database) @@ -160,7 +186,7 @@ impl CreateDatabaseError { } } -pub type DropDatabasesOutput = BTreeMap>; +pub type DropDatabasesOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum DropDatabaseError { SanitizationError(NameValidationError), @@ -173,7 +199,10 @@ pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) { for (database_name, result) in output { match result { Ok(()) => { - println!("Database '{}' dropped successfully.", database_name); + println!( + "Database '{}' dropped successfully.", + database_name.as_str() + ); } Err(err) => { println!("{}", err.to_error_message(database_name)); @@ -184,8 +213,29 @@ pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) { } } +pub fn print_drop_databases_output_status_json(output: &DropDatabasesOutput) { + let value = output + .iter() + .map(|(name, result)| match result { + Ok(()) => (name.to_string(), json!({ "status": "success" })), + Err(err) => ( + name.to_string(), + json!({ + "status": "error", + "error": err.to_error_message(name), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + impl DropDatabaseError { - pub fn to_error_message(&self, database_name: &str) -> String { + pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String { match self { DropDatabaseError::SanitizationError(err) => { err.to_error_message(database_name, DbOrUser::Database) @@ -203,7 +253,7 @@ impl DropDatabaseError { } } -pub type ListDatabasesOutput = BTreeMap>; +pub type ListDatabasesOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ListDatabasesError { SanitizationError(NameValidationError), @@ -213,7 +263,7 @@ pub enum ListDatabasesError { } impl ListDatabasesError { - pub fn to_error_message(&self, database_name: &str) -> String { + pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String { match self { ListDatabasesError::SanitizationError(err) => { err.to_error_message(database_name, DbOrUser::Database) @@ -250,7 +300,7 @@ impl ListAllDatabasesError { // no need to index by database name. pub type GetDatabasesPrivilegeData = - BTreeMap, GetDatabasesPrivilegeDataError>>; + BTreeMap, GetDatabasesPrivilegeDataError>>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum GetDatabasesPrivilegeDataError { SanitizationError(NameValidationError), @@ -260,7 +310,7 @@ pub enum GetDatabasesPrivilegeDataError { } impl GetDatabasesPrivilegeDataError { - pub fn to_error_message(&self, database_name: &str) -> String { + pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String { match self { GetDatabasesPrivilegeDataError::SanitizationError(err) => { err.to_error_message(database_name, DbOrUser::Database) @@ -294,7 +344,7 @@ impl GetAllDatabasesPrivilegeDataError { } pub type ModifyDatabasePrivilegesOutput = - BTreeMap<(String, String), Result<(), ModifyDatabasePrivilegesError>>; + BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ModifyDatabasePrivilegesError { DatabaseSanitizationError(NameValidationError), @@ -309,8 +359,8 @@ pub enum ModifyDatabasePrivilegesError { #[allow(clippy::enum_variant_names)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum DiffDoesNotApplyError { - RowAlreadyExists(String, String), - RowDoesNotExist(String, String), + RowAlreadyExists(MySQLDatabase, MySQLUser), + RowDoesNotExist(MySQLDatabase, MySQLUser), RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow), } @@ -333,7 +383,7 @@ pub fn print_modify_database_privileges_output_status(output: &ModifyDatabasePri } impl ModifyDatabasePrivilegesError { - pub fn to_error_message(&self, database_name: &str, username: &str) -> String { + pub fn to_error_message(&self, database_name: &MySQLDatabase, username: &MySQLUser) -> String { match self { ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => { err.to_error_message(database_name, DbOrUser::Database) @@ -388,7 +438,7 @@ impl DiffDoesNotApplyError { } } -pub type CreateUsersOutput = BTreeMap>; +pub type CreateUsersOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum CreateUserError { SanitizationError(NameValidationError), @@ -412,8 +462,29 @@ pub fn print_create_users_output_status(output: &CreateUsersOutput) { } } +pub fn print_create_users_output_status_json(output: &CreateUsersOutput) { + let value = output + .iter() + .map(|(name, result)| match result { + Ok(()) => (name.to_string(), json!({ "status": "success" })), + Err(err) => ( + name.to_string(), + json!({ + "status": "error", + "error": err.to_error_message(name), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + impl CreateUserError { - pub fn to_error_message(&self, username: &str) -> String { + pub fn to_error_message(&self, username: &MySQLUser) -> String { match self { CreateUserError::SanitizationError(err) => { err.to_error_message(username, DbOrUser::User) @@ -429,7 +500,7 @@ impl CreateUserError { } } -pub type DropUsersOutput = BTreeMap>; +pub type DropUsersOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum DropUserError { SanitizationError(NameValidationError), @@ -453,8 +524,29 @@ pub fn print_drop_users_output_status(output: &DropUsersOutput) { } } +pub fn print_drop_users_output_status_json(output: &DropUsersOutput) { + let value = output + .iter() + .map(|(name, result)| match result { + Ok(()) => (name.to_string(), json!({ "status": "success" })), + Err(err) => ( + name.to_string(), + json!({ + "status": "error", + "error": err.to_error_message(name), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + impl DropUserError { - pub fn to_error_message(&self, username: &str) -> String { + pub fn to_error_message(&self, username: &MySQLUser) -> String { match self { DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), @@ -477,7 +569,7 @@ pub enum SetPasswordError { MySqlError(String), } -pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &str) { +pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &MySQLUser) { match output { Ok(()) => { println!("Password for user '{}' set successfully.", username); @@ -490,7 +582,7 @@ pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &s } impl SetPasswordError { - pub fn to_error_message(&self, username: &str) -> String { + pub fn to_error_message(&self, username: &MySQLUser) -> String { match self { SetPasswordError::SanitizationError(err) => { err.to_error_message(username, DbOrUser::User) @@ -506,7 +598,7 @@ impl SetPasswordError { } } -pub type LockUsersOutput = BTreeMap>; +pub type LockUsersOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum LockUserError { SanitizationError(NameValidationError), @@ -531,8 +623,29 @@ pub fn print_lock_users_output_status(output: &LockUsersOutput) { } } +pub fn print_lock_users_output_status_json(output: &LockUsersOutput) { + let value = output + .iter() + .map(|(name, result)| match result { + Ok(()) => (name.to_string(), json!({ "status": "success" })), + Err(err) => ( + name.to_string(), + json!({ + "status": "error", + "error": err.to_error_message(name), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + impl LockUserError { - pub fn to_error_message(&self, username: &str) -> String { + pub fn to_error_message(&self, username: &MySQLUser) -> String { match self { LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), @@ -549,7 +662,7 @@ impl LockUserError { } } -pub type UnlockUsersOutput = BTreeMap>; +pub type UnlockUsersOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum UnlockUserError { SanitizationError(NameValidationError), @@ -574,8 +687,29 @@ pub fn print_unlock_users_output_status(output: &UnlockUsersOutput) { } } +pub fn print_unlock_users_output_status_json(output: &UnlockUsersOutput) { + let value = output + .iter() + .map(|(name, result)| match result { + Ok(()) => (name.to_string(), json!({ "status": "success" })), + Err(err) => ( + name.to_string(), + json!({ + "status": "error", + "error": err.to_error_message(name), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + impl UnlockUserError { - pub fn to_error_message(&self, username: &str) -> String { + pub fn to_error_message(&self, username: &MySQLUser) -> String { match self { UnlockUserError::SanitizationError(err) => { err.to_error_message(username, DbOrUser::User) @@ -594,7 +728,7 @@ impl UnlockUserError { } } -pub type ListUsersOutput = BTreeMap>; +pub type ListUsersOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ListUsersError { SanitizationError(NameValidationError), @@ -604,7 +738,7 @@ pub enum ListUsersError { } impl ListUsersError { - pub fn to_error_message(&self, username: &str) -> String { + pub fn to_error_message(&self, username: &MySQLUser) -> String { match self { ListUsersError::SanitizationError(err) => { err.to_error_message(username, DbOrUser::User) diff --git a/src/main.rs b/src/main.rs index 549072f..3c2f574 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ extern crate prettytable; use clap::{CommandFactory, Parser, ValueEnum}; use clap_complete::{generate, Shell}; +use clap_verbosity_flag::Verbosity; use std::path::PathBuf; @@ -62,6 +63,10 @@ struct Args { )] config: Option, + #[command(flatten)] + verbose: Verbosity, + + /// Run in TUI mode. #[cfg(feature = "tui")] #[arg(short, long, alias = "tui", global = true)] interactive: bool, @@ -103,11 +108,6 @@ enum ToplevelCommands { // comments emphasizing the need for caution. fn main() -> anyhow::Result<()> { - // TODO: find out if there are any security risks of running - // env_logger and clap with elevated privileges. - - env_logger::init(); - #[cfg(feature = "mysql-admutils-compatibility")] if handle_mysql_admutils_command()?.is_some() { return Ok(()); @@ -126,6 +126,10 @@ fn main() -> anyhow::Result<()> { let server_connection = bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; + env_logger::Builder::new() + .filter_level(args.verbose.log_level_filter()) + .init(); + tokio_run_command(args.command, server_connection)?; Ok(()) @@ -149,9 +153,10 @@ fn handle_server_command(args: &Args) -> anyhow::Result> { match args.command { Command::Server(ref command) => { tokio_start_server( - args.server_socket_path.clone(), - args.config.clone(), - command.clone(), + args.server_socket_path.to_owned(), + args.config.to_owned(), + args.verbose.to_owned(), + command.to_owned(), )?; Ok(Some(())) } @@ -188,6 +193,7 @@ fn handle_generate_completions_command(args: &Args) -> anyhow::Result fn tokio_start_server( server_socket_path: Option, config_path: Option, + verbosity: Verbosity, args: ServerArgs, ) -> anyhow::Result<()> { tokio::runtime::Builder::new_current_thread() @@ -195,7 +201,7 @@ fn tokio_start_server( .build() .unwrap() .block_on(async { - server::command::handle_command(server_socket_path, config_path, args).await + server::command::handle_command(server_socket_path, config_path, verbosity, args).await }) } diff --git a/src/server/command.rs b/src/server/command.rs index cb279ab..57e69e2 100644 --- a/src/server/command.rs +++ b/src/server/command.rs @@ -3,11 +3,16 @@ use std::path::PathBuf; use anyhow::Context; use clap::Parser; +use clap_verbosity_flag::Verbosity; +use futures::SinkExt; +use indoc::concatdoc; +use systemd_journal_logger::JournalLog; use std::os::unix::net::UnixStream as StdUnixStream; use tokio::net::UnixStream as TokioUnixStream; use crate::core::common::UnixUser; +use crate::core::protocol::{create_server_to_client_message_stream, Response}; use crate::server::config::read_config_from_path_with_arg_overrides; use crate::server::server_loop::listen_for_incoming_connections; use crate::server::{ @@ -22,6 +27,9 @@ pub struct ServerArgs { #[command(flatten)] config_overrides: ServerConfigArgs, + + #[arg(long)] + systemd: bool, } #[derive(Parser, Debug, Clone)] @@ -33,27 +41,148 @@ pub enum ServerCommand { SocketActivate, } +const LOG_LEVEL_WARNING: &str = r#" +=================================================== +== WARNING: LOG LEVEL IS SET TO 'TRACE'! == +== THIS WILL CAUSE THE SERVER TO LOG SQL QUERIES == +== THAT MAY CONTAIN SENSITIVE INFORMATION LIKE == +== PASSWORDS AND AUTHENTICATION TOKENS. == +== THIS IS INTENDED FOR DEBUGGING PURPOSES ONLY == +== AND SHOULD *NEVER* BE USED IN PRODUCTION. == +=================================================== +"#; + pub async fn handle_command( socket_path: Option, config_path: Option, + verbosity: Verbosity, args: ServerArgs, ) -> anyhow::Result<()> { + let mut auto_detected_systemd_mode = false; + let systemd_mode = args.systemd || { + if let Ok(true) = sd_notify::booted() { + auto_detected_systemd_mode = true; + true + } else { + false + } + }; + if systemd_mode { + JournalLog::new() + .context("Failed to initialize journald logging")? + .install() + .context("Failed to install journald logger")?; + + log::set_max_level(verbosity.log_level_filter()); + + if verbosity.log_level_filter() >= log::LevelFilter::Trace { + log::warn!("{}", LOG_LEVEL_WARNING.trim()); + } + + if auto_detected_systemd_mode { + log::info!("Running in systemd mode, auto-detected"); + } else { + log::info!("Running in systemd mode"); + } + + start_watchdog_thread_if_enabled(); + } else { + env_logger::Builder::new() + .filter_level(verbosity.log_level_filter()) + .init(); + + log::info!("Running in standalone mode"); + } + let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; match args.subcmd { ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, - ServerCommand::SocketActivate => socket_activate(config).await, + ServerCommand::SocketActivate => { + if !args.systemd { + anyhow::bail!(concat!( + "The `--systemd` flag must be used with the `socket-activate` command.\n", + "This command currently only supports socket activation under systemd." + )); + } + + socket_activate(config).await + } + } +} + +fn start_watchdog_thread_if_enabled() { + let mut micro_seconds: u64 = 0; + let watchdog_enabled = sd_notify::watchdog_enabled(true, &mut micro_seconds); + + if watchdog_enabled { + micro_seconds = micro_seconds.max(2_000_000).div_ceil(2); + + tokio::spawn(async move { + log::debug!( + "Starting systemd watchdog thread with {} millisecond interval", + micro_seconds.div_ceil(1000) + ); + loop { + tokio::time::sleep(tokio::time::Duration::from_micros(micro_seconds)).await; + if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) { + log::warn!("Failed to notify systemd watchdog: {}", err); + } else { + log::trace!("Ping sent to systemd watchdog"); + } + } + }); + } else { + log::debug!("Systemd watchdog not enabled, skipping watchdog thread"); } } async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { let conn = get_socket_from_systemd().await?; - let uid = conn.peer_cred()?.uid(); - let unix_user = UnixUser::from_uid(uid)?; + + let uid = match conn.peer_cred() { + Ok(cred) => cred.uid(), + Err(e) => { + log::error!("Failed to get peer credentials from socket: {}", e); + let mut message_stream = create_server_to_client_message_stream(conn); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get peer credentials from socket\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + anyhow::bail!("Failed to get peer credentials from socket"); + } + }; + + log::debug!("Accepted connection from uid {}", uid); + + let unix_user = match UnixUser::from_uid(uid) { + Ok(user) => user, + Err(e) => { + log::error!("Failed to get username from uid: {}", e); + let mut message_stream = create_server_to_client_message_stream(conn); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get user data from the system\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + anyhow::bail!("Failed to get username from uid"); + } + }; log::info!("Accepted connection from {}", unix_user.username); - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); + sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok(); handle_requests_for_single_session(conn, &unix_user, &config).await?; @@ -66,7 +195,12 @@ async fn get_socket_from_systemd() -> anyhow::Result { .next() .context("No file descriptors received from systemd")?; - log::debug!("Received file descriptor from systemd: {}", fd); + debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd); + + log::debug!( + "Received file descriptor from systemd with id: '{}', assuming socket", + fd + ); let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; let socket = TokioUnixStream::from_std(std_unix_stream)?; diff --git a/src/server/config.rs b/src/server/config.rs index 3010e43..ef45353 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -109,22 +109,16 @@ pub fn read_config_from_path_with_arg_overrides( pub fn read_config_from_path(config_path: Option) -> anyhow::Result { let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); - log::debug!("Reading config from {:?}", &config_path); + log::debug!("Reading config file at {:?}", &config_path); fs::read_to_string(&config_path) - .context(format!( - "Failed to read config file from {:?}", - &config_path - )) + .context(format!("Failed to read config file at {:?}", &config_path)) .and_then(|c| toml::from_str(&c).context("Failed to parse config file")) - .context(format!( - "Failed to parse config file from {:?}", - &config_path - )) + .context(format!("Failed to parse config file at {:?}", &config_path)) } fn log_config(config: &MysqlConfig) { - let mut display_config = config.clone(); + let mut display_config = config.to_owned(); display_config.password = display_config .password .as_ref() @@ -141,7 +135,9 @@ pub async fn create_mysql_connection_from_config( ) -> anyhow::Result { log_config(config); - let mut mysql_options = MySqlConnectOptions::new().database("mysql"); + let mut mysql_options = MySqlConnectOptions::new() + .database("mysql") + .log_statements(log::LevelFilter::Trace); if let Some(username) = &config.username { mysql_options = mysql_options.username(username); diff --git a/src/server/input_sanitization.rs b/src/server/input_sanitization.rs index bd6dd22..a16f69c 100644 --- a/src/server/input_sanitization.rs +++ b/src/server/input_sanitization.rs @@ -24,7 +24,7 @@ pub fn validate_ownership_by_unix_user( name: &str, user: &UnixUser, ) -> Result<(), OwnerValidationError> { - let prefixes = std::iter::once(user.username.clone()) + let prefixes = std::iter::once(user.username.to_owned()) .chain(user.groups.iter().cloned()) .collect::>(); diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs index f1dcf1a..9e60ebf 100644 --- a/src/server/server_loop.rs +++ b/src/server/server_loop.rs @@ -7,6 +7,7 @@ use tokio::net::{UnixListener, UnixStream}; use sqlx::prelude::*; use sqlx::MySqlConnection; +use crate::core::protocol::SetPasswordError; use crate::server::sql::database_operations::list_databases; use crate::{ core::{ @@ -56,7 +57,7 @@ pub async fn listen_for_incoming_connections( let listener = UnixListener::bind(socket_path)?; - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); + sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok(); while let Ok((conn, _addr)) = listener.accept().await { let uid = match conn.peer_cred() { @@ -78,7 +79,7 @@ pub async fn listen_for_incoming_connections( } }; - log::trace!("Accepted connection from uid {}", uid); + log::debug!("Accepted connection from uid {}", uid); let unix_user = match UnixUser::from_uid(uid) { Ok(user) => user, @@ -173,47 +174,47 @@ pub async fn handle_requests_for_single_session_with_db_connection( } }; - log::trace!("Received request: {:?}", request); + // TODO: don't clone the request + let request_to_display = match &request { + Request::PasswdUser(db_user, _) => { + Request::PasswdUser(db_user.to_owned(), "".to_string()) + } + request => request.to_owned(), + }; + log::info!("Received request: {:#?}", request_to_display); - match request { + let response = match request { Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; - stream.send(Response::CreateDatabases(result)).await?; + Response::CreateDatabases(result) } Request::DropDatabases(databases_names) => { let result = drop_databases(databases_names, unix_user, db_connection).await; - stream.send(Response::DropDatabases(result)).await?; - } - Request::ListDatabases(database_names) => { - let response = match database_names { - Some(database_names) => { - let result = list_databases(database_names, unix_user, db_connection).await; - Response::ListDatabases(result) - } - None => { - let result = list_all_databases_for_user(unix_user, db_connection).await; - Response::ListAllDatabases(result) - } - }; - stream.send(response).await?; - } - Request::ListPrivileges(database_names) => { - let response = match database_names { - Some(database_names) => { - let privilege_data = - get_databases_privilege_data(database_names, unix_user, db_connection) - .await; - Response::ListPrivileges(privilege_data) - } - None => { - let privilege_data = - get_all_database_privileges(unix_user, db_connection).await; - Response::ListAllPrivileges(privilege_data) - } - }; - - stream.send(response).await?; + Response::DropDatabases(result) } + Request::ListDatabases(database_names) => match database_names { + Some(database_names) => { + let result = list_databases(database_names, unix_user, db_connection).await; + Response::ListDatabases(result) + } + None => { + let result = list_all_databases_for_user(unix_user, db_connection).await; + Response::ListAllDatabases(result) + } + }, + Request::ListPrivileges(database_names) => match database_names { + Some(database_names) => { + let privilege_data = + get_databases_privilege_data(database_names, unix_user, db_connection) + .await; + Response::ListPrivileges(privilege_data) + } + None => { + let privilege_data = + get_all_database_privileges(unix_user, db_connection).await; + Response::ListAllPrivileges(privilege_data) + } + }, Request::ModifyPrivileges(database_privilege_diffs) => { let result = apply_privilege_diffs( BTreeSet::from_iter(database_privilege_diffs), @@ -221,50 +222,58 @@ pub async fn handle_requests_for_single_session_with_db_connection( db_connection, ) .await; - stream.send(Response::ModifyPrivileges(result)).await?; + Response::ModifyPrivileges(result) } Request::CreateUsers(db_users) => { let result = create_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::CreateUsers(result)).await?; + Response::CreateUsers(result) } Request::DropUsers(db_users) => { let result = drop_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::DropUsers(result)).await?; + Response::DropUsers(result) } Request::PasswdUser(db_user, password) => { let result = set_password_for_database_user(&db_user, &password, unix_user, db_connection) .await; - stream.send(Response::PasswdUser(result)).await?; - } - Request::ListUsers(db_users) => { - let response = match db_users { - Some(db_users) => { - let result = list_database_users(db_users, unix_user, db_connection).await; - Response::ListUsers(result) - } - None => { - let result = - list_all_database_users_for_unix_user(unix_user, db_connection).await; - Response::ListAllUsers(result) - } - }; - stream.send(response).await?; + Response::PasswdUser(result) } + Request::ListUsers(db_users) => match db_users { + Some(db_users) => { + let result = list_database_users(db_users, unix_user, db_connection).await; + Response::ListUsers(result) + } + None => { + let result = + list_all_database_users_for_unix_user(unix_user, db_connection).await; + Response::ListAllUsers(result) + } + }, Request::LockUsers(db_users) => { let result = lock_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::LockUsers(result)).await?; + Response::LockUsers(result) } Request::UnlockUsers(db_users) => { let result = unlock_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::UnlockUsers(result)).await?; + Response::UnlockUsers(result) } Request::Exit => { break; } - } + }; + // TODO: don't clone the response + let response_to_display = match &response { + Response::PasswdUser(Err(SetPasswordError::MySqlError(_))) => { + Response::PasswdUser(Err(SetPasswordError::MySqlError("".to_string()))) + } + response => response.to_owned(), + }; + log::info!("Response: {:#?}", response_to_display); + + stream.send(response).await?; stream.flush().await?; + log::debug!("Successfully processed request"); } Ok(()) diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index 3c87725..ddd4566 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -5,6 +5,7 @@ use sqlx::MySqlConnection; use serde::{Deserialize, Serialize}; +use crate::core::protocol::MySQLDatabase; use crate::{ core::{ common::UnixUser, @@ -42,7 +43,7 @@ pub(super) async fn unsafe_database_exists( } pub async fn create_databases( - database_names: Vec, + database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> CreateDatabasesOutput { @@ -51,7 +52,7 @@ pub async fn create_databases( for database_name in database_names { if let Err(err) = validate_name(&database_name) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(CreateDatabaseError::SanitizationError(err)), ); continue; @@ -59,7 +60,7 @@ pub async fn create_databases( if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(CreateDatabaseError::OwnershipError(err)), ); continue; @@ -68,14 +69,14 @@ pub async fn create_databases( match unsafe_database_exists(&database_name, &mut *connection).await { Ok(true) => { results.insert( - database_name.clone(), + database_name.to_owned(), Err(CreateDatabaseError::DatabaseAlreadyExists), ); continue; } Err(err) => { results.insert( - database_name.clone(), + database_name.to_owned(), Err(CreateDatabaseError::MySqlError(err.to_string())), ); continue; @@ -101,7 +102,7 @@ pub async fn create_databases( } pub async fn drop_databases( - database_names: Vec, + database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> DropDatabasesOutput { @@ -110,7 +111,7 @@ pub async fn drop_databases( for database_name in database_names { if let Err(err) = validate_name(&database_name) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(DropDatabaseError::SanitizationError(err)), ); continue; @@ -118,7 +119,7 @@ pub async fn drop_databases( if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(DropDatabaseError::OwnershipError(err)), ); continue; @@ -127,14 +128,14 @@ pub async fn drop_databases( match unsafe_database_exists(&database_name, &mut *connection).await { Ok(false) => { results.insert( - database_name.clone(), + database_name.to_owned(), Err(DropDatabaseError::DatabaseDoesNotExist), ); continue; } Err(err) => { results.insert( - database_name.clone(), + database_name.to_owned(), Err(DropDatabaseError::MySqlError(err.to_string())), ); continue; @@ -159,13 +160,21 @@ pub async fn drop_databases( results } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct DatabaseRow { - pub database: String, + pub database: MySQLDatabase, +} + +impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow { + fn from_row(row: &sqlx::mysql::MySqlRow) -> Result { + Ok(DatabaseRow { + database: row.try_get::("database")?.into(), + }) + } } pub async fn list_databases( - database_names: Vec, + database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> ListDatabasesOutput { @@ -174,7 +183,7 @@ pub async fn list_databases( for database_name in database_names { if let Err(err) = validate_name(&database_name) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(ListDatabasesError::SanitizationError(err)), ); continue; @@ -182,7 +191,7 @@ pub async fn list_databases( if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(ListDatabasesError::OwnershipError(err)), ); continue; @@ -195,7 +204,7 @@ pub async fn list_databases( WHERE `SCHEMA_NAME` = ? "#, ) - .bind(&database_name) + .bind(database_name.to_string()) .fetch_optional(&mut *connection) .await .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index a9d1dc5..da40e07 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -28,7 +28,8 @@ use crate::{ protocol::{ DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError, GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError, - ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput, + ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput, MySQLDatabase, + MySQLUser, }, }, server::{ @@ -63,8 +64,8 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ /// This struct represents the set of privileges for a single user on a single database. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct DatabasePrivilegeRow { - pub db: String, - pub user: String, + pub db: MySQLDatabase, + pub user: MySQLUser, pub select_priv: bool, pub insert_priv: bool, pub update_priv: bool, @@ -115,8 +116,8 @@ fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result for DatabasePrivilegeRow { fn from_row(row: &MySqlRow) -> Result { Ok(Self { - db: try_get_with_binary_fallback(row, "Db")?, - user: try_get_with_binary_fallback(row, "User")?, + db: try_get_with_binary_fallback(row, "Db")?.into(), + user: try_get_with_binary_fallback(row, "User")?.into(), select_priv: get_mysql_row_priv_field(row, 2)?, insert_priv: get_mysql_row_priv_field(row, 3)?, update_priv: get_mysql_row_priv_field(row, 4)?, @@ -163,8 +164,8 @@ async fn unsafe_get_database_privileges( // NOTE: this function is unsafe because it does no input validation. /// Get all users + privileges for a single database-user pair. pub async fn unsafe_get_database_privileges_for_db_user_pair( - database_name: &str, - user_name: &str, + database_name: &MySQLDatabase, + user_name: &MySQLUser, connection: &mut MySqlConnection, ) -> Result, sqlx::Error> { let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( @@ -174,8 +175,8 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair( .map(|field| quote_identifier(field)) .join(","), )) - .bind(database_name) - .bind(user_name) + .bind(database_name.as_str()) + .bind(user_name.as_str()) .fetch_optional(connection) .await; @@ -192,7 +193,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair( } pub async fn get_databases_privilege_data( - database_names: Vec, + database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> GetDatabasesPrivilegeData { @@ -201,7 +202,7 @@ pub async fn get_databases_privilege_data( for database_name in database_names.iter() { if let Err(err) = validate_name(database_name) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(GetDatabasesPrivilegeDataError::SanitizationError(err)), ); continue; @@ -209,7 +210,7 @@ pub async fn get_databases_privilege_data( if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) { results.insert( - database_name.clone(), + database_name.to_owned(), Err(GetDatabasesPrivilegeDataError::OwnershipError(err)), ); continue; @@ -220,7 +221,7 @@ pub async fn get_databases_privilege_data( .unwrap() { results.insert( - database_name.clone(), + database_name.to_owned(), Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist), ); continue; @@ -230,7 +231,7 @@ pub async fn get_databases_privilege_data( .await .map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string())); - results.insert(database_name.clone(), result); + results.insert(database_name.to_owned(), result); } debug_assert!(database_names.len() == results.len()); @@ -364,8 +365,8 @@ async fn validate_diff( if privilege_row.is_some() { Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( DiffDoesNotApplyError::RowAlreadyExists( - diff.get_user_name().to_string(), - diff.get_database_name().to_string(), + diff.get_database_name().to_owned(), + diff.get_user_name().to_owned(), ), )) } else { @@ -375,8 +376,8 @@ async fn validate_diff( DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => { Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( DiffDoesNotApplyError::RowDoesNotExist( - diff.get_user_name().to_string(), - diff.get_database_name().to_string(), + diff.get_database_name().to_owned(), + diff.get_user_name().to_owned(), ), )) } @@ -390,7 +391,7 @@ async fn validate_diff( if error_exists { Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( - DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.clone(), row), + DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.to_owned(), row), )) } else { Ok(()) @@ -400,8 +401,8 @@ async fn validate_diff( if privilege_row.is_none() { Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( DiffDoesNotApplyError::RowDoesNotExist( - diff.get_user_name().to_string(), - diff.get_database_name().to_string(), + diff.get_database_name().to_owned(), + diff.get_user_name().to_owned(), ), )) } else { @@ -419,12 +420,12 @@ pub async fn apply_privilege_diffs( unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> ModifyDatabasePrivilegesOutput { - let mut results: BTreeMap<(String, String), _> = BTreeMap::new(); + let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new(); for diff in database_privilege_diffs { let key = ( - diff.get_database_name().to_string(), - diff.get_user_name().to_string(), + diff.get_database_name().to_owned(), + diff.get_user_name().to_owned(), ); if let Err(err) = validate_name(diff.get_database_name()) { results.insert( diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 120903b..468b47b 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -7,18 +7,17 @@ use serde::{Deserialize, Serialize}; use sqlx::prelude::*; use sqlx::MySqlConnection; -use crate::server::common::try_get_with_binary_fallback; use crate::{ core::{ common::UnixUser, protocol::{ CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError, ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput, - SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput, + MySQLUser, SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput, }, }, server::{ - common::create_user_group_matching_regex, + common::{create_user_group_matching_regex, try_get_with_binary_fallback}, input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user}, }, }; @@ -52,7 +51,7 @@ async fn unsafe_user_exists( } pub async fn create_database_users( - db_users: Vec, + db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> CreateUsersOutput { @@ -98,7 +97,7 @@ pub async fn create_database_users( } pub async fn drop_database_users( - db_users: Vec, + db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> DropUsersOutput { @@ -144,7 +143,7 @@ pub async fn drop_database_users( } pub async fn set_password_for_database_user( - db_user: &str, + db_user: &MySQLUser, password: &str, unix_user: &UnixUser, connection: &mut MySqlConnection, @@ -167,7 +166,7 @@ pub async fn set_password_for_database_user( format!( "ALTER USER {}@'%' IDENTIFIED BY {}", quote_literal(db_user), - quote_literal(password).as_str() + quote_literal(password).as_str(), ) .as_str(), ) @@ -176,11 +175,10 @@ pub async fn set_password_for_database_user( .map(|_| ()) .map_err(|err| SetPasswordError::MySqlError(err.to_string())); - if let Err(err) = &result { + if result.is_err() { log::error!( - "Failed to set password for database user '{}': {:?}", + "Failed to set password for database user '{}': ", &db_user, - err ); } @@ -220,7 +218,7 @@ async fn database_user_is_locked_unsafe( } pub async fn lock_database_users( - db_users: Vec, + db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> LockUsersOutput { @@ -280,7 +278,7 @@ pub async fn lock_database_users( } pub async fn unlock_database_users( - db_users: Vec, + db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> UnlockUsersOutput { @@ -343,7 +341,7 @@ pub async fn unlock_database_users( /// This can be extended if we need more information in the future. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct DatabaseUser { - pub user: String, + pub user: MySQLUser, #[serde(skip)] pub host: String, pub has_password: bool, @@ -354,7 +352,7 @@ pub struct DatabaseUser { impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser { fn from_row(row: &sqlx::mysql::MySqlRow) -> Result { Ok(Self { - user: try_get_with_binary_fallback(row, "User")?, + user: try_get_with_binary_fallback(row, "User")?.into(), host: try_get_with_binary_fallback(row, "Host")?, has_password: row.try_get("has_password")?, is_locked: row.try_get("is_locked")?, @@ -379,7 +377,7 @@ JOIN `global_priv` ON "#; pub async fn list_database_users( - db_users: Vec, + db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, ) -> ListUsersOutput { @@ -399,7 +397,7 @@ pub async fn list_database_users( let mut result = sqlx::query_as::<_, DatabaseUser>( &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), ) - .bind(&db_user) + .bind(db_user.as_str()) .fetch_optional(&mut *connection) .await; @@ -464,7 +462,7 @@ pub async fn append_databases_where_user_has_privileges( ) .as_str(), ) - .bind(db_user.user.clone()) + .bind(db_user.user.as_str()) .fetch_all(&mut *connection) .await;