Integrate better with systemd + better logs and protocol usage
This commits adds the following: - Better systemd integration and usage: - More hardening - A watchdog thread - Journald native logging as well as - Better logs - Some protocol usage fixes
This commit is contained in:
		
							
								
								
									
										31
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										31
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							@@ -253,6 +253,16 @@ dependencies = [
 | 
			
		||||
 "clap_derive",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "clap-verbosity-flag"
 | 
			
		||||
version = "2.2.1"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "63d19864d6b68464c59f7162c9914a0b569ddc2926b4a2d71afe62a9738eff53"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "clap",
 | 
			
		||||
 "log",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "clap_builder"
 | 
			
		||||
version = "4.5.15"
 | 
			
		||||
@@ -993,6 +1003,9 @@ name = "log"
 | 
			
		||||
version = "0.4.22"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "value-bag",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "lru"
 | 
			
		||||
@@ -1064,6 +1077,7 @@ dependencies = [
 | 
			
		||||
 "async-bincode",
 | 
			
		||||
 "bincode",
 | 
			
		||||
 "clap",
 | 
			
		||||
 "clap-verbosity-flag",
 | 
			
		||||
 "clap_complete",
 | 
			
		||||
 "derive_more",
 | 
			
		||||
 "dialoguer",
 | 
			
		||||
@@ -1082,6 +1096,7 @@ dependencies = [
 | 
			
		||||
 "serde",
 | 
			
		||||
 "serde_json",
 | 
			
		||||
 "sqlx",
 | 
			
		||||
 "systemd-journal-logger",
 | 
			
		||||
 "tokio",
 | 
			
		||||
 "tokio-serde",
 | 
			
		||||
 "tokio-stream",
 | 
			
		||||
@@ -1996,6 +2011,16 @@ dependencies = [
 | 
			
		||||
 "unicode-ident",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "systemd-journal-logger"
 | 
			
		||||
version = "2.1.1"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "b5f3848dd723f2a54ac1d96da793b32923b52de8dfcced8722516dac312a5b2a"
 | 
			
		||||
dependencies = [
 | 
			
		||||
 "log",
 | 
			
		||||
 "rustix",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "tempfile"
 | 
			
		||||
version = "3.12.0"
 | 
			
		||||
@@ -2287,6 +2312,12 @@ dependencies = [
 | 
			
		||||
 "getrandom",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "value-bag"
 | 
			
		||||
version = "1.9.0"
 | 
			
		||||
source = "registry+https://github.com/rust-lang/crates.io-index"
 | 
			
		||||
checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101"
 | 
			
		||||
 | 
			
		||||
[[package]]
 | 
			
		||||
name = "vcpkg"
 | 
			
		||||
version = "0.2.15"
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ anyhow = "1.0.86"
 | 
			
		||||
async-bincode = "0.7.2"
 | 
			
		||||
bincode = "1.3.3"
 | 
			
		||||
clap = { version = "4.5.16", features = ["derive"] }
 | 
			
		||||
clap-verbosity-flag = "2.2.1"
 | 
			
		||||
clap_complete = "4.5.18"
 | 
			
		||||
derive_more = { version = "1.0.0", features = ["display", "error"] }
 | 
			
		||||
dialoguer = "0.11.0"
 | 
			
		||||
@@ -25,6 +26,7 @@ sd-notify = "0.4.2"
 | 
			
		||||
serde = "1.0.208"
 | 
			
		||||
serde_json = { version = "1.0.125", features = ["preserve_order"] }
 | 
			
		||||
sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] }
 | 
			
		||||
systemd-journal-logger = "2.1.1"
 | 
			
		||||
tokio = { version = "1.39.3", features = ["rt", "macros"] }
 | 
			
		||||
tokio-serde = { version = "0.9.0", features = ["bincode"] }
 | 
			
		||||
tokio-stream = "0.1.15"
 | 
			
		||||
 
 | 
			
		||||
@@ -15,6 +15,20 @@ in
 | 
			
		||||
      description = "Create a local database user for mysqladm-rs";
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    logLevel = lib.mkOption {
 | 
			
		||||
      type = lib.types.enum [ "quiet" "error" "warn" "info" "debug" "trace" ];
 | 
			
		||||
      default = "debug";
 | 
			
		||||
      description = "Log level for mysqladm-rs";
 | 
			
		||||
      apply = level: {
 | 
			
		||||
        "quiet" = "-q";
 | 
			
		||||
        "error" = "";
 | 
			
		||||
        "warn" = "-v";
 | 
			
		||||
        "info" = "-vv";
 | 
			
		||||
        "debug" = "-vvv";
 | 
			
		||||
        "trace" = "-vvvv";
 | 
			
		||||
      }.${level};
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    settings = lib.mkOption {
 | 
			
		||||
      default = { };
 | 
			
		||||
      type = lib.types.submodule {
 | 
			
		||||
@@ -76,7 +90,7 @@ in
 | 
			
		||||
        name = cfg.settings.mysql.username;
 | 
			
		||||
        ensurePermissions = {
 | 
			
		||||
          "mysql.*" = "SELECT, INSERT, UPDATE, DELETE";
 | 
			
		||||
          "*.*" = "CREATE USER, GRANT OPTION";
 | 
			
		||||
          "*.*" = "GRANT OPTION, CREATE, DROP";
 | 
			
		||||
        };
 | 
			
		||||
      }
 | 
			
		||||
    ];
 | 
			
		||||
@@ -86,7 +100,9 @@ in
 | 
			
		||||
      environment.RUST_LOG = "debug";
 | 
			
		||||
      serviceConfig = {
 | 
			
		||||
        Type = "notify";
 | 
			
		||||
        ExecStart = "${lib.getExe cfg.package} server socket-activate --config ${configFile}";
 | 
			
		||||
        ExecStart = "${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate --config ${configFile}";
 | 
			
		||||
 | 
			
		||||
        WatchdogSec = 15;
 | 
			
		||||
 | 
			
		||||
        User = "mysqladm";
 | 
			
		||||
        Group = "mysqladm";
 | 
			
		||||
@@ -95,7 +111,18 @@ in
 | 
			
		||||
        # This is required to read unix user/group details.
 | 
			
		||||
        PrivateUsers = false;
 | 
			
		||||
 | 
			
		||||
        CapabilityBoundingSet = "";
 | 
			
		||||
        # Needed to communicate with MySQL.
 | 
			
		||||
        PrivateNetwork = false;
 | 
			
		||||
 | 
			
		||||
        IPAddressDeny =
 | 
			
		||||
          lib.optionals (lib.elem cfg.settings.mysql.host [ null "localhost" "127.0.0.1" ]) [ "any" ];
 | 
			
		||||
 | 
			
		||||
        RestrictAddressFamilies = [ "AF_UNIX" ]
 | 
			
		||||
          ++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]);
 | 
			
		||||
 | 
			
		||||
        AmbientCapabilities = [ "" ];
 | 
			
		||||
        CapabilityBoundingSet = [ "" ];
 | 
			
		||||
        DeviceAllow = [ "" ];
 | 
			
		||||
        LockPersonality = true;
 | 
			
		||||
        MemoryDenyWriteExecute = true;
 | 
			
		||||
        NoNewPrivileges = true;
 | 
			
		||||
@@ -113,12 +140,12 @@ in
 | 
			
		||||
        ProtectProc = "invisible";
 | 
			
		||||
        ProtectSystem = "strict";
 | 
			
		||||
        RemoveIPC = true;
 | 
			
		||||
        UMask = "0000";
 | 
			
		||||
        RestrictAddressFamilies = [ "AF_UNIX" "AF_INET" "AF_INET6" ];
 | 
			
		||||
        UMask = "0777";
 | 
			
		||||
        RestrictNamespaces = true;
 | 
			
		||||
        RestrictRealtime = true;
 | 
			
		||||
        RestrictSUIDSGID = true;
 | 
			
		||||
        SystemCallArchitectures = "native";
 | 
			
		||||
        SocketBindDeny = [ "any" ];
 | 
			
		||||
        SystemCallFilter = [
 | 
			
		||||
          "@system-service"
 | 
			
		||||
          "~@privileged"
 | 
			
		||||
 
 | 
			
		||||
@@ -393,6 +393,7 @@ pub async fn edit_database_privileges(
 | 
			
		||||
 | 
			
		||||
    if diffs.is_empty() {
 | 
			
		||||
        println!("No changes to make.");
 | 
			
		||||
        server_connection.send(Request::Exit).await?;
 | 
			
		||||
        return Ok(());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										18
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								src/main.rs
									
									
									
									
									
								
							@@ -3,6 +3,7 @@ extern crate prettytable;
 | 
			
		||||
 | 
			
		||||
use clap::{CommandFactory, Parser, ValueEnum};
 | 
			
		||||
use clap_complete::{generate, Shell};
 | 
			
		||||
use clap_verbosity_flag::Verbosity;
 | 
			
		||||
 | 
			
		||||
use std::path::PathBuf;
 | 
			
		||||
 | 
			
		||||
@@ -62,6 +63,10 @@ struct Args {
 | 
			
		||||
    )]
 | 
			
		||||
    config: Option<PathBuf>,
 | 
			
		||||
 | 
			
		||||
    #[command(flatten)]
 | 
			
		||||
    verbose: Verbosity,
 | 
			
		||||
 | 
			
		||||
    /// Run in TUI mode.
 | 
			
		||||
    #[cfg(feature = "tui")]
 | 
			
		||||
    #[arg(short, long, alias = "tui", global = true)]
 | 
			
		||||
    interactive: bool,
 | 
			
		||||
@@ -103,11 +108,6 @@ enum ToplevelCommands {
 | 
			
		||||
//       comments emphasizing the need for caution.
 | 
			
		||||
 | 
			
		||||
fn main() -> anyhow::Result<()> {
 | 
			
		||||
    // TODO: find out if there are any security risks of running
 | 
			
		||||
    //       env_logger and clap with elevated privileges.
 | 
			
		||||
 | 
			
		||||
    env_logger::init();
 | 
			
		||||
 | 
			
		||||
    #[cfg(feature = "mysql-admutils-compatibility")]
 | 
			
		||||
    if handle_mysql_admutils_command()?.is_some() {
 | 
			
		||||
        return Ok(());
 | 
			
		||||
@@ -126,6 +126,10 @@ fn main() -> anyhow::Result<()> {
 | 
			
		||||
    let server_connection =
 | 
			
		||||
        bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?;
 | 
			
		||||
 | 
			
		||||
    env_logger::Builder::new()
 | 
			
		||||
        .filter_level(args.verbose.log_level_filter())
 | 
			
		||||
        .init();
 | 
			
		||||
 | 
			
		||||
    tokio_run_command(args.command, server_connection)?;
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
@@ -151,6 +155,7 @@ fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
 | 
			
		||||
            tokio_start_server(
 | 
			
		||||
                args.server_socket_path.clone(),
 | 
			
		||||
                args.config.clone(),
 | 
			
		||||
                args.verbose.clone(),
 | 
			
		||||
                command.clone(),
 | 
			
		||||
            )?;
 | 
			
		||||
            Ok(Some(()))
 | 
			
		||||
@@ -188,6 +193,7 @@ fn handle_generate_completions_command(args: &Args) -> anyhow::Result<Option<()>
 | 
			
		||||
fn tokio_start_server(
 | 
			
		||||
    server_socket_path: Option<PathBuf>,
 | 
			
		||||
    config_path: Option<PathBuf>,
 | 
			
		||||
    verbosity: Verbosity,
 | 
			
		||||
    args: ServerArgs,
 | 
			
		||||
) -> anyhow::Result<()> {
 | 
			
		||||
    tokio::runtime::Builder::new_current_thread()
 | 
			
		||||
@@ -195,7 +201,7 @@ fn tokio_start_server(
 | 
			
		||||
        .build()
 | 
			
		||||
        .unwrap()
 | 
			
		||||
        .block_on(async {
 | 
			
		||||
            server::command::handle_command(server_socket_path, config_path, args).await
 | 
			
		||||
            server::command::handle_command(server_socket_path, config_path, verbosity, args).await
 | 
			
		||||
        })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,11 +3,16 @@ use std::path::PathBuf;
 | 
			
		||||
 | 
			
		||||
use anyhow::Context;
 | 
			
		||||
use clap::Parser;
 | 
			
		||||
use clap_verbosity_flag::Verbosity;
 | 
			
		||||
use futures::SinkExt;
 | 
			
		||||
use indoc::concatdoc;
 | 
			
		||||
use systemd_journal_logger::JournalLog;
 | 
			
		||||
 | 
			
		||||
use std::os::unix::net::UnixStream as StdUnixStream;
 | 
			
		||||
use tokio::net::UnixStream as TokioUnixStream;
 | 
			
		||||
 | 
			
		||||
use crate::core::common::UnixUser;
 | 
			
		||||
use crate::core::protocol::{create_server_to_client_message_stream, Response};
 | 
			
		||||
use crate::server::config::read_config_from_path_with_arg_overrides;
 | 
			
		||||
use crate::server::server_loop::listen_for_incoming_connections;
 | 
			
		||||
use crate::server::{
 | 
			
		||||
@@ -22,6 +27,9 @@ pub struct ServerArgs {
 | 
			
		||||
 | 
			
		||||
    #[command(flatten)]
 | 
			
		||||
    config_overrides: ServerConfigArgs,
 | 
			
		||||
 | 
			
		||||
    #[arg(long)]
 | 
			
		||||
    systemd: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Parser, Debug, Clone)]
 | 
			
		||||
@@ -33,27 +41,148 @@ pub enum ServerCommand {
 | 
			
		||||
    SocketActivate,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const LOG_LEVEL_WARNING: &str = r#"
 | 
			
		||||
===================================================
 | 
			
		||||
== WARNING: LOG LEVEL IS SET TO 'TRACE'!         ==
 | 
			
		||||
== THIS WILL CAUSE THE SERVER TO LOG SQL QUERIES ==
 | 
			
		||||
== THAT MAY CONTAIN SENSITIVE INFORMATION LIKE   ==
 | 
			
		||||
== PASSWORDS AND AUTHENTICATION TOKENS.          ==
 | 
			
		||||
== THIS IS INTENDED FOR DEBUGGING PURPOSES ONLY  ==
 | 
			
		||||
== AND SHOULD *NEVER* BE USED IN PRODUCTION.     ==
 | 
			
		||||
===================================================
 | 
			
		||||
"#;
 | 
			
		||||
 | 
			
		||||
pub async fn handle_command(
 | 
			
		||||
    socket_path: Option<PathBuf>,
 | 
			
		||||
    config_path: Option<PathBuf>,
 | 
			
		||||
    verbosity: Verbosity,
 | 
			
		||||
    args: ServerArgs,
 | 
			
		||||
) -> anyhow::Result<()> {
 | 
			
		||||
    let mut auto_detected_systemd_mode = false;
 | 
			
		||||
    let systemd_mode = args.systemd || {
 | 
			
		||||
        if let Ok(true) = sd_notify::booted() {
 | 
			
		||||
            auto_detected_systemd_mode = true;
 | 
			
		||||
            true
 | 
			
		||||
        } else {
 | 
			
		||||
            false
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    if systemd_mode {
 | 
			
		||||
        JournalLog::new()
 | 
			
		||||
            .context("Failed to initialize journald logging")?
 | 
			
		||||
            .install()
 | 
			
		||||
            .context("Failed to install journald logger")?;
 | 
			
		||||
 | 
			
		||||
        log::set_max_level(verbosity.log_level_filter());
 | 
			
		||||
 | 
			
		||||
        if verbosity.log_level_filter() >= log::LevelFilter::Trace {
 | 
			
		||||
            log::warn!("{}", LOG_LEVEL_WARNING.trim());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if auto_detected_systemd_mode {
 | 
			
		||||
            log::info!("Running in systemd mode, auto-detected");
 | 
			
		||||
        } else {
 | 
			
		||||
            log::info!("Running in systemd mode");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        start_watchdog_thread_if_enabled();
 | 
			
		||||
    } else {
 | 
			
		||||
        env_logger::Builder::new()
 | 
			
		||||
            .filter_level(verbosity.log_level_filter())
 | 
			
		||||
            .init();
 | 
			
		||||
 | 
			
		||||
        log::info!("Running in standalone mode");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
 | 
			
		||||
 | 
			
		||||
    match args.subcmd {
 | 
			
		||||
        ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
 | 
			
		||||
        ServerCommand::SocketActivate => socket_activate(config).await,
 | 
			
		||||
        ServerCommand::SocketActivate => {
 | 
			
		||||
            if !args.systemd {
 | 
			
		||||
                anyhow::bail!(concat!(
 | 
			
		||||
                    "The `--systemd` flag must be used with the `socket-activate` command.\n",
 | 
			
		||||
                    "This command currently only supports socket activation under systemd."
 | 
			
		||||
                ));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            socket_activate(config).await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn start_watchdog_thread_if_enabled() {
 | 
			
		||||
    let mut micro_seconds: u64 = 0;
 | 
			
		||||
    let watchdog_enabled = sd_notify::watchdog_enabled(true, &mut micro_seconds);
 | 
			
		||||
 | 
			
		||||
    if watchdog_enabled {
 | 
			
		||||
        micro_seconds = micro_seconds.max(2_000_000).div_ceil(2);
 | 
			
		||||
 | 
			
		||||
        tokio::spawn(async move {
 | 
			
		||||
            log::debug!(
 | 
			
		||||
                "Starting systemd watchdog thread with {} millisecond interval",
 | 
			
		||||
                micro_seconds.div_ceil(1000)
 | 
			
		||||
            );
 | 
			
		||||
            loop {
 | 
			
		||||
                tokio::time::sleep(tokio::time::Duration::from_micros(micro_seconds)).await;
 | 
			
		||||
                if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
 | 
			
		||||
                    log::warn!("Failed to notify systemd watchdog: {}", err);
 | 
			
		||||
                } else {
 | 
			
		||||
                    log::trace!("Ping sent to systemd watchdog");
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
    } else {
 | 
			
		||||
        log::debug!("Systemd watchdog not enabled, skipping watchdog thread");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
 | 
			
		||||
    let conn = get_socket_from_systemd().await?;
 | 
			
		||||
    let uid = conn.peer_cred()?.uid();
 | 
			
		||||
    let unix_user = UnixUser::from_uid(uid)?;
 | 
			
		||||
 | 
			
		||||
    let uid = match conn.peer_cred() {
 | 
			
		||||
        Ok(cred) => cred.uid(),
 | 
			
		||||
        Err(e) => {
 | 
			
		||||
            log::error!("Failed to get peer credentials from socket: {}", e);
 | 
			
		||||
            let mut message_stream = create_server_to_client_message_stream(conn);
 | 
			
		||||
            message_stream
 | 
			
		||||
                .send(Response::Error(
 | 
			
		||||
                    (concatdoc! {
 | 
			
		||||
                        "Server failed to get peer credentials from socket\n",
 | 
			
		||||
                        "Please check the server logs or contact the system administrators"
 | 
			
		||||
                    })
 | 
			
		||||
                    .to_string(),
 | 
			
		||||
                ))
 | 
			
		||||
                .await
 | 
			
		||||
                .ok();
 | 
			
		||||
            anyhow::bail!("Failed to get peer credentials from socket");
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    log::debug!("Accepted connection from uid {}", uid);
 | 
			
		||||
 | 
			
		||||
    let unix_user = match UnixUser::from_uid(uid) {
 | 
			
		||||
        Ok(user) => user,
 | 
			
		||||
        Err(e) => {
 | 
			
		||||
            log::error!("Failed to get username from uid: {}", e);
 | 
			
		||||
            let mut message_stream = create_server_to_client_message_stream(conn);
 | 
			
		||||
            message_stream
 | 
			
		||||
                .send(Response::Error(
 | 
			
		||||
                    (concatdoc! {
 | 
			
		||||
                        "Server failed to get user data from the system\n",
 | 
			
		||||
                        "Please check the server logs or contact the system administrators"
 | 
			
		||||
                    })
 | 
			
		||||
                    .to_string(),
 | 
			
		||||
                ))
 | 
			
		||||
                .await
 | 
			
		||||
                .ok();
 | 
			
		||||
            anyhow::bail!("Failed to get username from uid");
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    log::info!("Accepted connection from {}", unix_user.username);
 | 
			
		||||
 | 
			
		||||
    sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
 | 
			
		||||
    sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
 | 
			
		||||
 | 
			
		||||
    handle_requests_for_single_session(conn, &unix_user, &config).await?;
 | 
			
		||||
 | 
			
		||||
@@ -66,7 +195,12 @@ async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
 | 
			
		||||
        .next()
 | 
			
		||||
        .context("No file descriptors received from systemd")?;
 | 
			
		||||
 | 
			
		||||
    log::debug!("Received file descriptor from systemd: {}", fd);
 | 
			
		||||
    debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
 | 
			
		||||
 | 
			
		||||
    log::debug!(
 | 
			
		||||
        "Received file descriptor from systemd with id: '{}', assuming socket",
 | 
			
		||||
        fd
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
 | 
			
		||||
    let socket = TokioUnixStream::from_std(std_unix_stream)?;
 | 
			
		||||
 
 | 
			
		||||
@@ -109,18 +109,12 @@ pub fn read_config_from_path_with_arg_overrides(
 | 
			
		||||
pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
 | 
			
		||||
    let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
 | 
			
		||||
 | 
			
		||||
    log::debug!("Reading config from {:?}", &config_path);
 | 
			
		||||
    log::debug!("Reading config file at {:?}", &config_path);
 | 
			
		||||
 | 
			
		||||
    fs::read_to_string(&config_path)
 | 
			
		||||
        .context(format!(
 | 
			
		||||
            "Failed to read config file from {:?}",
 | 
			
		||||
            &config_path
 | 
			
		||||
        ))
 | 
			
		||||
        .context(format!("Failed to read config file at {:?}", &config_path))
 | 
			
		||||
        .and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
 | 
			
		||||
        .context(format!(
 | 
			
		||||
            "Failed to parse config file from {:?}",
 | 
			
		||||
            &config_path
 | 
			
		||||
        ))
 | 
			
		||||
        .context(format!("Failed to parse config file at {:?}", &config_path))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn log_config(config: &MysqlConfig) {
 | 
			
		||||
@@ -141,7 +135,9 @@ pub async fn create_mysql_connection_from_config(
 | 
			
		||||
) -> anyhow::Result<MySqlConnection> {
 | 
			
		||||
    log_config(config);
 | 
			
		||||
 | 
			
		||||
    let mut mysql_options = MySqlConnectOptions::new().database("mysql");
 | 
			
		||||
    let mut mysql_options = MySqlConnectOptions::new()
 | 
			
		||||
        .database("mysql")
 | 
			
		||||
        .log_statements(log::LevelFilter::Trace);
 | 
			
		||||
 | 
			
		||||
    if let Some(username) = &config.username {
 | 
			
		||||
        mysql_options = mysql_options.username(username);
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ use tokio::net::{UnixListener, UnixStream};
 | 
			
		||||
use sqlx::prelude::*;
 | 
			
		||||
use sqlx::MySqlConnection;
 | 
			
		||||
 | 
			
		||||
use crate::core::protocol::SetPasswordError;
 | 
			
		||||
use crate::server::sql::database_operations::list_databases;
 | 
			
		||||
use crate::{
 | 
			
		||||
    core::{
 | 
			
		||||
@@ -56,7 +57,7 @@ pub async fn listen_for_incoming_connections(
 | 
			
		||||
 | 
			
		||||
    let listener = UnixListener::bind(socket_path)?;
 | 
			
		||||
 | 
			
		||||
    sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
 | 
			
		||||
    sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
 | 
			
		||||
 | 
			
		||||
    while let Ok((conn, _addr)) = listener.accept().await {
 | 
			
		||||
        let uid = match conn.peer_cred() {
 | 
			
		||||
@@ -78,7 +79,7 @@ pub async fn listen_for_incoming_connections(
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        log::trace!("Accepted connection from uid {}", uid);
 | 
			
		||||
        log::debug!("Accepted connection from uid {}", uid);
 | 
			
		||||
 | 
			
		||||
        let unix_user = match UnixUser::from_uid(uid) {
 | 
			
		||||
            Ok(user) => user,
 | 
			
		||||
@@ -173,47 +174,47 @@ pub async fn handle_requests_for_single_session_with_db_connection(
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        log::trace!("Received request: {:?}", request);
 | 
			
		||||
        // TODO: don't clone the request
 | 
			
		||||
        let request_to_display = match &request {
 | 
			
		||||
            Request::PasswdUser(db_user, _) => {
 | 
			
		||||
                Request::PasswdUser(db_user.to_owned(), "<REDACTED>".to_string())
 | 
			
		||||
            }
 | 
			
		||||
            request => request.to_owned(),
 | 
			
		||||
        };
 | 
			
		||||
        log::info!("Received request: {:#?}", request_to_display);
 | 
			
		||||
 | 
			
		||||
        match request {
 | 
			
		||||
        let response = match request {
 | 
			
		||||
            Request::CreateDatabases(databases_names) => {
 | 
			
		||||
                let result = create_databases(databases_names, unix_user, db_connection).await;
 | 
			
		||||
                stream.send(Response::CreateDatabases(result)).await?;
 | 
			
		||||
                Response::CreateDatabases(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::DropDatabases(databases_names) => {
 | 
			
		||||
                let result = drop_databases(databases_names, unix_user, db_connection).await;
 | 
			
		||||
                stream.send(Response::DropDatabases(result)).await?;
 | 
			
		||||
            }
 | 
			
		||||
            Request::ListDatabases(database_names) => {
 | 
			
		||||
                let response = match database_names {
 | 
			
		||||
                    Some(database_names) => {
 | 
			
		||||
                        let result = list_databases(database_names, unix_user, db_connection).await;
 | 
			
		||||
                        Response::ListDatabases(result)
 | 
			
		||||
                    }
 | 
			
		||||
                    None => {
 | 
			
		||||
                        let result = list_all_databases_for_user(unix_user, db_connection).await;
 | 
			
		||||
                        Response::ListAllDatabases(result)
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
                stream.send(response).await?;
 | 
			
		||||
            }
 | 
			
		||||
            Request::ListPrivileges(database_names) => {
 | 
			
		||||
                let response = match database_names {
 | 
			
		||||
                    Some(database_names) => {
 | 
			
		||||
                        let privilege_data =
 | 
			
		||||
                            get_databases_privilege_data(database_names, unix_user, db_connection)
 | 
			
		||||
                                .await;
 | 
			
		||||
                        Response::ListPrivileges(privilege_data)
 | 
			
		||||
                    }
 | 
			
		||||
                    None => {
 | 
			
		||||
                        let privilege_data =
 | 
			
		||||
                            get_all_database_privileges(unix_user, db_connection).await;
 | 
			
		||||
                        Response::ListAllPrivileges(privilege_data)
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                stream.send(response).await?;
 | 
			
		||||
                Response::DropDatabases(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::ListDatabases(database_names) => match database_names {
 | 
			
		||||
                Some(database_names) => {
 | 
			
		||||
                    let result = list_databases(database_names, unix_user, db_connection).await;
 | 
			
		||||
                    Response::ListDatabases(result)
 | 
			
		||||
                }
 | 
			
		||||
                None => {
 | 
			
		||||
                    let result = list_all_databases_for_user(unix_user, db_connection).await;
 | 
			
		||||
                    Response::ListAllDatabases(result)
 | 
			
		||||
                }
 | 
			
		||||
            },
 | 
			
		||||
            Request::ListPrivileges(database_names) => match database_names {
 | 
			
		||||
                Some(database_names) => {
 | 
			
		||||
                    let privilege_data =
 | 
			
		||||
                        get_databases_privilege_data(database_names, unix_user, db_connection)
 | 
			
		||||
                            .await;
 | 
			
		||||
                    Response::ListPrivileges(privilege_data)
 | 
			
		||||
                }
 | 
			
		||||
                None => {
 | 
			
		||||
                    let privilege_data =
 | 
			
		||||
                        get_all_database_privileges(unix_user, db_connection).await;
 | 
			
		||||
                    Response::ListAllPrivileges(privilege_data)
 | 
			
		||||
                }
 | 
			
		||||
            },
 | 
			
		||||
            Request::ModifyPrivileges(database_privilege_diffs) => {
 | 
			
		||||
                let result = apply_privilege_diffs(
 | 
			
		||||
                    BTreeSet::from_iter(database_privilege_diffs),
 | 
			
		||||
@@ -221,50 +222,58 @@ pub async fn handle_requests_for_single_session_with_db_connection(
 | 
			
		||||
                    db_connection,
 | 
			
		||||
                )
 | 
			
		||||
                .await;
 | 
			
		||||
                stream.send(Response::ModifyPrivileges(result)).await?;
 | 
			
		||||
                Response::ModifyPrivileges(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::CreateUsers(db_users) => {
 | 
			
		||||
                let result = create_database_users(db_users, unix_user, db_connection).await;
 | 
			
		||||
                stream.send(Response::CreateUsers(result)).await?;
 | 
			
		||||
                Response::CreateUsers(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::DropUsers(db_users) => {
 | 
			
		||||
                let result = drop_database_users(db_users, unix_user, db_connection).await;
 | 
			
		||||
                stream.send(Response::DropUsers(result)).await?;
 | 
			
		||||
                Response::DropUsers(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::PasswdUser(db_user, password) => {
 | 
			
		||||
                let result =
 | 
			
		||||
                    set_password_for_database_user(&db_user, &password, unix_user, db_connection)
 | 
			
		||||
                        .await;
 | 
			
		||||
                stream.send(Response::PasswdUser(result)).await?;
 | 
			
		||||
            }
 | 
			
		||||
            Request::ListUsers(db_users) => {
 | 
			
		||||
                let response = match db_users {
 | 
			
		||||
                    Some(db_users) => {
 | 
			
		||||
                        let result = list_database_users(db_users, unix_user, db_connection).await;
 | 
			
		||||
                        Response::ListUsers(result)
 | 
			
		||||
                    }
 | 
			
		||||
                    None => {
 | 
			
		||||
                        let result =
 | 
			
		||||
                            list_all_database_users_for_unix_user(unix_user, db_connection).await;
 | 
			
		||||
                        Response::ListAllUsers(result)
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
                stream.send(response).await?;
 | 
			
		||||
                Response::PasswdUser(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::ListUsers(db_users) => match db_users {
 | 
			
		||||
                Some(db_users) => {
 | 
			
		||||
                    let result = list_database_users(db_users, unix_user, db_connection).await;
 | 
			
		||||
                    Response::ListUsers(result)
 | 
			
		||||
                }
 | 
			
		||||
                None => {
 | 
			
		||||
                    let result =
 | 
			
		||||
                        list_all_database_users_for_unix_user(unix_user, db_connection).await;
 | 
			
		||||
                    Response::ListAllUsers(result)
 | 
			
		||||
                }
 | 
			
		||||
            },
 | 
			
		||||
            Request::LockUsers(db_users) => {
 | 
			
		||||
                let result = lock_database_users(db_users, unix_user, db_connection).await;
 | 
			
		||||
                stream.send(Response::LockUsers(result)).await?;
 | 
			
		||||
                Response::LockUsers(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::UnlockUsers(db_users) => {
 | 
			
		||||
                let result = unlock_database_users(db_users, unix_user, db_connection).await;
 | 
			
		||||
                stream.send(Response::UnlockUsers(result)).await?;
 | 
			
		||||
                Response::UnlockUsers(result)
 | 
			
		||||
            }
 | 
			
		||||
            Request::Exit => {
 | 
			
		||||
                break;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // TODO: don't clone the response
 | 
			
		||||
        let response_to_display = match &response {
 | 
			
		||||
            Response::PasswdUser(Err(SetPasswordError::MySqlError(_))) => {
 | 
			
		||||
                Response::PasswdUser(Err(SetPasswordError::MySqlError("<REDACTED>".to_string())))
 | 
			
		||||
            }
 | 
			
		||||
            response => response.to_owned(),
 | 
			
		||||
        };
 | 
			
		||||
        log::info!("Response: {:#?}", response_to_display);
 | 
			
		||||
 | 
			
		||||
        stream.send(response).await?;
 | 
			
		||||
        stream.flush().await?;
 | 
			
		||||
        log::debug!("Successfully processed request");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
 
 | 
			
		||||
@@ -167,7 +167,7 @@ pub async fn set_password_for_database_user(
 | 
			
		||||
        format!(
 | 
			
		||||
            "ALTER USER {}@'%' IDENTIFIED BY {}",
 | 
			
		||||
            quote_literal(db_user),
 | 
			
		||||
            quote_literal(password).as_str()
 | 
			
		||||
            quote_literal(password).as_str(),
 | 
			
		||||
        )
 | 
			
		||||
        .as_str(),
 | 
			
		||||
    )
 | 
			
		||||
@@ -176,11 +176,10 @@ pub async fn set_password_for_database_user(
 | 
			
		||||
    .map(|_| ())
 | 
			
		||||
    .map_err(|err| SetPasswordError::MySqlError(err.to_string()));
 | 
			
		||||
 | 
			
		||||
    if let Err(err) = &result {
 | 
			
		||||
    if result.is_err() {
 | 
			
		||||
        log::error!(
 | 
			
		||||
            "Failed to set password for database user '{}': {:?}",
 | 
			
		||||
            "Failed to set password for database user '{}': <REDACTED>",
 | 
			
		||||
            &db_user,
 | 
			
		||||
            err
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user