diff --git a/Cargo.lock b/Cargo.lock index d5e3f10..d30af81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,16 @@ dependencies = [ "clap_derive", ] +[[package]] +name = "clap-verbosity-flag" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63d19864d6b68464c59f7162c9914a0b569ddc2926b4a2d71afe62a9738eff53" +dependencies = [ + "clap", + "log", +] + [[package]] name = "clap_builder" version = "4.5.15" @@ -993,6 +1003,9 @@ name = "log" version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +dependencies = [ + "value-bag", +] [[package]] name = "lru" @@ -1064,6 +1077,7 @@ dependencies = [ "async-bincode", "bincode", "clap", + "clap-verbosity-flag", "clap_complete", "derive_more", "dialoguer", @@ -1082,6 +1096,7 @@ dependencies = [ "serde", "serde_json", "sqlx", + "systemd-journal-logger", "tokio", "tokio-serde", "tokio-stream", @@ -1996,6 +2011,16 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "systemd-journal-logger" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5f3848dd723f2a54ac1d96da793b32923b52de8dfcced8722516dac312a5b2a" +dependencies = [ + "log", + "rustix", +] + [[package]] name = "tempfile" version = "3.12.0" @@ -2287,6 +2312,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "value-bag" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index faa78e2..e94df3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ anyhow = "1.0.86" async-bincode = "0.7.2" bincode = "1.3.3" clap = { version = "4.5.16", features = ["derive"] } +clap-verbosity-flag = "2.2.1" clap_complete = "4.5.18" derive_more = { version = "1.0.0", features = ["display", "error"] } dialoguer = "0.11.0" @@ -25,6 +26,7 @@ sd-notify = "0.4.2" serde = "1.0.208" serde_json = { version = "1.0.125", features = ["preserve_order"] } sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] } +systemd-journal-logger = "2.1.1" tokio = { version = "1.39.3", features = ["rt", "macros"] } tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-stream = "0.1.15" diff --git a/nix/module.nix b/nix/module.nix index dd39f4f..514753c 100644 --- a/nix/module.nix +++ b/nix/module.nix @@ -15,6 +15,20 @@ in description = "Create a local database user for mysqladm-rs"; }; + logLevel = lib.mkOption { + type = lib.types.enum [ "quiet" "error" "warn" "info" "debug" "trace" ]; + default = "debug"; + description = "Log level for mysqladm-rs"; + apply = level: { + "quiet" = "-q"; + "error" = ""; + "warn" = "-v"; + "info" = "-vv"; + "debug" = "-vvv"; + "trace" = "-vvvv"; + }.${level}; + }; + settings = lib.mkOption { default = { }; type = lib.types.submodule { @@ -76,7 +90,7 @@ in name = cfg.settings.mysql.username; ensurePermissions = { "mysql.*" = "SELECT, INSERT, UPDATE, DELETE"; - "*.*" = "CREATE USER, GRANT OPTION"; + "*.*" = "GRANT OPTION, CREATE, DROP"; }; } ]; @@ -86,7 +100,9 @@ in environment.RUST_LOG = "debug"; serviceConfig = { Type = "notify"; - ExecStart = "${lib.getExe cfg.package} server socket-activate --config ${configFile}"; + ExecStart = "${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate --config ${configFile}"; + + WatchdogSec = 15; User = "mysqladm"; Group = "mysqladm"; @@ -95,7 +111,18 @@ in # This is required to read unix user/group details. PrivateUsers = false; - CapabilityBoundingSet = ""; + # Needed to communicate with MySQL. + PrivateNetwork = false; + + IPAddressDeny = + lib.optionals (lib.elem cfg.settings.mysql.host [ null "localhost" "127.0.0.1" ]) [ "any" ]; + + RestrictAddressFamilies = [ "AF_UNIX" ] + ++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]); + + AmbientCapabilities = [ "" ]; + CapabilityBoundingSet = [ "" ]; + DeviceAllow = [ "" ]; LockPersonality = true; MemoryDenyWriteExecute = true; NoNewPrivileges = true; @@ -113,12 +140,12 @@ in ProtectProc = "invisible"; ProtectSystem = "strict"; RemoveIPC = true; - UMask = "0000"; - RestrictAddressFamilies = [ "AF_UNIX" "AF_INET" "AF_INET6" ]; + UMask = "0777"; RestrictNamespaces = true; RestrictRealtime = true; RestrictSUIDSGID = true; SystemCallArchitectures = "native"; + SocketBindDeny = [ "any" ]; SystemCallFilter = [ "@system-service" "~@privileged" diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs index 3d2f9ab..4435257 100644 --- a/src/cli/database_command.rs +++ b/src/cli/database_command.rs @@ -393,6 +393,7 @@ pub async fn edit_database_privileges( if diffs.is_empty() { println!("No changes to make."); + server_connection.send(Request::Exit).await?; return Ok(()); } diff --git a/src/main.rs b/src/main.rs index 549072f..02671b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ extern crate prettytable; use clap::{CommandFactory, Parser, ValueEnum}; use clap_complete::{generate, Shell}; +use clap_verbosity_flag::Verbosity; use std::path::PathBuf; @@ -62,6 +63,10 @@ struct Args { )] config: Option, + #[command(flatten)] + verbose: Verbosity, + + /// Run in TUI mode. #[cfg(feature = "tui")] #[arg(short, long, alias = "tui", global = true)] interactive: bool, @@ -103,11 +108,6 @@ enum ToplevelCommands { // comments emphasizing the need for caution. fn main() -> anyhow::Result<()> { - // TODO: find out if there are any security risks of running - // env_logger and clap with elevated privileges. - - env_logger::init(); - #[cfg(feature = "mysql-admutils-compatibility")] if handle_mysql_admutils_command()?.is_some() { return Ok(()); @@ -126,6 +126,10 @@ fn main() -> anyhow::Result<()> { let server_connection = bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; + env_logger::Builder::new() + .filter_level(args.verbose.log_level_filter()) + .init(); + tokio_run_command(args.command, server_connection)?; Ok(()) @@ -151,6 +155,7 @@ fn handle_server_command(args: &Args) -> anyhow::Result> { tokio_start_server( args.server_socket_path.clone(), args.config.clone(), + args.verbose.clone(), command.clone(), )?; Ok(Some(())) @@ -188,6 +193,7 @@ fn handle_generate_completions_command(args: &Args) -> anyhow::Result fn tokio_start_server( server_socket_path: Option, config_path: Option, + verbosity: Verbosity, args: ServerArgs, ) -> anyhow::Result<()> { tokio::runtime::Builder::new_current_thread() @@ -195,7 +201,7 @@ fn tokio_start_server( .build() .unwrap() .block_on(async { - server::command::handle_command(server_socket_path, config_path, args).await + server::command::handle_command(server_socket_path, config_path, verbosity, args).await }) } diff --git a/src/server/command.rs b/src/server/command.rs index cb279ab..57e69e2 100644 --- a/src/server/command.rs +++ b/src/server/command.rs @@ -3,11 +3,16 @@ use std::path::PathBuf; use anyhow::Context; use clap::Parser; +use clap_verbosity_flag::Verbosity; +use futures::SinkExt; +use indoc::concatdoc; +use systemd_journal_logger::JournalLog; use std::os::unix::net::UnixStream as StdUnixStream; use tokio::net::UnixStream as TokioUnixStream; use crate::core::common::UnixUser; +use crate::core::protocol::{create_server_to_client_message_stream, Response}; use crate::server::config::read_config_from_path_with_arg_overrides; use crate::server::server_loop::listen_for_incoming_connections; use crate::server::{ @@ -22,6 +27,9 @@ pub struct ServerArgs { #[command(flatten)] config_overrides: ServerConfigArgs, + + #[arg(long)] + systemd: bool, } #[derive(Parser, Debug, Clone)] @@ -33,27 +41,148 @@ pub enum ServerCommand { SocketActivate, } +const LOG_LEVEL_WARNING: &str = r#" +=================================================== +== WARNING: LOG LEVEL IS SET TO 'TRACE'! == +== THIS WILL CAUSE THE SERVER TO LOG SQL QUERIES == +== THAT MAY CONTAIN SENSITIVE INFORMATION LIKE == +== PASSWORDS AND AUTHENTICATION TOKENS. == +== THIS IS INTENDED FOR DEBUGGING PURPOSES ONLY == +== AND SHOULD *NEVER* BE USED IN PRODUCTION. == +=================================================== +"#; + pub async fn handle_command( socket_path: Option, config_path: Option, + verbosity: Verbosity, args: ServerArgs, ) -> anyhow::Result<()> { + let mut auto_detected_systemd_mode = false; + let systemd_mode = args.systemd || { + if let Ok(true) = sd_notify::booted() { + auto_detected_systemd_mode = true; + true + } else { + false + } + }; + if systemd_mode { + JournalLog::new() + .context("Failed to initialize journald logging")? + .install() + .context("Failed to install journald logger")?; + + log::set_max_level(verbosity.log_level_filter()); + + if verbosity.log_level_filter() >= log::LevelFilter::Trace { + log::warn!("{}", LOG_LEVEL_WARNING.trim()); + } + + if auto_detected_systemd_mode { + log::info!("Running in systemd mode, auto-detected"); + } else { + log::info!("Running in systemd mode"); + } + + start_watchdog_thread_if_enabled(); + } else { + env_logger::Builder::new() + .filter_level(verbosity.log_level_filter()) + .init(); + + log::info!("Running in standalone mode"); + } + let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; match args.subcmd { ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, - ServerCommand::SocketActivate => socket_activate(config).await, + ServerCommand::SocketActivate => { + if !args.systemd { + anyhow::bail!(concat!( + "The `--systemd` flag must be used with the `socket-activate` command.\n", + "This command currently only supports socket activation under systemd." + )); + } + + socket_activate(config).await + } + } +} + +fn start_watchdog_thread_if_enabled() { + let mut micro_seconds: u64 = 0; + let watchdog_enabled = sd_notify::watchdog_enabled(true, &mut micro_seconds); + + if watchdog_enabled { + micro_seconds = micro_seconds.max(2_000_000).div_ceil(2); + + tokio::spawn(async move { + log::debug!( + "Starting systemd watchdog thread with {} millisecond interval", + micro_seconds.div_ceil(1000) + ); + loop { + tokio::time::sleep(tokio::time::Duration::from_micros(micro_seconds)).await; + if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) { + log::warn!("Failed to notify systemd watchdog: {}", err); + } else { + log::trace!("Ping sent to systemd watchdog"); + } + } + }); + } else { + log::debug!("Systemd watchdog not enabled, skipping watchdog thread"); } } async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { let conn = get_socket_from_systemd().await?; - let uid = conn.peer_cred()?.uid(); - let unix_user = UnixUser::from_uid(uid)?; + + let uid = match conn.peer_cred() { + Ok(cred) => cred.uid(), + Err(e) => { + log::error!("Failed to get peer credentials from socket: {}", e); + let mut message_stream = create_server_to_client_message_stream(conn); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get peer credentials from socket\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + anyhow::bail!("Failed to get peer credentials from socket"); + } + }; + + log::debug!("Accepted connection from uid {}", uid); + + let unix_user = match UnixUser::from_uid(uid) { + Ok(user) => user, + Err(e) => { + log::error!("Failed to get username from uid: {}", e); + let mut message_stream = create_server_to_client_message_stream(conn); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get user data from the system\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + anyhow::bail!("Failed to get username from uid"); + } + }; log::info!("Accepted connection from {}", unix_user.username); - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); + sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok(); handle_requests_for_single_session(conn, &unix_user, &config).await?; @@ -66,7 +195,12 @@ async fn get_socket_from_systemd() -> anyhow::Result { .next() .context("No file descriptors received from systemd")?; - log::debug!("Received file descriptor from systemd: {}", fd); + debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd); + + log::debug!( + "Received file descriptor from systemd with id: '{}', assuming socket", + fd + ); let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; let socket = TokioUnixStream::from_std(std_unix_stream)?; diff --git a/src/server/config.rs b/src/server/config.rs index 3010e43..e973200 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -109,18 +109,12 @@ pub fn read_config_from_path_with_arg_overrides( pub fn read_config_from_path(config_path: Option) -> anyhow::Result { let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); - log::debug!("Reading config from {:?}", &config_path); + log::debug!("Reading config file at {:?}", &config_path); fs::read_to_string(&config_path) - .context(format!( - "Failed to read config file from {:?}", - &config_path - )) + .context(format!("Failed to read config file at {:?}", &config_path)) .and_then(|c| toml::from_str(&c).context("Failed to parse config file")) - .context(format!( - "Failed to parse config file from {:?}", - &config_path - )) + .context(format!("Failed to parse config file at {:?}", &config_path)) } fn log_config(config: &MysqlConfig) { @@ -141,7 +135,9 @@ pub async fn create_mysql_connection_from_config( ) -> anyhow::Result { log_config(config); - let mut mysql_options = MySqlConnectOptions::new().database("mysql"); + let mut mysql_options = MySqlConnectOptions::new() + .database("mysql") + .log_statements(log::LevelFilter::Trace); if let Some(username) = &config.username { mysql_options = mysql_options.username(username); diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs index f1dcf1a..9e60ebf 100644 --- a/src/server/server_loop.rs +++ b/src/server/server_loop.rs @@ -7,6 +7,7 @@ use tokio::net::{UnixListener, UnixStream}; use sqlx::prelude::*; use sqlx::MySqlConnection; +use crate::core::protocol::SetPasswordError; use crate::server::sql::database_operations::list_databases; use crate::{ core::{ @@ -56,7 +57,7 @@ pub async fn listen_for_incoming_connections( let listener = UnixListener::bind(socket_path)?; - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); + sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok(); while let Ok((conn, _addr)) = listener.accept().await { let uid = match conn.peer_cred() { @@ -78,7 +79,7 @@ pub async fn listen_for_incoming_connections( } }; - log::trace!("Accepted connection from uid {}", uid); + log::debug!("Accepted connection from uid {}", uid); let unix_user = match UnixUser::from_uid(uid) { Ok(user) => user, @@ -173,47 +174,47 @@ pub async fn handle_requests_for_single_session_with_db_connection( } }; - log::trace!("Received request: {:?}", request); + // TODO: don't clone the request + let request_to_display = match &request { + Request::PasswdUser(db_user, _) => { + Request::PasswdUser(db_user.to_owned(), "".to_string()) + } + request => request.to_owned(), + }; + log::info!("Received request: {:#?}", request_to_display); - match request { + let response = match request { Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; - stream.send(Response::CreateDatabases(result)).await?; + Response::CreateDatabases(result) } Request::DropDatabases(databases_names) => { let result = drop_databases(databases_names, unix_user, db_connection).await; - stream.send(Response::DropDatabases(result)).await?; - } - Request::ListDatabases(database_names) => { - let response = match database_names { - Some(database_names) => { - let result = list_databases(database_names, unix_user, db_connection).await; - Response::ListDatabases(result) - } - None => { - let result = list_all_databases_for_user(unix_user, db_connection).await; - Response::ListAllDatabases(result) - } - }; - stream.send(response).await?; - } - Request::ListPrivileges(database_names) => { - let response = match database_names { - Some(database_names) => { - let privilege_data = - get_databases_privilege_data(database_names, unix_user, db_connection) - .await; - Response::ListPrivileges(privilege_data) - } - None => { - let privilege_data = - get_all_database_privileges(unix_user, db_connection).await; - Response::ListAllPrivileges(privilege_data) - } - }; - - stream.send(response).await?; + Response::DropDatabases(result) } + Request::ListDatabases(database_names) => match database_names { + Some(database_names) => { + let result = list_databases(database_names, unix_user, db_connection).await; + Response::ListDatabases(result) + } + None => { + let result = list_all_databases_for_user(unix_user, db_connection).await; + Response::ListAllDatabases(result) + } + }, + Request::ListPrivileges(database_names) => match database_names { + Some(database_names) => { + let privilege_data = + get_databases_privilege_data(database_names, unix_user, db_connection) + .await; + Response::ListPrivileges(privilege_data) + } + None => { + let privilege_data = + get_all_database_privileges(unix_user, db_connection).await; + Response::ListAllPrivileges(privilege_data) + } + }, Request::ModifyPrivileges(database_privilege_diffs) => { let result = apply_privilege_diffs( BTreeSet::from_iter(database_privilege_diffs), @@ -221,50 +222,58 @@ pub async fn handle_requests_for_single_session_with_db_connection( db_connection, ) .await; - stream.send(Response::ModifyPrivileges(result)).await?; + Response::ModifyPrivileges(result) } Request::CreateUsers(db_users) => { let result = create_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::CreateUsers(result)).await?; + Response::CreateUsers(result) } Request::DropUsers(db_users) => { let result = drop_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::DropUsers(result)).await?; + Response::DropUsers(result) } Request::PasswdUser(db_user, password) => { let result = set_password_for_database_user(&db_user, &password, unix_user, db_connection) .await; - stream.send(Response::PasswdUser(result)).await?; - } - Request::ListUsers(db_users) => { - let response = match db_users { - Some(db_users) => { - let result = list_database_users(db_users, unix_user, db_connection).await; - Response::ListUsers(result) - } - None => { - let result = - list_all_database_users_for_unix_user(unix_user, db_connection).await; - Response::ListAllUsers(result) - } - }; - stream.send(response).await?; + Response::PasswdUser(result) } + Request::ListUsers(db_users) => match db_users { + Some(db_users) => { + let result = list_database_users(db_users, unix_user, db_connection).await; + Response::ListUsers(result) + } + None => { + let result = + list_all_database_users_for_unix_user(unix_user, db_connection).await; + Response::ListAllUsers(result) + } + }, Request::LockUsers(db_users) => { let result = lock_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::LockUsers(result)).await?; + Response::LockUsers(result) } Request::UnlockUsers(db_users) => { let result = unlock_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::UnlockUsers(result)).await?; + Response::UnlockUsers(result) } Request::Exit => { break; } - } + }; + // TODO: don't clone the response + let response_to_display = match &response { + Response::PasswdUser(Err(SetPasswordError::MySqlError(_))) => { + Response::PasswdUser(Err(SetPasswordError::MySqlError("".to_string()))) + } + response => response.to_owned(), + }; + log::info!("Response: {:#?}", response_to_display); + + stream.send(response).await?; stream.flush().await?; + log::debug!("Successfully processed request"); } Ok(()) diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 120903b..a003ffb 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -167,7 +167,7 @@ pub async fn set_password_for_database_user( format!( "ALTER USER {}@'%' IDENTIFIED BY {}", quote_literal(db_user), - quote_literal(password).as_str() + quote_literal(password).as_str(), ) .as_str(), ) @@ -176,11 +176,10 @@ pub async fn set_password_for_database_user( .map(|_| ()) .map_err(|err| SetPasswordError::MySqlError(err.to_string())); - if let Err(err) = &result { + if result.is_err() { log::error!( - "Failed to set password for database user '{}': {:?}", + "Failed to set password for database user '{}': ", &db_user, - err ); }