diff --git a/example-config.toml b/example-config.toml index 2d3200e..c2f9e44 100644 --- a/example-config.toml +++ b/example-config.toml @@ -1,8 +1,22 @@ # This should go to `/etc/mysqladm/config.toml` +[server] +# Note that this gets ignored if you are using socket activation. +socket_path = "/var/run/mysqladm/mysqladm.sock" + [mysql] + +# if you use a socket, the host and port will be ignored +# socket_path = "/var/run/mysql/mysql.sock" + host = "localhost" port = 3306 + +# The username and password can be omitted if you are using +# socket based authentication. However, the vendored systemd +# service is running as DynamicUser, so by default you need +# to at least specify the username. username = "root" password = "secret" + timeout = 2 # seconds \ No newline at end of file diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index f64d89e..891b52a 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -7,7 +7,7 @@ use tokio::net::UnixStream as TokioUnixStream; use crate::{ core::common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, - server::{config::read_config_form_path, server_loop::handle_requests_for_single_session}, + server::{config::read_config_from_path, server_loop::handle_requests_for_single_session}, }; // TODO: this function is security critical, it should be integration tested @@ -140,7 +140,7 @@ fn run_forked_server( server_socket: StdUnixStream, unix_user: UnixUser, ) -> anyhow::Result<()> { - let config = read_config_form_path(Some(config_path))?; + let config = read_config_from_path(Some(config_path))?; let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/src/server/config.rs b/src/server/config.rs index e3639ca..4989c07 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -21,21 +21,37 @@ pub struct ServerConfig { #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename = "mysql")] pub struct MysqlConfig { - pub host: String, + pub socket_path: Option, + pub host: Option, pub port: Option, - pub username: String, - pub password: String, + pub username: Option, + pub password: Option, + pub password_file: Option, pub timeout: Option, } #[derive(Parser, Debug, Clone)] pub struct ServerConfigArgs { + /// Path to the socket of the MySQL server. + #[arg(long, value_name = "PATH", global = true)] + socket_path: Option, + /// Hostname of the MySQL server. - #[arg(long, value_name = "HOST", global = true)] + #[arg( + long, + value_name = "HOST", + global = true, + conflicts_with = "socket_path" + )] mysql_host: Option, /// Port of the MySQL server. - #[arg(long, value_name = "PORT", global = true)] + #[arg( + long, + value_name = "PORT", + global = true, + conflicts_with = "socket_path" + )] mysql_port: Option, /// Username to use for the MySQL connection. @@ -44,7 +60,7 @@ pub struct ServerConfigArgs { /// Path to a file containing the MySQL password. #[arg(long, value_name = "PATH", global = true)] - mysql_password_file: Option, + mysql_password_file: Option, /// Seconds to wait for the MySQL connection to be established. #[arg(long, value_name = "SECONDS", global = true)] @@ -57,30 +73,40 @@ pub fn read_config_from_path_with_arg_overrides( config_path: Option, args: ServerConfigArgs, ) -> anyhow::Result { - let config = read_config_form_path(config_path)?; + let config = read_config_from_path(config_path)?; - let mysql = &config.mysql; + let mysql = config.mysql; - let password = if let Some(path) = args.mysql_password_file { - fs::read_to_string(path) - .context("Failed to read MySQL password file") - .map(|s| s.trim().to_owned())? + let password = if let Some(path) = &args.mysql_password_file { + Some( + fs::read_to_string(path) + .context("Failed to read MySQL password file") + .map(|s| s.trim().to_owned())?, + ) + } else if let Some(path) = &mysql.password_file { + Some( + fs::read_to_string(path) + .context("Failed to read MySQL password file") + .map(|s| s.trim().to_owned())?, + ) } else { mysql.password.to_owned() }; Ok(ServerConfig { mysql: MysqlConfig { - host: args.mysql_host.unwrap_or(mysql.host.to_owned()), + socket_path: args.socket_path.or(mysql.socket_path), + host: args.mysql_host.or(mysql.host), port: args.mysql_port.or(mysql.port), - username: args.mysql_user.unwrap_or(mysql.username.to_owned()), + username: args.mysql_user.or(mysql.username.to_owned()), password, + password_file: args.mysql_password_file.or(mysql.password_file), timeout: args.mysql_connect_timeout.or(mysql.timeout), }, }) } -pub fn read_config_form_path(config_path: Option) -> anyhow::Result { +pub fn read_config_from_path(config_path: Option) -> anyhow::Result { let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); log::debug!("Reading config from {:?}", &config_path); @@ -97,30 +123,51 @@ pub fn read_config_form_path(config_path: Option) -> anyhow::Result anyhow::Result { +fn log_config(config: &MysqlConfig) { let mut display_config = config.clone(); - "".clone_into(&mut display_config.password); + display_config.password = display_config + .password + .as_ref() + .map(|_| "".to_owned()); log::debug!( "Connecting to MySQL server with parameters: {:#?}", display_config ); +} + +/// Use the provided configuration to establish a connection to a MySQL server. +pub async fn create_mysql_connection_from_config( + config: &MysqlConfig, +) -> anyhow::Result { + log_config(config); + + let mut mysql_options = MySqlConnectOptions::new() + .database("mysql"); + + if let Some(username) = &config.username { + mysql_options = mysql_options.username(username); + } + + if let Some(password) = &config.password { + mysql_options = mysql_options.password(password); + } + + if let Some(socket_path) = &config.socket_path { + mysql_options = mysql_options.socket(socket_path); + } else if let Some(host) = &config.host { + mysql_options = mysql_options.host(host); + mysql_options = mysql_options.port(config.port.unwrap_or(DEFAULT_PORT)); + } else { + anyhow::bail!("No MySQL host or socket path provided"); + } match tokio::time::timeout( Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)), - MySqlConnectOptions::new() - .host(&config.host) - .username(&config.username) - .password(&config.password) - .port(config.port.unwrap_or(DEFAULT_PORT)) - .database("mysql") - .connect(), + mysql_options.connect(), ) .await { - Ok(connection) => connection.context("Failed to connect to MySQL"), - Err(_) => Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to MySQL"), + Ok(connection) => connection.context("Failed to connect to the database"), + Err(_) => Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to the database"), } }