Compare commits

..

9 Commits

18 changed files with 407 additions and 113 deletions

1
Cargo.lock generated
View File

@ -1077,6 +1077,7 @@ dependencies = [
"prettytable", "prettytable",
"rand", "rand",
"ratatui", "ratatui",
"regex",
"sd-notify", "sd-notify",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -49,3 +49,6 @@ codegen-units = 1
[build-dependencies] [build-dependencies]
anyhow = "1.0.82" anyhow = "1.0.82"
[dev-dependencies]
regex = "1.10.6"

View File

@ -1,8 +1,22 @@
# This should go to `/etc/mysqladm/config.toml` # This should go to `/etc/mysqladm/config.toml`
[server]
# Note that this gets ignored if you are using socket activation.
socket_path = "/var/run/mysqladm/mysqladm.sock"
[mysql] [mysql]
# if you use a socket, the host and port will be ignored
# socket_path = "/var/run/mysql/mysql.sock"
host = "localhost" host = "localhost"
port = 3306 port = 3306
# The username and password can be omitted if you are using
# socket based authentication. However, the vendored systemd
# service is running as DynamicUser, so by default you need
# to at least specify the username.
username = "root" username = "root"
password = "secret" password = "secret"
timeout = 2 # seconds timeout = 2 # seconds

View File

@ -2,16 +2,16 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1723637854, "lastModified": 1723938990,
"narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=", "narHash": "sha256-9tUadhnZQbWIiYVXH8ncfGXGvkNq3Hag4RCBEMUk7MI=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9", "rev": "c42fcfbdfeae23e68fc520f9182dde9f38ad1890",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "NixOS", "owner": "NixOS",
"ref": "nixos-unstable", "ref": "nixos-24.05",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }
@ -29,11 +29,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1723947704, "lastModified": 1724034091,
"narHash": "sha256-TcVf66N2NgGhxORFytzgqWcg0XJ+kk8uNLNsTRI5sYM=", "narHash": "sha256-b1g7w0sw+MDAhUAeCoX1vlTghsqcDZkxr+k9OZmxPa8=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "456e78a55feade2c3bc6d7bc0bf5e710c9d86120", "rev": "c7d36e0947826e0751a5214ffe82533fbc909bc0",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -1,6 +1,6 @@
{ {
inputs = { inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.05";
rust-overlay.url = "github:oxalica/rust-overlay"; rust-overlay.url = "github:oxalica/rust-overlay";
rust-overlay.inputs.nixpkgs.follows = "nixpkgs"; rust-overlay.inputs.nixpkgs.follows = "nixpkgs";
@ -52,11 +52,16 @@
overlays = { overlays = {
default = self.overlays.mysqladm-rs; default = self.overlays.mysqladm-rs;
greg-ng = final: prev: { mysqladm-rs = final: prev: {
inherit (self.packages.${prev.system}) mysqladm-rs; inherit (self.packages.${prev.system}) mysqladm-rs;
}; };
}; };
nixosModules = {
default = self.nixosModules.mysqladm-rs;
mysqladm-rs = import ./nix/module.nix;
};
packages = let packages = let
cargoToml = builtins.fromTOML (builtins.readFile ./Cargo.toml); cargoToml = builtins.fromTOML (builtins.readFile ./Cargo.toml);
cargoLock = ./Cargo.lock; cargoLock = ./Cargo.lock;

141
nix/module.nix Normal file
View File

@ -0,0 +1,141 @@
{ config, pkgs, lib, ... }:
let
cfg = config.services.mysqladm-rs;
format = pkgs.formats.toml { };
in
{
options.services.mysqladm-rs = {
enable = lib.mkEnableOption "Enable mysqladm-rs";
package = lib.mkPackageOption pkgs "mysqladm-rs" { };
createLocalDatabaseUser = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Create a local database user for mysqladm-rs";
};
settings = lib.mkOption {
default = { };
type = lib.types.submodule {
freeformType = format.type;
options = {
server = {
socket_path = lib.mkOption {
type = lib.types.path;
default = "/var/run/mysqladm/mysqladm.sock";
description = "Path to the MySQL socket";
};
};
mysql = {
socket_path = lib.mkOption {
type = with lib.types; nullOr path;
default = "/var/run/mysqld/mysqld.sock";
description = "Path to the MySQL socket";
};
host = lib.mkOption {
type = with lib.types; nullOr str;
default = null;
description = "MySQL host";
};
port = lib.mkOption {
type = with lib.types; nullOr port;
default = 3306;
description = "MySQL port";
};
username = lib.mkOption {
type = lib.types.str;
default = "mysqladm";
description = "MySQL username";
};
passwordFile = lib.mkOption {
type = with lib.types; nullOr path;
default = null;
description = "Path to a file containing the MySQL password";
};
timeout = lib.mkOption {
type = lib.types.ints.positive;
default = 2;
description = "Number of seconds to wait for a response from the MySQL server";
};
};
};
};
};
};
config = let
nullStrippedConfig = lib.filterAttrsRecursive (_: v: v != null) cfg.settings;
configFile = format.generate "mysqladm-rs.conf" nullStrippedConfig;
in lib.mkIf config.services.mysqladm-rs.enable {
environment.systemPackages = [ cfg.package ];
services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [
{
name = cfg.settings.mysql.username;
ensurePermissions = {
"mysql.*" = "SELECT, INSERT, UPDATE, DELETE";
"*.*" = "CREATE USER, GRANT OPTION";
};
}
];
systemd.services."mysqladm@" = {
description = "MySQL administration tool for non-admin users";
environment.RUST_LOG = "debug";
serviceConfig = {
Type = "notify";
ExecStart = "${lib.getExe cfg.package} server socket-activate --config ${configFile}";
User = "mysqladm";
Group = "mysqladm";
DynamicUser = true;
# This is required to read unix user/group details.
PrivateUsers = false;
CapabilityBoundingSet = "";
LockPersonality = true;
MemoryDenyWriteExecute = true;
NoNewPrivileges = true;
PrivateDevices = true;
PrivateMounts = true;
PrivateTmp = "yes";
ProcSubset = "pid";
ProtectClock = true;
ProtectControlGroups = true;
ProtectHome = true;
ProtectHostname = true;
ProtectKernelLogs = true;
ProtectKernelModules = true;
ProtectKernelTunables = true;
ProtectProc = "invisible";
ProtectSystem = "strict";
RemoveIPC = true;
UMask = "0000";
RestrictAddressFamilies = [ "AF_UNIX" "AF_INET" "AF_INET6" ];
RestrictNamespaces = true;
RestrictRealtime = true;
RestrictSUIDSGID = true;
SystemCallArchitectures = "native";
SystemCallFilter = [
"@system-service"
"~@privileged"
"~@resources"
];
};
};
systemd.sockets."mysqladm" = {
description = "MySQL administration tool for non-admin users";
wantedBy = [ "sockets.target" ];
restartTriggers = [ configFile ];
socketConfig = {
ListenStream = cfg.settings.server.socket_path;
Accept = "yes";
PassCredentials = true;
};
};
};
}

View File

@ -56,11 +56,11 @@ The Y/N-values corresponds to the following mysql privileges:
/// Please consider using the newer mysqladm command instead. /// Please consider using the newer mysqladm command instead.
#[derive(Parser)] #[derive(Parser)]
#[command( #[command(
bin_name = "mysql-dbadm", bin_name = "mysql-dbadm",
version, version,
about, about,
disable_help_subcommand = true, disable_help_subcommand = true,
verbatim_doc_comment, verbatim_doc_comment
)] )]
pub struct Args { pub struct Args {
#[command(subcommand)] #[command(subcommand)]

View File

@ -32,11 +32,11 @@ use crate::{
/// Please consider using the newer mysqladm command instead. /// Please consider using the newer mysqladm command instead.
#[derive(Parser)] #[derive(Parser)]
#[command( #[command(
bin_name = "mysql-useradm", bin_name = "mysql-useradm",
version, version,
about, about,
disable_help_subcommand = true, disable_help_subcommand = true,
verbatim_doc_comment, verbatim_doc_comment
)] )]
pub struct Args { pub struct Args {
#[command(subcommand)] #[command(subcommand)]

View File

@ -140,6 +140,7 @@ async fn create_users(
"Do you want to set a password for user '{}'?", "Do you want to set a password for user '{}'?",
username username
)) ))
.default(false)
.interact()? .interact()?
{ {
let password = read_password_from_stdin_with_double_check(username)?; let password = read_password_from_stdin_with_double_check(username)?;

View File

@ -7,7 +7,7 @@ use tokio::net::UnixStream as TokioUnixStream;
use crate::{ use crate::{
core::common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, core::common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH},
server::{config::read_config_form_path, server_loop::handle_requests_for_single_session}, server::{config::read_config_from_path, server_loop::handle_requests_for_single_session},
}; };
// TODO: this function is security critical, it should be integration tested // TODO: this function is security critical, it should be integration tested
@ -32,15 +32,18 @@ pub fn drop_privs() -> anyhow::Result<()> {
/// This function is used to bootstrap the connection to the server. /// This function is used to bootstrap the connection to the server.
/// This can happen in two ways: /// This can happen in two ways:
///
/// 1. If a socket path is provided, or exists in the default location, /// 1. If a socket path is provided, or exists in the default location,
/// the function will connect to the socket and authenticate with the /// the function will connect to the socket and authenticate with the
/// server to ensure that the server knows the uid of the client. /// server to ensure that the server knows the uid of the client.
///
/// 2. If a config path is provided, or exists in the default location, /// 2. If a config path is provided, or exists in the default location,
/// and the config is readable, the function will assume it is either /// and the config is readable, the function will assume it is either
/// setuid or setgid, and will fork a child process to run the server /// setuid or setgid, and will fork a child process to run the server
/// with the provided config. The server will exit silently by itself /// with the provided config. The server will exit silently by itself
/// when it is done, and this function will only return for the client /// when it is done, and this function will only return for the client
/// with the socket for the server. /// with the socket for the server.
///
/// If neither of these options are available, the function will fail. /// If neither of these options are available, the function will fail.
pub fn bootstrap_server_connection_and_drop_privileges( pub fn bootstrap_server_connection_and_drop_privileges(
server_socket_path: Option<PathBuf>, server_socket_path: Option<PathBuf>,
@ -140,7 +143,7 @@ fn run_forked_server(
server_socket: StdUnixStream, server_socket: StdUnixStream,
unix_user: UnixUser, unix_user: UnixUser,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let config = read_config_form_path(Some(config_path))?; let config = read_config_from_path(Some(config_path))?;
let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread() let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
.enable_all() .enable_all()

View File

@ -14,8 +14,8 @@ use crate::server::sql::database_privilege_operations::{
pub fn db_priv_field_human_readable_name(name: &str) -> String { pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name { match name {
"db" => "Database".to_owned(), "Db" => "Database".to_owned(),
"user" => "User".to_owned(), "User" => "User".to_owned(),
"select_priv" => "Select".to_owned(), "select_priv" => "Select".to_owned(),
"insert_priv" => "Insert".to_owned(), "insert_priv" => "Insert".to_owned(),
"update_priv" => "Update".to_owned(), "update_priv" => "Update".to_owned(),
@ -128,8 +128,8 @@ pub fn format_privileges_line_for_editor(
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.into_iter() .into_iter()
.map(|field| match field { .map(|field| match field {
"db" => format!("{:width$}", privs.db, width = database_name_len), "Db" => format!("{:width$}", privs.db, width = database_name_len),
"user" => format!("{:width$}", privs.user, width = username_len), "User" => format!("{:width$}", privs.user, width = username_len),
privilege => format!( privilege => format!(
"{:width$}", "{:width$}",
yn(privs.get_privilege_by_name(privilege)), yn(privs.get_privilege_by_name(privilege)),

View File

@ -73,7 +73,6 @@ pub enum Response {
UnlockUsers(UnlockUsersOutput), UnlockUsers(UnlockUsersOutput),
// Generic responses // Generic responses
OperationAborted, Ready,
Error(String), Error(String),
Exit,
} }

View File

@ -9,10 +9,12 @@ use std::path::PathBuf;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream; use tokio::net::UnixStream as TokioUnixStream;
use futures::StreamExt;
use crate::{ use crate::{
core::{ core::{
bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs}, bootstrap::bootstrap_server_connection_and_drop_privileges,
protocol::create_client_to_server_message_stream, protocol::{create_client_to_server_message_stream, Response},
}, },
server::command::ServerArgs, server::command::ServerArgs,
}; };
@ -107,17 +109,17 @@ fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
#[cfg(feature = "mysql-admutils-compatibility")] #[cfg(feature = "mysql-admutils-compatibility")]
if let Some(_) = handle_mysql_admutils_command()? { if handle_mysql_admutils_command()?.is_some() {
return Ok(()); return Ok(());
} }
let args: Args = Args::parse(); let args: Args = Args::parse();
if let Some(_) = handle_server_command(&args)? { if handle_server_command(&args)?.is_some() {
return Ok(()); return Ok(());
} }
if let Some(_) = handle_generate_completions_command(&args)? { if handle_generate_completions_command(&args)?.is_some() {
return Ok(()); return Ok(());
} }
@ -137,8 +139,8 @@ fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
}); });
match argv0.as_deref() { match argv0.as_deref() {
Some("mysql-dbadm") => mysql_dbadm::main().map(|result| Some(result)), Some("mysql-dbadm") => mysql_dbadm::main().map(Some),
Some("mysql-useradm") => mysql_useradm::main().map(|result| Some(result)), Some("mysql-useradm") => mysql_useradm::main().map(Some),
_ => Ok(None), _ => Ok(None),
} }
} }
@ -146,7 +148,6 @@ fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> { fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
match args.command { match args.command {
Command::Server(ref command) => { Command::Server(ref command) => {
drop_privs()?;
tokio_start_server( tokio_start_server(
args.server_socket_path.clone(), args.server_socket_path.clone(),
args.config.clone(), args.config.clone(),
@ -205,7 +206,20 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh
.unwrap() .unwrap()
.block_on(async { .block_on(async {
let tokio_socket = TokioUnixStream::from_std(server_connection)?; let tokio_socket = TokioUnixStream::from_std(server_connection)?;
let message_stream = create_client_to_server_message_stream(tokio_socket); let mut message_stream = create_client_to_server_message_stream(tokio_socket);
while let Some(Ok(message)) = message_stream.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
}
}
}
match command { match command {
Command::User(user_args) => { Command::User(user_args) => {
cli::user_command::handle_command(user_args, message_stream).await cli::user_command::handle_command(user_args, message_stream).await

View File

@ -1,11 +1,51 @@
use crate::core::common::UnixUser; use crate::core::common::UnixUser;
use sqlx::prelude::*;
/// This function creates a regex that matches items (users, databases) /// This function creates a regex that matches items (users, databases)
/// that belong to the user or any of the user's groups. /// that belong to the user or any of the user's groups.
pub fn create_user_group_matching_regex(user: &UnixUser) -> String { pub fn create_user_group_matching_regex(user: &UnixUser) -> String {
if user.groups.is_empty() { if user.groups.is_empty() {
format!("{}(_.+)?", user.username) format!("{}_.+", user.username)
} else { } else {
format!("({}|{})(_.+)?", user.username, user.groups.join("|")) format!("({}|{})_.+", user.username, user.groups.join("|"))
}
}
/// Some mysql versions with some collations mark some columns as binary fields,
/// which in the current version of sqlx is not parsable as string.
/// See: https://github.com/launchbadge/sqlx/issues/3387
#[inline]
pub fn try_get_with_binary_fallback(
row: &sqlx::mysql::MySqlRow,
column: &str,
) -> Result<String, sqlx::Error> {
row.try_get(column).or_else(|_| {
row.try_get::<Vec<u8>, _>(column)
.map(|v| String::from_utf8_lossy(&v).to_string())
})
}
#[cfg(test)]
mod tests {
use super::*;
use regex::Regex;
#[test]
fn test_create_user_group_matching_regex() {
let user = UnixUser {
username: "user".to_owned(),
groups: vec!["group1".to_owned(), "group2".to_owned()],
};
let regex = create_user_group_matching_regex(&user);
let re = Regex::new(&regex).unwrap();
assert!(re.is_match("user_something"));
assert!(re.is_match("group1_something"));
assert!(re.is_match("group2_something"));
assert!(!re.is_match("other_something"));
assert!(!re.is_match("user"));
assert!(!re.is_match("usersomething"));
} }
} }

View File

@ -21,21 +21,37 @@ pub struct ServerConfig {
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename = "mysql")] #[serde(rename = "mysql")]
pub struct MysqlConfig { pub struct MysqlConfig {
pub host: String, pub socket_path: Option<PathBuf>,
pub host: Option<String>,
pub port: Option<u16>, pub port: Option<u16>,
pub username: String, pub username: Option<String>,
pub password: String, pub password: Option<String>,
pub password_file: Option<PathBuf>,
pub timeout: Option<u64>, pub timeout: Option<u64>,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct ServerConfigArgs { pub struct ServerConfigArgs {
/// Path to the socket of the MySQL server.
#[arg(long, value_name = "PATH", global = true)]
socket_path: Option<PathBuf>,
/// Hostname of the MySQL server. /// Hostname of the MySQL server.
#[arg(long, value_name = "HOST", global = true)] #[arg(
long,
value_name = "HOST",
global = true,
conflicts_with = "socket_path"
)]
mysql_host: Option<String>, mysql_host: Option<String>,
/// Port of the MySQL server. /// Port of the MySQL server.
#[arg(long, value_name = "PORT", global = true)] #[arg(
long,
value_name = "PORT",
global = true,
conflicts_with = "socket_path"
)]
mysql_port: Option<u16>, mysql_port: Option<u16>,
/// Username to use for the MySQL connection. /// Username to use for the MySQL connection.
@ -44,7 +60,7 @@ pub struct ServerConfigArgs {
/// Path to a file containing the MySQL password. /// Path to a file containing the MySQL password.
#[arg(long, value_name = "PATH", global = true)] #[arg(long, value_name = "PATH", global = true)]
mysql_password_file: Option<String>, mysql_password_file: Option<PathBuf>,
/// Seconds to wait for the MySQL connection to be established. /// Seconds to wait for the MySQL connection to be established.
#[arg(long, value_name = "SECONDS", global = true)] #[arg(long, value_name = "SECONDS", global = true)]
@ -57,30 +73,40 @@ pub fn read_config_from_path_with_arg_overrides(
config_path: Option<PathBuf>, config_path: Option<PathBuf>,
args: ServerConfigArgs, args: ServerConfigArgs,
) -> anyhow::Result<ServerConfig> { ) -> anyhow::Result<ServerConfig> {
let config = read_config_form_path(config_path)?; let config = read_config_from_path(config_path)?;
let mysql = &config.mysql; let mysql = config.mysql;
let password = if let Some(path) = args.mysql_password_file { let password = if let Some(path) = &args.mysql_password_file {
fs::read_to_string(path) Some(
.context("Failed to read MySQL password file") fs::read_to_string(path)
.map(|s| s.trim().to_owned())? .context("Failed to read MySQL password file")
.map(|s| s.trim().to_owned())?,
)
} else if let Some(path) = &mysql.password_file {
Some(
fs::read_to_string(path)
.context("Failed to read MySQL password file")
.map(|s| s.trim().to_owned())?,
)
} else { } else {
mysql.password.to_owned() mysql.password.to_owned()
}; };
Ok(ServerConfig { Ok(ServerConfig {
mysql: MysqlConfig { mysql: MysqlConfig {
host: args.mysql_host.unwrap_or(mysql.host.to_owned()), socket_path: args.socket_path.or(mysql.socket_path),
host: args.mysql_host.or(mysql.host),
port: args.mysql_port.or(mysql.port), port: args.mysql_port.or(mysql.port),
username: args.mysql_user.unwrap_or(mysql.username.to_owned()), username: args.mysql_user.or(mysql.username.to_owned()),
password, password,
password_file: args.mysql_password_file.or(mysql.password_file),
timeout: args.mysql_connect_timeout.or(mysql.timeout), timeout: args.mysql_connect_timeout.or(mysql.timeout),
}, },
}) })
} }
pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> { pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
log::debug!("Reading config from {:?}", &config_path); log::debug!("Reading config from {:?}", &config_path);
@ -97,30 +123,52 @@ pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<Ser
)) ))
} }
/// Use the provided configuration to establish a connection to a MySQL server. fn log_config(config: &MysqlConfig) {
pub async fn create_mysql_connection_from_config(
config: &MysqlConfig,
) -> anyhow::Result<MySqlConnection> {
let mut display_config = config.clone(); let mut display_config = config.clone();
"<REDACTED>".clone_into(&mut display_config.password); display_config.password = display_config
.password
.as_ref()
.map(|_| "<REDACTED>".to_owned());
log::debug!( log::debug!(
"Connecting to MySQL server with parameters: {:#?}", "Connecting to MySQL server with parameters: {:#?}",
display_config display_config
); );
}
/// Use the provided configuration to establish a connection to a MySQL server.
pub async fn create_mysql_connection_from_config(
config: &MysqlConfig,
) -> anyhow::Result<MySqlConnection> {
log_config(config);
let mut mysql_options = MySqlConnectOptions::new().database("mysql");
if let Some(username) = &config.username {
mysql_options = mysql_options.username(username);
}
if let Some(password) = &config.password {
mysql_options = mysql_options.password(password);
}
if let Some(socket_path) = &config.socket_path {
mysql_options = mysql_options.socket(socket_path);
} else if let Some(host) = &config.host {
mysql_options = mysql_options.host(host);
mysql_options = mysql_options.port(config.port.unwrap_or(DEFAULT_PORT));
} else {
anyhow::bail!("No MySQL host or socket path provided");
}
match tokio::time::timeout( match tokio::time::timeout(
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)), Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
MySqlConnectOptions::new() mysql_options.connect(),
.host(&config.host)
.username(&config.username)
.password(&config.password)
.port(config.port.unwrap_or(DEFAULT_PORT))
.database("mysql")
.connect(),
) )
.await .await
{ {
Ok(connection) => connection.context("Failed to connect to MySQL"), Ok(connection) => connection.context("Failed to connect to the database"),
Err(_) => Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to MySQL"), Err(_) => {
Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to the database")
}
} }
} }

View File

@ -1,7 +1,7 @@
use std::{collections::BTreeSet, fs, path::PathBuf}; use std::{collections::BTreeSet, fs, path::PathBuf};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio::io::AsyncWriteExt; use indoc::concatdoc;
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use sqlx::prelude::*; use sqlx::prelude::*;
@ -57,15 +57,43 @@ pub async fn listen_for_incoming_connections(
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
while let Ok((mut conn, _addr)) = listener.accept().await { while let Ok((conn, _addr)) = listener.accept().await {
let uid = conn.peer_cred()?.uid(); let uid = match conn.peer_cred() {
Ok(cred) => cred.uid(),
Err(e) => {
log::error!("Failed to get peer credentials from socket: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get peer credentials from socket\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue;
}
};
log::trace!("Accepted connection from uid {}", uid); log::trace!("Accepted connection from uid {}", uid);
let unix_user = match UnixUser::from_uid(uid) { let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user, Ok(user) => user,
Err(e) => { Err(e) => {
eprintln!("Failed to get UnixUser from uid: {}", e); log::error!("Failed to get username from uid: {}", e);
conn.shutdown().await?; let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get user data from the system\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue; continue;
} }
}; };
@ -73,9 +101,9 @@ pub async fn listen_for_incoming_connections(
log::info!("Accepted connection from {}", unix_user.username); log::info!("Accepted connection from {}", unix_user.username);
match handle_requests_for_single_session(conn, &unix_user, &config).await { match handle_requests_for_single_session(conn, &unix_user, &config).await {
Ok(_) => {} Ok(()) => {}
Err(e) => { Err(e) => {
eprintln!("Failed to run server: {}", e); log::error!("Failed to run server: {}", e);
} }
} }
} }
@ -88,8 +116,24 @@ pub async fn handle_requests_for_single_session(
unix_user: &UnixUser, unix_user: &UnixUser,
config: &ServerConfig, config: &ServerConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let message_stream = create_server_to_client_message_stream(socket); let mut message_stream = create_server_to_client_message_stream(socket);
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
Ok(connection) => connection,
Err(err) => {
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to connect to database\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await?;
message_stream.flush().await?;
return Err(err);
}
};
log::debug!("Successfully connected to database"); log::debug!("Successfully connected to database");
let result = handle_requests_for_single_session_with_db_connection( let result = handle_requests_for_single_session_with_db_connection(
@ -100,9 +144,9 @@ pub async fn handle_requests_for_single_session(
.await; .await;
if let Err(e) = db_connection.close().await { if let Err(e) = db_connection.close().await {
eprintln!("Failed to close database connection: {}", e); log::error!("Failed to close database connection: {}", e);
eprintln!("{}", e); log::error!("{}", e);
eprintln!("Ignoring..."); log::error!("Ignoring...");
} }
result result
@ -116,6 +160,7 @@ pub async fn handle_requests_for_single_session_with_db_connection(
unix_user: &UnixUser, unix_user: &UnixUser,
db_connection: &mut MySqlConnection, db_connection: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
stream.send(Response::Ready).await?;
loop { loop {
// TODO: better error handling // TODO: better error handling
let request = match stream.next().await { let request = match stream.next().await {
@ -133,17 +178,14 @@ pub async fn handle_requests_for_single_session_with_db_connection(
Request::CreateDatabases(databases_names) => { Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await; let result = create_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::CreateDatabases(result)).await?; stream.send(Response::CreateDatabases(result)).await?;
stream.flush().await?;
} }
Request::DropDatabases(databases_names) => { Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await; let result = drop_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::DropDatabases(result)).await?; stream.send(Response::DropDatabases(result)).await?;
stream.flush().await?;
} }
Request::ListDatabases => { Request::ListDatabases => {
let result = list_databases_for_user(unix_user, db_connection).await; let result = list_databases_for_user(unix_user, db_connection).await;
stream.send(Response::ListAllDatabases(result)).await?; stream.send(Response::ListAllDatabases(result)).await?;
stream.flush().await?;
} }
Request::ListPrivileges(database_names) => { Request::ListPrivileges(database_names) => {
let response = match database_names { let response = match database_names {
@ -161,7 +203,6 @@ pub async fn handle_requests_for_single_session_with_db_connection(
}; };
stream.send(response).await?; stream.send(response).await?;
stream.flush().await?;
} }
Request::ModifyPrivileges(database_privilege_diffs) => { Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs( let result = apply_privilege_diffs(
@ -171,24 +212,20 @@ pub async fn handle_requests_for_single_session_with_db_connection(
) )
.await; .await;
stream.send(Response::ModifyPrivileges(result)).await?; stream.send(Response::ModifyPrivileges(result)).await?;
stream.flush().await?;
} }
Request::CreateUsers(db_users) => { Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await; let result = create_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::CreateUsers(result)).await?; stream.send(Response::CreateUsers(result)).await?;
stream.flush().await?;
} }
Request::DropUsers(db_users) => { Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await; let result = drop_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::DropUsers(result)).await?; stream.send(Response::DropUsers(result)).await?;
stream.flush().await?;
} }
Request::PasswdUser(db_user, password) => { Request::PasswdUser(db_user, password) => {
let result = let result =
set_password_for_database_user(&db_user, &password, unix_user, db_connection) set_password_for_database_user(&db_user, &password, unix_user, db_connection)
.await; .await;
stream.send(Response::PasswdUser(result)).await?; stream.send(Response::PasswdUser(result)).await?;
stream.flush().await?;
} }
Request::ListUsers(db_users) => { Request::ListUsers(db_users) => {
let response = match db_users { let response = match db_users {
@ -203,22 +240,21 @@ pub async fn handle_requests_for_single_session_with_db_connection(
} }
}; };
stream.send(response).await?; stream.send(response).await?;
stream.flush().await?;
} }
Request::LockUsers(db_users) => { Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await; let result = lock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::LockUsers(result)).await?; stream.send(Response::LockUsers(result)).await?;
stream.flush().await?;
} }
Request::UnlockUsers(db_users) => { Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await; let result = unlock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::UnlockUsers(result)).await?; stream.send(Response::UnlockUsers(result)).await?;
stream.flush().await?;
} }
Request::Exit => { Request::Exit => {
break; break;
} }
} }
stream.flush().await?;
} }
Ok(()) Ok(())

View File

@ -32,7 +32,7 @@ use crate::{
}, },
}, },
server::{ server::{
common::create_user_group_matching_regex, common::{create_user_group_matching_regex, try_get_with_binary_fallback},
input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user}, input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
sql::database_operations::unsafe_database_exists, sql::database_operations::unsafe_database_exists,
}, },
@ -42,8 +42,8 @@ use crate::{
/// from the `db` table in the database. If you need to add or remove privilege /// from the `db` table in the database. If you need to add or remove privilege
/// fields, this is a good place to start. /// fields, this is a good place to start.
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db", "Db",
"user", "User",
"select_priv", "select_priv",
"insert_priv", "insert_priv",
"update_priv", "update_priv",
@ -97,6 +97,8 @@ impl DatabasePrivilegeRow {
} }
} }
// TODO: get by name instead of row tuple position
#[inline] #[inline]
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> { fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
let field = DATABASE_PRIVILEGE_FIELDS[position]; let field = DATABASE_PRIVILEGE_FIELDS[position];
@ -113,8 +115,8 @@ fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sql
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> { fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self { Ok(Self {
db: row.try_get("db")?, db: try_get_with_binary_fallback(row, "Db")?,
user: row.try_get("user")?, user: try_get_with_binary_fallback(row, "User")?,
select_priv: get_mysql_row_priv_field(row, 2)?, select_priv: get_mysql_row_priv_field(row, 2)?,
insert_priv: get_mysql_row_priv_field(row, 3)?, insert_priv: get_mysql_row_priv_field(row, 3)?,
update_priv: get_mysql_row_priv_field(row, 4)?, update_priv: get_mysql_row_priv_field(row, 4)?,
@ -137,7 +139,7 @@ async fn unsafe_get_database_privileges(
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?", "SELECT {} FROM `db` WHERE `Db` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
@ -166,7 +168,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?", "SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
@ -316,7 +318,7 @@ async fn unsafe_apply_privilege_diff(
.join(","); .join(",");
sqlx::query( sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(), format!("UPDATE `db` SET {} WHERE `Db` = ? AND `User` = ?", changes).as_str(),
) )
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
@ -325,7 +327,7 @@ async fn unsafe_apply_privilege_diff(
.map(|_| ()) .map(|_| ())
} }
DatabasePrivilegesDiff::Deleted(p) => { DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ?")
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.execute(connection) .execute(connection)

View File

@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use sqlx::prelude::*; use sqlx::prelude::*;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use crate::server::common::try_get_with_binary_fallback;
use crate::{ use crate::{
core::{ core::{
common::UnixUser, common::UnixUser,
@ -350,20 +351,6 @@ pub struct DatabaseUser {
pub databases: Vec<String>, pub databases: Vec<String>,
} }
/// Some mysql versions with some collations mark some columns as binary fields,
/// which in the current version of sqlx is not parsable as string.
/// See: https://github.com/launchbadge/sqlx/issues/3387
#[inline]
fn try_get_with_binary_fallback(
row: &sqlx::mysql::MySqlRow,
column: &str,
) -> Result<String, sqlx::Error> {
row.try_get(column).or_else(|_| {
row.try_get::<Vec<u8>, _>(column)
.map(|v| String::from_utf8_lossy(&v).to_string())
})
}
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser { impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> { fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self { Ok(Self {