Compare commits

..

4 Commits

Author SHA1 Message Date
Oystein Kristoffer Tveit a0be0d3b92
Wrap database users and database names in newtypes
Also, use less cloning where possible
2024-08-20 17:46:43 +02:00
Oystein Kristoffer Tveit 8c2754c9d7
cargo-deny: init 2024-08-20 17:46:43 +02:00
Oystein Kristoffer Tveit 338694a64e
Add more `--json` flags 2024-08-20 17:46:43 +02:00
Oystein Kristoffer Tveit cdb1fb4181
Integrate better with systemd + better logs and protocol usage
This commits adds the following:

- Better systemd integration and usage:
  - More hardening
  - A watchdog thread
  - Journald native logging

as well as

- Better logs
- Some protocol usage fixes
2024-08-20 17:46:40 +02:00
22 changed files with 945 additions and 379 deletions

31
Cargo.lock generated
View File

@ -253,6 +253,16 @@ dependencies = [
"clap_derive", "clap_derive",
] ]
[[package]]
name = "clap-verbosity-flag"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63d19864d6b68464c59f7162c9914a0b569ddc2926b4a2d71afe62a9738eff53"
dependencies = [
"clap",
"log",
]
[[package]] [[package]]
name = "clap_builder" name = "clap_builder"
version = "4.5.15" version = "4.5.15"
@ -993,6 +1003,9 @@ name = "log"
version = "0.4.22" version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
dependencies = [
"value-bag",
]
[[package]] [[package]]
name = "lru" name = "lru"
@ -1064,6 +1077,7 @@ dependencies = [
"async-bincode", "async-bincode",
"bincode", "bincode",
"clap", "clap",
"clap-verbosity-flag",
"clap_complete", "clap_complete",
"derive_more", "derive_more",
"dialoguer", "dialoguer",
@ -1082,6 +1096,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
"systemd-journal-logger",
"tokio", "tokio",
"tokio-serde", "tokio-serde",
"tokio-stream", "tokio-stream",
@ -1996,6 +2011,16 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "systemd-journal-logger"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5f3848dd723f2a54ac1d96da793b32923b52de8dfcced8722516dac312a5b2a"
dependencies = [
"log",
"rustix",
]
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.12.0" version = "3.12.0"
@ -2287,6 +2312,12 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "value-bag"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a84c137d37ab0142f0f2ddfe332651fdbf252e7b7dbb4e67b6c1f1b2e925101"
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"
version = "0.2.15" version = "0.2.15"

View File

@ -8,6 +8,7 @@ anyhow = "1.0.86"
async-bincode = "0.7.2" async-bincode = "0.7.2"
bincode = "1.3.3" bincode = "1.3.3"
clap = { version = "4.5.16", features = ["derive"] } clap = { version = "4.5.16", features = ["derive"] }
clap-verbosity-flag = "2.2.1"
clap_complete = "4.5.18" clap_complete = "4.5.18"
derive_more = { version = "1.0.0", features = ["display", "error"] } derive_more = { version = "1.0.0", features = ["display", "error"] }
dialoguer = "0.11.0" dialoguer = "0.11.0"
@ -25,6 +26,7 @@ sd-notify = "0.4.2"
serde = "1.0.208" serde = "1.0.208"
serde_json = { version = "1.0.125", features = ["preserve_order"] } serde_json = { version = "1.0.125", features = ["preserve_order"] }
sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] } sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] }
systemd-journal-logger = "2.1.1"
tokio = { version = "1.39.3", features = ["rt", "macros"] } tokio = { version = "1.39.3", features = ["rt", "macros"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.15" tokio-stream = "0.1.15"

78
deny.toml Normal file
View File

@ -0,0 +1,78 @@
[graph]
targets = [
"x86_64-unknown-linux-gnu",
"aarch64-unknown-linux-gnu",
"armv7-unknown-linux-gnueabihf",
"x86_64-unknown-freebsd",
"aarch64-unknown-freebsd",
"armv7-unknown-freebsd",
"x86_64-apple-darwin",
"aarch64-apple-darwin",
]
all-features = false
no-default-features = false
#features = []
[output]
feature-depth = 1
[advisories]
#db-path = "$CARGO_HOME/advisory-dbs"
#db-urls = ["https://github.com/rustsec/advisory-db"]
ignore = []
[licenses]
allow = [
"GPL-2.0",
"MIT",
"Apache-2.0",
"ISC",
"MPL-2.0",
"Unicode-DFS-2016",
"BSD-3-Clause",
"OpenSSL",
]
confidence-threshold = 0.8
exceptions = []
[[licenses.clarify]]
crate = "ring"
expression = "MIT AND ISC AND OpenSSL"
license-files = [
{ path = "LICENSE", hash = 0xbd0eed23 }
]
[licenses.private]
ignore = false
registries = []
[bans]
multiple-versions = "allow"
wildcards = "allow"
highlight = "all"
workspace-default-features = "allow"
external-default-features = "allow"
allow = []
deny = []
#[[bans.features]]
#crate = "reqwest"
#deny = ["json"]
#allow = []
#exact = true
skip = []
skip-tree = []
[sources]
unknown-registry = "warn"
unknown-git = "warn"
allow-registry = ["https://github.com/rust-lang/crates.io-index"]
allow-git = []
[sources.allow-org]

View File

@ -47,6 +47,7 @@
toolchain toolchain
mysql-client mysql-client
cargo-nextest cargo-nextest
cargo-deny
]; ];
RUST_SRC_PATH = "${toolchain}/lib/rustlib/src/rust/library"; RUST_SRC_PATH = "${toolchain}/lib/rustlib/src/rust/library";

View File

@ -15,6 +15,20 @@ in
description = "Create a local database user for mysqladm-rs"; description = "Create a local database user for mysqladm-rs";
}; };
logLevel = lib.mkOption {
type = lib.types.enum [ "quiet" "error" "warn" "info" "debug" "trace" ];
default = "debug";
description = "Log level for mysqladm-rs";
apply = level: {
"quiet" = "-q";
"error" = "";
"warn" = "-v";
"info" = "-vv";
"debug" = "-vvv";
"trace" = "-vvvv";
}.${level};
};
settings = lib.mkOption { settings = lib.mkOption {
default = { }; default = { };
type = lib.types.submodule { type = lib.types.submodule {
@ -76,7 +90,7 @@ in
name = cfg.settings.mysql.username; name = cfg.settings.mysql.username;
ensurePermissions = { ensurePermissions = {
"mysql.*" = "SELECT, INSERT, UPDATE, DELETE"; "mysql.*" = "SELECT, INSERT, UPDATE, DELETE";
"*.*" = "CREATE USER, GRANT OPTION"; "*.*" = "GRANT OPTION, CREATE, DROP";
}; };
} }
]; ];
@ -86,7 +100,9 @@ in
environment.RUST_LOG = "debug"; environment.RUST_LOG = "debug";
serviceConfig = { serviceConfig = {
Type = "notify"; Type = "notify";
ExecStart = "${lib.getExe cfg.package} server socket-activate --config ${configFile}"; ExecStart = "${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate --config ${configFile}";
WatchdogSec = 15;
User = "mysqladm"; User = "mysqladm";
Group = "mysqladm"; Group = "mysqladm";
@ -95,7 +111,18 @@ in
# This is required to read unix user/group details. # This is required to read unix user/group details.
PrivateUsers = false; PrivateUsers = false;
CapabilityBoundingSet = ""; # Needed to communicate with MySQL.
PrivateNetwork = false;
IPAddressDeny =
lib.optionals (lib.elem cfg.settings.mysql.host [ null "localhost" "127.0.0.1" ]) [ "any" ];
RestrictAddressFamilies = [ "AF_UNIX" ]
++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]);
AmbientCapabilities = [ "" ];
CapabilityBoundingSet = [ "" ];
DeviceAllow = [ "" ];
LockPersonality = true; LockPersonality = true;
MemoryDenyWriteExecute = true; MemoryDenyWriteExecute = true;
NoNewPrivileges = true; NoNewPrivileges = true;
@ -113,12 +140,12 @@ in
ProtectProc = "invisible"; ProtectProc = "invisible";
ProtectSystem = "strict"; ProtectSystem = "strict";
RemoveIPC = true; RemoveIPC = true;
UMask = "0000"; UMask = "0777";
RestrictAddressFamilies = [ "AF_UNIX" "AF_INET" "AF_INET6" ];
RestrictNamespaces = true; RestrictNamespaces = true;
RestrictRealtime = true; RestrictRealtime = true;
RestrictSUIDSGID = true; RestrictSUIDSGID = true;
SystemCallArchitectures = "native"; SystemCallArchitectures = "native";
SocketBindDeny = [ "any" ];
SystemCallFilter = [ SystemCallFilter = [
"@system-service" "@system-service"
"~@privileged" "~@privileged"

View File

@ -15,9 +15,10 @@ use crate::{
parse_privilege_table_cli_arg, parse_privilege_table_cli_arg,
}, },
protocol::{ protocol::{
print_create_databases_output_status, print_drop_databases_output_status, print_create_databases_output_status, print_create_databases_output_status_json,
print_modify_database_privileges_output_status, ClientToServerMessageStream, Request, print_drop_databases_output_status, print_drop_databases_output_status_json,
Response, print_modify_database_privileges_output_status, ClientToServerMessageStream,
MySQLDatabase, Request, Response,
}, },
}, },
server::sql::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS}, server::sql::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS},
@ -102,49 +103,57 @@ pub enum DatabaseCommand {
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct DatabaseCreateArgs { pub struct DatabaseCreateArgs {
/// The name of the database(s) to create. /// The name of the database(s) to create
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct DatabaseDropArgs { pub struct DatabaseDropArgs {
/// The name of the database(s) to drop. /// The name of the database(s) to drop
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct DatabaseShowArgs { pub struct DatabaseShowArgs {
/// The name of the database(s) to show. /// The name of the database(s) to show
#[arg(num_args = 0..)] #[arg(num_args = 0..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
/// Whether to output the information in JSON format. /// Print the information as JSON
#[arg(short, long)] #[arg(short, long)]
json: bool, json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct DatabaseShowPrivsArgs { pub struct DatabaseShowPrivsArgs {
/// The name of the database(s) to show. /// The name of the database(s) to show
#[arg(num_args = 0..)] #[arg(num_args = 0..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
/// Whether to output the information in JSON format. /// Print the information as JSON
#[arg(short, long)] #[arg(short, long)]
json: bool, json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct DatabaseEditPrivsArgs { pub struct DatabaseEditPrivsArgs {
/// The name of the database to edit privileges for. /// The name of the database to edit privileges for
pub name: Option<String>, pub name: Option<MySQLDatabase>,
#[arg(short, long, value_name = "[DATABASE:]USER:PRIVILEGES", num_args = 0..)] #[arg(short, long, value_name = "[DATABASE:]USER:PRIVILEGES", num_args = 0..)]
pub privs: Vec<String>, pub privs: Vec<String>,
/// Whether to output the information in JSON format. /// Print the information as JSON
#[arg(short, long)] #[arg(short, long)]
pub json: bool, pub json: bool,
@ -152,7 +161,7 @@ pub struct DatabaseEditPrivsArgs {
#[arg(short, long)] #[arg(short, long)]
pub editor: Option<String>, pub editor: Option<String>,
/// Disable interactive confirmation before saving changes. /// Disable interactive confirmation before saving changes
#[arg(short, long)] #[arg(short, long)]
pub yes: bool, pub yes: bool,
} }
@ -182,7 +191,7 @@ async fn create_databases(
anyhow::bail!("No database names provided"); anyhow::bail!("No database names provided");
} }
let message = Request::CreateDatabases(args.name.clone()); let message = Request::CreateDatabases(args.name.to_owned());
server_connection.send(message).await?; server_connection.send(message).await?;
let result = match server_connection.next().await { let result = match server_connection.next().await {
@ -192,7 +201,11 @@ async fn create_databases(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if args.json {
print_create_databases_output_status_json(&result);
} else {
print_create_databases_output_status(&result); print_create_databases_output_status(&result);
}
Ok(()) Ok(())
} }
@ -205,7 +218,7 @@ async fn drop_databases(
anyhow::bail!("No database names provided"); anyhow::bail!("No database names provided");
} }
let message = Request::DropDatabases(args.name.clone()); let message = Request::DropDatabases(args.name.to_owned());
server_connection.send(message).await?; server_connection.send(message).await?;
let result = match server_connection.next().await { let result = match server_connection.next().await {
@ -215,7 +228,11 @@ async fn drop_databases(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if args.json {
print_drop_databases_output_status_json(&result);
} else {
print_drop_databases_output_status(&result); print_drop_databases_output_status(&result);
};
Ok(()) Ok(())
} }
@ -227,11 +244,13 @@ async fn show_databases(
let message = if args.name.is_empty() { let message = if args.name.is_empty() {
Request::ListDatabases(None) Request::ListDatabases(None)
} else { } else {
Request::ListDatabases(Some(args.name.clone())) Request::ListDatabases(Some(args.name.to_owned()))
}; };
server_connection.send(message).await?; server_connection.send(message).await?;
// TODO: collect errors for json output.
let database_list = match server_connection.next().await { let database_list = match server_connection.next().await {
Some(Ok(Response::ListDatabases(databases))) => databases Some(Ok(Response::ListDatabases(databases))) => databases
.into_iter() .into_iter()
@ -282,7 +301,7 @@ async fn show_database_privileges(
let message = if args.name.is_empty() { let message = if args.name.is_empty() {
Request::ListPrivileges(None) Request::ListPrivileges(None)
} else { } else {
Request::ListPrivileges(Some(args.name.clone())) Request::ListPrivileges(Some(args.name.to_owned()))
}; };
server_connection.send(message).await?; server_connection.send(message).await?;
@ -354,7 +373,7 @@ pub async fn edit_database_privileges(
args: DatabaseEditPrivsArgs, args: DatabaseEditPrivsArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let message = Request::ListPrivileges(args.name.clone().map(|name| vec![name])); let message = Request::ListPrivileges(args.name.to_owned().map(|name| vec![name]));
server_connection.send(message).await?; server_connection.send(message).await?;
@ -386,13 +405,14 @@ pub async fn edit_database_privileges(
let privileges_to_change = if !args.privs.is_empty() { let privileges_to_change = if !args.privs.is_empty() {
parse_privilege_tables_from_args(&args)? parse_privilege_tables_from_args(&args)?
} else { } else {
edit_privileges_with_editor(&privilege_data, args.name.as_deref())? edit_privileges_with_editor(&privilege_data, args.name.as_ref())?
}; };
let diffs = diff_privileges(&privilege_data, &privileges_to_change); let diffs = diff_privileges(&privilege_data, &privileges_to_change);
if diffs.is_empty() { if diffs.is_empty() {
println!("No changes to make."); println!("No changes to make.");
server_connection.send(Request::Exit).await?;
return Ok(()); return Ok(());
} }
@ -451,7 +471,7 @@ fn parse_privilege_tables_from_args(
fn edit_privileges_with_editor( fn edit_privileges_with_editor(
privilege_data: &[DatabasePrivilegeRow], privilege_data: &[DatabasePrivilegeRow],
database_name: Option<&str>, database_name: Option<&MySQLDatabase>,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = User::from_uid(getuid()) let unix_user = User::from_uid(getuid())
.context("Failed to look up your UNIX username") .context("Failed to look up your UNIX username")

View File

@ -1,4 +1,11 @@
use crate::core::protocol::{MySQLDatabase, MySQLUser};
#[inline] #[inline]
pub fn trim_to_32_chars(name: &str) -> String { pub fn trim_db_name_to_32_chars(db_name: &MySQLDatabase) -> MySQLDatabase {
name.chars().take(32).collect() db_name.chars().take(32).collect::<String>().into()
}
#[inline]
pub fn trim_user_name_to_32_chars(user_name: &MySQLUser) -> MySQLUser {
user_name.chars().take(32).collect::<String>().into()
} }

View File

@ -9,7 +9,7 @@ use crate::{
common::erroneous_server_response, common::erroneous_server_response,
database_command, database_command,
mysql_admutils_compatibility::{ mysql_admutils_compatibility::{
common::trim_to_32_chars, common::trim_db_name_to_32_chars,
error_messages::{ error_messages::{
format_show_database_error_message, handle_create_database_error, format_show_database_error_message, handle_create_database_error,
handle_drop_database_error, handle_drop_database_error,
@ -20,7 +20,7 @@ use crate::{
bootstrap::bootstrap_server_connection_and_drop_privileges, bootstrap::bootstrap_server_connection_and_drop_privileges,
protocol::{ protocol::{
create_client_to_server_message_stream, ClientToServerMessageStream, create_client_to_server_message_stream, ClientToServerMessageStream,
GetDatabasesPrivilegeDataError, Request, Response, GetDatabasesPrivilegeDataError, MySQLDatabase, Request, Response,
}, },
}, },
server::sql::database_privilege_operations::DatabasePrivilegeRow, server::sql::database_privilege_operations::DatabasePrivilegeRow,
@ -120,27 +120,27 @@ pub enum Command {
pub struct CreateArgs { pub struct CreateArgs {
/// The name of the DATABASE(s) to create. /// The name of the DATABASE(s) to create.
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
} }
#[derive(Parser)] #[derive(Parser)]
pub struct DatabaseDropArgs { pub struct DatabaseDropArgs {
/// The name of the DATABASE(s) to drop. /// The name of the DATABASE(s) to drop.
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
} }
#[derive(Parser)] #[derive(Parser)]
pub struct DatabaseShowArgs { pub struct DatabaseShowArgs {
/// The name of the DATABASE(s) to show. /// The name of the DATABASE(s) to show.
#[arg(num_args = 0..)] #[arg(num_args = 0..)]
name: Vec<String>, name: Vec<MySQLDatabase>,
} }
#[derive(Parser)] #[derive(Parser)]
pub struct EditPermArgs { pub struct EditPermArgs {
/// The name of the DATABASE to edit permissions for. /// The name of the DATABASE to edit permissions for.
pub database: String, pub database: MySQLDatabase,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
@ -202,11 +202,7 @@ async fn create_databases(
args: CreateArgs, args: CreateArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let database_names = args let database_names = args.name.iter().map(trim_db_name_to_32_chars).collect();
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::CreateDatabases(database_names); let message = Request::CreateDatabases(database_names);
server_connection.send(message).await?; server_connection.send(message).await?;
@ -232,11 +228,7 @@ async fn drop_databases(
args: DatabaseDropArgs, args: DatabaseDropArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let database_names = args let database_names = args.name.iter().map(trim_db_name_to_32_chars).collect();
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::DropDatabases(database_names); let message = Request::DropDatabases(database_names);
server_connection.send(message).await?; server_connection.send(message).await?;
@ -262,11 +254,8 @@ async fn show_databases(
args: DatabaseShowArgs, args: DatabaseShowArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let database_names: Vec<String> = args let database_names: Vec<MySQLDatabase> =
.name args.name.iter().map(trim_db_name_to_32_chars).collect();
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = if database_names.is_empty() { let message = if database_names.is_empty() {
let message = Request::ListDatabases(None); let message = Request::ListDatabases(None);
@ -291,14 +280,16 @@ async fn show_databases(
// NOTE: mysql-dbadm show has a quirk where valid database names // NOTE: mysql-dbadm show has a quirk where valid database names
// for non-existent databases will report with no users. // for non-existent databases will report with no users.
let results: Vec<Result<(String, Vec<DatabasePrivilegeRow>), String>> = match response { let results: Vec<Result<(MySQLDatabase, Vec<DatabasePrivilegeRow>), String>> = match response {
Some(Ok(Response::ListPrivileges(result))) => result Some(Ok(Response::ListPrivileges(result))) => result
.into_iter() .into_iter()
.map(|(name, rows)| match rows.map(|rows| (name.clone(), rows)) { .map(
|(name, rows)| match rows.map(|rows| (name.to_owned(), rows)) {
Ok(rows) => Ok(rows), Ok(rows) => Ok(rows),
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])), Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])),
Err(err) => Err(format_show_database_error_message(err, &name)), Err(err) => Err(format_show_database_error_message(err, &name)),
}) },
)
.collect(), .collect(),
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };

View File

@ -9,7 +9,7 @@ use crate::{
cli::{ cli::{
common::erroneous_server_response, common::erroneous_server_response,
mysql_admutils_compatibility::{ mysql_admutils_compatibility::{
common::trim_to_32_chars, common::trim_user_name_to_32_chars,
error_messages::{ error_messages::{
handle_create_user_error, handle_drop_user_error, handle_list_users_error, handle_create_user_error, handle_drop_user_error, handle_list_users_error,
}, },
@ -19,7 +19,8 @@ use crate::{
core::{ core::{
bootstrap::bootstrap_server_connection_and_drop_privileges, bootstrap::bootstrap_server_connection_and_drop_privileges,
protocol::{ protocol::{
create_client_to_server_message_stream, ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream, ClientToServerMessageStream, MySQLUser,
Request, Response,
}, },
}, },
server::sql::user_operations::DatabaseUser, server::sql::user_operations::DatabaseUser,
@ -83,28 +84,28 @@ pub enum Command {
pub struct CreateArgs { pub struct CreateArgs {
/// The name of the USER(s) to create. /// The name of the USER(s) to create.
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLUser>,
} }
#[derive(Parser)] #[derive(Parser)]
pub struct DeleteArgs { pub struct DeleteArgs {
/// The name of the USER(s) to delete. /// The name of the USER(s) to delete.
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLUser>,
} }
#[derive(Parser)] #[derive(Parser)]
pub struct PasswdArgs { pub struct PasswdArgs {
/// The name of the USER(s) to change the password for. /// The name of the USER(s) to change the password for.
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
name: Vec<String>, name: Vec<MySQLUser>,
} }
#[derive(Parser)] #[derive(Parser)]
pub struct ShowArgs { pub struct ShowArgs {
/// The name of the USER(s) to show. /// The name of the USER(s) to show.
#[arg(num_args = 0..)] #[arg(num_args = 0..)]
name: Vec<String>, name: Vec<MySQLUser>,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
@ -152,13 +153,9 @@ async fn create_user(
args: CreateArgs, args: CreateArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let usernames = args let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::CreateUsers(usernames); let message = Request::CreateUsers(db_users);
server_connection.send(message).await?; server_connection.send(message).await?;
let result = match server_connection.next().await { let result = match server_connection.next().await {
@ -182,13 +179,9 @@ async fn drop_users(
args: DeleteArgs, args: DeleteArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let usernames = args let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::DropUsers(usernames); let message = Request::DropUsers(db_users);
server_connection.send(message).await?; server_connection.send(message).await?;
let result = match server_connection.next().await { let result = match server_connection.next().await {
@ -212,13 +205,9 @@ async fn passwd_users(
args: PasswdArgs, args: PasswdArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let usernames = args let db_users = args.name.iter().map(trim_user_name_to_32_chars).collect();
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::ListUsers(Some(usernames)); let message = Request::ListUsers(Some(db_users));
server_connection.send(message).await?; server_connection.send(message).await?;
let response = match server_connection.next().await { let response = match server_connection.next().await {
@ -243,11 +232,11 @@ async fn passwd_users(
for user in users { for user in users {
let password = read_password_from_stdin_with_double_check(&user.user)?; let password = read_password_from_stdin_with_double_check(&user.user)?;
let message = Request::PasswdUser(user.user.clone(), password); let message = Request::PasswdUser(user.user.to_owned(), password);
server_connection.send(message).await?; server_connection.send(message).await?;
match server_connection.next().await { match server_connection.next().await {
Some(Ok(Response::PasswdUser(result))) => match result { Some(Ok(Response::PasswdUser(result))) => match result {
Ok(()) => println!("Password updated for user '{}'.", user.user), Ok(()) => println!("Password updated for user '{}'.", &user.user),
Err(_) => eprintln!( Err(_) => eprintln!(
"{}: Failed to update password for user '{}'.", "{}: Failed to update password for user '{}'.",
argv0, user.user, argv0, user.user,
@ -266,16 +255,12 @@ async fn show_users(
args: ShowArgs, args: ShowArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let usernames: Vec<_> = args let db_users: Vec<_> = args.name.iter().map(trim_user_name_to_32_chars).collect();
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = if usernames.is_empty() { let message = if db_users.is_empty() {
Request::ListUsers(None) Request::ListUsers(None)
} else { } else {
Request::ListUsers(Some(usernames)) Request::ListUsers(Some(db_users))
}; };
server_connection.send(message).await?; server_connection.send(message).await?;

View File

@ -4,10 +4,12 @@ use dialoguer::{Confirm, Password};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use crate::core::protocol::{ use crate::core::protocol::{
print_create_users_output_status, print_drop_users_output_status, print_create_users_output_status, print_create_users_output_status_json,
print_lock_users_output_status, print_set_password_output_status, print_drop_users_output_status, print_drop_users_output_status_json,
print_unlock_users_output_status, ClientToServerMessageStream, ListUsersError, Request, print_lock_users_output_status, print_lock_users_output_status_json,
Response, print_set_password_output_status, print_unlock_users_output_status,
print_unlock_users_output_status_json, ClientToServerMessageStream, ListUsersError, MySQLUser,
Request, Response,
}; };
use super::common::erroneous_server_response; use super::common::erroneous_server_response;
@ -51,46 +53,69 @@ pub enum UserCommand {
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct UserCreateArgs { pub struct UserCreateArgs {
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
username: Vec<String>, username: Vec<MySQLUser>,
/// Do not ask for a password, leave it unset /// Do not ask for a password, leave it unset
#[clap(long)] #[clap(long)]
no_password: bool, no_password: bool,
/// Print the information as JSON
///
/// Note that this implies `--no-password`, since the command will become non-interactive.
#[arg(short, long)]
json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct UserDeleteArgs { pub struct UserDeleteArgs {
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
username: Vec<String>, username: Vec<MySQLUser>,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct UserPasswdArgs { pub struct UserPasswdArgs {
username: String, username: MySQLUser,
#[clap(short, long)] #[clap(short, long)]
password_file: Option<String>, password_file: Option<String>,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct UserShowArgs { pub struct UserShowArgs {
#[arg(num_args = 0..)] #[arg(num_args = 0..)]
username: Vec<String>, username: Vec<MySQLUser>,
#[clap(short, long)] /// Print the information as JSON
#[arg(short, long)]
json: bool, json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct UserLockArgs { pub struct UserLockArgs {
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
username: Vec<String>, username: Vec<MySQLUser>,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct UserUnlockArgs { pub struct UserUnlockArgs {
#[arg(num_args = 1..)] #[arg(num_args = 1..)]
username: Vec<String>, username: Vec<MySQLUser>,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
} }
pub async fn handle_command( pub async fn handle_command(
@ -115,7 +140,7 @@ async fn create_users(
anyhow::bail!("No usernames provided"); anyhow::bail!("No usernames provided");
} }
let message = Request::CreateUsers(args.username.clone()); let message = Request::CreateUsers(args.username.to_owned());
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
@ -126,6 +151,9 @@ async fn create_users(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
if args.json {
print_create_users_output_status_json(&result);
} else {
print_create_users_output_status(&result); print_create_users_output_status(&result);
let successfully_created_users = result let successfully_created_users = result
@ -144,7 +172,7 @@ async fn create_users(
.interact()? .interact()?
{ {
let password = read_password_from_stdin_with_double_check(username)?; let password = read_password_from_stdin_with_double_check(username)?;
let message = Request::PasswdUser(username.clone(), password); let message = Request::PasswdUser(username.to_owned(), password);
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
@ -161,6 +189,7 @@ async fn create_users(
println!(); println!();
} }
} }
}
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
@ -175,7 +204,7 @@ async fn drop_users(
anyhow::bail!("No usernames provided"); anyhow::bail!("No usernames provided");
} }
let message = Request::DropUsers(args.username.clone()); let message = Request::DropUsers(args.username.to_owned());
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
@ -189,12 +218,16 @@ async fn drop_users(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if args.json {
print_drop_users_output_status_json(&result);
} else {
print_drop_users_output_status(&result); print_drop_users_output_status(&result);
}
Ok(()) Ok(())
} }
pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> { pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
Password::new() Password::new()
.with_prompt(format!("New MySQL password for user '{}'", username)) .with_prompt(format!("New MySQL password for user '{}'", username))
.with_confirmation( .with_confirmation(
@ -210,7 +243,7 @@ async fn passwd_user(
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// TODO: create a "user" exists check" command // TODO: create a "user" exists check" command
let message = Request::ListUsers(Some(vec![args.username.clone()])); let message = Request::ListUsers(Some(vec![args.username.to_owned()]));
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
anyhow::bail!(err); anyhow::bail!(err);
@ -240,7 +273,7 @@ async fn passwd_user(
read_password_from_stdin_with_double_check(&args.username)? read_password_from_stdin_with_double_check(&args.username)?
}; };
let message = Request::PasswdUser(args.username.clone(), password); let message = Request::PasswdUser(args.username.to_owned(), password);
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
@ -266,7 +299,7 @@ async fn show_users(
let message = if args.username.is_empty() { let message = if args.username.is_empty() {
Request::ListUsers(None) Request::ListUsers(None)
} else { } else {
Request::ListUsers(Some(args.username.clone())) Request::ListUsers(Some(args.username.to_owned()))
}; };
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
@ -337,7 +370,7 @@ async fn lock_users(
anyhow::bail!("No usernames provided"); anyhow::bail!("No usernames provided");
} }
let message = Request::LockUsers(args.username.clone()); let message = Request::LockUsers(args.username.to_owned());
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
@ -351,7 +384,11 @@ async fn lock_users(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if args.json {
print_lock_users_output_status_json(&result);
} else {
print_lock_users_output_status(&result); print_lock_users_output_status(&result);
}
Ok(()) Ok(())
} }
@ -364,7 +401,7 @@ async fn unlock_users(
anyhow::bail!("No usernames provided"); anyhow::bail!("No usernames provided");
} }
let message = Request::UnlockUsers(args.username.clone()); let message = Request::UnlockUsers(args.username.to_owned());
if let Err(err) = server_connection.send(message).await { if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok(); server_connection.close().await.ok();
@ -378,7 +415,11 @@ async fn unlock_users(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if args.json {
print_unlock_users_output_status_json(&result);
} else {
print_unlock_users_output_status(&result); print_unlock_users_output_status(&result);
}
Ok(()) Ok(())
} }

View File

@ -54,7 +54,7 @@ impl UnixUser {
Ok(UnixUser { Ok(UnixUser {
username: libc_user.name, username: libc_user.name,
groups: groups.iter().map(|g| g.name.clone()).collect(), groups: groups.iter().map(|g| g.name.to_owned()).collect(),
}) })
} }

View File

@ -7,7 +7,10 @@ use std::{
collections::{BTreeSet, HashMap}, collections::{BTreeSet, HashMap},
}; };
use super::common::{rev_yn, yn}; use super::{
common::{rev_yn, yn},
protocol::{MySQLDatabase, MySQLUser},
};
use crate::server::sql::database_privilege_operations::{ use crate::server::sql::database_privilege_operations::{
DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS, DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS,
}; };
@ -35,8 +38,8 @@ pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> Databas
debug_assert!(row1.db == row2.db && row1.user == row2.user); debug_assert!(row1.db == row2.db && row1.user == row2.user);
DatabasePrivilegeRowDiff { DatabasePrivilegeRowDiff {
db: row1.db.clone(), db: row1.db.to_owned(),
user: row1.user.clone(), user: row1.user.to_owned(),
diff: DATABASE_PRIVILEGE_FIELDS diff: DATABASE_PRIVILEGE_FIELDS
.into_iter() .into_iter()
.skip(2) .skip(2)
@ -70,8 +73,8 @@ pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivil
anyhow::bail!("Username cannot be empty."); anyhow::bail!("Username cannot be empty.");
} }
let db = parts[0].to_string(); let db = parts[0].into();
let user = parts[1].to_string(); let user = parts[1].into();
let privs = parts[2].to_string(); let privs = parts[2].to_string();
let mut result = DatabasePrivilegeRow { let mut result = DatabasePrivilegeRow {
@ -165,11 +168,11 @@ const EDITOR_COMMENT: &str = r#"
pub fn generate_editor_content_from_privilege_data( pub fn generate_editor_content_from_privilege_data(
privilege_data: &[DatabasePrivilegeRow], privilege_data: &[DatabasePrivilegeRow],
unix_user: &str, unix_user: &str,
database_name: Option<&str>, database_name: Option<&MySQLDatabase>,
) -> String { ) -> String {
let example_user = format!("{}_user", unix_user); let example_user = format!("{}_user", unix_user);
let example_db = database_name let example_db = database_name
.unwrap_or(&format!("{}_db", unix_user)) .unwrap_or(&format!("{}_db", unix_user).into())
.to_string(); .to_string();
// NOTE: `.max()`` fails when the iterator is empty. // NOTE: `.max()`` fails when the iterator is empty.
@ -206,8 +209,8 @@ pub fn generate_editor_content_from_privilege_data(
let example_line = format_privileges_line_for_editor( let example_line = format_privileges_line_for_editor(
&DatabasePrivilegeRow { &DatabasePrivilegeRow {
db: example_db, db: example_db.into(),
user: example_user, user: example_user.into(),
select_priv: true, select_priv: true,
insert_priv: true, insert_priv: true,
update_priv: true, update_priv: true,
@ -298,8 +301,8 @@ fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult {
} }
let row = DatabasePrivilegeRow { let row = DatabasePrivilegeRow {
db: (*parts.first().unwrap()).to_owned(), db: (*parts.first().unwrap()).into(),
user: (*parts.get(1).unwrap()).to_owned(), user: (*parts.get(1).unwrap()).into(),
select_priv: match parse_privilege_cell_from_editor( select_priv: match parse_privilege_cell_from_editor(
parts.get(2).unwrap(), parts.get(2).unwrap(),
DATABASE_PRIVILEGE_FIELDS[2], DATABASE_PRIVILEGE_FIELDS[2],
@ -423,8 +426,8 @@ pub fn parse_privilege_data_from_editor_content(
/// The `User` and `Database` are the same for both instances. /// The `User` and `Database` are the same for both instances.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRowDiff { pub struct DatabasePrivilegeRowDiff {
pub db: String, pub db: MySQLDatabase,
pub user: String, pub user: MySQLUser,
pub diff: BTreeSet<DatabasePrivilegeChange>, pub diff: BTreeSet<DatabasePrivilegeChange>,
} }
@ -454,7 +457,7 @@ pub enum DatabasePrivilegesDiff {
} }
impl DatabasePrivilegesDiff { impl DatabasePrivilegesDiff {
pub fn get_database_name(&self) -> &str { pub fn get_database_name(&self) -> &MySQLDatabase {
match self { match self {
DatabasePrivilegesDiff::New(p) => &p.db, DatabasePrivilegesDiff::New(p) => &p.db,
DatabasePrivilegesDiff::Modified(p) => &p.db, DatabasePrivilegesDiff::Modified(p) => &p.db,
@ -462,7 +465,7 @@ impl DatabasePrivilegesDiff {
} }
} }
pub fn get_user_name(&self) -> &str { pub fn get_user_name(&self) -> &MySQLUser {
match self { match self {
DatabasePrivilegesDiff::New(p) => &p.user, DatabasePrivilegesDiff::New(p) => &p.user,
DatabasePrivilegesDiff::Modified(p) => &p.user, DatabasePrivilegesDiff::Modified(p) => &p.user,
@ -478,34 +481,36 @@ pub fn diff_privileges(
from: &[DatabasePrivilegeRow], from: &[DatabasePrivilegeRow],
to: &[DatabasePrivilegeRow], to: &[DatabasePrivilegeRow],
) -> BTreeSet<DatabasePrivilegesDiff> { ) -> BTreeSet<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> =
HashMap::from_iter(
from.iter() from.iter()
.cloned() .cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)), .map(|p| ((p.db.to_owned(), p.user.to_owned()), p)),
); );
let to_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( let to_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> =
HashMap::from_iter(
to.iter() to.iter()
.cloned() .cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)), .map(|p| ((p.db.to_owned(), p.user.to_owned()), p)),
); );
let mut result = BTreeSet::new(); let mut result = BTreeSet::new();
for p in to { for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { if let Some(old_p) = from_lookup_table.get(&(p.db.to_owned(), p.user.to_owned())) {
let diff = diff(old_p, p); let diff = diff(old_p, p);
if !diff.diff.is_empty() { if !diff.diff.is_empty() {
result.insert(DatabasePrivilegesDiff::Modified(diff)); result.insert(DatabasePrivilegesDiff::Modified(diff));
} }
} else { } else {
result.insert(DatabasePrivilegesDiff::New(p.clone())); result.insert(DatabasePrivilegesDiff::New(p.to_owned()));
} }
} }
for p in from { for p in from {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { if !to_lookup_table.contains_key(&(p.db.to_owned(), p.user.to_owned())) {
result.insert(DatabasePrivilegesDiff::Deleted(p.clone())); result.insert(DatabasePrivilegesDiff::Deleted(p.to_owned()));
} }
} }
@ -593,8 +598,8 @@ mod tests {
assert_eq!( assert_eq!(
result.ok(), result.ok(),
Some(DatabasePrivilegeRow { Some(DatabasePrivilegeRow {
db: "db".to_owned(), db: "db".into(),
user: "user".to_owned(), user: "user".into(),
select_priv: true, select_priv: true,
insert_priv: true, insert_priv: true,
update_priv: true, update_priv: true,
@ -613,8 +618,8 @@ mod tests {
assert_eq!( assert_eq!(
result.ok(), result.ok(),
Some(DatabasePrivilegeRow { Some(DatabasePrivilegeRow {
db: "db".to_owned(), db: "db".into(),
user: "user".to_owned(), user: "user".into(),
select_priv: false, select_priv: false,
insert_priv: false, insert_priv: false,
update_priv: false, update_priv: false,
@ -633,8 +638,8 @@ mod tests {
assert_eq!( assert_eq!(
result.ok(), result.ok(),
Some(DatabasePrivilegeRow { Some(DatabasePrivilegeRow {
db: "db".to_owned(), db: "db".into(),
user: "user".to_owned(), user: "user".into(),
select_priv: true, select_priv: true,
insert_priv: true, insert_priv: true,
update_priv: true, update_priv: true,
@ -668,8 +673,8 @@ mod tests {
#[test] #[test]
fn test_diff_privileges() { fn test_diff_privileges() {
let row_to_be_modified = DatabasePrivilegeRow { let row_to_be_modified = DatabasePrivilegeRow {
db: "db".to_owned(), db: "db".into(),
user: "user".to_owned(), user: "user".into(),
select_priv: true, select_priv: true,
insert_priv: true, insert_priv: true,
update_priv: true, update_priv: true,
@ -683,20 +688,20 @@ mod tests {
references_priv: false, references_priv: false,
}; };
let mut row_to_be_deleted = row_to_be_modified.clone(); let mut row_to_be_deleted = row_to_be_modified.to_owned();
"user2".clone_into(&mut row_to_be_deleted.user); "user2".clone_into(&mut row_to_be_deleted.user);
let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()]; let from = vec![row_to_be_modified.to_owned(), row_to_be_deleted.to_owned()];
let mut modified_row = row_to_be_modified.clone(); let mut modified_row = row_to_be_modified.to_owned();
modified_row.select_priv = false; modified_row.select_priv = false;
modified_row.insert_priv = false; modified_row.insert_priv = false;
modified_row.index_priv = true; modified_row.index_priv = true;
let mut new_row = row_to_be_modified.clone(); let mut new_row = row_to_be_modified.to_owned();
"user3".clone_into(&mut new_row.user); "user3".clone_into(&mut new_row.user);
let to = vec![modified_row.clone(), new_row.clone()]; let to = vec![modified_row.to_owned(), new_row.to_owned()];
let diffs = diff_privileges(&from, &to); let diffs = diff_privileges(&from, &to);
@ -705,8 +710,8 @@ mod tests {
BTreeSet::from_iter(vec![ BTreeSet::from_iter(vec![
DatabasePrivilegesDiff::Deleted(row_to_be_deleted), DatabasePrivilegesDiff::Deleted(row_to_be_deleted),
DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff { DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "db".to_owned(), db: "db".into(),
user: "user".to_owned(), user: "user".into(),
diff: BTreeSet::from_iter(vec![ diff: BTreeSet::from_iter(vec![
DatabasePrivilegeChange::YesToNo("select_priv".to_owned()), DatabasePrivilegeChange::YesToNo("select_priv".to_owned()),
DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()), DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()),
@ -722,8 +727,8 @@ mod tests {
fn ensure_generated_and_parsed_editor_content_is_equal() { fn ensure_generated_and_parsed_editor_content_is_equal() {
let permissions = vec![ let permissions = vec![
DatabasePrivilegeRow { DatabasePrivilegeRow {
db: "db".to_owned(), db: "db".into(),
user: "user".to_owned(), user: "user".into(),
select_priv: true, select_priv: true,
insert_priv: true, insert_priv: true,
update_priv: true, update_priv: true,
@ -737,8 +742,8 @@ mod tests {
references_priv: true, references_priv: true,
}, },
DatabasePrivilegeRow { DatabasePrivilegeRow {
db: "db2".to_owned(), db: "db".into(),
user: "user2".to_owned(), user: "user".into(),
select_priv: false, select_priv: false,
insert_priv: false, insert_priv: false,
update_priv: false, update_priv: false,

View File

@ -1,4 +1,9 @@
use std::collections::BTreeSet; use std::{
collections::BTreeSet,
fmt::{Display, Formatter},
ops::{Deref, DerefMut},
str::FromStr,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::UnixStream; use tokio::net::UnixStream;
@ -31,21 +36,107 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer
tokio_serde::Framed::new(length_delimited, Bincode::default()) tokio_serde::Framed::new(length_delimited, Bincode::default())
} }
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct MySQLUser(String);
impl FromStr for MySQLUser {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(MySQLUser(s.to_string()))
}
}
impl Deref for MySQLUser {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for MySQLUser {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Display for MySQLUser {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for MySQLUser {
fn from(s: &str) -> Self {
MySQLUser(s.to_string())
}
}
impl From<String> for MySQLUser {
fn from(s: String) -> Self {
MySQLUser(s)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct MySQLDatabase(String);
impl FromStr for MySQLDatabase {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(MySQLDatabase(s.to_string()))
}
}
impl Deref for MySQLDatabase {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for MySQLDatabase {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Display for MySQLDatabase {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for MySQLDatabase {
fn from(s: &str) -> Self {
MySQLDatabase(s.to_string())
}
}
impl From<String> for MySQLDatabase {
fn from(s: String) -> Self {
MySQLDatabase(s)
}
}
#[non_exhaustive] #[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Request { pub enum Request {
CreateDatabases(Vec<String>), CreateDatabases(Vec<MySQLDatabase>),
DropDatabases(Vec<String>), DropDatabases(Vec<MySQLDatabase>),
ListDatabases(Option<Vec<String>>), ListDatabases(Option<Vec<MySQLDatabase>>),
ListPrivileges(Option<Vec<String>>), ListPrivileges(Option<Vec<MySQLDatabase>>),
ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>), ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>),
CreateUsers(Vec<String>), CreateUsers(Vec<MySQLUser>),
DropUsers(Vec<String>), DropUsers(Vec<MySQLUser>),
PasswdUser(String, String), PasswdUser(MySQLUser, String),
ListUsers(Option<Vec<String>>), ListUsers(Option<Vec<MySQLUser>>),
LockUsers(Vec<String>), LockUsers(Vec<MySQLUser>),
UnlockUsers(Vec<String>), UnlockUsers(Vec<MySQLUser>),
// Commit, // Commit,
Exit, Exit,

View File

@ -3,6 +3,7 @@ use std::collections::BTreeMap;
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::{ use crate::{
core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff}, core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff},
@ -12,6 +13,8 @@ use crate::{
}, },
}; };
use super::{MySQLDatabase, MySQLUser};
/// This enum is used to differentiate between database and user operations. /// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used. /// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -21,17 +24,17 @@ pub enum DbOrUser {
} }
impl DbOrUser { impl DbOrUser {
pub fn lowercased(&self) -> String { pub fn lowercased(&self) -> &'static str {
match self { match self {
DbOrUser::Database => "database".to_string(), DbOrUser::Database => "database",
DbOrUser::User => "user".to_string(), DbOrUser::User => "user",
} }
} }
pub fn capitalized(&self) -> String { pub fn capitalized(&self) -> &'static str {
match self { match self {
DbOrUser::Database => "Database".to_string(), DbOrUser::Database => "Database",
DbOrUser::User => "User".to_string(), DbOrUser::User => "User",
} }
} }
} }
@ -72,6 +75,11 @@ impl OwnerValidationError {
pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String { pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String {
let user = UnixUser::from_enviroment(); let user = UnixUser::from_enviroment();
let UnixUser { username, groups } = user.unwrap_or(UnixUser {
username: "???".to_string(),
groups: vec![],
});
match self { match self {
OwnerValidationError::NoMatch => format!( OwnerValidationError::NoMatch => format!(
indoc! {r#" indoc! {r#"
@ -87,11 +95,8 @@ impl OwnerValidationError {
name, name,
db_or_user.lowercased(), db_or_user.lowercased(),
db_or_user.lowercased(), db_or_user.lowercased(),
user.as_ref() username,
.map(|u| u.username.clone()) groups
.unwrap_or("???".to_string()),
user.map(|u| u.groups)
.unwrap_or_default()
.iter() .iter()
.map(|g| format!(" - {}", g)) .map(|g| format!(" - {}", g))
.sorted() .sorted()
@ -117,7 +122,7 @@ pub enum OwnerValidationError {
StringEmpty, StringEmpty,
} }
pub type CreateDatabasesOutput = BTreeMap<String, Result<(), CreateDatabaseError>>; pub type CreateDatabasesOutput = BTreeMap<MySQLDatabase, Result<(), CreateDatabaseError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CreateDatabaseError { pub enum CreateDatabaseError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -141,8 +146,29 @@ pub fn print_create_databases_output_status(output: &CreateDatabasesOutput) {
} }
} }
pub fn print_create_databases_output_status_json(output: &CreateDatabasesOutput) {
let value = output
.iter()
.map(|(name, result)| match result {
Ok(()) => (name.to_string(), json!({ "status": "success" })),
Err(err) => (
name.to_string(),
json!({
"status": "error",
"error": err.to_error_message(name),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl CreateDatabaseError { impl CreateDatabaseError {
pub fn to_error_message(&self, database_name: &str) -> String { pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self { match self {
CreateDatabaseError::SanitizationError(err) => { CreateDatabaseError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database) err.to_error_message(database_name, DbOrUser::Database)
@ -160,7 +186,7 @@ impl CreateDatabaseError {
} }
} }
pub type DropDatabasesOutput = BTreeMap<String, Result<(), DropDatabaseError>>; pub type DropDatabasesOutput = BTreeMap<MySQLDatabase, Result<(), DropDatabaseError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DropDatabaseError { pub enum DropDatabaseError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -173,7 +199,10 @@ pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) {
for (database_name, result) in output { for (database_name, result) in output {
match result { match result {
Ok(()) => { Ok(()) => {
println!("Database '{}' dropped successfully.", database_name); println!(
"Database '{}' dropped successfully.",
database_name.as_str()
);
} }
Err(err) => { Err(err) => {
println!("{}", err.to_error_message(database_name)); println!("{}", err.to_error_message(database_name));
@ -184,8 +213,29 @@ pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) {
} }
} }
pub fn print_drop_databases_output_status_json(output: &DropDatabasesOutput) {
let value = output
.iter()
.map(|(name, result)| match result {
Ok(()) => (name.to_string(), json!({ "status": "success" })),
Err(err) => (
name.to_string(),
json!({
"status": "error",
"error": err.to_error_message(name),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl DropDatabaseError { impl DropDatabaseError {
pub fn to_error_message(&self, database_name: &str) -> String { pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self { match self {
DropDatabaseError::SanitizationError(err) => { DropDatabaseError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database) err.to_error_message(database_name, DbOrUser::Database)
@ -203,7 +253,7 @@ impl DropDatabaseError {
} }
} }
pub type ListDatabasesOutput = BTreeMap<String, Result<DatabaseRow, ListDatabasesError>>; pub type ListDatabasesOutput = BTreeMap<MySQLDatabase, Result<DatabaseRow, ListDatabasesError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListDatabasesError { pub enum ListDatabasesError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -213,7 +263,7 @@ pub enum ListDatabasesError {
} }
impl ListDatabasesError { impl ListDatabasesError {
pub fn to_error_message(&self, database_name: &str) -> String { pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self { match self {
ListDatabasesError::SanitizationError(err) => { ListDatabasesError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database) err.to_error_message(database_name, DbOrUser::Database)
@ -250,7 +300,7 @@ impl ListAllDatabasesError {
// no need to index by database name. // no need to index by database name.
pub type GetDatabasesPrivilegeData = pub type GetDatabasesPrivilegeData =
BTreeMap<String, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>; BTreeMap<MySQLDatabase, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum GetDatabasesPrivilegeDataError { pub enum GetDatabasesPrivilegeDataError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -260,7 +310,7 @@ pub enum GetDatabasesPrivilegeDataError {
} }
impl GetDatabasesPrivilegeDataError { impl GetDatabasesPrivilegeDataError {
pub fn to_error_message(&self, database_name: &str) -> String { pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self { match self {
GetDatabasesPrivilegeDataError::SanitizationError(err) => { GetDatabasesPrivilegeDataError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database) err.to_error_message(database_name, DbOrUser::Database)
@ -294,7 +344,7 @@ impl GetAllDatabasesPrivilegeDataError {
} }
pub type ModifyDatabasePrivilegesOutput = pub type ModifyDatabasePrivilegesOutput =
BTreeMap<(String, String), Result<(), ModifyDatabasePrivilegesError>>; BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModifyDatabasePrivilegesError { pub enum ModifyDatabasePrivilegesError {
DatabaseSanitizationError(NameValidationError), DatabaseSanitizationError(NameValidationError),
@ -309,8 +359,8 @@ pub enum ModifyDatabasePrivilegesError {
#[allow(clippy::enum_variant_names)] #[allow(clippy::enum_variant_names)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DiffDoesNotApplyError { pub enum DiffDoesNotApplyError {
RowAlreadyExists(String, String), RowAlreadyExists(MySQLDatabase, MySQLUser),
RowDoesNotExist(String, String), RowDoesNotExist(MySQLDatabase, MySQLUser),
RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow), RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow),
} }
@ -333,7 +383,7 @@ pub fn print_modify_database_privileges_output_status(output: &ModifyDatabasePri
} }
impl ModifyDatabasePrivilegesError { impl ModifyDatabasePrivilegesError {
pub fn to_error_message(&self, database_name: &str, username: &str) -> String { pub fn to_error_message(&self, database_name: &MySQLDatabase, username: &MySQLUser) -> String {
match self { match self {
ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => { ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database) err.to_error_message(database_name, DbOrUser::Database)
@ -388,7 +438,7 @@ impl DiffDoesNotApplyError {
} }
} }
pub type CreateUsersOutput = BTreeMap<String, Result<(), CreateUserError>>; pub type CreateUsersOutput = BTreeMap<MySQLUser, Result<(), CreateUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CreateUserError { pub enum CreateUserError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -412,8 +462,29 @@ pub fn print_create_users_output_status(output: &CreateUsersOutput) {
} }
} }
pub fn print_create_users_output_status_json(output: &CreateUsersOutput) {
let value = output
.iter()
.map(|(name, result)| match result {
Ok(()) => (name.to_string(), json!({ "status": "success" })),
Err(err) => (
name.to_string(),
json!({
"status": "error",
"error": err.to_error_message(name),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl CreateUserError { impl CreateUserError {
pub fn to_error_message(&self, username: &str) -> String { pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self { match self {
CreateUserError::SanitizationError(err) => { CreateUserError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User) err.to_error_message(username, DbOrUser::User)
@ -429,7 +500,7 @@ impl CreateUserError {
} }
} }
pub type DropUsersOutput = BTreeMap<String, Result<(), DropUserError>>; pub type DropUsersOutput = BTreeMap<MySQLUser, Result<(), DropUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DropUserError { pub enum DropUserError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -453,8 +524,29 @@ pub fn print_drop_users_output_status(output: &DropUsersOutput) {
} }
} }
pub fn print_drop_users_output_status_json(output: &DropUsersOutput) {
let value = output
.iter()
.map(|(name, result)| match result {
Ok(()) => (name.to_string(), json!({ "status": "success" })),
Err(err) => (
name.to_string(),
json!({
"status": "error",
"error": err.to_error_message(name),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl DropUserError { impl DropUserError {
pub fn to_error_message(&self, username: &str) -> String { pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self { match self {
DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User),
DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
@ -477,7 +569,7 @@ pub enum SetPasswordError {
MySqlError(String), MySqlError(String),
} }
pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &str) { pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &MySQLUser) {
match output { match output {
Ok(()) => { Ok(()) => {
println!("Password for user '{}' set successfully.", username); println!("Password for user '{}' set successfully.", username);
@ -490,7 +582,7 @@ pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &s
} }
impl SetPasswordError { impl SetPasswordError {
pub fn to_error_message(&self, username: &str) -> String { pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self { match self {
SetPasswordError::SanitizationError(err) => { SetPasswordError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User) err.to_error_message(username, DbOrUser::User)
@ -506,7 +598,7 @@ impl SetPasswordError {
} }
} }
pub type LockUsersOutput = BTreeMap<String, Result<(), LockUserError>>; pub type LockUsersOutput = BTreeMap<MySQLUser, Result<(), LockUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum LockUserError { pub enum LockUserError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -531,8 +623,29 @@ pub fn print_lock_users_output_status(output: &LockUsersOutput) {
} }
} }
pub fn print_lock_users_output_status_json(output: &LockUsersOutput) {
let value = output
.iter()
.map(|(name, result)| match result {
Ok(()) => (name.to_string(), json!({ "status": "success" })),
Err(err) => (
name.to_string(),
json!({
"status": "error",
"error": err.to_error_message(name),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl LockUserError { impl LockUserError {
pub fn to_error_message(&self, username: &str) -> String { pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self { match self {
LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User),
LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
@ -549,7 +662,7 @@ impl LockUserError {
} }
} }
pub type UnlockUsersOutput = BTreeMap<String, Result<(), UnlockUserError>>; pub type UnlockUsersOutput = BTreeMap<MySQLUser, Result<(), UnlockUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum UnlockUserError { pub enum UnlockUserError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -574,8 +687,29 @@ pub fn print_unlock_users_output_status(output: &UnlockUsersOutput) {
} }
} }
pub fn print_unlock_users_output_status_json(output: &UnlockUsersOutput) {
let value = output
.iter()
.map(|(name, result)| match result {
Ok(()) => (name.to_string(), json!({ "status": "success" })),
Err(err) => (
name.to_string(),
json!({
"status": "error",
"error": err.to_error_message(name),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl UnlockUserError { impl UnlockUserError {
pub fn to_error_message(&self, username: &str) -> String { pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self { match self {
UnlockUserError::SanitizationError(err) => { UnlockUserError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User) err.to_error_message(username, DbOrUser::User)
@ -594,7 +728,7 @@ impl UnlockUserError {
} }
} }
pub type ListUsersOutput = BTreeMap<String, Result<DatabaseUser, ListUsersError>>; pub type ListUsersOutput = BTreeMap<MySQLUser, Result<DatabaseUser, ListUsersError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListUsersError { pub enum ListUsersError {
SanitizationError(NameValidationError), SanitizationError(NameValidationError),
@ -604,7 +738,7 @@ pub enum ListUsersError {
} }
impl ListUsersError { impl ListUsersError {
pub fn to_error_message(&self, username: &str) -> String { pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self { match self {
ListUsersError::SanitizationError(err) => { ListUsersError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User) err.to_error_message(username, DbOrUser::User)

View File

@ -3,6 +3,7 @@ extern crate prettytable;
use clap::{CommandFactory, Parser, ValueEnum}; use clap::{CommandFactory, Parser, ValueEnum};
use clap_complete::{generate, Shell}; use clap_complete::{generate, Shell};
use clap_verbosity_flag::Verbosity;
use std::path::PathBuf; use std::path::PathBuf;
@ -62,6 +63,10 @@ struct Args {
)] )]
config: Option<PathBuf>, config: Option<PathBuf>,
#[command(flatten)]
verbose: Verbosity,
/// Run in TUI mode.
#[cfg(feature = "tui")] #[cfg(feature = "tui")]
#[arg(short, long, alias = "tui", global = true)] #[arg(short, long, alias = "tui", global = true)]
interactive: bool, interactive: bool,
@ -103,11 +108,6 @@ enum ToplevelCommands {
// comments emphasizing the need for caution. // comments emphasizing the need for caution.
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
// TODO: find out if there are any security risks of running
// env_logger and clap with elevated privileges.
env_logger::init();
#[cfg(feature = "mysql-admutils-compatibility")] #[cfg(feature = "mysql-admutils-compatibility")]
if handle_mysql_admutils_command()?.is_some() { if handle_mysql_admutils_command()?.is_some() {
return Ok(()); return Ok(());
@ -126,6 +126,10 @@ fn main() -> anyhow::Result<()> {
let server_connection = let server_connection =
bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?;
env_logger::Builder::new()
.filter_level(args.verbose.log_level_filter())
.init();
tokio_run_command(args.command, server_connection)?; tokio_run_command(args.command, server_connection)?;
Ok(()) Ok(())
@ -149,9 +153,10 @@ fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
match args.command { match args.command {
Command::Server(ref command) => { Command::Server(ref command) => {
tokio_start_server( tokio_start_server(
args.server_socket_path.clone(), args.server_socket_path.to_owned(),
args.config.clone(), args.config.to_owned(),
command.clone(), args.verbose.to_owned(),
command.to_owned(),
)?; )?;
Ok(Some(())) Ok(Some(()))
} }
@ -188,6 +193,7 @@ fn handle_generate_completions_command(args: &Args) -> anyhow::Result<Option<()>
fn tokio_start_server( fn tokio_start_server(
server_socket_path: Option<PathBuf>, server_socket_path: Option<PathBuf>,
config_path: Option<PathBuf>, config_path: Option<PathBuf>,
verbosity: Verbosity,
args: ServerArgs, args: ServerArgs,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread() tokio::runtime::Builder::new_current_thread()
@ -195,7 +201,7 @@ fn tokio_start_server(
.build() .build()
.unwrap() .unwrap()
.block_on(async { .block_on(async {
server::command::handle_command(server_socket_path, config_path, args).await server::command::handle_command(server_socket_path, config_path, verbosity, args).await
}) })
} }

View File

@ -3,11 +3,16 @@ use std::path::PathBuf;
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::Parser;
use clap_verbosity_flag::Verbosity;
use futures::SinkExt;
use indoc::concatdoc;
use systemd_journal_logger::JournalLog;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream; use tokio::net::UnixStream as TokioUnixStream;
use crate::core::common::UnixUser; use crate::core::common::UnixUser;
use crate::core::protocol::{create_server_to_client_message_stream, Response};
use crate::server::config::read_config_from_path_with_arg_overrides; use crate::server::config::read_config_from_path_with_arg_overrides;
use crate::server::server_loop::listen_for_incoming_connections; use crate::server::server_loop::listen_for_incoming_connections;
use crate::server::{ use crate::server::{
@ -22,6 +27,9 @@ pub struct ServerArgs {
#[command(flatten)] #[command(flatten)]
config_overrides: ServerConfigArgs, config_overrides: ServerConfigArgs,
#[arg(long)]
systemd: bool,
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -33,27 +41,148 @@ pub enum ServerCommand {
SocketActivate, SocketActivate,
} }
const LOG_LEVEL_WARNING: &str = r#"
===================================================
== WARNING: LOG LEVEL IS SET TO 'TRACE'! ==
== THIS WILL CAUSE THE SERVER TO LOG SQL QUERIES ==
== THAT MAY CONTAIN SENSITIVE INFORMATION LIKE ==
== PASSWORDS AND AUTHENTICATION TOKENS. ==
== THIS IS INTENDED FOR DEBUGGING PURPOSES ONLY ==
== AND SHOULD *NEVER* BE USED IN PRODUCTION. ==
===================================================
"#;
pub async fn handle_command( pub async fn handle_command(
socket_path: Option<PathBuf>, socket_path: Option<PathBuf>,
config_path: Option<PathBuf>, config_path: Option<PathBuf>,
verbosity: Verbosity,
args: ServerArgs, args: ServerArgs,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut auto_detected_systemd_mode = false;
let systemd_mode = args.systemd || {
if let Ok(true) = sd_notify::booted() {
auto_detected_systemd_mode = true;
true
} else {
false
}
};
if systemd_mode {
JournalLog::new()
.context("Failed to initialize journald logging")?
.install()
.context("Failed to install journald logger")?;
log::set_max_level(verbosity.log_level_filter());
if verbosity.log_level_filter() >= log::LevelFilter::Trace {
log::warn!("{}", LOG_LEVEL_WARNING.trim());
}
if auto_detected_systemd_mode {
log::info!("Running in systemd mode, auto-detected");
} else {
log::info!("Running in systemd mode");
}
start_watchdog_thread_if_enabled();
} else {
env_logger::Builder::new()
.filter_level(verbosity.log_level_filter())
.init();
log::info!("Running in standalone mode");
}
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
match args.subcmd { match args.subcmd {
ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
ServerCommand::SocketActivate => socket_activate(config).await, ServerCommand::SocketActivate => {
if !args.systemd {
anyhow::bail!(concat!(
"The `--systemd` flag must be used with the `socket-activate` command.\n",
"This command currently only supports socket activation under systemd."
));
}
socket_activate(config).await
}
}
}
fn start_watchdog_thread_if_enabled() {
let mut micro_seconds: u64 = 0;
let watchdog_enabled = sd_notify::watchdog_enabled(true, &mut micro_seconds);
if watchdog_enabled {
micro_seconds = micro_seconds.max(2_000_000).div_ceil(2);
tokio::spawn(async move {
log::debug!(
"Starting systemd watchdog thread with {} millisecond interval",
micro_seconds.div_ceil(1000)
);
loop {
tokio::time::sleep(tokio::time::Duration::from_micros(micro_seconds)).await;
if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
log::warn!("Failed to notify systemd watchdog: {}", err);
} else {
log::trace!("Ping sent to systemd watchdog");
}
}
});
} else {
log::debug!("Systemd watchdog not enabled, skipping watchdog thread");
} }
} }
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
let conn = get_socket_from_systemd().await?; let conn = get_socket_from_systemd().await?;
let uid = conn.peer_cred()?.uid();
let unix_user = UnixUser::from_uid(uid)?; let uid = match conn.peer_cred() {
Ok(cred) => cred.uid(),
Err(e) => {
log::error!("Failed to get peer credentials from socket: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get peer credentials from socket\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
anyhow::bail!("Failed to get peer credentials from socket");
}
};
log::debug!("Accepted connection from uid {}", uid);
let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user,
Err(e) => {
log::error!("Failed to get username from uid: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get user data from the system\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
anyhow::bail!("Failed to get username from uid");
}
};
log::info!("Accepted connection from {}", unix_user.username); log::info!("Accepted connection from {}", unix_user.username);
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
handle_requests_for_single_session(conn, &unix_user, &config).await?; handle_requests_for_single_session(conn, &unix_user, &config).await?;
@ -66,7 +195,12 @@ async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
.next() .next()
.context("No file descriptors received from systemd")?; .context("No file descriptors received from systemd")?;
log::debug!("Received file descriptor from systemd: {}", fd); debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
log::debug!(
"Received file descriptor from systemd with id: '{}', assuming socket",
fd
);
let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
let socket = TokioUnixStream::from_std(std_unix_stream)?; let socket = TokioUnixStream::from_std(std_unix_stream)?;

View File

@ -109,22 +109,16 @@ pub fn read_config_from_path_with_arg_overrides(
pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> { pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
log::debug!("Reading config from {:?}", &config_path); log::debug!("Reading config file at {:?}", &config_path);
fs::read_to_string(&config_path) fs::read_to_string(&config_path)
.context(format!( .context(format!("Failed to read config file at {:?}", &config_path))
"Failed to read config file from {:?}",
&config_path
))
.and_then(|c| toml::from_str(&c).context("Failed to parse config file")) .and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
.context(format!( .context(format!("Failed to parse config file at {:?}", &config_path))
"Failed to parse config file from {:?}",
&config_path
))
} }
fn log_config(config: &MysqlConfig) { fn log_config(config: &MysqlConfig) {
let mut display_config = config.clone(); let mut display_config = config.to_owned();
display_config.password = display_config display_config.password = display_config
.password .password
.as_ref() .as_ref()
@ -141,7 +135,9 @@ pub async fn create_mysql_connection_from_config(
) -> anyhow::Result<MySqlConnection> { ) -> anyhow::Result<MySqlConnection> {
log_config(config); log_config(config);
let mut mysql_options = MySqlConnectOptions::new().database("mysql"); let mut mysql_options = MySqlConnectOptions::new()
.database("mysql")
.log_statements(log::LevelFilter::Trace);
if let Some(username) = &config.username { if let Some(username) = &config.username {
mysql_options = mysql_options.username(username); mysql_options = mysql_options.username(username);

View File

@ -24,7 +24,7 @@ pub fn validate_ownership_by_unix_user(
name: &str, name: &str,
user: &UnixUser, user: &UnixUser,
) -> Result<(), OwnerValidationError> { ) -> Result<(), OwnerValidationError> {
let prefixes = std::iter::once(user.username.clone()) let prefixes = std::iter::once(user.username.to_owned())
.chain(user.groups.iter().cloned()) .chain(user.groups.iter().cloned())
.collect::<Vec<String>>(); .collect::<Vec<String>>();

View File

@ -7,6 +7,7 @@ use tokio::net::{UnixListener, UnixStream};
use sqlx::prelude::*; use sqlx::prelude::*;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use crate::core::protocol::SetPasswordError;
use crate::server::sql::database_operations::list_databases; use crate::server::sql::database_operations::list_databases;
use crate::{ use crate::{
core::{ core::{
@ -56,7 +57,7 @@ pub async fn listen_for_incoming_connections(
let listener = UnixListener::bind(socket_path)?; let listener = UnixListener::bind(socket_path)?;
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
while let Ok((conn, _addr)) = listener.accept().await { while let Ok((conn, _addr)) = listener.accept().await {
let uid = match conn.peer_cred() { let uid = match conn.peer_cred() {
@ -78,7 +79,7 @@ pub async fn listen_for_incoming_connections(
} }
}; };
log::trace!("Accepted connection from uid {}", uid); log::debug!("Accepted connection from uid {}", uid);
let unix_user = match UnixUser::from_uid(uid) { let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user, Ok(user) => user,
@ -173,19 +174,25 @@ pub async fn handle_requests_for_single_session_with_db_connection(
} }
}; };
log::trace!("Received request: {:?}", request); // TODO: don't clone the request
let request_to_display = match &request {
Request::PasswdUser(db_user, _) => {
Request::PasswdUser(db_user.to_owned(), "<REDACTED>".to_string())
}
request => request.to_owned(),
};
log::info!("Received request: {:#?}", request_to_display);
match request { let response = match request {
Request::CreateDatabases(databases_names) => { Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await; let result = create_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::CreateDatabases(result)).await?; Response::CreateDatabases(result)
} }
Request::DropDatabases(databases_names) => { Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await; let result = drop_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::DropDatabases(result)).await?; Response::DropDatabases(result)
} }
Request::ListDatabases(database_names) => { Request::ListDatabases(database_names) => match database_names {
let response = match database_names {
Some(database_names) => { Some(database_names) => {
let result = list_databases(database_names, unix_user, db_connection).await; let result = list_databases(database_names, unix_user, db_connection).await;
Response::ListDatabases(result) Response::ListDatabases(result)
@ -194,11 +201,8 @@ pub async fn handle_requests_for_single_session_with_db_connection(
let result = list_all_databases_for_user(unix_user, db_connection).await; let result = list_all_databases_for_user(unix_user, db_connection).await;
Response::ListAllDatabases(result) Response::ListAllDatabases(result)
} }
}; },
stream.send(response).await?; Request::ListPrivileges(database_names) => match database_names {
}
Request::ListPrivileges(database_names) => {
let response = match database_names {
Some(database_names) => { Some(database_names) => {
let privilege_data = let privilege_data =
get_databases_privilege_data(database_names, unix_user, db_connection) get_databases_privilege_data(database_names, unix_user, db_connection)
@ -210,10 +214,7 @@ pub async fn handle_requests_for_single_session_with_db_connection(
get_all_database_privileges(unix_user, db_connection).await; get_all_database_privileges(unix_user, db_connection).await;
Response::ListAllPrivileges(privilege_data) Response::ListAllPrivileges(privilege_data)
} }
}; },
stream.send(response).await?;
}
Request::ModifyPrivileges(database_privilege_diffs) => { Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs( let result = apply_privilege_diffs(
BTreeSet::from_iter(database_privilege_diffs), BTreeSet::from_iter(database_privilege_diffs),
@ -221,24 +222,23 @@ pub async fn handle_requests_for_single_session_with_db_connection(
db_connection, db_connection,
) )
.await; .await;
stream.send(Response::ModifyPrivileges(result)).await?; Response::ModifyPrivileges(result)
} }
Request::CreateUsers(db_users) => { Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await; let result = create_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::CreateUsers(result)).await?; Response::CreateUsers(result)
} }
Request::DropUsers(db_users) => { Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await; let result = drop_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::DropUsers(result)).await?; Response::DropUsers(result)
} }
Request::PasswdUser(db_user, password) => { Request::PasswdUser(db_user, password) => {
let result = let result =
set_password_for_database_user(&db_user, &password, unix_user, db_connection) set_password_for_database_user(&db_user, &password, unix_user, db_connection)
.await; .await;
stream.send(Response::PasswdUser(result)).await?; Response::PasswdUser(result)
} }
Request::ListUsers(db_users) => { Request::ListUsers(db_users) => match db_users {
let response = match db_users {
Some(db_users) => { Some(db_users) => {
let result = list_database_users(db_users, unix_user, db_connection).await; let result = list_database_users(db_users, unix_user, db_connection).await;
Response::ListUsers(result) Response::ListUsers(result)
@ -248,23 +248,32 @@ pub async fn handle_requests_for_single_session_with_db_connection(
list_all_database_users_for_unix_user(unix_user, db_connection).await; list_all_database_users_for_unix_user(unix_user, db_connection).await;
Response::ListAllUsers(result) Response::ListAllUsers(result)
} }
}; },
stream.send(response).await?;
}
Request::LockUsers(db_users) => { Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await; let result = lock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::LockUsers(result)).await?; Response::LockUsers(result)
} }
Request::UnlockUsers(db_users) => { Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await; let result = unlock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::UnlockUsers(result)).await?; Response::UnlockUsers(result)
} }
Request::Exit => { Request::Exit => {
break; break;
} }
} };
// TODO: don't clone the response
let response_to_display = match &response {
Response::PasswdUser(Err(SetPasswordError::MySqlError(_))) => {
Response::PasswdUser(Err(SetPasswordError::MySqlError("<REDACTED>".to_string())))
}
response => response.to_owned(),
};
log::info!("Response: {:#?}", response_to_display);
stream.send(response).await?;
stream.flush().await?; stream.flush().await?;
log::debug!("Successfully processed request");
} }
Ok(()) Ok(())

View File

@ -5,6 +5,7 @@ use sqlx::MySqlConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::core::protocol::MySQLDatabase;
use crate::{ use crate::{
core::{ core::{
common::UnixUser, common::UnixUser,
@ -42,7 +43,7 @@ pub(super) async fn unsafe_database_exists(
} }
pub async fn create_databases( pub async fn create_databases(
database_names: Vec<String>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> CreateDatabasesOutput { ) -> CreateDatabasesOutput {
@ -51,7 +52,7 @@ pub async fn create_databases(
for database_name in database_names { for database_name in database_names {
if let Err(err) = validate_name(&database_name) { if let Err(err) = validate_name(&database_name) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(CreateDatabaseError::SanitizationError(err)), Err(CreateDatabaseError::SanitizationError(err)),
); );
continue; continue;
@ -59,7 +60,7 @@ pub async fn create_databases(
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(CreateDatabaseError::OwnershipError(err)), Err(CreateDatabaseError::OwnershipError(err)),
); );
continue; continue;
@ -68,14 +69,14 @@ pub async fn create_databases(
match unsafe_database_exists(&database_name, &mut *connection).await { match unsafe_database_exists(&database_name, &mut *connection).await {
Ok(true) => { Ok(true) => {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(CreateDatabaseError::DatabaseAlreadyExists), Err(CreateDatabaseError::DatabaseAlreadyExists),
); );
continue; continue;
} }
Err(err) => { Err(err) => {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(CreateDatabaseError::MySqlError(err.to_string())), Err(CreateDatabaseError::MySqlError(err.to_string())),
); );
continue; continue;
@ -101,7 +102,7 @@ pub async fn create_databases(
} }
pub async fn drop_databases( pub async fn drop_databases(
database_names: Vec<String>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> DropDatabasesOutput { ) -> DropDatabasesOutput {
@ -110,7 +111,7 @@ pub async fn drop_databases(
for database_name in database_names { for database_name in database_names {
if let Err(err) = validate_name(&database_name) { if let Err(err) = validate_name(&database_name) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(DropDatabaseError::SanitizationError(err)), Err(DropDatabaseError::SanitizationError(err)),
); );
continue; continue;
@ -118,7 +119,7 @@ pub async fn drop_databases(
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(DropDatabaseError::OwnershipError(err)), Err(DropDatabaseError::OwnershipError(err)),
); );
continue; continue;
@ -127,14 +128,14 @@ pub async fn drop_databases(
match unsafe_database_exists(&database_name, &mut *connection).await { match unsafe_database_exists(&database_name, &mut *connection).await {
Ok(false) => { Ok(false) => {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(DropDatabaseError::DatabaseDoesNotExist), Err(DropDatabaseError::DatabaseDoesNotExist),
); );
continue; continue;
} }
Err(err) => { Err(err) => {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(DropDatabaseError::MySqlError(err.to_string())), Err(DropDatabaseError::MySqlError(err.to_string())),
); );
continue; continue;
@ -159,13 +160,21 @@ pub async fn drop_databases(
results results
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabaseRow { pub struct DatabaseRow {
pub database: String, pub database: MySQLDatabase,
}
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
Ok(DatabaseRow {
database: row.try_get::<String, _>("database")?.into(),
})
}
} }
pub async fn list_databases( pub async fn list_databases(
database_names: Vec<String>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> ListDatabasesOutput { ) -> ListDatabasesOutput {
@ -174,7 +183,7 @@ pub async fn list_databases(
for database_name in database_names { for database_name in database_names {
if let Err(err) = validate_name(&database_name) { if let Err(err) = validate_name(&database_name) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(ListDatabasesError::SanitizationError(err)), Err(ListDatabasesError::SanitizationError(err)),
); );
continue; continue;
@ -182,7 +191,7 @@ pub async fn list_databases(
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(ListDatabasesError::OwnershipError(err)), Err(ListDatabasesError::OwnershipError(err)),
); );
continue; continue;
@ -195,7 +204,7 @@ pub async fn list_databases(
WHERE `SCHEMA_NAME` = ? WHERE `SCHEMA_NAME` = ?
"#, "#,
) )
.bind(&database_name) .bind(database_name.to_string())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await .await
.map_err(|err| ListDatabasesError::MySqlError(err.to_string())) .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))

View File

@ -28,7 +28,8 @@ use crate::{
protocol::{ protocol::{
DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError, DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError,
GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError, GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError,
ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput, ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput, MySQLDatabase,
MySQLUser,
}, },
}, },
server::{ server::{
@ -63,8 +64,8 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
/// This struct represents the set of privileges for a single user on a single database. /// This struct represents the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow { pub struct DatabasePrivilegeRow {
pub db: String, pub db: MySQLDatabase,
pub user: String, pub user: MySQLUser,
pub select_priv: bool, pub select_priv: bool,
pub insert_priv: bool, pub insert_priv: bool,
pub update_priv: bool, pub update_priv: bool,
@ -115,8 +116,8 @@ fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sql
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> { fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self { Ok(Self {
db: try_get_with_binary_fallback(row, "Db")?, db: try_get_with_binary_fallback(row, "Db")?.into(),
user: try_get_with_binary_fallback(row, "User")?, user: try_get_with_binary_fallback(row, "User")?.into(),
select_priv: get_mysql_row_priv_field(row, 2)?, select_priv: get_mysql_row_priv_field(row, 2)?,
insert_priv: get_mysql_row_priv_field(row, 3)?, insert_priv: get_mysql_row_priv_field(row, 3)?,
update_priv: get_mysql_row_priv_field(row, 4)?, update_priv: get_mysql_row_priv_field(row, 4)?,
@ -163,8 +164,8 @@ async fn unsafe_get_database_privileges(
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
/// Get all users + privileges for a single database-user pair. /// Get all users + privileges for a single database-user pair.
pub async fn unsafe_get_database_privileges_for_db_user_pair( pub async fn unsafe_get_database_privileges_for_db_user_pair(
database_name: &str, database_name: &MySQLDatabase,
user_name: &str, user_name: &MySQLUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
@ -174,8 +175,8 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)) ))
.bind(database_name) .bind(database_name.as_str())
.bind(user_name) .bind(user_name.as_str())
.fetch_optional(connection) .fetch_optional(connection)
.await; .await;
@ -192,7 +193,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
} }
pub async fn get_databases_privilege_data( pub async fn get_databases_privilege_data(
database_names: Vec<String>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> GetDatabasesPrivilegeData { ) -> GetDatabasesPrivilegeData {
@ -201,7 +202,7 @@ pub async fn get_databases_privilege_data(
for database_name in database_names.iter() { for database_name in database_names.iter() {
if let Err(err) = validate_name(database_name) { if let Err(err) = validate_name(database_name) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(GetDatabasesPrivilegeDataError::SanitizationError(err)), Err(GetDatabasesPrivilegeDataError::SanitizationError(err)),
); );
continue; continue;
@ -209,7 +210,7 @@ pub async fn get_databases_privilege_data(
if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) { if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(GetDatabasesPrivilegeDataError::OwnershipError(err)), Err(GetDatabasesPrivilegeDataError::OwnershipError(err)),
); );
continue; continue;
@ -220,7 +221,7 @@ pub async fn get_databases_privilege_data(
.unwrap() .unwrap()
{ {
results.insert( results.insert(
database_name.clone(), database_name.to_owned(),
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist), Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist),
); );
continue; continue;
@ -230,7 +231,7 @@ pub async fn get_databases_privilege_data(
.await .await
.map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string())); .map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string()));
results.insert(database_name.clone(), result); results.insert(database_name.to_owned(), result);
} }
debug_assert!(database_names.len() == results.len()); debug_assert!(database_names.len() == results.len());
@ -364,8 +365,8 @@ async fn validate_diff(
if privilege_row.is_some() { if privilege_row.is_some() {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowAlreadyExists( DiffDoesNotApplyError::RowAlreadyExists(
diff.get_user_name().to_string(), diff.get_database_name().to_owned(),
diff.get_database_name().to_string(), diff.get_user_name().to_owned(),
), ),
)) ))
} else { } else {
@ -375,8 +376,8 @@ async fn validate_diff(
DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => { DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowDoesNotExist( DiffDoesNotApplyError::RowDoesNotExist(
diff.get_user_name().to_string(), diff.get_database_name().to_owned(),
diff.get_database_name().to_string(), diff.get_user_name().to_owned(),
), ),
)) ))
} }
@ -390,7 +391,7 @@ async fn validate_diff(
if error_exists { if error_exists {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.clone(), row), DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.to_owned(), row),
)) ))
} else { } else {
Ok(()) Ok(())
@ -400,8 +401,8 @@ async fn validate_diff(
if privilege_row.is_none() { if privilege_row.is_none() {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowDoesNotExist( DiffDoesNotApplyError::RowDoesNotExist(
diff.get_user_name().to_string(), diff.get_database_name().to_owned(),
diff.get_database_name().to_string(), diff.get_user_name().to_owned(),
), ),
)) ))
} else { } else {
@ -419,12 +420,12 @@ pub async fn apply_privilege_diffs(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> ModifyDatabasePrivilegesOutput { ) -> ModifyDatabasePrivilegesOutput {
let mut results: BTreeMap<(String, String), _> = BTreeMap::new(); let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new();
for diff in database_privilege_diffs { for diff in database_privilege_diffs {
let key = ( let key = (
diff.get_database_name().to_string(), diff.get_database_name().to_owned(),
diff.get_user_name().to_string(), diff.get_user_name().to_owned(),
); );
if let Err(err) = validate_name(diff.get_database_name()) { if let Err(err) = validate_name(diff.get_database_name()) {
results.insert( results.insert(

View File

@ -7,18 +7,17 @@ use serde::{Deserialize, Serialize};
use sqlx::prelude::*; use sqlx::prelude::*;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use crate::server::common::try_get_with_binary_fallback;
use crate::{ use crate::{
core::{ core::{
common::UnixUser, common::UnixUser,
protocol::{ protocol::{
CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError, CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError,
ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput, ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput,
SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput, MySQLUser, SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput,
}, },
}, },
server::{ server::{
common::create_user_group_matching_regex, common::{create_user_group_matching_regex, try_get_with_binary_fallback},
input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user}, input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user},
}, },
}; };
@ -52,7 +51,7 @@ async fn unsafe_user_exists(
} }
pub async fn create_database_users( pub async fn create_database_users(
db_users: Vec<String>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> CreateUsersOutput { ) -> CreateUsersOutput {
@ -98,7 +97,7 @@ pub async fn create_database_users(
} }
pub async fn drop_database_users( pub async fn drop_database_users(
db_users: Vec<String>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> DropUsersOutput { ) -> DropUsersOutput {
@ -144,7 +143,7 @@ pub async fn drop_database_users(
} }
pub async fn set_password_for_database_user( pub async fn set_password_for_database_user(
db_user: &str, db_user: &MySQLUser,
password: &str, password: &str,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
@ -167,7 +166,7 @@ pub async fn set_password_for_database_user(
format!( format!(
"ALTER USER {}@'%' IDENTIFIED BY {}", "ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user), quote_literal(db_user),
quote_literal(password).as_str() quote_literal(password).as_str(),
) )
.as_str(), .as_str(),
) )
@ -176,11 +175,10 @@ pub async fn set_password_for_database_user(
.map(|_| ()) .map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string())); .map_err(|err| SetPasswordError::MySqlError(err.to_string()));
if let Err(err) = &result { if result.is_err() {
log::error!( log::error!(
"Failed to set password for database user '{}': {:?}", "Failed to set password for database user '{}': <REDACTED>",
&db_user, &db_user,
err
); );
} }
@ -220,7 +218,7 @@ async fn database_user_is_locked_unsafe(
} }
pub async fn lock_database_users( pub async fn lock_database_users(
db_users: Vec<String>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> LockUsersOutput { ) -> LockUsersOutput {
@ -280,7 +278,7 @@ pub async fn lock_database_users(
} }
pub async fn unlock_database_users( pub async fn unlock_database_users(
db_users: Vec<String>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> UnlockUsersOutput { ) -> UnlockUsersOutput {
@ -343,7 +341,7 @@ pub async fn unlock_database_users(
/// This can be extended if we need more information in the future. /// This can be extended if we need more information in the future.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabaseUser { pub struct DatabaseUser {
pub user: String, pub user: MySQLUser,
#[serde(skip)] #[serde(skip)]
pub host: String, pub host: String,
pub has_password: bool, pub has_password: bool,
@ -354,7 +352,7 @@ pub struct DatabaseUser {
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser { impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> { fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self { Ok(Self {
user: try_get_with_binary_fallback(row, "User")?, user: try_get_with_binary_fallback(row, "User")?.into(),
host: try_get_with_binary_fallback(row, "Host")?, host: try_get_with_binary_fallback(row, "Host")?,
has_password: row.try_get("has_password")?, has_password: row.try_get("has_password")?,
is_locked: row.try_get("is_locked")?, is_locked: row.try_get("is_locked")?,
@ -379,7 +377,7 @@ JOIN `global_priv` ON
"#; "#;
pub async fn list_database_users( pub async fn list_database_users(
db_users: Vec<String>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> ListUsersOutput { ) -> ListUsersOutput {
@ -399,7 +397,7 @@ pub async fn list_database_users(
let mut result = sqlx::query_as::<_, DatabaseUser>( let mut result = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
) )
.bind(&db_user) .bind(db_user.as_str())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await; .await;
@ -464,7 +462,7 @@ pub async fn append_databases_where_user_has_privileges(
) )
.as_str(), .as_str(),
) )
.bind(db_user.user.clone()) .bind(db_user.user.as_str())
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await; .await;