Rewrite entire codebase to split into client and server
This commit is contained in:
		
							
								
								
									
										99
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										99
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -418,6 +418,27 @@ dependencies = [ | ||||
|  "zeroize", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "derive_more" | ||||
| version = "1.0.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" | ||||
| dependencies = [ | ||||
|  "derive_more-impl", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "derive_more-impl" | ||||
| version = "1.0.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn 2.0.60", | ||||
|  "unicode-xid", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "dialoguer" | ||||
| version = "0.11.0" | ||||
| @@ -470,6 +491,18 @@ version = "0.15.7" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" | ||||
|  | ||||
| [[package]] | ||||
| name = "educe" | ||||
| version = "0.5.11" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e4bd92664bf78c4d3dba9b7cdafce6fa15b13ed3ed16175218196942e99168a8" | ||||
| dependencies = [ | ||||
|  "enum-ordinalize", | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn 2.0.60", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "either" | ||||
| version = "1.11.0" | ||||
| @@ -491,6 +524,26 @@ version = "1.0.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" | ||||
|  | ||||
| [[package]] | ||||
| name = "enum-ordinalize" | ||||
| version = "4.3.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" | ||||
| dependencies = [ | ||||
|  "enum-ordinalize-derive", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "enum-ordinalize-derive" | ||||
| version = "4.3.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn 2.0.60", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "env_filter" | ||||
| version = "0.1.0" | ||||
| @@ -961,9 +1014,11 @@ dependencies = [ | ||||
|  "async-bincode", | ||||
|  "bincode", | ||||
|  "clap", | ||||
|  "derive_more", | ||||
|  "dialoguer", | ||||
|  "env_logger", | ||||
|  "futures", | ||||
|  "futures-util", | ||||
|  "indoc", | ||||
|  "itertools", | ||||
|  "log", | ||||
| @@ -974,8 +1029,9 @@ dependencies = [ | ||||
|  "serde", | ||||
|  "serde_json", | ||||
|  "sqlx", | ||||
|  "thiserror", | ||||
|  "tokio", | ||||
|  "tokio-serde", | ||||
|  "tokio-stream", | ||||
|  "tokio-util", | ||||
|  "toml", | ||||
|  "uuid", | ||||
| @@ -1109,6 +1165,26 @@ version = "2.3.1" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" | ||||
|  | ||||
| [[package]] | ||||
| name = "pin-project" | ||||
| version = "1.1.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" | ||||
| dependencies = [ | ||||
|  "pin-project-internal", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "pin-project-internal" | ||||
| version = "1.1.5" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" | ||||
| dependencies = [ | ||||
|  "proc-macro2", | ||||
|  "quote", | ||||
|  "syn 2.0.60", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "pin-project-lite" | ||||
| version = "0.2.14" | ||||
| @@ -1931,6 +2007,21 @@ dependencies = [ | ||||
|  "syn 2.0.60", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tokio-serde" | ||||
| version = "0.9.0" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "caf600e7036b17782571dd44fa0a5cea3c82f60db5137f774a325a76a0d6852b" | ||||
| dependencies = [ | ||||
|  "bincode", | ||||
|  "bytes", | ||||
|  "educe", | ||||
|  "futures-core", | ||||
|  "futures-sink", | ||||
|  "pin-project", | ||||
|  "serde", | ||||
| ] | ||||
|  | ||||
| [[package]] | ||||
| name = "tokio-stream" | ||||
| version = "0.1.15" | ||||
| @@ -2060,6 +2151,12 @@ version = "0.1.11" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" | ||||
|  | ||||
| [[package]] | ||||
| name = "unicode-xid" | ||||
| version = "0.2.4" | ||||
| source = "registry+https://github.com/rust-lang/crates.io-index" | ||||
| checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" | ||||
|  | ||||
| [[package]] | ||||
| name = "unicode_categories" | ||||
| version = "0.1.1" | ||||
|   | ||||
| @@ -8,22 +8,25 @@ anyhow = "1.0.82" | ||||
| async-bincode = "0.7.2" | ||||
| bincode = "1.3.3" | ||||
| clap = { version = "4.5.4", features = ["derive"] } | ||||
| derive_more = { version = "1.0.0", features = ["display", "error"] } | ||||
| dialoguer = "0.11.0" | ||||
| env_logger = "0.11.3" | ||||
| futures = "0.3.30" | ||||
| futures-util = "0.3.30" | ||||
| indoc = "2.0.5" | ||||
| itertools = "0.12.1" | ||||
| log = "0.4.21" | ||||
| nix = { version = "0.28.0", features = ["fs", "user"] } | ||||
| nix = { version = "0.28.0", features = ["fs", "process", "user"] } | ||||
| prettytable = "0.10.0" | ||||
| rand = "0.8.5" | ||||
| ratatui = { version = "0.26.2", optional = true } | ||||
| serde = "1.0.198" | ||||
| serde_json = { version = "1.0.116", features = ["preserve_order"] } | ||||
| sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] } | ||||
| thiserror = "1.0.63" | ||||
| tokio = { version = "1.37.0", features = ["rt", "macros"] } | ||||
| tokio-util = "0.7.11" | ||||
| tokio-serde = { version = "0.9.0", features = ["bincode"] } | ||||
| tokio-stream = "0.1.15" | ||||
| tokio-util = { version = "0.7.11", features = ["codec"] } | ||||
| toml = "0.8.12" | ||||
| uuid = { version = "1.10.0", features = ["v4"] } | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,6 @@ | ||||
| mod common; | ||||
| pub mod database_command; | ||||
| pub mod mysql_admutils_compatibility; | ||||
| pub mod user_command; | ||||
|  | ||||
| #[cfg(feature = "mysql-admutils-compatibility")] | ||||
| pub mod mysql_admutils_compatibility; | ||||
|   | ||||
							
								
								
									
										20
									
								
								src/cli/common.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								src/cli/common.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| use crate::core::protocol::Response; | ||||
|  | ||||
| pub fn erroneous_server_response( | ||||
|     response: Option<Result<Response, std::io::Error>>, | ||||
| ) -> anyhow::Result<()> { | ||||
|     match response { | ||||
|         Some(Ok(Response::Error(e))) => { | ||||
|             anyhow::bail!("Server returned error: {}", e); | ||||
|         } | ||||
|         Some(Err(e)) => { | ||||
|             anyhow::bail!(e); | ||||
|         } | ||||
|         Some(response) => { | ||||
|             anyhow::bail!("Unexpected response from server: {:?}", response); | ||||
|         } | ||||
|         None => { | ||||
|             anyhow::bail!("No response from server"); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -1,17 +1,29 @@ | ||||
| use anyhow::Context; | ||||
| use clap::Parser; | ||||
| use dialoguer::{Confirm, Editor}; | ||||
| use futures_util::{SinkExt, StreamExt}; | ||||
| use nix::unistd::{getuid, User}; | ||||
| use prettytable::{Cell, Row, Table}; | ||||
| use sqlx::{Connection, MySqlConnection}; | ||||
|  | ||||
| use crate::core::{ | ||||
|     common::{close_database_connection, get_current_unix_user, yn, CommandStatus}, | ||||
|     database_operations::*, | ||||
|     database_privilege_operations::*, | ||||
|     user_operations::user_exists, | ||||
| use crate::{ | ||||
|     cli::common::erroneous_server_response, | ||||
|     core::{ | ||||
|         common::yn, | ||||
|         database_privileges::{ | ||||
|             db_priv_field_human_readable_name, diff_privileges, display_privilege_diffs, | ||||
|             generate_editor_content_from_privilege_data, parse_privilege_data_from_editor_content, | ||||
|             parse_privilege_table_cli_arg, | ||||
|         }, | ||||
|         protocol::{ | ||||
|             print_create_databases_output_status, print_drop_databases_output_status, | ||||
|             print_modify_database_privileges_output_status, ClientToServerMessageStream, Request, | ||||
|             Response, | ||||
|         }, | ||||
|     }, | ||||
|     server::sql::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS}, | ||||
| }; | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| // #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))] | ||||
| pub enum DatabaseCommand { | ||||
|     /// Create one or more databases | ||||
| @@ -86,28 +98,28 @@ pub enum DatabaseCommand { | ||||
|     EditDbPrivs(DatabaseEditPrivsArgs), | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct DatabaseCreateArgs { | ||||
|     /// The name of the database(s) to create. | ||||
|     #[arg(num_args = 1..)] | ||||
|     name: Vec<String>, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct DatabaseDropArgs { | ||||
|     /// The name of the database(s) to drop. | ||||
|     #[arg(num_args = 1..)] | ||||
|     name: Vec<String>, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct DatabaseListArgs { | ||||
|     /// Whether to output the information in JSON format. | ||||
|     #[arg(short, long)] | ||||
|     json: bool, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct DatabaseShowPrivsArgs { | ||||
|     /// The name of the database(s) to show. | ||||
|     #[arg(num_args = 0..)] | ||||
| @@ -118,7 +130,7 @@ pub struct DatabaseShowPrivsArgs { | ||||
|     json: bool, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct DatabaseEditPrivsArgs { | ||||
|     /// The name of the database to edit privileges for. | ||||
|     pub name: Option<String>, | ||||
| @@ -141,125 +153,143 @@ pub struct DatabaseEditPrivsArgs { | ||||
|  | ||||
| pub async fn handle_command( | ||||
|     command: DatabaseCommand, | ||||
|     mut connection: MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     let result = connection | ||||
|         .transaction(|txn| { | ||||
|             Box::pin(async move { | ||||
|                 match command { | ||||
|                     DatabaseCommand::CreateDb(args) => create_databases(args, txn).await, | ||||
|                     DatabaseCommand::DropDb(args) => drop_databases(args, txn).await, | ||||
|                     DatabaseCommand::ListDb(args) => list_databases(args, txn).await, | ||||
|                     DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await, | ||||
|                     DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await, | ||||
|                 } | ||||
|             }) | ||||
|         }) | ||||
|         .await; | ||||
|  | ||||
|     close_database_connection(connection).await; | ||||
|  | ||||
|     result | ||||
|     server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     match command { | ||||
|         DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await, | ||||
|         DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await, | ||||
|         DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await, | ||||
|         DatabaseCommand::ShowDbPrivs(args) => { | ||||
|             show_database_privileges(args, server_connection).await | ||||
|         } | ||||
|         DatabaseCommand::EditDbPrivs(args) => { | ||||
|             edit_database_privileges(args, server_connection).await | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn create_databases( | ||||
|     args: DatabaseCreateArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     if args.name.is_empty() { | ||||
|         anyhow::bail!("No database names provided"); | ||||
|     } | ||||
|  | ||||
|     let mut result = CommandStatus::SuccessfullyModified; | ||||
|     let message = Request::CreateDatabases(args.name.clone()); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     for name in args.name { | ||||
|         // TODO: This can be optimized by fetching all the database privileges in one query. | ||||
|         if let Err(e) = create_database(&name, connection).await { | ||||
|             eprintln!("Failed to create database '{}': {}", name, e); | ||||
|             eprintln!("Skipping..."); | ||||
|             result = CommandStatus::PartiallySuccessfullyModified; | ||||
|         } else { | ||||
|             println!("Database '{}' created.", name); | ||||
|         } | ||||
|     } | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::CreateDatabases(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     Ok(result) | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     print_create_databases_output_status(&result); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn drop_databases( | ||||
|     args: DatabaseDropArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     if args.name.is_empty() { | ||||
|         anyhow::bail!("No database names provided"); | ||||
|     } | ||||
|  | ||||
|     let mut result = CommandStatus::SuccessfullyModified; | ||||
|     let message = Request::DropDatabases(args.name.clone()); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     for name in args.name { | ||||
|         // TODO: This can be optimized by fetching all the database privileges in one query. | ||||
|         if let Err(e) = drop_database(&name, connection).await { | ||||
|             eprintln!("Failed to drop database '{}': {}", name, e); | ||||
|             eprintln!("Skipping..."); | ||||
|             result = CommandStatus::PartiallySuccessfullyModified; | ||||
|         } else { | ||||
|             println!("Database '{}' dropped.", name); | ||||
|         } | ||||
|     } | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::DropDatabases(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     Ok(result) | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     print_drop_databases_output_status(&result); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn list_databases( | ||||
|     args: DatabaseListArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     let databases = get_database_list(connection).await?; | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let message = Request::ListDatabases; | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::ListAllDatabases(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     let database_list = match result { | ||||
|         Ok(list) => list, | ||||
|         Err(err) => { | ||||
|             return Err(anyhow::anyhow!(err.to_error_message()).context("Failed to list databases")) | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     if args.json { | ||||
|         println!("{}", serde_json::to_string_pretty(&databases)?); | ||||
|         return Ok(CommandStatus::NoModificationsIntended); | ||||
|     } | ||||
|  | ||||
|     if databases.is_empty() { | ||||
|         println!("{}", serde_json::to_string_pretty(&database_list)?); | ||||
|     } else if database_list.is_empty() { | ||||
|         println!("No databases to show."); | ||||
|     } else { | ||||
|         for db in databases { | ||||
|         for db in database_list { | ||||
|             println!("{}", db); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(CommandStatus::NoModificationsIntended) | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn show_database_privileges( | ||||
|     args: DatabaseShowPrivsArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     let database_users_to_show = if args.name.is_empty() { | ||||
|         get_all_database_privileges(connection).await? | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let message = if args.name.is_empty() { | ||||
|         Request::ListPrivileges(None) | ||||
|     } else { | ||||
|         // TODO: This can be optimized by fetching all the database privileges in one query. | ||||
|         let mut result = Vec::with_capacity(args.name.len()); | ||||
|         for name in args.name { | ||||
|             match get_database_privileges(&name, connection).await { | ||||
|                 Ok(db) => result.extend(db), | ||||
|                 Err(e) => { | ||||
|                     eprintln!("Failed to show database '{}': {}", name, e); | ||||
|         Request::ListPrivileges(Some(args.name.clone())) | ||||
|     }; | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let privilege_data = match server_connection.next().await { | ||||
|         Some(Ok(Response::ListPrivileges(databases))) => databases | ||||
|             .into_iter() | ||||
|             .filter_map(|(database_name, result)| match result { | ||||
|                 Ok(privileges) => Some(privileges), | ||||
|                 Err(err) => { | ||||
|                     eprintln!("{}", err.to_error_message(&database_name)); | ||||
|                     eprintln!("Skipping..."); | ||||
|                     println!(); | ||||
|                     None | ||||
|                 } | ||||
|             }) | ||||
|             .flatten() | ||||
|             .collect::<Vec<_>>(), | ||||
|         Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { | ||||
|             Ok(list) => list, | ||||
|             Err(err) => { | ||||
|                 server_connection.send(Request::Exit).await?; | ||||
|                 return Err(anyhow::anyhow!(err.to_error_message()) | ||||
|                     .context("Failed to list database privileges")); | ||||
|             } | ||||
|         } | ||||
|         result | ||||
|         }, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     if args.json { | ||||
|         println!("{}", serde_json::to_string_pretty(&database_users_to_show)?); | ||||
|         return Ok(CommandStatus::NoModificationsIntended); | ||||
|     } | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     if database_users_to_show.is_empty() { | ||||
|         println!("No database users to show."); | ||||
|     if args.json { | ||||
|         println!("{}", serde_json::to_string_pretty(&privilege_data)?); | ||||
|     } else if privilege_data.is_empty() { | ||||
|         println!("No database privileges to show."); | ||||
|     } else { | ||||
|         let mut table = Table::new(); | ||||
|         table.add_row(Row::new( | ||||
| @@ -270,7 +300,7 @@ async fn show_database_privileges( | ||||
|                 .collect(), | ||||
|         )); | ||||
|  | ||||
|         for row in database_users_to_show { | ||||
|         for row in privilege_data { | ||||
|             table.add_row(row![ | ||||
|                 row.db, | ||||
|                 row.user, | ||||
| @@ -290,17 +320,40 @@ async fn show_database_privileges( | ||||
|         table.printstd(); | ||||
|     } | ||||
|  | ||||
|     Ok(CommandStatus::NoModificationsIntended) | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn edit_privileges( | ||||
| pub async fn edit_database_privileges( | ||||
|     args: DatabaseEditPrivsArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     let privilege_data = if let Some(name) = &args.name { | ||||
|         get_database_privileges(name, connection).await? | ||||
|     } else { | ||||
|         get_all_database_privileges(connection).await? | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let message = Request::ListPrivileges(args.name.clone().map(|name| vec![name])); | ||||
|  | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let privilege_data = match server_connection.next().await { | ||||
|         Some(Ok(Response::ListPrivileges(databases))) => databases | ||||
|             .into_iter() | ||||
|             .filter_map(|(database_name, result)| match result { | ||||
|                 Ok(privileges) => Some(privileges), | ||||
|                 Err(err) => { | ||||
|                     eprintln!("{}", err.to_error_message(&database_name)); | ||||
|                     eprintln!("Skipping..."); | ||||
|                     println!(); | ||||
|                     None | ||||
|                 } | ||||
|             }) | ||||
|             .flatten() | ||||
|             .collect::<Vec<_>>(), | ||||
|         Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { | ||||
|             Ok(list) => list, | ||||
|             Err(err) => { | ||||
|                 server_connection.send(Request::Exit).await?; | ||||
|                 return Err(anyhow::anyhow!(err.to_error_message()) | ||||
|                     .context("Failed to list database privileges")); | ||||
|             } | ||||
|         }, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     // TODO: The data from args should not be absolute. | ||||
| @@ -316,22 +369,16 @@ pub async fn edit_privileges( | ||||
|         edit_privileges_with_editor(&privilege_data)? | ||||
|     }; | ||||
|  | ||||
|     for row in privileges_to_change.iter() { | ||||
|         if !user_exists(&row.user, connection).await? { | ||||
|             // TODO: allow user to return and correct their mistake | ||||
|             anyhow::bail!("User {} does not exist", row.user); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     let diffs = diff_privileges(&privilege_data, &privileges_to_change); | ||||
|  | ||||
|     if diffs.is_empty() { | ||||
|         println!("No changes to make."); | ||||
|         return Ok(CommandStatus::NoModificationsNeeded); | ||||
|         return Ok(()); | ||||
|     } | ||||
|  | ||||
|     println!("The following changes will be made:\n"); | ||||
|     println!("{}", display_privilege_diffs(&diffs)); | ||||
|  | ||||
|     if !args.yes | ||||
|         && !Confirm::new() | ||||
|             .with_prompt("Do you want to apply these changes?") | ||||
| @@ -339,15 +386,27 @@ pub async fn edit_privileges( | ||||
|             .show_default(true) | ||||
|             .interact()? | ||||
|     { | ||||
|         return Ok(CommandStatus::Cancelled); | ||||
|         server_connection.send(Request::Exit).await?; | ||||
|         return Ok(()); | ||||
|     } | ||||
|  | ||||
|     apply_privilege_diffs(diffs, connection).await?; | ||||
|     let message = Request::ModifyPrivileges(diffs); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     Ok(CommandStatus::SuccessfullyModified) | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::ModifyPrivileges(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     // TODO: allow user to return and correct their mistake | ||||
|     print_modify_database_privileges_output_status(&result); | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub fn parse_privilege_tables_from_args( | ||||
| fn parse_privilege_tables_from_args( | ||||
|     args: &DatabaseEditPrivsArgs, | ||||
| ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { | ||||
|     debug_assert!(!args.privs.is_empty()); | ||||
| @@ -371,20 +430,22 @@ pub fn parse_privilege_tables_from_args( | ||||
|     Ok(result) | ||||
| } | ||||
|  | ||||
| pub fn edit_privileges_with_editor( | ||||
| fn edit_privileges_with_editor( | ||||
|     privilege_data: &[DatabasePrivilegeRow], | ||||
| ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|     let unix_user = User::from_uid(getuid()) | ||||
|         .context("Failed to look up your UNIX username") | ||||
|         .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))?; | ||||
|  | ||||
|     let editor_content = | ||||
|         generate_editor_content_from_privilege_data(privilege_data, &unix_user.name); | ||||
|  | ||||
|     // TODO: handle errors better here | ||||
|     let result = Editor::new() | ||||
|         .extension("tsv") | ||||
|         .edit(&editor_content)? | ||||
|         .unwrap(); | ||||
|     let result = Editor::new().extension("tsv").edit(&editor_content)?; | ||||
|  | ||||
|     parse_privilege_data_from_editor_content(result) | ||||
|         .context("Could not parse privilege data from editor") | ||||
|     match result { | ||||
|         None => Ok(privilege_data.to_vec()), | ||||
|         Some(result) => parse_privilege_data_from_editor_content(result) | ||||
|             .context("Could not parse privilege data from editor"), | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| pub mod common; | ||||
| mod error_messages; | ||||
| pub mod mysql_dbadm; | ||||
| pub mod mysql_useradm; | ||||
|   | ||||
| @@ -1,57 +1,4 @@ | ||||
| use crate::core::common::{ | ||||
|     get_current_unix_user, validate_name_or_error, validate_ownership_or_error, DbOrUser, | ||||
| }; | ||||
|  | ||||
| /// In contrast to the new implementation which reports errors on any invalid name | ||||
| /// for any reason, mysql-admutils would only log the error and skip that particular | ||||
| /// name. This function replicates that behavior. | ||||
| pub fn filter_db_or_user_names( | ||||
|     names: Vec<String>, | ||||
|     db_or_user: DbOrUser, | ||||
| ) -> anyhow::Result<Vec<String>> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|     let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user { | ||||
|         DbOrUser::Database => "mysql-dbadm".to_string(), | ||||
|         DbOrUser::User => "mysql-useradm".to_string(), | ||||
|     }); | ||||
|  | ||||
|     let filtered_names = names | ||||
|         .into_iter() | ||||
|         // NOTE: The original implementation would only copy the first 32 characters | ||||
|         //       of the argument into it's internal buffer. We replicate that behavior | ||||
|         //       here. | ||||
|         .map(|name| name.chars().take(32).collect::<String>()) | ||||
|         .filter(|name| { | ||||
|             if let Err(_err) = validate_ownership_or_error(name, &unix_user, db_or_user) { | ||||
|                 println!( | ||||
|                     "You are not in charge of mysql-{}: '{}'.  Skipping.", | ||||
|                     db_or_user.lowercased(), | ||||
|                     name | ||||
|                 ); | ||||
|                 return false; | ||||
|             } | ||||
|             true | ||||
|         }) | ||||
|         .filter(|name| { | ||||
|             // NOTE: while this also checks for the length of the name, | ||||
|             //       the name is already truncated to 32 characters. So | ||||
|             //       if there is an error, it's guaranteed to be due to | ||||
|             //       invalid characters. | ||||
|             if let Err(_err) = validate_name_or_error(name, db_or_user) { | ||||
|                 println!( | ||||
|                     concat!( | ||||
|                         "{}: {} name '{}' contains invalid characters.\n", | ||||
|                         "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", | ||||
|                     ), | ||||
|                     argv0, | ||||
|                     db_or_user.capitalized(), | ||||
|                     name | ||||
|                 ); | ||||
|                 return false; | ||||
|             } | ||||
|             true | ||||
|         }) | ||||
|         .collect(); | ||||
|  | ||||
|     Ok(filtered_names) | ||||
| #[inline] | ||||
| pub fn trim_to_32_chars(name: &str) -> String { | ||||
|     name.chars().take(32).collect() | ||||
| } | ||||
|   | ||||
							
								
								
									
										176
									
								
								src/cli/mysql_admutils_compatibility/error_messages.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								src/cli/mysql_admutils_compatibility/error_messages.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,176 @@ | ||||
| use crate::core::protocol::{ | ||||
|     CreateDatabaseError, CreateUserError, DbOrUser, DropDatabaseError, DropUserError, | ||||
|     GetDatabasesPrivilegeDataError, ListUsersError, | ||||
| }; | ||||
|  | ||||
| pub fn name_validation_error_to_error_message(name: &str, db_or_user: DbOrUser) -> String { | ||||
|     let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user { | ||||
|         DbOrUser::Database => "mysql-dbadm".to_string(), | ||||
|         DbOrUser::User => "mysql-useradm".to_string(), | ||||
|     }); | ||||
|  | ||||
|     format!( | ||||
|         concat!( | ||||
|             "{}: {} name '{}' contains invalid characters.\n", | ||||
|             "Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.", | ||||
|         ), | ||||
|         argv0, | ||||
|         db_or_user.capitalized(), | ||||
|         name, | ||||
|     ) | ||||
| } | ||||
|  | ||||
| pub fn owner_validation_error_message(name: &str, db_or_user: DbOrUser) -> String { | ||||
|     format!( | ||||
|         "You are not in charge of mysql-{}: '{}'.  Skipping.", | ||||
|         db_or_user.lowercased(), | ||||
|         name | ||||
|     ) | ||||
| } | ||||
|  | ||||
| pub fn handle_create_user_error(error: CreateUserError, name: &str) { | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or_else(|| "mysql-useradm".to_string()); | ||||
|     match error { | ||||
|         CreateUserError::SanitizationError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 name_validation_error_to_error_message(name, DbOrUser::User) | ||||
|             ); | ||||
|         } | ||||
|         CreateUserError::OwnershipError(_) => { | ||||
|             eprintln!("{}", owner_validation_error_message(name, DbOrUser::User)); | ||||
|         } | ||||
|         CreateUserError::MySqlError(_) | CreateUserError::UserAlreadyExists => { | ||||
|             eprintln!("{}: Failed to create user '{}'.", argv0, name); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn handle_drop_user_error(error: DropUserError, name: &str) { | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or_else(|| "mysql-useradm".to_string()); | ||||
|     match error { | ||||
|         DropUserError::SanitizationError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 name_validation_error_to_error_message(name, DbOrUser::User) | ||||
|             ); | ||||
|         } | ||||
|         DropUserError::OwnershipError(_) => { | ||||
|             eprintln!("{}", owner_validation_error_message(name, DbOrUser::User)); | ||||
|         } | ||||
|         DropUserError::MySqlError(_) | DropUserError::UserDoesNotExist => { | ||||
|             eprintln!("{}: Failed to delete user '{}'.", argv0, name); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn handle_list_users_error(error: ListUsersError, name: &str) { | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or_else(|| "mysql-useradm".to_string()); | ||||
|     match error { | ||||
|         ListUsersError::SanitizationError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 name_validation_error_to_error_message(name, DbOrUser::User) | ||||
|             ); | ||||
|         } | ||||
|         ListUsersError::OwnershipError(_) => { | ||||
|             eprintln!("{}", owner_validation_error_message(name, DbOrUser::User)); | ||||
|         } | ||||
|         ListUsersError::UserDoesNotExist => { | ||||
|             eprintln!( | ||||
|                 "{}: User '{}' does not exist. You must create it first.", | ||||
|                 argv0, name, | ||||
|             ); | ||||
|         } | ||||
|         ListUsersError::MySqlError(_) => { | ||||
|             eprintln!("{}: Failed to look up password for user '{}'", argv0, name); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| pub fn handle_create_database_error(error: CreateDatabaseError, name: &str) { | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or_else(|| "mysql-dbadm".to_string()); | ||||
|     match error { | ||||
|         CreateDatabaseError::SanitizationError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 name_validation_error_to_error_message(name, DbOrUser::Database) | ||||
|             ); | ||||
|         } | ||||
|         CreateDatabaseError::OwnershipError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 owner_validation_error_message(name, DbOrUser::Database) | ||||
|             ); | ||||
|         } | ||||
|         CreateDatabaseError::MySqlError(_) => { | ||||
|             eprintln!("{}: Cannot create database '{}'.", argv0, name); | ||||
|         } | ||||
|         CreateDatabaseError::DatabaseAlreadyExists => { | ||||
|             eprintln!("{}: Database '{}' already exists.", argv0, name); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn handle_drop_database_error(error: DropDatabaseError, name: &str) { | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or_else(|| "mysql-dbadm".to_string()); | ||||
|     match error { | ||||
|         DropDatabaseError::SanitizationError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 name_validation_error_to_error_message(name, DbOrUser::Database) | ||||
|             ); | ||||
|         } | ||||
|         DropDatabaseError::OwnershipError(_) => { | ||||
|             eprintln!( | ||||
|                 "{}", | ||||
|                 owner_validation_error_message(name, DbOrUser::Database) | ||||
|             ); | ||||
|         } | ||||
|         DropDatabaseError::MySqlError(_) => { | ||||
|             eprintln!("{}: Cannot drop database '{}'.", argv0, name); | ||||
|         } | ||||
|         DropDatabaseError::DatabaseDoesNotExist => { | ||||
|             eprintln!("{}: Database '{}' doesn't exist.", argv0, name); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn format_show_database_error_message( | ||||
|     error: GetDatabasesPrivilegeDataError, | ||||
|     name: &str, | ||||
| ) -> String { | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or_else(|| "mysql-dbadm".to_string()); | ||||
|  | ||||
|     match error { | ||||
|         GetDatabasesPrivilegeDataError::SanitizationError(_) => { | ||||
|             name_validation_error_to_error_message(name, DbOrUser::Database) | ||||
|         } | ||||
|         GetDatabasesPrivilegeDataError::OwnershipError(_) => { | ||||
|             owner_validation_error_message(name, DbOrUser::Database) | ||||
|         } | ||||
|         GetDatabasesPrivilegeDataError::MySqlError(err) => { | ||||
|             format!( | ||||
|                 "{}: Failed to look up privileges for database '{}': {}", | ||||
|                 argv0, name, err | ||||
|             ) | ||||
|         } | ||||
|         GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => { | ||||
|             format!("{}: Database '{}' doesn't exist.", argv0, name) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -1,14 +1,29 @@ | ||||
| use clap::Parser; | ||||
| use sqlx::MySqlConnection; | ||||
| use futures_util::{SinkExt, StreamExt}; | ||||
| use std::os::unix::net::UnixStream as StdUnixStream; | ||||
| use std::path::PathBuf; | ||||
| use tokio::net::UnixStream as TokioUnixStream; | ||||
|  | ||||
| use crate::{ | ||||
|     cli::{database_command, mysql_admutils_compatibility::common::filter_db_or_user_names}, | ||||
|     core::{ | ||||
|         common::{yn, DbOrUser}, | ||||
|         config::{create_mysql_connection_from_config, get_config, GlobalConfigArgs}, | ||||
|         database_operations::{create_database, drop_database, get_database_list}, | ||||
|         database_privilege_operations, | ||||
|     cli::{ | ||||
|         common::erroneous_server_response, | ||||
|         database_command, | ||||
|         mysql_admutils_compatibility::{ | ||||
|             common::trim_to_32_chars, | ||||
|             error_messages::{ | ||||
|                 format_show_database_error_message, handle_create_database_error, | ||||
|                 handle_drop_database_error, | ||||
|             }, | ||||
|         }, | ||||
|     }, | ||||
|     core::{ | ||||
|         bootstrap::bootstrap_server_connection_and_drop_privileges, | ||||
|         protocol::{ | ||||
|             create_client_to_server_message_stream, ClientToServerMessageStream, | ||||
|             GetDatabasesPrivilegeDataError, Request, Response, | ||||
|         }, | ||||
|     }, | ||||
|     server::sql::database_privilege_operations::DatabasePrivilegeRow, | ||||
| }; | ||||
|  | ||||
| const HELP_DB_PERM: &str = r#" | ||||
| @@ -39,8 +54,25 @@ pub struct Args { | ||||
|     #[command(subcommand)] | ||||
|     pub command: Option<Command>, | ||||
|  | ||||
|     #[command(flatten)] | ||||
|     config_overrides: GlobalConfigArgs, | ||||
|     /// Path to the socket of the server, if it already exists. | ||||
|     #[arg( | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true | ||||
|     )] | ||||
|     server_socket_path: Option<PathBuf>, | ||||
|  | ||||
|     /// Config file to use for the server. | ||||
|     #[arg( | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true | ||||
|     )] | ||||
|     config: Option<PathBuf>, | ||||
|  | ||||
|     /// Print help for the 'editperm' subcommand. | ||||
|     #[arg(long, global = true)] | ||||
| @@ -76,7 +108,7 @@ pub enum Command { | ||||
|     /// to make changes to the permission table. | ||||
|     /// Run 'mysql-dbadm --help-editperm' for more | ||||
|     /// information. | ||||
|     EditPerm(EditPermArgs), | ||||
|     Editperm(EditPermArgs), | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| @@ -106,7 +138,7 @@ pub struct EditPermArgs { | ||||
|     pub database: String, | ||||
| } | ||||
|  | ||||
| pub async fn main() -> anyhow::Result<()> { | ||||
| pub fn main() -> anyhow::Result<()> { | ||||
|     let args: Args = Args::parse(); | ||||
|  | ||||
|     if args.help_editperm { | ||||
| @@ -114,6 +146,9 @@ pub async fn main() -> anyhow::Result<()> { | ||||
|         return Ok(()); | ||||
|     } | ||||
|  | ||||
|     let server_connection = | ||||
|         bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; | ||||
|  | ||||
|     let command = match args.command { | ||||
|         Some(command) => command, | ||||
|         None => { | ||||
| @@ -125,64 +160,164 @@ pub async fn main() -> anyhow::Result<()> { | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     let config = get_config(args.config_overrides)?; | ||||
|     let mut connection = create_mysql_connection_from_config(config.mysql).await?; | ||||
|     tokio_run_command(command, server_connection)?; | ||||
|  | ||||
|     match command { | ||||
|         Command::Create(args) => { | ||||
|             let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; | ||||
|             for name in filtered_names { | ||||
|                 create_database(&name, &mut connection).await?; | ||||
|                 println!("Database {} created.", name); | ||||
|             } | ||||
|         } | ||||
|         Command::Drop(args) => { | ||||
|             let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; | ||||
|             for name in filtered_names { | ||||
|                 drop_database(&name, &mut connection).await?; | ||||
|                 println!("Database {} dropped.", name); | ||||
|             } | ||||
|         } | ||||
|         Command::Show(args) => { | ||||
|             let names = if args.name.is_empty() { | ||||
|                 get_database_list(&mut connection).await? | ||||
|             } else { | ||||
|                 filter_db_or_user_names(args.name, DbOrUser::Database)? | ||||
|             }; | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
|             for name in names { | ||||
|                 show_db(&name, &mut connection).await?; | ||||
|             } | ||||
|         } | ||||
|         Command::EditPerm(args) => { | ||||
|             // TODO: This does not accurately replicate the behavior of the old implementation. | ||||
|             //       Hopefully, not many people rely on this in an automated fashion, as it | ||||
|             //       is made to be interactive in nature. However, we should still try to | ||||
|             //        replicate the old behavior as closely as possible. | ||||
|             let edit_privileges_args = database_command::DatabaseEditPrivsArgs { | ||||
|                 name: Some(args.database), | ||||
|                 privs: vec![], | ||||
|                 json: false, | ||||
|                 editor: None, | ||||
|                 yes: false, | ||||
|             }; | ||||
| fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> { | ||||
|     tokio::runtime::Builder::new_current_thread() | ||||
|         .enable_all() | ||||
|         .build() | ||||
|         .unwrap() | ||||
|         .block_on(async { | ||||
|             let tokio_socket = TokioUnixStream::from_std(server_connection)?; | ||||
|             let message_stream = create_client_to_server_message_stream(tokio_socket); | ||||
|             match command { | ||||
|                 Command::Create(args) => create_databases(args, message_stream).await, | ||||
|                 Command::Drop(args) => drop_databases(args, message_stream).await, | ||||
|                 Command::Show(args) => show_databases(args, message_stream).await, | ||||
|                 Command::Editperm(args) => { | ||||
|                     let edit_privileges_args = database_command::DatabaseEditPrivsArgs { | ||||
|                         name: Some(args.database), | ||||
|                         privs: vec![], | ||||
|                         json: false, | ||||
|                         // TODO: use this to mimic the old editor-finding logic | ||||
|                         editor: None, | ||||
|                         yes: false, | ||||
|                     }; | ||||
|  | ||||
|             database_command::edit_privileges(edit_privileges_args, &mut connection).await?; | ||||
|                     database_command::edit_database_privileges(edit_privileges_args, message_stream) | ||||
|                         .await | ||||
|                 } | ||||
|             } | ||||
|         }) | ||||
| } | ||||
|  | ||||
| async fn create_databases( | ||||
|     args: CreateArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let database_names = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = Request::CreateDatabases(database_names); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::CreateDatabases(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     for (name, result) in result { | ||||
|         match result { | ||||
|             Ok(()) => println!("Database {} created.", name), | ||||
|             Err(err) => handle_create_database_error(err, &name), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn show_db(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { | ||||
| async fn drop_databases( | ||||
|     args: DatabaseDropArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let database_names = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = Request::DropDatabases(database_names); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::DropDatabases(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     for (name, result) in result { | ||||
|         match result { | ||||
|             Ok(()) => println!("Database {} dropped.", name), | ||||
|             Err(err) => handle_drop_database_error(err, &name), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn show_databases( | ||||
|     args: DatabaseShowArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let database_names: Vec<String> = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = if database_names.is_empty() { | ||||
|         let message = Request::ListDatabases; | ||||
|         server_connection.send(message).await?; | ||||
|         let response = server_connection.next().await; | ||||
|         let databases = match response { | ||||
|             Some(Ok(Response::ListAllDatabases(databases))) => databases.unwrap_or(vec![]), | ||||
|             response => return erroneous_server_response(response), | ||||
|         }; | ||||
|  | ||||
|         Request::ListPrivileges(Some(databases)) | ||||
|     } else { | ||||
|         Request::ListPrivileges(Some(database_names)) | ||||
|     }; | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let response = server_connection.next().await; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     // NOTE: mysql-dbadm show has a quirk where valid database names | ||||
|     //       for non-existent databases will report with no users. | ||||
|     //       This function should *not* check for db existence, only | ||||
|     //       validate the names. | ||||
|     let privileges = database_privilege_operations::get_database_privileges(name, connection) | ||||
|         .await | ||||
|         .unwrap_or(vec![]); | ||||
|     let results: Vec<Result<(String, Vec<DatabasePrivilegeRow>), String>> = match response { | ||||
|         Some(Ok(Response::ListPrivileges(result))) => result | ||||
|             .into_iter() | ||||
|             .map(|(name, rows)| match rows.map(|rows| (name.clone(), rows)) { | ||||
|                 Ok(rows) => Ok(rows), | ||||
|                 Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])), | ||||
|                 Err(err) => Err(format_show_database_error_message(err, &name)), | ||||
|             }) | ||||
|             .collect(), | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     results.into_iter().try_for_each(|result| match result { | ||||
|         Ok((name, rows)) => print_db_privs(&name, rows), | ||||
|         Err(err) => { | ||||
|             eprintln!("{}", err); | ||||
|             Ok(()) | ||||
|         } | ||||
|     })?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| fn yn(value: bool) -> &'static str { | ||||
|     if value { | ||||
|         "Y" | ||||
|     } else { | ||||
|         "N" | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn print_db_privs(name: &str, rows: Vec<DatabasePrivilegeRow>) -> anyhow::Result<()> { | ||||
|     println!( | ||||
|         concat!( | ||||
|             "Database '{}':\n", | ||||
| @@ -191,10 +326,10 @@ async fn show_db(name: &str, connection: &mut MySqlConnection) -> anyhow::Result | ||||
|         ), | ||||
|         name, | ||||
|     ); | ||||
|     if privileges.is_empty() { | ||||
|     if rows.is_empty() { | ||||
|         println!("# (no permissions currently granted to any users)"); | ||||
|     } else { | ||||
|         for privilege in privileges { | ||||
|         for privilege in rows { | ||||
|             println!( | ||||
|                 "  {:<16}      {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {}", | ||||
|                 privilege.user, | ||||
|   | ||||
| @@ -1,13 +1,28 @@ | ||||
| use clap::Parser; | ||||
| use sqlx::MySqlConnection; | ||||
| use futures_util::{SinkExt, StreamExt}; | ||||
| use std::path::PathBuf; | ||||
|  | ||||
| use std::os::unix::net::UnixStream as StdUnixStream; | ||||
| use tokio::net::UnixStream as TokioUnixStream; | ||||
|  | ||||
| use crate::{ | ||||
|     cli::{mysql_admutils_compatibility::common::filter_db_or_user_names, user_command}, | ||||
|     core::{ | ||||
|         common::{close_database_connection, get_current_unix_user, DbOrUser}, | ||||
|         config::{create_mysql_connection_from_config, get_config, GlobalConfigArgs}, | ||||
|         user_operations::*, | ||||
|     cli::{ | ||||
|         common::erroneous_server_response, | ||||
|         mysql_admutils_compatibility::{ | ||||
|             common::trim_to_32_chars, | ||||
|             error_messages::{ | ||||
|                 handle_create_user_error, handle_drop_user_error, handle_list_users_error, | ||||
|             }, | ||||
|         }, | ||||
|         user_command::read_password_from_stdin_with_double_check, | ||||
|     }, | ||||
|     core::{ | ||||
|         bootstrap::bootstrap_server_connection_and_drop_privileges, | ||||
|         protocol::{ | ||||
|             create_client_to_server_message_stream, ClientToServerMessageStream, Request, Response, | ||||
|         }, | ||||
|     }, | ||||
|     server::sql::user_operations::DatabaseUser, | ||||
| }; | ||||
|  | ||||
| #[derive(Parser)] | ||||
| @@ -15,8 +30,25 @@ pub struct Args { | ||||
|     #[command(subcommand)] | ||||
|     pub command: Option<Command>, | ||||
|  | ||||
|     #[command(flatten)] | ||||
|     config_overrides: GlobalConfigArgs, | ||||
|     /// Path to the socket of the server, if it already exists. | ||||
|     #[arg( | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true | ||||
|     )] | ||||
|     server_socket_path: Option<PathBuf>, | ||||
|  | ||||
|     /// Config file to use for the server. | ||||
|     #[arg( | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true | ||||
|     )] | ||||
|     config: Option<PathBuf>, | ||||
| } | ||||
|  | ||||
| /// Create, delete or change password for the USER(s), | ||||
| @@ -69,7 +101,7 @@ pub struct ShowArgs { | ||||
|     name: Vec<String>, | ||||
| } | ||||
|  | ||||
| pub async fn main() -> anyhow::Result<()> { | ||||
| pub fn main() -> anyhow::Result<()> { | ||||
|     let args: Args = Args::parse(); | ||||
|  | ||||
|     let command = match args.command { | ||||
| @@ -85,78 +117,185 @@ pub async fn main() -> anyhow::Result<()> { | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     let config = get_config(args.config_overrides)?; | ||||
|     let mut connection = create_mysql_connection_from_config(config.mysql).await?; | ||||
|     let server_connection = | ||||
|         bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; | ||||
|  | ||||
|     match command { | ||||
|         Command::Create(args) => { | ||||
|             let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?; | ||||
|             for name in filtered_names { | ||||
|                 create_database_user(&name, &mut connection).await?; | ||||
|             } | ||||
|         } | ||||
|         Command::Delete(args) => { | ||||
|             let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?; | ||||
|             for name in filtered_names { | ||||
|                 delete_database_user(&name, &mut connection).await?; | ||||
|             } | ||||
|         } | ||||
|         Command::Passwd(args) => passwd(args, &mut connection).await?, | ||||
|         Command::Show(args) => show(args, &mut connection).await?, | ||||
|     } | ||||
|  | ||||
|     close_database_connection(connection).await; | ||||
|     tokio_run_command(command, server_connection)?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn passwd(args: PasswdArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> { | ||||
|     let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?; | ||||
|  | ||||
|     // NOTE: this gets doubly checked during the call to `set_password_for_database_user`. | ||||
|     //       This is moving the check before asking the user for the password, | ||||
|     //       to avoid having them figure out that the user does not exist after they | ||||
|     //       have entered the password twice. | ||||
|     let mut better_filtered_names = Vec::with_capacity(filtered_names.len()); | ||||
|     for name in filtered_names.into_iter() { | ||||
|         if !user_exists(&name, connection).await? { | ||||
|             println!( | ||||
|                 "{}: User '{}' does not exist. You must create it first.", | ||||
|                 std::env::args() | ||||
|                     .next() | ||||
|                     .unwrap_or("mysql-useradm".to_string()), | ||||
|                 name, | ||||
|             ); | ||||
|         } else { | ||||
|             better_filtered_names.push(name); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     for name in better_filtered_names { | ||||
|         let password = user_command::read_password_from_stdin_with_double_check(&name)?; | ||||
|         set_password_for_database_user(&name, &password, connection).await?; | ||||
|         println!("Password updated for user '{}'.", name); | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> { | ||||
|     tokio::runtime::Builder::new_current_thread() | ||||
|         .enable_all() | ||||
|         .build() | ||||
|         .unwrap() | ||||
|         .block_on(async { | ||||
|             let tokio_socket = TokioUnixStream::from_std(server_connection)?; | ||||
|             let message_stream = create_client_to_server_message_stream(tokio_socket); | ||||
|             match command { | ||||
|                 Command::Create(args) => create_user(args, message_stream).await, | ||||
|                 Command::Delete(args) => drop_users(args, message_stream).await, | ||||
|                 Command::Passwd(args) => passwd_users(args, message_stream).await, | ||||
|                 Command::Show(args) => show_users(args, message_stream).await, | ||||
|             } | ||||
|         }) | ||||
| } | ||||
|  | ||||
| async fn show(args: ShowArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> { | ||||
|     let users = if args.name.is_empty() { | ||||
|         let unix_user = get_current_unix_user()?; | ||||
|         get_all_database_users_for_unix_user(&unix_user, connection).await? | ||||
|     } else { | ||||
|         let filtered_usernames = filter_db_or_user_names(args.name, DbOrUser::User)?; | ||||
|         let mut result = Vec::with_capacity(filtered_usernames.len()); | ||||
|         for username in filtered_usernames.iter() { | ||||
|             // TODO: fetch all users in one query | ||||
|             if let Some(user) = get_database_user_for_user(username, connection).await? { | ||||
|                 result.push(user) | ||||
|             } | ||||
|         } | ||||
|         result | ||||
| async fn create_user( | ||||
|     args: CreateArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let usernames = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = Request::CreateUsers(usernames); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::CreateUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     for (name, result) in result { | ||||
|         match result { | ||||
|             Ok(()) => println!("User '{}' created.", name), | ||||
|             Err(err) => handle_create_user_error(err, &name), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn drop_users( | ||||
|     args: DeleteArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let usernames = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = Request::DropUsers(usernames); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::DropUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     for (name, result) in result { | ||||
|         match result { | ||||
|             Ok(()) => println!("User '{}' deleted.", name), | ||||
|             Err(err) => handle_drop_user_error(err, &name), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn passwd_users( | ||||
|     args: PasswdArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let usernames = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = Request::ListUsers(Some(usernames)); | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let response = match server_connection.next().await { | ||||
|         Some(Ok(Response::ListUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     let argv0 = std::env::args() | ||||
|         .next() | ||||
|         .unwrap_or("mysql-useradm".to_string()); | ||||
|  | ||||
|     let users = response | ||||
|         .into_iter() | ||||
|         .filter_map(|(name, result)| match result { | ||||
|             Ok(user) => Some(user), | ||||
|             Err(err) => { | ||||
|                 handle_list_users_error(err, &name); | ||||
|                 None | ||||
|             } | ||||
|         }) | ||||
|         .collect::<Vec<_>>(); | ||||
|  | ||||
|     for user in users { | ||||
|         let password = read_password_from_stdin_with_double_check(&user.user)?; | ||||
|         let message = Request::PasswdUser(user.user.clone(), password); | ||||
|         server_connection.send(message).await?; | ||||
|         match server_connection.next().await { | ||||
|             Some(Ok(Response::PasswdUser(result))) => match result { | ||||
|                 Ok(()) => println!("Password updated for user '{}'.", user.user), | ||||
|                 Err(_) => eprintln!( | ||||
|                     "{}: Failed to update password for user '{}'.", | ||||
|                     argv0, user.user, | ||||
|                 ), | ||||
|             }, | ||||
|             response => return erroneous_server_response(response), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn show_users( | ||||
|     args: ShowArgs, | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let usernames: Vec<_> = args | ||||
|         .name | ||||
|         .iter() | ||||
|         .map(|name| trim_to_32_chars(name)) | ||||
|         .collect(); | ||||
|  | ||||
|     let message = if usernames.is_empty() { | ||||
|         Request::ListUsers(None) | ||||
|     } else { | ||||
|         Request::ListUsers(Some(usernames)) | ||||
|     }; | ||||
|     server_connection.send(message).await?; | ||||
|  | ||||
|     let users: Vec<DatabaseUser> = match server_connection.next().await { | ||||
|         Some(Ok(Response::ListAllUsers(result))) => match result { | ||||
|             Ok(users) => users, | ||||
|             Err(err) => { | ||||
|                 println!("Failed to list users: {:?}", err); | ||||
|                 return Ok(()); | ||||
|             } | ||||
|         }, | ||||
|         Some(Ok(Response::ListUsers(result))) => result | ||||
|             .into_iter() | ||||
|             .filter_map(|(name, result)| match result { | ||||
|                 Ok(user) => Some(user), | ||||
|                 Err(err) => { | ||||
|                     handle_list_users_error(err, &name); | ||||
|                     None | ||||
|                 } | ||||
|             }) | ||||
|             .collect(), | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     for user in users { | ||||
|         if user.has_password { | ||||
|             println!("User '{}': password set.", user.user); | ||||
|   | ||||
| @@ -1,27 +1,24 @@ | ||||
| use std::collections::BTreeMap; | ||||
| use std::vec; | ||||
|  | ||||
| use anyhow::Context; | ||||
| use clap::Parser; | ||||
| use dialoguer::{Confirm, Password}; | ||||
| use prettytable::Table; | ||||
| use serde_json::json; | ||||
| use sqlx::{Connection, MySqlConnection}; | ||||
| use futures_util::{SinkExt, StreamExt}; | ||||
|  | ||||
| use crate::core::{ | ||||
|     common::{close_database_connection, get_current_unix_user, CommandStatus}, | ||||
|     database_operations::*, | ||||
|     user_operations::*, | ||||
| use crate::core::protocol::{ | ||||
|     print_create_users_output_status, print_drop_users_output_status, | ||||
|     print_lock_users_output_status, print_set_password_output_status, | ||||
|     print_unlock_users_output_status, ClientToServerMessageStream, Request, Response, | ||||
| }; | ||||
|  | ||||
| #[derive(Parser)] | ||||
| use super::common::erroneous_server_response; | ||||
|  | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserArgs { | ||||
|     #[clap(subcommand)] | ||||
|     subcmd: UserCommand, | ||||
| } | ||||
|  | ||||
| #[allow(clippy::enum_variant_names)] | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub enum UserCommand { | ||||
|     /// Create one or more users | ||||
|     #[command()] | ||||
| @@ -50,7 +47,7 @@ pub enum UserCommand { | ||||
|     UnlockUser(UserUnlockArgs), | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserCreateArgs { | ||||
|     #[arg(num_args = 1..)] | ||||
|     username: Vec<String>, | ||||
| @@ -60,13 +57,13 @@ pub struct UserCreateArgs { | ||||
|     no_password: bool, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserDeleteArgs { | ||||
|     #[arg(num_args = 1..)] | ||||
|     username: Vec<String>, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserPasswdArgs { | ||||
|     username: String, | ||||
|  | ||||
| @@ -74,7 +71,7 @@ pub struct UserPasswdArgs { | ||||
|     password_file: Option<String>, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserShowArgs { | ||||
|     #[arg(num_args = 0..)] | ||||
|     username: Vec<String>, | ||||
| @@ -83,13 +80,13 @@ pub struct UserShowArgs { | ||||
|     json: bool, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserLockArgs { | ||||
|     #[arg(num_args = 1..)] | ||||
|     username: Vec<String>, | ||||
| } | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct UserUnlockArgs { | ||||
|     #[arg(num_args = 1..)] | ||||
|     username: Vec<String>, | ||||
| @@ -97,48 +94,45 @@ pub struct UserUnlockArgs { | ||||
|  | ||||
| pub async fn handle_command( | ||||
|     command: UserCommand, | ||||
|     mut connection: MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     let result = connection | ||||
|         .transaction(|txn| { | ||||
|             Box::pin(async move { | ||||
|                 match command { | ||||
|                     UserCommand::CreateUser(args) => create_users(args, txn).await, | ||||
|                     UserCommand::DropUser(args) => drop_users(args, txn).await, | ||||
|                     UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await, | ||||
|                     UserCommand::ShowUser(args) => show_users(args, txn).await, | ||||
|                     UserCommand::LockUser(args) => lock_users(args, txn).await, | ||||
|                     UserCommand::UnlockUser(args) => unlock_users(args, txn).await, | ||||
|                 } | ||||
|             }) | ||||
|         }) | ||||
|         .await; | ||||
|  | ||||
|     close_database_connection(connection).await; | ||||
|  | ||||
|     result | ||||
|     server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     match command { | ||||
|         UserCommand::CreateUser(args) => create_users(args, server_connection).await, | ||||
|         UserCommand::DropUser(args) => drop_users(args, server_connection).await, | ||||
|         UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await, | ||||
|         UserCommand::ShowUser(args) => show_users(args, server_connection).await, | ||||
|         UserCommand::LockUser(args) => lock_users(args, server_connection).await, | ||||
|         UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await, | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn create_users( | ||||
|     args: UserCreateArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     if args.username.is_empty() { | ||||
|         anyhow::bail!("No usernames provided"); | ||||
|     } | ||||
|  | ||||
|     let mut result = CommandStatus::SuccessfullyModified; | ||||
|     let message = Request::CreateUsers(args.username.clone()); | ||||
|     if let Err(err) = server_connection.send(message).await { | ||||
|         server_connection.close().await.ok(); | ||||
|         anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); | ||||
|     } | ||||
|  | ||||
|     for username in args.username { | ||||
|         if let Err(e) = create_database_user(&username, connection).await { | ||||
|             eprintln!("{}", e); | ||||
|             eprintln!("Skipping...\n"); | ||||
|             result = CommandStatus::PartiallySuccessfullyModified; | ||||
|             continue; | ||||
|         } else { | ||||
|             println!("User '{}' created.", username); | ||||
|         } | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::CreateUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     print_create_users_output_status(&result); | ||||
|  | ||||
|     let successfully_created_users = result | ||||
|         .iter() | ||||
|         .filter_map(|(username, result)| result.as_ref().ok().map(|_| username)) | ||||
|         .collect::<Vec<_>>(); | ||||
|  | ||||
|     for username in successfully_created_users { | ||||
|         if !args.no_password | ||||
|             && Confirm::new() | ||||
|                 .with_prompt(format!( | ||||
| @@ -147,41 +141,55 @@ async fn create_users( | ||||
|                 )) | ||||
|                 .interact()? | ||||
|         { | ||||
|             change_password_for_user( | ||||
|                 UserPasswdArgs { | ||||
|                     username, | ||||
|                     password_file: None, | ||||
|                 }, | ||||
|                 connection, | ||||
|             ) | ||||
|             .await?; | ||||
|             let password = read_password_from_stdin_with_double_check(username)?; | ||||
|             let message = Request::PasswdUser(username.clone(), password); | ||||
|  | ||||
|             if let Err(err) = server_connection.send(message).await { | ||||
|                 server_connection.close().await.ok(); | ||||
|                 anyhow::bail!(err); | ||||
|             } | ||||
|  | ||||
|             match server_connection.next().await { | ||||
|                 Some(Ok(Response::PasswdUser(result))) => { | ||||
|                     print_set_password_output_status(&result, username) | ||||
|                 } | ||||
|                 response => return erroneous_server_response(response), | ||||
|             } | ||||
|  | ||||
|             println!(); | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
|     Ok(result) | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn drop_users( | ||||
|     args: UserDeleteArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     if args.username.is_empty() { | ||||
|         anyhow::bail!("No usernames provided"); | ||||
|     } | ||||
|  | ||||
|     let mut result = CommandStatus::SuccessfullyModified; | ||||
|     let message = Request::DropUsers(args.username.clone()); | ||||
|  | ||||
|     for username in args.username { | ||||
|         if let Err(e) = delete_database_user(&username, connection).await { | ||||
|             eprintln!("{}", e); | ||||
|             eprintln!("Skipping..."); | ||||
|             result = CommandStatus::PartiallySuccessfullyModified; | ||||
|         } else { | ||||
|             println!("User '{}' dropped.", username); | ||||
|         } | ||||
|     if let Err(err) = server_connection.send(message).await { | ||||
|         server_connection.close().await.ok(); | ||||
|         anyhow::bail!(err); | ||||
|     } | ||||
|  | ||||
|     Ok(result) | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::DropUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     print_drop_users_output_status(&result); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> { | ||||
| @@ -195,15 +203,10 @@ pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Res | ||||
|         .map_err(Into::into) | ||||
| } | ||||
|  | ||||
| async fn change_password_for_user( | ||||
| async fn passwd_user( | ||||
|     args: UserPasswdArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     // NOTE: although this also is checked in `set_password_for_database_user`, we check it here | ||||
|     //       to provide a more natural order of error messages. | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|     validate_user_name(&args.username, &unix_user)?; | ||||
|  | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let password = if let Some(password_file) = args.password_file { | ||||
|         std::fs::read_to_string(password_file) | ||||
|             .context("Failed to read password file")? | ||||
| @@ -213,129 +216,146 @@ async fn change_password_for_user( | ||||
|         read_password_from_stdin_with_double_check(&args.username)? | ||||
|     }; | ||||
|  | ||||
|     set_password_for_database_user(&args.username, &password, connection).await?; | ||||
|     let message = Request::PasswdUser(args.username.clone(), password); | ||||
|  | ||||
|     Ok(CommandStatus::SuccessfullyModified) | ||||
|     if let Err(err) = server_connection.send(message).await { | ||||
|         server_connection.close().await.ok(); | ||||
|         anyhow::bail!(err); | ||||
|     } | ||||
|  | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::PasswdUser(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     print_set_password_output_status(&result, &args.username); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn show_users( | ||||
|     args: UserShowArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     let users = if args.username.is_empty() { | ||||
|         get_all_database_users_for_unix_user(&unix_user, connection).await? | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let message = if args.username.is_empty() { | ||||
|         Request::ListUsers(None) | ||||
|     } else { | ||||
|         let mut result = vec![]; | ||||
|         for username in args.username { | ||||
|             if let Err(e) = validate_user_name(&username, &unix_user) { | ||||
|                 eprintln!("{}", e); | ||||
|                 eprintln!("Skipping..."); | ||||
|                 continue; | ||||
|             } | ||||
|  | ||||
|             let user = get_database_user_for_user(&username, connection).await?; | ||||
|             if let Some(user) = user { | ||||
|                 result.push(user); | ||||
|             } else { | ||||
|                 eprintln!("User not found: {}", username); | ||||
|             } | ||||
|         } | ||||
|         result | ||||
|         Request::ListUsers(Some(args.username.clone())) | ||||
|     }; | ||||
|  | ||||
|     let mut user_databases: BTreeMap<String, Vec<String>> = BTreeMap::new(); | ||||
|     for user in users.iter() { | ||||
|         user_databases.insert( | ||||
|             user.user.clone(), | ||||
|             get_databases_where_user_has_privileges(&user.user, connection).await?, | ||||
|         ); | ||||
|     if let Err(err) = server_connection.send(message).await { | ||||
|         server_connection.close().await.ok(); | ||||
|         anyhow::bail!(err); | ||||
|     } | ||||
|  | ||||
|     if args.json { | ||||
|         let users_json = users | ||||
|     let users = match server_connection.next().await { | ||||
|         Some(Ok(Response::ListUsers(users))) => users | ||||
|             .into_iter() | ||||
|             .map(|user| { | ||||
|                 json!({ | ||||
|                     "user": user.user, | ||||
|                     "has_password": user.has_password, | ||||
|                     "is_locked": user.is_locked, | ||||
|                     "databases": user_databases.get(&user.user).unwrap_or(&vec![]), | ||||
|                 }) | ||||
|             .filter_map(|(username, result)| match result { | ||||
|                 Ok(user) => Some(user), | ||||
|                 Err(err) => { | ||||
|                     eprintln!("{}", err.to_error_message(&username)); | ||||
|                     eprintln!("Skipping..."); | ||||
|                     None | ||||
|                 } | ||||
|             }) | ||||
|             .collect::<serde_json::Value>(); | ||||
|             .collect::<Vec<_>>(), | ||||
|         Some(Ok(Response::ListAllUsers(users))) => match users { | ||||
|             Ok(users) => users, | ||||
|             Err(err) => { | ||||
|                 server_connection.send(Request::Exit).await?; | ||||
|                 return Err( | ||||
|                     anyhow::anyhow!(err.to_error_message()).context("Failed to list all users") | ||||
|                 ); | ||||
|             } | ||||
|         }, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     // TODO: print databases where user has privileges | ||||
|     if args.json { | ||||
|         println!( | ||||
|             "{}", | ||||
|             serde_json::to_string_pretty(&users_json) | ||||
|                 .context("Failed to serialize users to JSON")? | ||||
|             serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")? | ||||
|         ); | ||||
|     } else if users.is_empty() { | ||||
|         println!("No users found."); | ||||
|         println!("No users to show."); | ||||
|     } else { | ||||
|         let mut table = Table::new(); | ||||
|         let mut table = prettytable::Table::new(); | ||||
|         table.add_row(row![ | ||||
|             "User", | ||||
|             "Password is set", | ||||
|             "Locked", | ||||
|             "Databases where user has privileges" | ||||
|             // "Databases where user has privileges" | ||||
|         ]); | ||||
|         for user in users { | ||||
|             table.add_row(row![ | ||||
|                 user.user, | ||||
|                 user.has_password, | ||||
|                 user.is_locked, | ||||
|                 user_databases.get(&user.user).unwrap_or(&vec![]).join("\n") | ||||
|                 // user.databases.join("\n") | ||||
|             ]); | ||||
|         } | ||||
|         table.printstd(); | ||||
|     } | ||||
|  | ||||
|     Ok(CommandStatus::NoModificationsIntended) | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn lock_users( | ||||
|     args: UserLockArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     if args.username.is_empty() { | ||||
|         anyhow::bail!("No usernames provided"); | ||||
|     } | ||||
|  | ||||
|     let mut result = CommandStatus::SuccessfullyModified; | ||||
|     let message = Request::LockUsers(args.username.clone()); | ||||
|  | ||||
|     for username in args.username { | ||||
|         if let Err(e) = lock_database_user(&username, connection).await { | ||||
|             eprintln!("{}", e); | ||||
|             eprintln!("Skipping..."); | ||||
|             result = CommandStatus::PartiallySuccessfullyModified; | ||||
|         } else { | ||||
|             println!("User '{}' locked.", username); | ||||
|         } | ||||
|     if let Err(err) = server_connection.send(message).await { | ||||
|         server_connection.close().await.ok(); | ||||
|         anyhow::bail!(err); | ||||
|     } | ||||
|  | ||||
|     Ok(result) | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::LockUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     print_lock_users_output_status(&result); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn unlock_users( | ||||
|     args: UserUnlockArgs, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<CommandStatus> { | ||||
|     mut server_connection: ClientToServerMessageStream, | ||||
| ) -> anyhow::Result<()> { | ||||
|     if args.username.is_empty() { | ||||
|         anyhow::bail!("No usernames provided"); | ||||
|     } | ||||
|  | ||||
|     let mut result = CommandStatus::SuccessfullyModified; | ||||
|     let message = Request::UnlockUsers(args.username.clone()); | ||||
|  | ||||
|     for username in args.username { | ||||
|         if let Err(e) = unlock_database_user(&username, connection).await { | ||||
|             eprintln!("{}", e); | ||||
|             eprintln!("Skipping..."); | ||||
|             result = CommandStatus::PartiallySuccessfullyModified; | ||||
|         } else { | ||||
|             println!("User '{}' unlocked.", username); | ||||
|         } | ||||
|     if let Err(err) = server_connection.send(message).await { | ||||
|         server_connection.close().await.ok(); | ||||
|         anyhow::bail!(err); | ||||
|     } | ||||
|  | ||||
|     Ok(result) | ||||
|     let result = match server_connection.next().await { | ||||
|         Some(Ok(Response::UnlockUsers(result))) => result, | ||||
|         response => return erroneous_server_response(response), | ||||
|     }; | ||||
|  | ||||
|     server_connection.send(Request::Exit).await?; | ||||
|  | ||||
|     print_unlock_users_output_status(&result); | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,4 @@ | ||||
| pub mod bootstrap; | ||||
| pub mod common; | ||||
| pub mod config; | ||||
| pub mod database_operations; | ||||
| pub mod database_privilege_operations; | ||||
| pub mod user_operations; | ||||
| pub mod database_privileges; | ||||
| pub mod protocol; | ||||
|   | ||||
							
								
								
									
										177
									
								
								src/core/bootstrap.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								src/core/bootstrap.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,177 @@ | ||||
| use std::{fs, path::PathBuf}; | ||||
|  | ||||
| use anyhow::Context; | ||||
| use nix::libc::{exit, EXIT_SUCCESS}; | ||||
| use std::os::unix::net::UnixStream as StdUnixStream; | ||||
| use tokio::net::UnixStream as TokioUnixStream; | ||||
|  | ||||
| use crate::{ | ||||
|     core::{ | ||||
|         bootstrap::authenticated_unix_socket::client_authenticate, | ||||
|         common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, | ||||
|     }, | ||||
|     server::{config::read_config_form_path, server_loop::handle_requests_for_single_session}, | ||||
| }; | ||||
|  | ||||
| pub mod authenticated_unix_socket; | ||||
|  | ||||
| // TODO: this function is security critical, it should be integration tested | ||||
| //       in isolation. | ||||
| /// Drop privileges to the real user and group of the process. | ||||
| /// If the process is not running with elevated privileges, this function | ||||
| /// is a no-op. | ||||
| pub fn drop_privs() -> anyhow::Result<()> { | ||||
|     log::debug!("Dropping privileges"); | ||||
|     let real_uid = nix::unistd::getuid(); | ||||
|     let real_gid = nix::unistd::getgid(); | ||||
|  | ||||
|     nix::unistd::setuid(real_uid).context("Failed to drop privileges")?; | ||||
|     nix::unistd::setgid(real_gid).context("Failed to drop privileges")?; | ||||
|  | ||||
|     debug_assert_eq!(nix::unistd::getuid(), real_uid); | ||||
|     debug_assert_eq!(nix::unistd::getgid(), real_gid); | ||||
|  | ||||
|     log::debug!("Privileges dropped successfully"); | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| /// This function is used to bootstrap the connection to the server. | ||||
| /// This can happen in two ways: | ||||
| /// 1. If a socket path is provided, or exists in the default location, | ||||
| ///    the function will connect to the socket and authenticate with the | ||||
| ///    server to ensure that the server knows the uid of the client. | ||||
| /// 2. If a config path is provided, or exists in the default location, | ||||
| ///    and the config is readable, the function will assume it is either | ||||
| ///    setuid or setgid, and will fork a child process to run the server | ||||
| ///    with the provided config. The server will exit silently by itself | ||||
| ///    when it is done, and this function will only return for the client | ||||
| ///    with the socket for the server. | ||||
| /// If neither of these options are available, the function will fail. | ||||
| pub fn bootstrap_server_connection_and_drop_privileges( | ||||
|     server_socket_path: Option<PathBuf>, | ||||
|     config_path: Option<PathBuf>, | ||||
| ) -> anyhow::Result<StdUnixStream> { | ||||
|     if server_socket_path.is_some() && config_path.is_some() { | ||||
|         anyhow::bail!("Cannot provide both a socket path and a config path"); | ||||
|     } | ||||
|  | ||||
|     log::debug!("Starting the server connection bootstrap process"); | ||||
|  | ||||
|     let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?; | ||||
|  | ||||
|     drop_privs()?; | ||||
|  | ||||
|     let result: anyhow::Result<StdUnixStream> = if do_authenticate { | ||||
|         tokio::runtime::Builder::new_current_thread() | ||||
|             .enable_all() | ||||
|             .build() | ||||
|             .unwrap() | ||||
|             .block_on(async { | ||||
|                 let mut socket = TokioUnixStream::from_std(socket)?; | ||||
|                 client_authenticate(&mut socket, None).await?; | ||||
|                 Ok(socket.into_std()?) | ||||
|             }) | ||||
|     } else { | ||||
|         Ok(socket) | ||||
|     }; | ||||
|  | ||||
|     result | ||||
| } | ||||
|  | ||||
| /// Inner function for [`bootstrap_server_connection_and_drop_privileges`]. | ||||
| /// See that function for more information. | ||||
| fn bootstrap_server_connection( | ||||
|     socket_path: Option<PathBuf>, | ||||
|     config_path: Option<PathBuf>, | ||||
| ) -> anyhow::Result<(StdUnixStream, bool)> { | ||||
|     // TODO: ensure this is both readable and writable | ||||
|     if let Some(socket_path) = socket_path { | ||||
|         log::debug!("Connecting to socket at {:?}", socket_path); | ||||
|         return match StdUnixStream::connect(socket_path) { | ||||
|             Ok(socket) => Ok((socket, true)), | ||||
|             Err(e) => match e.kind() { | ||||
|                 std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), | ||||
|                 std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), | ||||
|                 _ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)), | ||||
|             }, | ||||
|         }; | ||||
|     } | ||||
|     if let Some(config_path) = config_path { | ||||
|         // ensure config exists and is readable | ||||
|         if fs::metadata(&config_path).is_err() { | ||||
|             return Err(anyhow::anyhow!("Config file not found or not readable")); | ||||
|         } | ||||
|  | ||||
|         log::debug!("Starting server with config at {:?}", config_path); | ||||
|         return invoke_server_with_config(config_path).map(|socket| (socket, false)); | ||||
|     } | ||||
|  | ||||
|     if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { | ||||
|         return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) { | ||||
|             Ok(socket) => Ok((socket, true)), | ||||
|             Err(e) => match e.kind() { | ||||
|                 std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), | ||||
|                 std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), | ||||
|                 _ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)), | ||||
|             }, | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     let config_path = PathBuf::from(DEFAULT_CONFIG_PATH); | ||||
|     if fs::metadata(&config_path).is_ok() { | ||||
|         return invoke_server_with_config(config_path).map(|socket| (socket, false)); | ||||
|     } | ||||
|  | ||||
|     anyhow::bail!("No socket path or config path provided, and no default socket or config found"); | ||||
| } | ||||
|  | ||||
| // TODO: we should somehow ensure that the forked process is killed on completion, | ||||
| //       just in case the client does not behave properly. | ||||
| /// Fork a child process to run the server with the provided config. | ||||
| /// The server will exit silently by itself when it is done, and this function | ||||
| /// will only return for the client with the socket for the server. | ||||
| fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStream> { | ||||
|     let (server_socket, client_socket) = StdUnixStream::pair()?; | ||||
|     let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?; | ||||
|  | ||||
|     match (unsafe { nix::unistd::fork() }).context("Failed to fork")? { | ||||
|         nix::unistd::ForkResult::Parent { child } => { | ||||
|             log::debug!("Forked child process with PID {}", child); | ||||
|             Ok(client_socket) | ||||
|         } | ||||
|         nix::unistd::ForkResult::Child => { | ||||
|             log::debug!("Running server in child process"); | ||||
|  | ||||
|             match run_forked_server(config_path, server_socket, unix_user) { | ||||
|                 Err(e) => Err(e), | ||||
|                 Ok(_) => unreachable!(), | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Run the server in the forked child process. | ||||
| /// This function will not return, but will exit the process with a success code. | ||||
| fn run_forked_server( | ||||
|     config_path: PathBuf, | ||||
|     server_socket: StdUnixStream, | ||||
|     unix_user: UnixUser, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let config = read_config_form_path(Some(config_path))?; | ||||
|  | ||||
|     let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread() | ||||
|         .enable_all() | ||||
|         .build() | ||||
|         .unwrap() | ||||
|         .block_on(async { | ||||
|             let socket = TokioUnixStream::from_std(server_socket)?; | ||||
|             handle_requests_for_single_session(socket, &unix_user, &config).await?; | ||||
|             Ok(()) | ||||
|         }); | ||||
|  | ||||
|     result?; | ||||
|  | ||||
|     unsafe { | ||||
|         exit(EXIT_SUCCESS); | ||||
|     } | ||||
| } | ||||
| @@ -30,10 +30,13 @@ | ||||
| //! Also note that it is essential that the client does not send any sensitive information
 | ||||
| //! over it's authentication socket, since it is readable by any user on the system.
 | ||||
| 
 | ||||
| // TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream
 | ||||
| 
 | ||||
| use std::os::unix::io::AsRawFd; | ||||
| use std::path::PathBuf; | ||||
| use std::path::{Path, PathBuf}; | ||||
| 
 | ||||
| use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination}; | ||||
| use derive_more::derive::{Display, Error}; | ||||
| use futures::{SinkExt, StreamExt}; | ||||
| use nix::{sys::stat, unistd::Uid}; | ||||
| use rand::distributions::Alphanumeric; | ||||
| @@ -52,7 +55,7 @@ pub enum ClientRequest { | ||||
|     Cancel, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)] | ||||
| pub enum ServerResponse { | ||||
|     Authenticated, | ||||
|     ChallengeDidNotMatch, | ||||
| @@ -61,7 +64,7 @@ pub enum ServerResponse { | ||||
| 
 | ||||
| // TODO: wrap more data into the errors
 | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] | ||||
| #[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)] | ||||
| pub enum ServerError { | ||||
|     InvalidRequest, | ||||
|     UnableToReadPermissionsFromAuthSocket, | ||||
| @@ -72,7 +75,7 @@ pub enum ServerError { | ||||
|     InvalidChallenge, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq)] | ||||
| #[derive(Debug, PartialEq, Display, Error)] | ||||
| pub enum ClientError { | ||||
|     UnableToConnectToServer, | ||||
|     UnableToOpenAuthSocket, | ||||
| @@ -80,13 +83,12 @@ pub enum ClientError { | ||||
|     AuthSocketClosedEarly, | ||||
|     UnableToCloseAuthSocket, | ||||
|     AuthenticationError, | ||||
|     InvalidServerResponse(ServerResponse), | ||||
|     UnableToParseServerResponse, | ||||
|     NoServerResponse, | ||||
|     ServerError(ServerError), | ||||
| } | ||||
| 
 | ||||
| async fn create_auth_socket(socket_addr: &str) -> Result<UnixListener, ClientError> { | ||||
| async fn create_auth_socket(socket_addr: &PathBuf) -> Result<UnixListener, ClientError> { | ||||
|     let auth_socket = | ||||
|         UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?; | ||||
| 
 | ||||
| @@ -109,11 +111,13 @@ type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDest | ||||
| 
 | ||||
| // TODO: add timeout
 | ||||
| 
 | ||||
| // TODO: respect $XDG_RUNTIME_DIR and $TMPDIR
 | ||||
| 
 | ||||
| const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock"; | ||||
| 
 | ||||
| pub async fn client_authenticate( | ||||
|     normal_socket: &mut UnixStream, | ||||
|     #[cfg(not(test))] auth_socket_dir: Option<PathBuf>, | ||||
|     #[cfg(test)] auth_socket_file: Option<PathBuf>, | ||||
|     auth_socket_dir: Option<PathBuf>, | ||||
| ) -> Result<(), ClientError> { | ||||
|     let random_prefix: String = rand::thread_rng() | ||||
|         .sample_iter(&Alphanumeric) | ||||
| @@ -123,32 +127,16 @@ pub async fn client_authenticate( | ||||
| 
 | ||||
|     let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME); | ||||
| 
 | ||||
|     #[cfg(not(test))] | ||||
|     let auth_socket_address = match auth_socket_dir { | ||||
|         Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(), | ||||
|         None => std::env::temp_dir() | ||||
|             .join(socket_name) | ||||
|             .to_str() | ||||
|             .unwrap() | ||||
|             .to_string(), | ||||
|     }; | ||||
| 
 | ||||
|     #[cfg(test)] | ||||
|     let auth_socket_address = match auth_socket_file { | ||||
|         Some(file) => file.to_str().unwrap().to_string(), | ||||
|         None => std::env::temp_dir() | ||||
|             .join(socket_name) | ||||
|             .to_str() | ||||
|             .unwrap() | ||||
|             .to_string(), | ||||
|     }; | ||||
|     let auth_socket_address = auth_socket_dir | ||||
|         .unwrap_or(std::env::temp_dir()) | ||||
|         .join(socket_name); | ||||
| 
 | ||||
|     client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await | ||||
| } | ||||
| 
 | ||||
| async fn client_authenticate_with_auth_socket_address( | ||||
|     normal_socket: &mut UnixStream, | ||||
|     auth_socket_address: &str, | ||||
|     auth_socket_address: &PathBuf, | ||||
| ) -> Result<(), ClientError> { | ||||
|     let auth_socket = create_auth_socket(auth_socket_address).await?; | ||||
| 
 | ||||
| @@ -164,7 +152,7 @@ async fn client_authenticate_with_auth_socket_address( | ||||
| async fn client_authenticate_with_auth_socket( | ||||
|     normal_socket: &mut UnixStream, | ||||
|     auth_socket: UnixListener, | ||||
|     auth_socket_address: &str, | ||||
|     auth_socket_address: &Path, | ||||
| ) -> Result<(), ClientError> { | ||||
|     let challenge = rand::random::<u64>(); | ||||
|     let uid = nix::unistd::getuid(); | ||||
| @@ -199,7 +187,10 @@ async fn client_authenticate_with_auth_socket( | ||||
|     let client_hello = ClientRequest::Initialize { | ||||
|         uid: uid.into(), | ||||
|         challenge, | ||||
|         auth_socket: auth_socket_address.to_string(), | ||||
|         auth_socket: auth_socket_address | ||||
|             .to_str() | ||||
|             .ok_or(ClientError::UnableToConfigureAuthSocket)? | ||||
|             .to_owned(), | ||||
|     }; | ||||
| 
 | ||||
|     normal_socket | ||||
| @@ -239,9 +230,13 @@ macro_rules! report_server_error_and_return { | ||||
|     }}; | ||||
| } | ||||
| 
 | ||||
| async fn server_authenticate( | ||||
| pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result<Uid, ServerError> { | ||||
|     _server_authenticate(normal_socket, None).await | ||||
| } | ||||
| 
 | ||||
| pub async fn _server_authenticate( | ||||
|     normal_socket: &mut UnixStream, | ||||
|     #[cfg(test)] unix_user_uid: Option<u32>, | ||||
|     unix_user_uid: Option<u32>, | ||||
| ) -> Result<Uid, ServerError> { | ||||
|     let mut normal_socket: ServerToClientStream = | ||||
|         AsyncBincodeStream::from(normal_socket).for_async(); | ||||
| @@ -256,22 +251,15 @@ async fn server_authenticate( | ||||
|         _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest), | ||||
|     }; | ||||
| 
 | ||||
|     #[cfg(test)] | ||||
|     let auth_socket_uid = match unix_user_uid { | ||||
|         Some(uid) => uid, | ||||
|         None => report_server_error_and_return!( | ||||
|             normal_socket, | ||||
|             ServerError::UnableToReadPermissionsFromAuthSocket | ||||
|         ), | ||||
|     }; | ||||
| 
 | ||||
|     #[cfg(not(test))] | ||||
|     let auth_socket_uid = match stat::stat(auth_socket.as_str()) { | ||||
|         Ok(stat) => stat.st_uid, | ||||
|         Err(_err) => report_server_error_and_return!( | ||||
|             normal_socket, | ||||
|             ServerError::UnableToReadPermissionsFromAuthSocket | ||||
|         ), | ||||
|         None => match stat::stat(auth_socket.as_str()) { | ||||
|             Ok(stat) => stat.st_uid, | ||||
|             Err(_err) => report_server_error_and_return!( | ||||
|                 normal_socket, | ||||
|                 ServerError::UnableToReadPermissionsFromAuthSocket | ||||
|             ), | ||||
|         }, | ||||
|     }; | ||||
| 
 | ||||
|     if uid != auth_socket_uid { | ||||
| @@ -324,10 +312,7 @@ mod test { | ||||
|         let client_handle = | ||||
|             tokio::spawn(async move { client_authenticate(&mut client, None).await }); | ||||
| 
 | ||||
|         let server_handle = tokio::spawn(async move { | ||||
|             let uid = nix::unistd::getuid().into(); | ||||
|             server_authenticate(&mut server, Some(uid)).await | ||||
|         }); | ||||
|         let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); | ||||
| 
 | ||||
|         client_handle.await.unwrap().unwrap(); | ||||
|         server_handle.await.unwrap().unwrap(); | ||||
| @@ -340,15 +325,12 @@ mod test { | ||||
|         let client_handle = tokio::spawn(async move { | ||||
|             client_authenticate_with_auth_socket_address( | ||||
|                 &mut client, | ||||
|                 "/tmp/test_auth_socket_does_not_exist.sock", | ||||
|                 &PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"), | ||||
|             ) | ||||
|             .await | ||||
|         }); | ||||
| 
 | ||||
|         let server_handle = tokio::spawn(async move { | ||||
|             let uid = nix::unistd::getuid().into(); | ||||
|             server_authenticate(&mut server, Some(uid)).await | ||||
|         }); | ||||
|         let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); | ||||
| 
 | ||||
|         client_handle.await.unwrap().unwrap(); | ||||
|         server_handle.await.unwrap().unwrap(); | ||||
| @@ -365,7 +347,7 @@ mod test { | ||||
| 
 | ||||
|         let server_handle = tokio::spawn(async move { | ||||
|             let uid: u32 = nix::unistd::getuid().into(); | ||||
|             let err = server_authenticate(&mut server, Some(uid + 1)).await; | ||||
|             let err = _server_authenticate(&mut server, Some(uid + 1)).await; | ||||
|             assert_eq!(err, Err(ServerError::UidMismatch)); | ||||
|         }); | ||||
| 
 | ||||
| @@ -379,13 +361,19 @@ mod test { | ||||
| 
 | ||||
|         let socket_path = std::env::temp_dir().join("socket_to_snoop.sock"); | ||||
|         let socket_path_clone = socket_path.clone(); | ||||
|         let client_handle = | ||||
|             tokio::spawn( | ||||
|                 async move { client_authenticate(&mut client, Some(socket_path_clone)).await }, | ||||
|             ); | ||||
|         let client_handle = tokio::spawn(async move { | ||||
|             client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await | ||||
|         }); | ||||
| 
 | ||||
|         while !socket_path.exists() { | ||||
|             sleep(std::time::Duration::from_millis(10)).await; | ||||
|         for i in 0..100 { | ||||
|             if socket_path.exists() { | ||||
|                 break; | ||||
|             } | ||||
|             sleep(Duration::from_millis(10)).await; | ||||
| 
 | ||||
|             if i == 99 { | ||||
|                 panic!("Socket not created after 1 second, assuming test failure"); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); | ||||
| @@ -409,10 +397,7 @@ mod test { | ||||
| 
 | ||||
|         sleep(Duration::from_millis(10)).await; | ||||
| 
 | ||||
|         let server_handle = tokio::spawn(async move { | ||||
|             let uid: u32 = nix::unistd::getuid().into(); | ||||
|             server_authenticate(&mut server, Some(uid)).await | ||||
|         }); | ||||
|         let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); | ||||
| 
 | ||||
|         client_handle.await.unwrap().unwrap(); | ||||
|         server_handle.await.unwrap().unwrap(); | ||||
| @@ -1,56 +1,32 @@ | ||||
| use anyhow::Context; | ||||
| use indoc::indoc; | ||||
| use itertools::Itertools; | ||||
| use nix::unistd::{getuid, Group, User}; | ||||
| use sqlx::{Connection, MySqlConnection}; | ||||
| use nix::unistd::{Group as LibcGroup, User as LibcUser}; | ||||
|  | ||||
| #[cfg(not(target_os = "macos"))] | ||||
| use std::ffi::CString; | ||||
|  | ||||
| /// Report the result status of a command. | ||||
| /// This is used to display a status message to the user. | ||||
| pub enum CommandStatus { | ||||
|     /// The command was successful, | ||||
|     /// and made modification to the database. | ||||
|     SuccessfullyModified, | ||||
| pub const DEFAULT_CONFIG_PATH: &str = "/etc/mysqladm/config.toml"; | ||||
| pub const DEFAULT_SOCKET_PATH: &str = "/run/mysqladm/mysqladm.sock"; | ||||
|  | ||||
|     /// The command was mostly successful, | ||||
|     /// and modifications have been made to the database. | ||||
|     /// However, some of the requested modifications failed. | ||||
|     PartiallySuccessfullyModified, | ||||
|  | ||||
|     /// The command was successful, | ||||
|     /// but no modifications were needed. | ||||
|     NoModificationsNeeded, | ||||
|  | ||||
|     /// The command was successful, | ||||
|     /// and made no modification to the database. | ||||
|     NoModificationsIntended, | ||||
|  | ||||
|     /// The command was cancelled, either through a dialog or a signal. | ||||
|     /// No modifications have been made to the database. | ||||
|     Cancelled, | ||||
| pub struct UnixUser { | ||||
|     pub username: String, | ||||
|     pub groups: Vec<String>, | ||||
| } | ||||
|  | ||||
| pub fn get_current_unix_user() -> anyhow::Result<User> { | ||||
|     User::from_uid(getuid()) | ||||
|         .context("Failed to look up your UNIX username") | ||||
|         .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))) | ||||
| } | ||||
| // TODO: these functions are somewhat critical, and should have integration tests | ||||
|  | ||||
| #[cfg(target_os = "macos")] | ||||
| pub fn get_unix_groups(_user: &User) -> anyhow::Result<Vec<Group>> { | ||||
| fn get_unix_groups(_user: &LibcUser) -> anyhow::Result<Vec<LibcGroup>> { | ||||
|     // Return an empty list on macOS since there is no `getgrouplist` function | ||||
|     Ok(vec![]) | ||||
| } | ||||
|  | ||||
| #[cfg(not(target_os = "macos"))] | ||||
| pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> { | ||||
| fn get_unix_groups(user: &LibcUser) -> anyhow::Result<Vec<LibcGroup>> { | ||||
|     let user_cstr = | ||||
|         CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?; | ||||
|     let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)? | ||||
|         .iter() | ||||
|         .filter_map(|gid| match Group::from_gid(*gid) { | ||||
|         .filter_map(|gid| match LibcGroup::from_gid(*gid) { | ||||
|             Ok(Some(group)) => Some(group), | ||||
|             Ok(None) => None, | ||||
|             Err(e) => { | ||||
| @@ -62,211 +38,32 @@ pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> { | ||||
|                 None | ||||
|             } | ||||
|         }) | ||||
|         .collect::<Vec<Group>>(); | ||||
|         .collect::<Vec<LibcGroup>>(); | ||||
|  | ||||
|     Ok(groups) | ||||
| } | ||||
|  | ||||
| /// This function creates a regex that matches items (users, databases) | ||||
| /// that belong to the user or any of the user's groups. | ||||
| pub fn create_user_group_matching_regex(user: &User) -> String { | ||||
|     let groups = get_unix_groups(user).unwrap_or_default(); | ||||
| impl UnixUser { | ||||
|     pub fn from_uid(uid: u32) -> anyhow::Result<Self> { | ||||
|         let libc_uid = nix::unistd::Uid::from_raw(uid); | ||||
|         let libc_user = LibcUser::from_uid(libc_uid) | ||||
|             .context("Failed to look up your UNIX username")? | ||||
|             .ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))?; | ||||
|  | ||||
|     if groups.is_empty() { | ||||
|         format!("{}(_.+)?", user.name) | ||||
|     } else { | ||||
|         format!( | ||||
|             "({}|{})(_.+)?", | ||||
|             user.name, | ||||
|             groups | ||||
|                 .iter() | ||||
|                 .map(|g| g.name.as_str()) | ||||
|                 .collect::<Vec<_>>() | ||||
|                 .join("|") | ||||
|         ) | ||||
|     } | ||||
| } | ||||
|         let groups = get_unix_groups(&libc_user)?; | ||||
|  | ||||
| /// This enum is used to differentiate between database and user operations. | ||||
| /// Their output are very similar, but there are slight differences in the words used. | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||||
| pub enum DbOrUser { | ||||
|     Database, | ||||
|     User, | ||||
| } | ||||
|  | ||||
| impl DbOrUser { | ||||
|     pub fn lowercased(&self) -> String { | ||||
|         match self { | ||||
|             DbOrUser::Database => "database".to_string(), | ||||
|             DbOrUser::User => "user".to_string(), | ||||
|         } | ||||
|         Ok(UnixUser { | ||||
|             username: libc_user.name, | ||||
|             groups: groups.iter().map(|g| g.name.clone()).collect(), | ||||
|         }) | ||||
|     } | ||||
|  | ||||
|     pub fn capitalized(&self) -> String { | ||||
|         match self { | ||||
|             DbOrUser::Database => "Database".to_string(), | ||||
|             DbOrUser::User => "User".to_string(), | ||||
|         } | ||||
|     pub fn from_enviroment() -> anyhow::Result<Self> { | ||||
|         let libc_uid = nix::unistd::getuid(); | ||||
|         UnixUser::from_uid(libc_uid.as_raw()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| pub enum NameValidationResult { | ||||
|     Valid, | ||||
|     EmptyString, | ||||
|     InvalidCharacters, | ||||
|     TooLong, | ||||
| } | ||||
|  | ||||
| pub fn validate_name(name: &str) -> NameValidationResult { | ||||
|     if name.is_empty() { | ||||
|         NameValidationResult::EmptyString | ||||
|     } else if name.len() > 64 { | ||||
|         NameValidationResult::TooLong | ||||
|     } else if !name | ||||
|         .chars() | ||||
|         .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') | ||||
|     { | ||||
|         NameValidationResult::InvalidCharacters | ||||
|     } else { | ||||
|         NameValidationResult::Valid | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> { | ||||
|     match validate_name(name) { | ||||
|         NameValidationResult::Valid => Ok(()), | ||||
|         NameValidationResult::EmptyString => { | ||||
|             anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized()) | ||||
|         } | ||||
|         NameValidationResult::TooLong => anyhow::bail!( | ||||
|             "{} is too long. Maximum length is 64 characters.", | ||||
|             db_or_user.capitalized() | ||||
|         ), | ||||
|         NameValidationResult::InvalidCharacters => anyhow::bail!( | ||||
|             indoc! {r#" | ||||
|               Invalid characters in {} name: '{}' | ||||
|  | ||||
|               Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. | ||||
|             "#}, | ||||
|             db_or_user.lowercased(), | ||||
|             name | ||||
|         ), | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, PartialEq, Eq)] | ||||
| pub enum OwnerValidationResult { | ||||
|     // The name is valid and matches one of the given prefixes | ||||
|     Match, | ||||
|  | ||||
|     // The name is valid, but none of the given prefixes matched the name | ||||
|     NoMatch, | ||||
|  | ||||
|     // The name is empty, which is invalid | ||||
|     StringEmpty, | ||||
|  | ||||
|     // The name is in the format "_<postfix>", which is invalid | ||||
|     MissingPrefix, | ||||
|  | ||||
|     // The name is in the format "<prefix>_", which is invalid | ||||
|     MissingPostfix, | ||||
| } | ||||
|  | ||||
| /// Core logic for validating the ownership of a database name. | ||||
| /// This function checks if the given name matches any of the given prefixes. | ||||
| /// These prefixes will in most cases be the user's unix username and any | ||||
| /// unix groups the user is a member of. | ||||
| pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult { | ||||
|     if name.is_empty() { | ||||
|         return OwnerValidationResult::StringEmpty; | ||||
|     } | ||||
|  | ||||
|     if name.starts_with('_') { | ||||
|         return OwnerValidationResult::MissingPrefix; | ||||
|     } | ||||
|  | ||||
|     let (prefix, _) = match name.split_once('_') { | ||||
|         Some(pair) => pair, | ||||
|         None => return OwnerValidationResult::MissingPostfix, | ||||
|     }; | ||||
|  | ||||
|     if prefixes.iter().any(|g| g == prefix) { | ||||
|         OwnerValidationResult::Match | ||||
|     } else { | ||||
|         OwnerValidationResult::NoMatch | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Validate the ownership of a database name or database user name. | ||||
| /// This function takes the name of a database or user and a unix user, | ||||
| /// for which it fetches the user's groups. It then checks if the name | ||||
| /// is prefixed with the user's username or any of the user's groups. | ||||
| pub fn validate_ownership_or_error<'a>( | ||||
|     name: &'a str, | ||||
|     user: &User, | ||||
|     db_or_user: DbOrUser, | ||||
| ) -> anyhow::Result<&'a str> { | ||||
|     let user_groups = get_unix_groups(user)?; | ||||
|     let prefixes = std::iter::once(user.name.clone()) | ||||
|         .chain(user_groups.iter().map(|g| g.name.clone())) | ||||
|         .collect::<Vec<String>>(); | ||||
|  | ||||
|     match validate_ownership_by_prefixes(name, &prefixes) { | ||||
|         OwnerValidationResult::Match => Ok(name), | ||||
|         OwnerValidationResult::NoMatch => { | ||||
|             anyhow::bail!( | ||||
|                 indoc! {r#" | ||||
|                   Invalid {} name prefix: '{}' does not match your username or any of your groups. | ||||
|                   Are you sure you are allowed to create {} names with this prefix? | ||||
|  | ||||
|                   Allowed prefixes: | ||||
|                     - {} | ||||
|                   {} | ||||
|                 "#}, | ||||
|                 db_or_user.lowercased(), | ||||
|                 name, | ||||
|                 db_or_user.lowercased(), | ||||
|                 user.name, | ||||
|                 user_groups | ||||
|                     .iter() | ||||
|                     .filter(|g| g.name != user.name) | ||||
|                     .map(|g| format!("  - {}", g.name)) | ||||
|                     .sorted() | ||||
|                     .join("\n"), | ||||
|             ); | ||||
|         } | ||||
|         _ => anyhow::bail!( | ||||
|             "'{}' is not a valid {} name.", | ||||
|             name, | ||||
|             db_or_user.lowercased() | ||||
|         ), | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Gracefully close a MySQL connection. | ||||
| pub async fn close_database_connection(connection: MySqlConnection) { | ||||
|     if let Err(e) = connection | ||||
|         .close() | ||||
|         .await | ||||
|         .context("Failed to close connection properly") | ||||
|     { | ||||
|         eprintln!("{}", e); | ||||
|         eprintln!("Ignoring..."); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| pub fn quote_literal(s: &str) -> String { | ||||
|     format!("'{}'", s.replace('\'', r"\'")) | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| pub fn quote_identifier(s: &str) -> String { | ||||
|     format!("`{}`", s.replace('`', r"\`")) | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| pub(crate) fn yn(b: bool) -> &'static str { | ||||
|     if b { | ||||
| @@ -303,94 +100,4 @@ mod test { | ||||
|         assert_eq!(rev_yn("n"), Some(false)); | ||||
|         assert_eq!(rev_yn("X"), None); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_quote_literal() { | ||||
|         let payload = "' OR 1=1 --"; | ||||
|         assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_quote_identifier() { | ||||
|         let payload = "` OR 1=1 --"; | ||||
|         assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_validate_name() { | ||||
|         assert_eq!(validate_name(""), NameValidationResult::EmptyString); | ||||
|         assert_eq!( | ||||
|             validate_name("abcdefghijklmnopqrstuvwxyz"), | ||||
|             NameValidationResult::Valid | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), | ||||
|             NameValidationResult::Valid | ||||
|         ); | ||||
|         assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid); | ||||
|  | ||||
|         for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() { | ||||
|             assert_eq!( | ||||
|                 validate_name(&c.to_string()), | ||||
|                 NameValidationResult::InvalidCharacters | ||||
|             ); | ||||
|         } | ||||
|  | ||||
|         assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_name(&"a".repeat(65)), | ||||
|             NameValidationResult::TooLong | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_validate_owner_by_prefixes() { | ||||
|         let prefixes = vec!["user".to_string(), "group".to_string()]; | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("", &prefixes), | ||||
|             OwnerValidationResult::StringEmpty | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("user", &prefixes), | ||||
|             OwnerValidationResult::MissingPostfix | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("something", &prefixes), | ||||
|             OwnerValidationResult::MissingPostfix | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("user-testdb", &prefixes), | ||||
|             OwnerValidationResult::MissingPostfix | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("_testdb", &prefixes), | ||||
|             OwnerValidationResult::MissingPrefix | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("user_testdb", &prefixes), | ||||
|             OwnerValidationResult::Match | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("group_testdb", &prefixes), | ||||
|             OwnerValidationResult::Match | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("group_test_db", &prefixes), | ||||
|             OwnerValidationResult::Match | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("group_test-db", &prefixes), | ||||
|             OwnerValidationResult::Match | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("nonexistent_testdb", &prefixes), | ||||
|             OwnerValidationResult::NoMatch | ||||
|         ); | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,120 +0,0 @@ | ||||
| use anyhow::Context; | ||||
| use indoc::formatdoc; | ||||
| use itertools::Itertools; | ||||
| use nix::unistd::User; | ||||
| use sqlx::{prelude::*, MySqlConnection}; | ||||
|  | ||||
| use crate::core::{ | ||||
|     common::{ | ||||
|         create_user_group_matching_regex, get_current_unix_user, quote_identifier, | ||||
|         validate_name_or_error, validate_ownership_or_error, DbOrUser, | ||||
|     }, | ||||
|     database_privilege_operations::DATABASE_PRIVILEGE_FIELDS, | ||||
| }; | ||||
|  | ||||
| pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { | ||||
|     let user = get_current_unix_user()?; | ||||
|     validate_database_name(name, &user)?; | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_owner_of_database_name` | ||||
|     sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name))) | ||||
|         .execute(connection) | ||||
|         .await | ||||
|         .map_err(|e| { | ||||
|             if e.to_string().contains("database exists") { | ||||
|                 anyhow::anyhow!("Database '{}' already exists", name) | ||||
|             } else { | ||||
|                 e.into() | ||||
|             } | ||||
|         })?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { | ||||
|     let user = get_current_unix_user()?; | ||||
|     validate_database_name(name, &user)?; | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_owner_of_database_name` | ||||
|     sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name))) | ||||
|         .execute(connection) | ||||
|         .await | ||||
|         .map_err(|e| { | ||||
|             if e.to_string().contains("doesn't exist") { | ||||
|                 anyhow::anyhow!("Database '{}' does not exist", name) | ||||
|             } else { | ||||
|                 e.into() | ||||
|             } | ||||
|         })?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result<Vec<String>> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     let databases: Vec<String> = sqlx::query( | ||||
|         r#" | ||||
|           SELECT `SCHEMA_NAME` AS `database` | ||||
|           FROM `information_schema`.`SCHEMATA` | ||||
|           WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') | ||||
|             AND `SCHEMA_NAME` REGEXP ? | ||||
|         "#, | ||||
|     ) | ||||
|     .bind(create_user_group_matching_regex(&unix_user)) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
|     .and_then(|row| { | ||||
|         row.into_iter() | ||||
|             .map(|row| row.try_get::<String, _>("database")) | ||||
|             .collect::<Result<_, _>>() | ||||
|     }) | ||||
|     .context(format!( | ||||
|         "Failed to get databases for user '{}'", | ||||
|         unix_user.name | ||||
|     ))?; | ||||
|  | ||||
|     Ok(databases) | ||||
| } | ||||
|  | ||||
| pub async fn get_databases_where_user_has_privileges( | ||||
|     username: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<Vec<String>> { | ||||
|     let result = sqlx::query( | ||||
|         formatdoc!( | ||||
|             r#" | ||||
|             SELECT `db` AS `database` | ||||
|             FROM `db` | ||||
|             WHERE `user` = ? | ||||
|               AND ({}) | ||||
|         "#, | ||||
|             DATABASE_PRIVILEGE_FIELDS | ||||
|                 .iter() | ||||
|                 .map(|field| format!("`{}` = 'Y'", field)) | ||||
|                 .join(" OR "), | ||||
|         ) | ||||
|         .as_str(), | ||||
|     ) | ||||
|     .bind(username) | ||||
|     .fetch_all(connection) | ||||
|     .await? | ||||
|     .into_iter() | ||||
|     .map(|databases| databases.try_get::<String, _>("database").unwrap()) | ||||
|     .collect(); | ||||
|  | ||||
|     Ok(result) | ||||
| } | ||||
|  | ||||
| /// NOTE: It is very critical that this function validates the database name | ||||
| ///       properly. MySQL does not seem to allow for prepared statements, binding | ||||
| ///       the database name as a parameter to the query. This means that we have | ||||
| ///       to validate the database name ourselves to prevent SQL injection. | ||||
| pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> { | ||||
|     validate_name_or_error(name, DbOrUser::Database) | ||||
|         .context(format!("Invalid database name: '{}'", name))?; | ||||
|     validate_ownership_or_error(name, user, DbOrUser::Database) | ||||
|         .context(format!("Invalid database name: '{}'", name))?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
| @@ -1,52 +1,16 @@ | ||||
| //! Database privilege operations
 | ||||
| //!
 | ||||
| //! This module contains functions for querying, modifying,
 | ||||
| //! displaying and comparing database privileges.
 | ||||
| //!
 | ||||
| //! A lot of the complexity comes from two core components:
 | ||||
| //!
 | ||||
| //! - The privilege editor that needs to be able to print
 | ||||
| //!   an editable table of privileges and reparse the content
 | ||||
| //!   after the user has made manual changes.
 | ||||
| //!
 | ||||
| //! - The comparison functionality that tells the user what
 | ||||
| //!   changes will be made when applying a set of changes
 | ||||
| //!   to the list of database privileges.
 | ||||
| 
 | ||||
| use std::collections::{BTreeSet, HashMap}; | ||||
| 
 | ||||
| use anyhow::{anyhow, Context}; | ||||
| use indoc::indoc; | ||||
| use itertools::Itertools; | ||||
| use prettytable::Table; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; | ||||
| 
 | ||||
| use crate::core::{ | ||||
|     common::{ | ||||
|         create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, | ||||
|     }, | ||||
|     database_operations::validate_database_name, | ||||
| use std::{ | ||||
|     cmp::max, | ||||
|     collections::{BTreeSet, HashMap}, | ||||
| }; | ||||
| 
 | ||||
| /// This is the list of fields that are used to fetch the db + user + privileges
 | ||||
| /// from the `db` table in the database. If you need to add or remove privilege
 | ||||
| /// fields, this is a good place to start.
 | ||||
| pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ | ||||
|     "db", | ||||
|     "user", | ||||
|     "select_priv", | ||||
|     "insert_priv", | ||||
|     "update_priv", | ||||
|     "delete_priv", | ||||
|     "create_priv", | ||||
|     "drop_priv", | ||||
|     "alter_priv", | ||||
|     "index_priv", | ||||
|     "create_tmp_table_priv", | ||||
|     "lock_tables_priv", | ||||
|     "references_priv", | ||||
| ]; | ||||
| use super::common::{rev_yn, yn}; | ||||
| use crate::server::sql::database_privilege_operations::{ | ||||
|     DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS, | ||||
| }; | ||||
| 
 | ||||
| pub fn db_priv_field_human_readable_name(name: &str) -> String { | ||||
|     match name { | ||||
| @@ -67,162 +31,24 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// This struct represents the set of privileges for a single user on a single database.
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] | ||||
| pub struct DatabasePrivilegeRow { | ||||
|     pub db: String, | ||||
|     pub user: String, | ||||
|     pub select_priv: bool, | ||||
|     pub insert_priv: bool, | ||||
|     pub update_priv: bool, | ||||
|     pub delete_priv: bool, | ||||
|     pub create_priv: bool, | ||||
|     pub drop_priv: bool, | ||||
|     pub alter_priv: bool, | ||||
|     pub index_priv: bool, | ||||
|     pub create_tmp_table_priv: bool, | ||||
|     pub lock_tables_priv: bool, | ||||
|     pub references_priv: bool, | ||||
| } | ||||
| pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { | ||||
|     debug_assert!(row1.db == row2.db && row1.user == row2.user); | ||||
| 
 | ||||
| impl DatabasePrivilegeRow { | ||||
|     pub fn empty(db: &str, user: &str) -> Self { | ||||
|         Self { | ||||
|             db: db.to_owned(), | ||||
|             user: user.to_owned(), | ||||
|             select_priv: false, | ||||
|             insert_priv: false, | ||||
|             update_priv: false, | ||||
|             delete_priv: false, | ||||
|             create_priv: false, | ||||
|             drop_priv: false, | ||||
|             alter_priv: false, | ||||
|             index_priv: false, | ||||
|             create_tmp_table_priv: false, | ||||
|             lock_tables_priv: false, | ||||
|             references_priv: false, | ||||
|         } | ||||
|     DatabasePrivilegeRowDiff { | ||||
|         db: row1.db.clone(), | ||||
|         user: row1.user.clone(), | ||||
|         diff: DATABASE_PRIVILEGE_FIELDS | ||||
|             .into_iter() | ||||
|             .skip(2) | ||||
|             .filter_map(|field| { | ||||
|                 DatabasePrivilegeChange::new( | ||||
|                     row1.get_privilege_by_name(field), | ||||
|                     row2.get_privilege_by_name(field), | ||||
|                     field, | ||||
|                 ) | ||||
|             }) | ||||
|             .collect(), | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_privilege_by_name(&self, name: &str) -> bool { | ||||
|         match name { | ||||
|             "select_priv" => self.select_priv, | ||||
|             "insert_priv" => self.insert_priv, | ||||
|             "update_priv" => self.update_priv, | ||||
|             "delete_priv" => self.delete_priv, | ||||
|             "create_priv" => self.create_priv, | ||||
|             "drop_priv" => self.drop_priv, | ||||
|             "alter_priv" => self.alter_priv, | ||||
|             "index_priv" => self.index_priv, | ||||
|             "create_tmp_table_priv" => self.create_tmp_table_priv, | ||||
|             "lock_tables_priv" => self.lock_tables_priv, | ||||
|             "references_priv" => self.references_priv, | ||||
|             _ => false, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { | ||||
|         debug_assert!(self.db == other.db && self.user == other.user); | ||||
| 
 | ||||
|         DatabasePrivilegeRowDiff { | ||||
|             db: self.db.clone(), | ||||
|             user: self.user.clone(), | ||||
|             diff: DATABASE_PRIVILEGE_FIELDS | ||||
|                 .into_iter() | ||||
|                 .skip(2) | ||||
|                 .filter_map(|field| { | ||||
|                     DatabasePrivilegeChange::new( | ||||
|                         self.get_privilege_by_name(field), | ||||
|                         other.get_privilege_by_name(field), | ||||
|                         field, | ||||
|                     ) | ||||
|                 }) | ||||
|                 .collect(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[inline] | ||||
| fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> { | ||||
|     let field = DATABASE_PRIVILEGE_FIELDS[position]; | ||||
|     let value = row.try_get(position)?; | ||||
|     match rev_yn(value) { | ||||
|         Some(val) => Ok(val), | ||||
|         _ => { | ||||
|             log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value); | ||||
|             Ok(false) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { | ||||
|     fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> { | ||||
|         Ok(Self { | ||||
|             db: row.try_get("db")?, | ||||
|             user: row.try_get("user")?, | ||||
|             select_priv: get_mysql_row_priv_field(row, 2)?, | ||||
|             insert_priv: get_mysql_row_priv_field(row, 3)?, | ||||
|             update_priv: get_mysql_row_priv_field(row, 4)?, | ||||
|             delete_priv: get_mysql_row_priv_field(row, 5)?, | ||||
|             create_priv: get_mysql_row_priv_field(row, 6)?, | ||||
|             drop_priv: get_mysql_row_priv_field(row, 7)?, | ||||
|             alter_priv: get_mysql_row_priv_field(row, 8)?, | ||||
|             index_priv: get_mysql_row_priv_field(row, 9)?, | ||||
|             create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?, | ||||
|             lock_tables_priv: get_mysql_row_priv_field(row, 11)?, | ||||
|             references_priv: get_mysql_row_priv_field(row, 12)?, | ||||
|         }) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Get all users + privileges for a single database.
 | ||||
| pub async fn get_database_privileges( | ||||
|     database_name: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|     validate_database_name(database_name, &unix_user)?; | ||||
| 
 | ||||
|     let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( | ||||
|         "SELECT {} FROM `db` WHERE `db` = ?", | ||||
|         DATABASE_PRIVILEGE_FIELDS | ||||
|             .iter() | ||||
|             .map(|field| quote_identifier(field)) | ||||
|             .join(","), | ||||
|     )) | ||||
|     .bind(database_name) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
|     .context("Failed to show database")?; | ||||
| 
 | ||||
|     Ok(result) | ||||
| } | ||||
| 
 | ||||
| /// Get all database + user + privileges pairs that are owned by the current user.
 | ||||
| pub async fn get_all_database_privileges( | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
| 
 | ||||
|     let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( | ||||
|         indoc! {r#" | ||||
|           SELECT {} FROM `db` WHERE `db` IN | ||||
|           (SELECT DISTINCT `SCHEMA_NAME` AS `database` | ||||
|             FROM `information_schema`.`SCHEMATA` | ||||
|             WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') | ||||
|               AND `SCHEMA_NAME` REGEXP ?) | ||||
|         "#},
 | ||||
|         DATABASE_PRIVILEGE_FIELDS | ||||
|             .iter() | ||||
|             .map(|field| format!("`{field}`")) | ||||
|             .join(","), | ||||
|     )) | ||||
|     .bind(create_user_group_matching_regex(&unix_user)) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
|     .context("Failed to show databases")?; | ||||
| 
 | ||||
|     Ok(result) | ||||
| } | ||||
| 
 | ||||
| /*************************/ | ||||
| @@ -340,17 +166,23 @@ pub fn generate_editor_content_from_privilege_data( | ||||
|     //       editor will be the example user and example db name.
 | ||||
|     //       Hence, it's put as the fallback value, despite not really
 | ||||
|     //       being a "fallback" in the normal sense.
 | ||||
|     let longest_username = privilege_data | ||||
|         .iter() | ||||
|         .map(|p| p.user.len()) | ||||
|         .max() | ||||
|         .unwrap_or(example_user.len()); | ||||
|     let longest_username = max( | ||||
|         privilege_data | ||||
|             .iter() | ||||
|             .map(|p| p.user.len()) | ||||
|             .max() | ||||
|             .unwrap_or(example_user.len()), | ||||
|         "User".len(), | ||||
|     ); | ||||
| 
 | ||||
|     let longest_database_name = privilege_data | ||||
|         .iter() | ||||
|         .map(|p| p.db.len()) | ||||
|         .max() | ||||
|         .unwrap_or(example_db.len()); | ||||
|     let longest_database_name = max( | ||||
|         privilege_data | ||||
|             .iter() | ||||
|             .map(|p| p.db.len()) | ||||
|             .max() | ||||
|             .unwrap_or(example_db.len()), | ||||
|         "Database".len(), | ||||
|     ); | ||||
| 
 | ||||
|     let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS | ||||
|         .into_iter() | ||||
| @@ -578,7 +410,7 @@ pub fn parse_privilege_data_from_editor_content( | ||||
| /// instances of privilege sets for a single user on a single database.
 | ||||
| ///
 | ||||
| /// The `User` and `Database` are the same for both instances.
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] | ||||
| pub struct DatabasePrivilegeRowDiff { | ||||
|     pub db: String, | ||||
|     pub user: String, | ||||
| @@ -586,7 +418,7 @@ pub struct DatabasePrivilegeRowDiff { | ||||
| } | ||||
| 
 | ||||
| /// This enum represents a change for a single privilege.
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] | ||||
| pub enum DatabasePrivilegeChange { | ||||
|     YesToNo(String), | ||||
|     NoToYes(String), | ||||
| @@ -603,13 +435,31 @@ impl DatabasePrivilegeChange { | ||||
| } | ||||
| 
 | ||||
| /// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] | ||||
| pub enum DatabasePrivilegesDiff { | ||||
|     New(DatabasePrivilegeRow), | ||||
|     Modified(DatabasePrivilegeRowDiff), | ||||
|     Deleted(DatabasePrivilegeRow), | ||||
| } | ||||
| 
 | ||||
| impl DatabasePrivilegesDiff { | ||||
|     pub fn get_database_name(&self) -> &str { | ||||
|         match self { | ||||
|             DatabasePrivilegesDiff::New(p) => &p.db, | ||||
|             DatabasePrivilegesDiff::Modified(p) => &p.db, | ||||
|             DatabasePrivilegesDiff::Deleted(p) => &p.db, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_user_name(&self) -> &str { | ||||
|         match self { | ||||
|             DatabasePrivilegesDiff::New(p) => &p.user, | ||||
|             DatabasePrivilegesDiff::Modified(p) => &p.user, | ||||
|             DatabasePrivilegesDiff::Deleted(p) => &p.user, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// This function calculates the differences between two sets of database privileges.
 | ||||
| /// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
 | ||||
| /// apply a set of privilege modifications to the database.
 | ||||
| @@ -633,7 +483,7 @@ pub fn diff_privileges( | ||||
| 
 | ||||
|     for p in to { | ||||
|         if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { | ||||
|             let diff = old_p.diff(p); | ||||
|             let diff = diff(old_p, p); | ||||
|             if !diff.diff.is_empty() { | ||||
|                 result.insert(DatabasePrivilegesDiff::Modified(diff)); | ||||
|             } | ||||
| @@ -651,72 +501,6 @@ pub fn diff_privileges( | ||||
|     result | ||||
| } | ||||
| 
 | ||||
| /// Uses the result of [`diff_privileges`] to modify privileges in the database.
 | ||||
| pub async fn apply_privilege_diffs( | ||||
|     diffs: BTreeSet<DatabasePrivilegesDiff>, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     for diff in diffs { | ||||
|         match diff { | ||||
|             DatabasePrivilegesDiff::New(p) => { | ||||
|                 let tables = DATABASE_PRIVILEGE_FIELDS | ||||
|                     .iter() | ||||
|                     .map(|field| format!("`{field}`")) | ||||
|                     .join(","); | ||||
| 
 | ||||
|                 let question_marks = std::iter::repeat("?") | ||||
|                     .take(DATABASE_PRIVILEGE_FIELDS.len()) | ||||
|                     .join(","); | ||||
| 
 | ||||
|                 sqlx::query( | ||||
|                     format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), | ||||
|                 ) | ||||
|                 .bind(p.db) | ||||
|                 .bind(p.user) | ||||
|                 .bind(yn(p.select_priv)) | ||||
|                 .bind(yn(p.insert_priv)) | ||||
|                 .bind(yn(p.update_priv)) | ||||
|                 .bind(yn(p.delete_priv)) | ||||
|                 .bind(yn(p.create_priv)) | ||||
|                 .bind(yn(p.drop_priv)) | ||||
|                 .bind(yn(p.alter_priv)) | ||||
|                 .bind(yn(p.index_priv)) | ||||
|                 .bind(yn(p.create_tmp_table_priv)) | ||||
|                 .bind(yn(p.lock_tables_priv)) | ||||
|                 .bind(yn(p.references_priv)) | ||||
|                 .execute(&mut *connection) | ||||
|                 .await?; | ||||
|             } | ||||
|             DatabasePrivilegesDiff::Modified(p) => { | ||||
|                 let tables = p | ||||
|                     .diff | ||||
|                     .iter() | ||||
|                     .map(|diff| match diff { | ||||
|                         DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name), | ||||
|                         DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name), | ||||
|                     }) | ||||
|                     .join(","); | ||||
| 
 | ||||
|                 sqlx::query( | ||||
|                     format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), | ||||
|                 ) | ||||
|                 .bind(p.db) | ||||
|                 .bind(p.user) | ||||
|                 .execute(&mut *connection) | ||||
|                 .await?; | ||||
|             } | ||||
|             DatabasePrivilegesDiff::Deleted(p) => { | ||||
|                 sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") | ||||
|                     .bind(p.db) | ||||
|                     .bind(p.user) | ||||
|                     .execute(&mut *connection) | ||||
|                     .await?; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     Ok(()) | ||||
| } | ||||
| 
 | ||||
| fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { | ||||
|     diff.diff | ||||
|         .iter() | ||||
| @@ -731,6 +515,20 @@ fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { | ||||
|         .join("\n") | ||||
| } | ||||
| 
 | ||||
| fn display_new_privileges_list(row: &DatabasePrivilegeRow) -> String { | ||||
|     DATABASE_PRIVILEGE_FIELDS | ||||
|         .into_iter() | ||||
|         .skip(2) | ||||
|         .map(|field| { | ||||
|             if row.get_privilege_by_name(field) { | ||||
|                 format!("{}: Y", db_priv_field_human_readable_name(field)) | ||||
|             } else { | ||||
|                 format!("{}: N", db_priv_field_human_readable_name(field)) | ||||
|             } | ||||
|         }) | ||||
|         .join("\n") | ||||
| } | ||||
| 
 | ||||
| /// Displays the difference between two sets of database privileges.
 | ||||
| pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> String { | ||||
|     let mut table = Table::new(); | ||||
| @@ -741,24 +539,14 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri | ||||
|                 table.add_row(row![ | ||||
|                     p.db, | ||||
|                     p.user, | ||||
|                     "(New user)\n".to_string() | ||||
|                         + &display_privilege_cell( | ||||
|                             &DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p) | ||||
|                         ) | ||||
|                     "(New user)\n".to_string() + &display_new_privileges_list(p) | ||||
|                 ]); | ||||
|             } | ||||
|             DatabasePrivilegesDiff::Modified(p) => { | ||||
|                 table.add_row(row![p.db, p.user, display_privilege_cell(p),]); | ||||
|             } | ||||
|             DatabasePrivilegesDiff::Deleted(p) => { | ||||
|                 table.add_row(row![ | ||||
|                     p.db, | ||||
|                     p.user, | ||||
|                     "(All privileges removed)\n".to_string() | ||||
|                         + &display_privilege_cell( | ||||
|                             &p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user)) | ||||
|                         ) | ||||
|                 ]); | ||||
|                 table.add_row(row![p.db, p.user, "Removed".to_string()]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
							
								
								
									
										5
									
								
								src/core/protocol.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								src/core/protocol.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| pub mod request_response; | ||||
| pub mod server_responses; | ||||
|  | ||||
| pub use request_response::*; | ||||
| pub use server_responses::*; | ||||
							
								
								
									
										79
									
								
								src/core/protocol/request_response.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								src/core/protocol/request_response.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| use std::collections::BTreeSet; | ||||
|  | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use tokio::net::UnixStream; | ||||
| use tokio_serde::{formats::Bincode, Framed as SerdeFramed}; | ||||
| use tokio_util::codec::{Framed, LengthDelimitedCodec}; | ||||
|  | ||||
| use crate::core::{database_privileges::DatabasePrivilegesDiff, protocol::*}; | ||||
|  | ||||
| pub type ServerToClientMessageStream = SerdeFramed< | ||||
|     Framed<UnixStream, LengthDelimitedCodec>, | ||||
|     Request, | ||||
|     Response, | ||||
|     Bincode<Request, Response>, | ||||
| >; | ||||
|  | ||||
| pub type ClientToServerMessageStream = SerdeFramed< | ||||
|     Framed<UnixStream, LengthDelimitedCodec>, | ||||
|     Response, | ||||
|     Request, | ||||
|     Bincode<Response, Request>, | ||||
| >; | ||||
|  | ||||
| pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream { | ||||
|     let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); | ||||
|     tokio_serde::Framed::new(length_delimited, Bincode::default()) | ||||
| } | ||||
|  | ||||
| pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream { | ||||
|     let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); | ||||
|     tokio_serde::Framed::new(length_delimited, Bincode::default()) | ||||
| } | ||||
|  | ||||
| #[non_exhaustive] | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum Request { | ||||
|     CreateDatabases(Vec<String>), | ||||
|     DropDatabases(Vec<String>), | ||||
|     ListDatabases, | ||||
|     ListPrivileges(Option<Vec<String>>), | ||||
|     ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>), | ||||
|  | ||||
|     CreateUsers(Vec<String>), | ||||
|     DropUsers(Vec<String>), | ||||
|     PasswdUser(String, String), | ||||
|     ListUsers(Option<Vec<String>>), | ||||
|     LockUsers(Vec<String>), | ||||
|     UnlockUsers(Vec<String>), | ||||
|  | ||||
|     // Commit, | ||||
|     Exit, | ||||
| } | ||||
|  | ||||
| // TODO: include a generic "message" that will display a message to the user? | ||||
|  | ||||
| #[non_exhaustive] | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum Response { | ||||
|     // Specific data for specific commands | ||||
|     CreateDatabases(CreateDatabasesOutput), | ||||
|     DropDatabases(DropDatabasesOutput), | ||||
|     ListAllDatabases(ListAllDatabasesOutput), | ||||
|     ListPrivileges(GetDatabasesPrivilegeData), | ||||
|     ListAllPrivileges(GetAllDatabasesPrivilegeData), | ||||
|     ModifyPrivileges(ModifyDatabasePrivilegesOutput), | ||||
|  | ||||
|     CreateUsers(CreateUsersOutput), | ||||
|     DropUsers(DropUsersOutput), | ||||
|     PasswdUser(SetPasswordOutput), | ||||
|     ListUsers(ListUsersOutput), | ||||
|     ListAllUsers(ListAllUsersOutput), | ||||
|     LockUsers(LockUsersOutput), | ||||
|     UnlockUsers(UnlockUsersOutput), | ||||
|  | ||||
|     // Generic responses | ||||
|     OperationAborted, | ||||
|     Error(String), | ||||
|     Exit, | ||||
| } | ||||
							
								
								
									
										611
									
								
								src/core/protocol/server_responses.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										611
									
								
								src/core/protocol/server_responses.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,611 @@ | ||||
| use std::collections::BTreeMap; | ||||
|  | ||||
| use indoc::indoc; | ||||
| use itertools::Itertools; | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use crate::{ | ||||
|     core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff}, | ||||
|     server::sql::{ | ||||
|         database_privilege_operations::DatabasePrivilegeRow, user_operations::DatabaseUser, | ||||
|     }, | ||||
| }; | ||||
|  | ||||
| /// This enum is used to differentiate between database and user operations. | ||||
| /// Their output are very similar, but there are slight differences in the words used. | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy)] | ||||
| pub enum DbOrUser { | ||||
|     Database, | ||||
|     User, | ||||
| } | ||||
|  | ||||
| impl DbOrUser { | ||||
|     pub fn lowercased(&self) -> String { | ||||
|         match self { | ||||
|             DbOrUser::Database => "database".to_string(), | ||||
|             DbOrUser::User => "user".to_string(), | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     pub fn capitalized(&self) -> String { | ||||
|         match self { | ||||
|             DbOrUser::Database => "Database".to_string(), | ||||
|             DbOrUser::User => "User".to_string(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] | ||||
| pub enum NameValidationError { | ||||
|     EmptyString, | ||||
|     InvalidCharacters, | ||||
|     TooLong, | ||||
| } | ||||
|  | ||||
| impl NameValidationError { | ||||
|     pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String { | ||||
|         match self { | ||||
|             NameValidationError::EmptyString => { | ||||
|                 format!("{} name cannot be empty.", db_or_user.capitalized()).to_owned() | ||||
|             } | ||||
|             NameValidationError::TooLong => format!( | ||||
|                 "{} is too long. Maximum length is 64 characters.", | ||||
|                 db_or_user.capitalized() | ||||
|             ) | ||||
|             .to_owned(), | ||||
|             NameValidationError::InvalidCharacters => format!( | ||||
|                 indoc! {r#" | ||||
|                   Invalid characters in {} name: '{}' | ||||
|  | ||||
|                   Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. | ||||
|                 "#}, | ||||
|                 db_or_user.lowercased(), | ||||
|                 name | ||||
|             ) | ||||
|             .to_owned(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl OwnerValidationError { | ||||
|     pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String { | ||||
|         let user = UnixUser::from_enviroment(); | ||||
|  | ||||
|         match self { | ||||
|             OwnerValidationError::NoMatch => format!( | ||||
|                 indoc! {r#" | ||||
|                   Invalid {} name prefix: '{}' does not match your username or any of your groups. | ||||
|                   Are you sure you are allowed to create {} names with this prefix? | ||||
|  | ||||
|                   Allowed prefixes: | ||||
|                     - {} | ||||
|                   {} | ||||
|                 "#}, | ||||
|                 db_or_user.lowercased(), | ||||
|                 name, | ||||
|                 db_or_user.lowercased(), | ||||
|                 user.as_ref() | ||||
|                     .map(|u| u.username.clone()) | ||||
|                     .unwrap_or("???".to_string()), | ||||
|                 user.map(|u| u.groups) | ||||
|                     .unwrap_or_default() | ||||
|                     .iter() | ||||
|                     .map(|g| format!("  - {}", g)) | ||||
|                     .sorted() | ||||
|                     .join("\n"), | ||||
|             ) | ||||
|             .to_owned(), | ||||
|  | ||||
|             _ => format!( | ||||
|                 "'{}' is not a valid {} name.", | ||||
|                 name, | ||||
|                 db_or_user.lowercased() | ||||
|             ) | ||||
|             .to_string(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] | ||||
| pub enum OwnerValidationError { | ||||
|     // The name is valid, but none of the given prefixes matched the name | ||||
|     NoMatch, | ||||
|  | ||||
|     // The name is empty, which is invalid | ||||
|     StringEmpty, | ||||
|  | ||||
|     // The name is in the format "_<postfix>", which is invalid | ||||
|     MissingPrefix, | ||||
|  | ||||
|     // The name is in the format "<prefix>_", which is invalid | ||||
|     MissingPostfix, | ||||
| } | ||||
|  | ||||
| pub type CreateDatabasesOutput = BTreeMap<String, Result<(), CreateDatabaseError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum CreateDatabaseError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     DatabaseAlreadyExists, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_create_databases_output_status(output: &CreateDatabasesOutput) { | ||||
|     for (database_name, result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!("Database '{}' created successfully.", database_name); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(database_name)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl CreateDatabaseError { | ||||
|     pub fn to_error_message(&self, database_name: &str) -> String { | ||||
|         match self { | ||||
|             CreateDatabaseError::SanitizationError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             CreateDatabaseError::OwnershipError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             CreateDatabaseError::DatabaseAlreadyExists => { | ||||
|                 format!("Database {} already exists.", database_name) | ||||
|             } | ||||
|             CreateDatabaseError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type DropDatabasesOutput = BTreeMap<String, Result<(), DropDatabaseError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum DropDatabaseError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     DatabaseDoesNotExist, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) { | ||||
|     for (database_name, result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!("Database '{}' dropped successfully.", database_name); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(database_name)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl DropDatabaseError { | ||||
|     pub fn to_error_message(&self, database_name: &str) -> String { | ||||
|         match self { | ||||
|             DropDatabaseError::SanitizationError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             DropDatabaseError::OwnershipError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             DropDatabaseError::DatabaseDoesNotExist => { | ||||
|                 format!("Database {} does not exist.", database_name) | ||||
|             } | ||||
|             DropDatabaseError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type ListAllDatabasesOutput = Result<Vec<String>, ListDatabasesError>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum ListDatabasesError { | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| impl ListDatabasesError { | ||||
|     pub fn to_error_message(&self) -> String { | ||||
|         match self { | ||||
|             ListDatabasesError::MySqlError(err) => format!("MySQL error: {}", err), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| // TODO: merge all rows into a single collection. | ||||
| //       they already contain which database they belong to. | ||||
| //       no need to index by database name. | ||||
|  | ||||
| pub type GetDatabasesPrivilegeData = | ||||
|     BTreeMap<String, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum GetDatabasesPrivilegeDataError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     DatabaseDoesNotExist, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| impl GetDatabasesPrivilegeDataError { | ||||
|     pub fn to_error_message(&self, database_name: &str) -> String { | ||||
|         match self { | ||||
|             GetDatabasesPrivilegeDataError::SanitizationError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             GetDatabasesPrivilegeDataError::OwnershipError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => { | ||||
|                 format!("Database '{}' does not exist.", database_name) | ||||
|             } | ||||
|             GetDatabasesPrivilegeDataError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type GetAllDatabasesPrivilegeData = | ||||
|     Result<Vec<DatabasePrivilegeRow>, GetAllDatabasesPrivilegeDataError>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum GetAllDatabasesPrivilegeDataError { | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| impl GetAllDatabasesPrivilegeDataError { | ||||
|     pub fn to_error_message(&self) -> String { | ||||
|         match self { | ||||
|             GetAllDatabasesPrivilegeDataError::MySqlError(err) => format!("MySQL error: {}", err), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type ModifyDatabasePrivilegesOutput = | ||||
|     BTreeMap<(String, String), Result<(), ModifyDatabasePrivilegesError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum ModifyDatabasePrivilegesError { | ||||
|     DatabaseSanitizationError(NameValidationError), | ||||
|     DatabaseOwnershipError(OwnerValidationError), | ||||
|     UserSanitizationError(NameValidationError), | ||||
|     UserOwnershipError(OwnerValidationError), | ||||
|     DatabaseDoesNotExist, | ||||
|     DiffDoesNotApply(DiffDoesNotApplyError), | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| #[allow(clippy::enum_variant_names)] | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum DiffDoesNotApplyError { | ||||
|     RowAlreadyExists(String, String), | ||||
|     RowDoesNotExist(String, String), | ||||
|     RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow), | ||||
| } | ||||
|  | ||||
| pub fn print_modify_database_privileges_output_status(output: &ModifyDatabasePrivilegesOutput) { | ||||
|     for ((database_name, username), result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!( | ||||
|                     "Privileges for user '{}' on database '{}' modified successfully.", | ||||
|                     username, database_name | ||||
|                 ); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(database_name, username)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl ModifyDatabasePrivilegesError { | ||||
|     pub fn to_error_message(&self, database_name: &str, username: &str) -> String { | ||||
|         match self { | ||||
|             ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             ModifyDatabasePrivilegesError::DatabaseOwnershipError(err) => { | ||||
|                 err.to_error_message(database_name, DbOrUser::Database) | ||||
|             } | ||||
|             ModifyDatabasePrivilegesError::UserSanitizationError(err) => { | ||||
|                 err.to_error_message(username, DbOrUser::User) | ||||
|             } | ||||
|             ModifyDatabasePrivilegesError::UserOwnershipError(err) => { | ||||
|                 err.to_error_message(username, DbOrUser::User) | ||||
|             } | ||||
|             ModifyDatabasePrivilegesError::DatabaseDoesNotExist => { | ||||
|                 format!("Database '{}' does not exist.", database_name) | ||||
|             } | ||||
|             ModifyDatabasePrivilegesError::DiffDoesNotApply(diff) => { | ||||
|                 format!( | ||||
|                     "Could not apply privilege change:\n{}", | ||||
|                     diff.to_error_message() | ||||
|                 ) | ||||
|             } | ||||
|             ModifyDatabasePrivilegesError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl DiffDoesNotApplyError { | ||||
|     pub fn to_error_message(&self) -> String { | ||||
|         match self { | ||||
|             DiffDoesNotApplyError::RowAlreadyExists(database_name, username) => { | ||||
|                 format!( | ||||
|                     "Privileges for user '{}' on database '{}' already exist.", | ||||
|                     username, database_name | ||||
|                 ) | ||||
|             } | ||||
|             DiffDoesNotApplyError::RowDoesNotExist(database_name, username) => { | ||||
|                 format!( | ||||
|                     "Privileges for user '{}' on database '{}' do not exist.", | ||||
|                     username, database_name | ||||
|                 ) | ||||
|             } | ||||
|             DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(diff, row) => { | ||||
|                 format!( | ||||
|                     "Could not apply privilege change {:?} to row {:?}", | ||||
|                     diff, row | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type CreateUsersOutput = BTreeMap<String, Result<(), CreateUserError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum CreateUserError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     UserAlreadyExists, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_create_users_output_status(output: &CreateUsersOutput) { | ||||
|     for (username, result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!("User '{}' created successfully.", username); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(username)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl CreateUserError { | ||||
|     pub fn to_error_message(&self, username: &str) -> String { | ||||
|         match self { | ||||
|             CreateUserError::SanitizationError(err) => { | ||||
|                 err.to_error_message(username, DbOrUser::User) | ||||
|             } | ||||
|             CreateUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             CreateUserError::UserAlreadyExists => { | ||||
|                 format!("User '{}' already exists.", username) | ||||
|             } | ||||
|             CreateUserError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type DropUsersOutput = BTreeMap<String, Result<(), DropUserError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum DropUserError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     UserDoesNotExist, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_drop_users_output_status(output: &DropUsersOutput) { | ||||
|     for (username, result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!("User '{}' dropped successfully.", username); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(username)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl DropUserError { | ||||
|     pub fn to_error_message(&self, username: &str) -> String { | ||||
|         match self { | ||||
|             DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             DropUserError::UserDoesNotExist => { | ||||
|                 format!("User '{}' does not exist.", username) | ||||
|             } | ||||
|             DropUserError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type SetPasswordOutput = Result<(), SetPasswordError>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum SetPasswordError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     UserDoesNotExist, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &str) { | ||||
|     match output { | ||||
|         Ok(()) => { | ||||
|             println!("Password for user '{}' set successfully.", username); | ||||
|         } | ||||
|         Err(err) => { | ||||
|             println!("{}", err.to_error_message(username)); | ||||
|             println!("Skipping..."); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl SetPasswordError { | ||||
|     pub fn to_error_message(&self, username: &str) -> String { | ||||
|         match self { | ||||
|             SetPasswordError::SanitizationError(err) => { | ||||
|                 err.to_error_message(username, DbOrUser::User) | ||||
|             } | ||||
|             SetPasswordError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             SetPasswordError::UserDoesNotExist => { | ||||
|                 format!("User '{}' does not exist.", username) | ||||
|             } | ||||
|             SetPasswordError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type LockUsersOutput = BTreeMap<String, Result<(), LockUserError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum LockUserError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     UserDoesNotExist, | ||||
|     UserIsAlreadyLocked, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_lock_users_output_status(output: &LockUsersOutput) { | ||||
|     for (username, result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!("User '{}' locked successfully.", username); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(username)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl LockUserError { | ||||
|     pub fn to_error_message(&self, username: &str) -> String { | ||||
|         match self { | ||||
|             LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             LockUserError::UserDoesNotExist => { | ||||
|                 format!("User '{}' does not exist.", username) | ||||
|             } | ||||
|             LockUserError::UserIsAlreadyLocked => { | ||||
|                 format!("User '{}' is already locked.", username) | ||||
|             } | ||||
|             LockUserError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type UnlockUsersOutput = BTreeMap<String, Result<(), UnlockUserError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum UnlockUserError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     UserDoesNotExist, | ||||
|     UserIsAlreadyUnlocked, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| pub fn print_unlock_users_output_status(output: &UnlockUsersOutput) { | ||||
|     for (username, result) in output { | ||||
|         match result { | ||||
|             Ok(()) => { | ||||
|                 println!("User '{}' unlocked successfully.", username); | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 println!("{}", err.to_error_message(username)); | ||||
|                 println!("Skipping..."); | ||||
|             } | ||||
|         } | ||||
|         println!(); | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl UnlockUserError { | ||||
|     pub fn to_error_message(&self, username: &str) -> String { | ||||
|         match self { | ||||
|             UnlockUserError::SanitizationError(err) => { | ||||
|                 err.to_error_message(username, DbOrUser::User) | ||||
|             } | ||||
|             UnlockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             UnlockUserError::UserDoesNotExist => { | ||||
|                 format!("User '{}' does not exist.", username) | ||||
|             } | ||||
|             UnlockUserError::UserIsAlreadyUnlocked => { | ||||
|                 format!("User '{}' is already unlocked.", username) | ||||
|             } | ||||
|             UnlockUserError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type ListUsersOutput = BTreeMap<String, Result<DatabaseUser, ListUsersError>>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum ListUsersError { | ||||
|     SanitizationError(NameValidationError), | ||||
|     OwnershipError(OwnerValidationError), | ||||
|     UserDoesNotExist, | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| impl ListUsersError { | ||||
|     pub fn to_error_message(&self, username: &str) -> String { | ||||
|         match self { | ||||
|             ListUsersError::SanitizationError(err) => { | ||||
|                 err.to_error_message(username, DbOrUser::User) | ||||
|             } | ||||
|             ListUsersError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User), | ||||
|             ListUsersError::UserDoesNotExist => { | ||||
|                 format!("User '{}' does not exist.", username) | ||||
|             } | ||||
|             ListUsersError::MySqlError(err) => { | ||||
|                 format!("MySQL error: {}", err) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub type ListAllUsersOutput = Result<Vec<DatabaseUser>, ListAllUsersError>; | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||||
| pub enum ListAllUsersError { | ||||
|     MySqlError(String), | ||||
| } | ||||
|  | ||||
| impl ListAllUsersError { | ||||
|     pub fn to_error_message(&self) -> String { | ||||
|         match self { | ||||
|             ListAllUsersError::MySqlError(err) => format!("MySQL error: {}", err), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -1,249 +0,0 @@ | ||||
| use anyhow::Context; | ||||
| use nix::unistd::User; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use sqlx::{prelude::*, MySqlConnection}; | ||||
|  | ||||
| use crate::core::common::{ | ||||
|     create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error, | ||||
|     validate_ownership_or_error, DbOrUser, | ||||
| }; | ||||
|  | ||||
| pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     let user_exists = sqlx::query( | ||||
|         r#" | ||||
|           SELECT EXISTS( | ||||
|             SELECT 1 | ||||
|             FROM `mysql`.`user` | ||||
|             WHERE `User` = ? | ||||
|           ) | ||||
|         "#, | ||||
|     ) | ||||
|     .bind(db_user) | ||||
|     .fetch_one(connection) | ||||
|     .await? | ||||
|     .get::<bool, _>(0); | ||||
|  | ||||
|     Ok(user_exists) | ||||
| } | ||||
|  | ||||
| pub async fn create_database_user( | ||||
|     db_user: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     if user_exists(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' already exists", db_user); | ||||
|     } | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` | ||||
|     sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str()) | ||||
|         .execute(connection) | ||||
|         .await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn delete_database_user( | ||||
|     db_user: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     if !user_exists(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' does not exist", db_user); | ||||
|     } | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` | ||||
|     sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str()) | ||||
|         .execute(connection) | ||||
|         .await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn set_password_for_database_user( | ||||
|     db_user: &str, | ||||
|     password: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let unix_user = crate::core::common::get_current_unix_user()?; | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     if !user_exists(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' does not exist", db_user); | ||||
|     } | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` | ||||
|     sqlx::query( | ||||
|         format!( | ||||
|             "ALTER USER {}@'%' IDENTIFIED BY {}", | ||||
|             quote_literal(db_user), | ||||
|             quote_literal(password).as_str() | ||||
|         ) | ||||
|         .as_str(), | ||||
|     ) | ||||
|     .execute(connection) | ||||
|     .await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     if !user_exists(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' does not exist", db_user); | ||||
|     } | ||||
|  | ||||
|     let is_locked = sqlx::query( | ||||
|         r#" | ||||
|           SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true' | ||||
|           FROM `mysql`.`global_priv` | ||||
|           WHERE `User` = ? | ||||
|           AND `Host` = '%' | ||||
|         "#, | ||||
|     ) | ||||
|     .bind(db_user) | ||||
|     .fetch_one(connection) | ||||
|     .await? | ||||
|     .get::<bool, _>(0); | ||||
|  | ||||
|     Ok(is_locked) | ||||
| } | ||||
|  | ||||
| pub async fn lock_database_user( | ||||
|     db_user: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     if !user_exists(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' does not exist", db_user); | ||||
|     } | ||||
|  | ||||
|     if user_is_locked(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' is already locked", db_user); | ||||
|     } | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` | ||||
|     sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str()) | ||||
|         .execute(connection) | ||||
|         .await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn unlock_database_user( | ||||
|     db_user: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let unix_user = get_current_unix_user()?; | ||||
|  | ||||
|     validate_user_name(db_user, &unix_user)?; | ||||
|  | ||||
|     if !user_exists(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' does not exist", db_user); | ||||
|     } | ||||
|  | ||||
|     if !user_is_locked(db_user, connection).await? { | ||||
|         anyhow::bail!("User '{}' is already unlocked", db_user); | ||||
|     } | ||||
|  | ||||
|     // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` | ||||
|     sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str()) | ||||
|         .execute(connection) | ||||
|         .await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| /// This struct contains information about a database user. | ||||
| /// This can be extended if we need more information in the future. | ||||
| #[derive(Debug, Clone, FromRow, Serialize, Deserialize)] | ||||
| pub struct DatabaseUser { | ||||
|     #[sqlx(rename = "User")] | ||||
|     pub user: String, | ||||
|  | ||||
|     #[allow(dead_code)] | ||||
|     #[serde(skip)] | ||||
|     #[sqlx(rename = "Host")] | ||||
|     pub host: String, | ||||
|  | ||||
|     #[sqlx(rename = "has_password")] | ||||
|     pub has_password: bool, | ||||
|  | ||||
|     #[sqlx(rename = "is_locked")] | ||||
|     pub is_locked: bool, | ||||
| } | ||||
|  | ||||
| const DB_USER_SELECT_STATEMENT: &str = r#" | ||||
| SELECT | ||||
|   `mysql`.`user`.`User`, | ||||
|   `mysql`.`user`.`Host`, | ||||
|   `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, | ||||
|   COALESCE( | ||||
|     JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), | ||||
|     'false' | ||||
|   ) != 'false' AS `is_locked` | ||||
| FROM `mysql`.`user` | ||||
| JOIN `mysql`.`global_priv` ON | ||||
|   `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` | ||||
|   AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` | ||||
| "#; | ||||
|  | ||||
| /// This function fetches all database users that have a prefix matching the | ||||
| /// unix username and group names of the given unix user. | ||||
| pub async fn get_all_database_users_for_unix_user( | ||||
|     unix_user: &User, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<Vec<DatabaseUser>> { | ||||
|     let users = sqlx::query_as::<_, DatabaseUser>( | ||||
|         &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), | ||||
|     ) | ||||
|     .bind(create_user_group_matching_regex(unix_user)) | ||||
|     .fetch_all(connection) | ||||
|     .await?; | ||||
|  | ||||
|     Ok(users) | ||||
| } | ||||
|  | ||||
| /// This function fetches a database user if it exists. | ||||
| pub async fn get_database_user_for_user( | ||||
|     username: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<Option<DatabaseUser>> { | ||||
|     let user = sqlx::query_as::<_, DatabaseUser>( | ||||
|         &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), | ||||
|     ) | ||||
|     .bind(username) | ||||
|     .fetch_optional(connection) | ||||
|     .await?; | ||||
|  | ||||
|     Ok(user) | ||||
| } | ||||
|  | ||||
| /// NOTE: It is very critical that this function validates the database name | ||||
| ///       properly. MySQL does not seem to allow for prepared statements, binding | ||||
| ///       the database name as a parameter to the query. This means that we have | ||||
| ///       to validate the database name ourselves to prevent SQL injection. | ||||
| pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> { | ||||
|     validate_name_or_error(name, DbOrUser::User) | ||||
|         .context(format!("Invalid username: '{}'", name))?; | ||||
|     validate_ownership_or_error(name, user, DbOrUser::User) | ||||
|         .context(format!("Invalid username: '{}'", name))?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
							
								
								
									
										147
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										147
									
								
								src/main.rs
									
									
									
									
									
								
							| @@ -1,42 +1,69 @@ | ||||
| #[macro_use] | ||||
| extern crate prettytable; | ||||
|  | ||||
| use core::common::CommandStatus; | ||||
| #[cfg(feature = "mysql-admutils-compatibility")] | ||||
| use clap::Parser; | ||||
|  | ||||
| use std::path::PathBuf; | ||||
|  | ||||
| use std::os::unix::net::UnixStream as StdUnixStream; | ||||
| use tokio::net::UnixStream as TokioUnixStream; | ||||
|  | ||||
| use crate::{ | ||||
|     core::{ | ||||
|         bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs}, | ||||
|         protocol::create_client_to_server_message_stream, | ||||
|     }, | ||||
|     server::command::ServerArgs, | ||||
| }; | ||||
|  | ||||
| #[cfg(feature = "mysql-admutils-compatibility")] | ||||
| use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; | ||||
|  | ||||
| use clap::Parser; | ||||
| mod server; | ||||
|  | ||||
| mod authenticated_unix_socket; | ||||
| mod cli; | ||||
| mod core; | ||||
|  | ||||
| #[cfg(feature = "tui")] | ||||
| mod tui; | ||||
|  | ||||
| #[derive(Parser)] | ||||
| #[derive(Parser, Debug)] | ||||
| struct Args { | ||||
|     #[command(subcommand)] | ||||
|     command: Command, | ||||
|  | ||||
|     #[command(flatten)] | ||||
|     config_overrides: core::config::GlobalConfigArgs, | ||||
|     /// Path to the socket of the server, if it already exists. | ||||
|     #[arg( | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true | ||||
|     )] | ||||
|     server_socket_path: Option<PathBuf>, | ||||
|  | ||||
|     /// Config file to use for the server. | ||||
|     #[arg( | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true | ||||
|     )] | ||||
|     config: Option<PathBuf>, | ||||
|  | ||||
|     #[cfg(feature = "tui")] | ||||
|     #[arg(short, long, alias = "tui", global = true)] | ||||
|     interactive: bool, | ||||
| } | ||||
|  | ||||
| /// Database administration tool for non-admin users to manage their own MySQL databases and users. | ||||
| /// | ||||
| /// This tool allows you to manage users and databases in MySQL. | ||||
| /// | ||||
| /// You are only allowed to manage databases and users that are prefixed with | ||||
| /// either your username, or a group that you are a member of. | ||||
| #[derive(Parser)] | ||||
| // Database administration tool for non-admin users to manage their own MySQL databases and users. | ||||
| // | ||||
| // This tool allows you to manage users and databases in MySQL. | ||||
| // | ||||
| // You are only allowed to manage databases and users that are prefixed with | ||||
| // either your username, or a group that you are a member of. | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| #[command(version, about, disable_help_subcommand = true)] | ||||
| enum Command { | ||||
|     #[command(flatten)] | ||||
| @@ -44,10 +71,18 @@ enum Command { | ||||
|  | ||||
|     #[command(flatten)] | ||||
|     User(cli::user_command::UserCommand), | ||||
|  | ||||
|     #[command(hide = true)] | ||||
|     Server(server::command::ServerArgs), | ||||
| } | ||||
|  | ||||
| #[tokio::main(flavor = "current_thread")] | ||||
| async fn main() -> anyhow::Result<()> { | ||||
| // TODO: tag all functions that are run with elevated privileges with | ||||
| //       comments emphasizing the need for caution. | ||||
|  | ||||
| fn main() -> anyhow::Result<()> { | ||||
|     // TODO: find out if there are any security risks of running | ||||
|     //       env_logger and clap with elevated privileges. | ||||
|  | ||||
|     env_logger::init(); | ||||
|  | ||||
|     #[cfg(feature = "mysql-admutils-compatibility")] | ||||
| @@ -59,42 +94,60 @@ async fn main() -> anyhow::Result<()> { | ||||
|         }); | ||||
|  | ||||
|         match argv0.as_deref() { | ||||
|             Some("mysql-dbadm") => return mysql_dbadm::main().await, | ||||
|             Some("mysql-useradm") => return mysql_useradm::main().await, | ||||
|             Some("mysql-dbadm") => return mysql_dbadm::main(), | ||||
|             Some("mysql-useradm") => return mysql_useradm::main(), | ||||
|             _ => { /* fall through */ } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     let args: Args = Args::parse(); | ||||
|     let config = core::config::get_config(args.config_overrides)?; | ||||
|     let connection = core::config::create_mysql_connection_from_config(config.mysql).await?; | ||||
|  | ||||
|     let result = match args.command { | ||||
|         Command::Db(command) => cli::database_command::handle_command(command, connection).await, | ||||
|         Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await, | ||||
|     }; | ||||
|  | ||||
|     match result { | ||||
|         Ok(CommandStatus::SuccessfullyModified) => { | ||||
|             println!("Modifications committed successfully"); | ||||
|             Ok(()) | ||||
|     match args.command { | ||||
|         Command::Server(ref command) => { | ||||
|             drop_privs()?; | ||||
|             tokio_start_server(args.server_socket_path, args.config, command.clone())?; | ||||
|             return Ok(()); | ||||
|         } | ||||
|         Ok(CommandStatus::PartiallySuccessfullyModified) => { | ||||
|             println!("Some modifications committed successfully"); | ||||
|             Ok(()) | ||||
|         } | ||||
|         Ok(CommandStatus::NoModificationsNeeded) => { | ||||
|             println!("No modifications made"); | ||||
|             Ok(()) | ||||
|         } | ||||
|         Ok(CommandStatus::NoModificationsIntended) => { | ||||
|             /* Don't report anything */ | ||||
|             Ok(()) | ||||
|         } | ||||
|         Ok(CommandStatus::Cancelled) => { | ||||
|             println!("Command cancelled successfully"); | ||||
|             Ok(()) | ||||
|         } | ||||
|         Err(e) => Err(e), | ||||
|         _ => { /* fall through */ } | ||||
|     } | ||||
|  | ||||
|     let server_connection = | ||||
|         bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?; | ||||
|  | ||||
|     tokio_run_command(args.command, server_connection)?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| fn tokio_start_server( | ||||
|     server_socket_path: Option<PathBuf>, | ||||
|     config_path: Option<PathBuf>, | ||||
|     args: ServerArgs, | ||||
| ) -> anyhow::Result<()> { | ||||
|     tokio::runtime::Builder::new_current_thread() | ||||
|         .enable_all() | ||||
|         .build() | ||||
|         .unwrap() | ||||
|         .block_on(async { | ||||
|             server::command::handle_command(server_socket_path, config_path, args).await | ||||
|         }) | ||||
| } | ||||
|  | ||||
| fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> { | ||||
|     tokio::runtime::Builder::new_current_thread() | ||||
|         .enable_all() | ||||
|         .build() | ||||
|         .unwrap() | ||||
|         .block_on(async { | ||||
|             let tokio_socket = TokioUnixStream::from_std(server_connection)?; | ||||
|             let message_stream = create_client_to_server_message_stream(tokio_socket); | ||||
|             match command { | ||||
|                 Command::User(user_args) => { | ||||
|                     cli::user_command::handle_command(user_args, message_stream).await | ||||
|                 } | ||||
|                 Command::Db(db_args) => { | ||||
|                     cli::database_command::handle_command(db_args, message_stream).await | ||||
|                 } | ||||
|                 Command::Server(_) => unreachable!(), | ||||
|             } | ||||
|         }) | ||||
| } | ||||
|   | ||||
							
								
								
									
										6
									
								
								src/server.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								src/server.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| pub mod command; | ||||
| mod common; | ||||
| pub mod config; | ||||
| pub mod input_sanitization; | ||||
| pub mod server_loop; | ||||
| pub mod sql; | ||||
							
								
								
									
										77
									
								
								src/server/command.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								src/server/command.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,77 @@ | ||||
| use std::os::fd::FromRawFd; | ||||
| use std::path::PathBuf; | ||||
|  | ||||
| use anyhow::Context; | ||||
| use clap::Parser; | ||||
|  | ||||
| use std::os::unix::net::UnixStream as StdUnixStream; | ||||
| use tokio::net::UnixStream as TokioUnixStream; | ||||
|  | ||||
| use crate::core::bootstrap::authenticated_unix_socket; | ||||
| use crate::core::common::UnixUser; | ||||
| use crate::server::config::read_config_from_path_with_arg_overrides; | ||||
| use crate::server::server_loop::listen_for_incoming_connections; | ||||
| use crate::server::{ | ||||
|     config::{ServerConfig, ServerConfigArgs}, | ||||
|     server_loop::handle_requests_for_single_session, | ||||
| }; | ||||
|  | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct ServerArgs { | ||||
|     #[command(subcommand)] | ||||
|     subcmd: ServerCommand, | ||||
|  | ||||
|     #[command(flatten)] | ||||
|     config_overrides: ServerConfigArgs, | ||||
| } | ||||
|  | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub enum ServerCommand { | ||||
|     #[command()] | ||||
|     Listen, | ||||
|  | ||||
|     #[command()] | ||||
|     SocketActivate, | ||||
| } | ||||
|  | ||||
| pub async fn handle_command( | ||||
|     socket_path: Option<PathBuf>, | ||||
|     config_path: Option<PathBuf>, | ||||
|     args: ServerArgs, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; | ||||
|  | ||||
|     // if let Err(e) = &result { | ||||
|     //     eprintln!("{}", e); | ||||
|     // } | ||||
|  | ||||
|     match args.subcmd { | ||||
|         ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, | ||||
|         ServerCommand::SocketActivate => socket_activate(config).await, | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { | ||||
|     // TODO: allow getting socket path from other socket activation sources | ||||
|     let mut conn = get_socket_from_systemd().await?; | ||||
|     let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?; | ||||
|     let unix_user = UnixUser::from_uid(uid.into())?; | ||||
|     handle_requests_for_single_session(conn, &unix_user, &config).await?; | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> { | ||||
|     let fd = std::env::var("LISTEN_FDS") | ||||
|         .context("LISTEN_FDS not set, not running under systemd?")? | ||||
|         .parse::<i32>() | ||||
|         .context("Failed to parse LISTEN_FDS")?; | ||||
|  | ||||
|     if fd != 1 { | ||||
|         return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd)); | ||||
|     } | ||||
|  | ||||
|     let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; | ||||
|     let socket = TokioUnixStream::from_std(std_unix_stream)?; | ||||
|     Ok(socket) | ||||
| } | ||||
							
								
								
									
										11
									
								
								src/server/common.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								src/server/common.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,11 @@ | ||||
| use crate::core::common::UnixUser; | ||||
|  | ||||
| /// This function creates a regex that matches items (users, databases) | ||||
| /// that belong to the user or any of the user's groups. | ||||
| pub fn create_user_group_matching_regex(user: &UnixUser) -> String { | ||||
|     if user.groups.is_empty() { | ||||
|         format!("{}(_.+)?", user.username) | ||||
|     } else { | ||||
|         format!("({}|{})(_.+)?", user.username, user.groups.join("|")) | ||||
|     } | ||||
| } | ||||
| @@ -5,11 +5,16 @@ use clap::Parser; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection}; | ||||
| 
 | ||||
| use crate::core::common::DEFAULT_CONFIG_PATH; | ||||
| 
 | ||||
| pub const DEFAULT_PORT: u16 = 3306; | ||||
| pub const DEFAULT_TIMEOUT: u64 = 2; | ||||
| 
 | ||||
| // NOTE: this might look empty now, and the extra wrapping for the mysql
 | ||||
| //       config seems unnecessary, but it will be useful later when we
 | ||||
| //       add more configuration options.
 | ||||
| #[derive(Debug, Clone, Deserialize, Serialize)] | ||||
| pub struct Config { | ||||
| pub struct ServerConfig { | ||||
|     pub mysql: MysqlConfig, | ||||
| } | ||||
| 
 | ||||
| @@ -23,58 +28,36 @@ pub struct MysqlConfig { | ||||
|     pub timeout: Option<u64>, | ||||
| } | ||||
| 
 | ||||
| const DEFAULT_PORT: u16 = 3306; | ||||
| const DEFAULT_TIMEOUT: u64 = 2; | ||||
| 
 | ||||
| #[derive(Parser)] | ||||
| pub struct GlobalConfigArgs { | ||||
|     /// Path to the configuration file.
 | ||||
|     #[arg(
 | ||||
|         short, | ||||
|         long, | ||||
|         value_name = "PATH", | ||||
|         global = true, | ||||
|         hide_short_help = true, | ||||
|         default_value = "/etc/mysqladm/config.toml" | ||||
|     )] | ||||
|     config_file: String, | ||||
| 
 | ||||
| #[derive(Parser, Debug, Clone)] | ||||
| pub struct ServerConfigArgs { | ||||
|     /// Hostname of the MySQL server.
 | ||||
|     #[arg(long, value_name = "HOST", global = true, hide_short_help = true)] | ||||
|     #[arg(long, value_name = "HOST", global = true)] | ||||
|     mysql_host: Option<String>, | ||||
| 
 | ||||
|     /// Port of the MySQL server.
 | ||||
|     #[arg(long, value_name = "PORT", global = true, hide_short_help = true)] | ||||
|     #[arg(long, value_name = "PORT", global = true)] | ||||
|     mysql_port: Option<u16>, | ||||
| 
 | ||||
|     /// Username to use for the MySQL connection.
 | ||||
|     #[arg(long, value_name = "USER", global = true, hide_short_help = true)] | ||||
|     #[arg(long, value_name = "USER", global = true)] | ||||
|     mysql_user: Option<String>, | ||||
| 
 | ||||
|     /// Path to a file containing the MySQL password.
 | ||||
|     #[arg(long, value_name = "PATH", global = true, hide_short_help = true)] | ||||
|     #[arg(long, value_name = "PATH", global = true)] | ||||
|     mysql_password_file: Option<String>, | ||||
| 
 | ||||
|     /// Seconds to wait for the MySQL connection to be established.
 | ||||
|     #[arg(long, value_name = "SECONDS", global = true, hide_short_help = true)] | ||||
|     #[arg(long, value_name = "SECONDS", global = true)] | ||||
|     mysql_connect_timeout: Option<u64>, | ||||
| } | ||||
| 
 | ||||
| /// Use the arguments and whichever configuration file which might or might not
 | ||||
| /// be found and default values to determine the configuration for the program.
 | ||||
| pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> { | ||||
|     let config_path = PathBuf::from(args.config_file); | ||||
| 
 | ||||
|     let config: Config = fs::read_to_string(&config_path) | ||||
|         .context(format!( | ||||
|             "Failed to read config file from {:?}", | ||||
|             &config_path | ||||
|         )) | ||||
|         .and_then(|c| toml::from_str(&c).context("Failed to parse config file")) | ||||
|         .context(format!( | ||||
|             "Failed to parse config file from {:?}", | ||||
|             &config_path | ||||
|         ))?; | ||||
| pub fn read_config_from_path_with_arg_overrides( | ||||
|     config_path: Option<PathBuf>, | ||||
|     args: ServerConfigArgs, | ||||
| ) -> anyhow::Result<ServerConfig> { | ||||
|     let config = read_config_form_path(config_path)?; | ||||
| 
 | ||||
|     let mysql = &config.mysql; | ||||
| 
 | ||||
| @@ -86,22 +69,35 @@ pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> { | ||||
|         mysql.password.to_owned() | ||||
|     }; | ||||
| 
 | ||||
|     let mysql_config = MysqlConfig { | ||||
|         host: args.mysql_host.unwrap_or(mysql.host.to_owned()), | ||||
|         port: args.mysql_port.or(mysql.port), | ||||
|         username: args.mysql_user.unwrap_or(mysql.username.to_owned()), | ||||
|         password, | ||||
|         timeout: args.mysql_connect_timeout.or(mysql.timeout), | ||||
|     }; | ||||
| 
 | ||||
|     Ok(Config { | ||||
|         mysql: mysql_config, | ||||
|     Ok(ServerConfig { | ||||
|         mysql: MysqlConfig { | ||||
|             host: args.mysql_host.unwrap_or(mysql.host.to_owned()), | ||||
|             port: args.mysql_port.or(mysql.port), | ||||
|             username: args.mysql_user.unwrap_or(mysql.username.to_owned()), | ||||
|             password, | ||||
|             timeout: args.mysql_connect_timeout.or(mysql.timeout), | ||||
|         }, | ||||
|     }) | ||||
| } | ||||
| 
 | ||||
| pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> { | ||||
|     let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); | ||||
| 
 | ||||
|     fs::read_to_string(&config_path) | ||||
|         .context(format!( | ||||
|             "Failed to read config file from {:?}", | ||||
|             &config_path | ||||
|         )) | ||||
|         .and_then(|c| toml::from_str(&c).context("Failed to parse config file")) | ||||
|         .context(format!( | ||||
|             "Failed to parse config file from {:?}", | ||||
|             &config_path | ||||
|         )) | ||||
| } | ||||
| 
 | ||||
| /// Use the provided configuration to establish a connection to a MySQL server.
 | ||||
| pub async fn create_mysql_connection_from_config( | ||||
|     config: MysqlConfig, | ||||
|     config: &MysqlConfig, | ||||
| ) -> anyhow::Result<MySqlConnection> { | ||||
|     match tokio::time::timeout( | ||||
|         Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)), | ||||
							
								
								
									
										158
									
								
								src/server/input_sanitization.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								src/server/input_sanitization.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,158 @@ | ||||
| use crate::core::{ | ||||
|     common::UnixUser, | ||||
|     protocol::server_responses::{NameValidationError, OwnerValidationError}, | ||||
| }; | ||||
|  | ||||
| const MAX_NAME_LENGTH: usize = 64; | ||||
|  | ||||
| pub fn validate_name(name: &str) -> Result<(), NameValidationError> { | ||||
|     if name.is_empty() { | ||||
|         Err(NameValidationError::EmptyString) | ||||
|     } else if name.len() > MAX_NAME_LENGTH { | ||||
|         Err(NameValidationError::TooLong) | ||||
|     } else if !name | ||||
|         .chars() | ||||
|         .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') | ||||
|     { | ||||
|         Err(NameValidationError::InvalidCharacters) | ||||
|     } else { | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|  | ||||
| pub fn validate_ownership_by_unix_user( | ||||
|     name: &str, | ||||
|     user: &UnixUser, | ||||
| ) -> Result<(), OwnerValidationError> { | ||||
|     let prefixes = std::iter::once(user.username.clone()) | ||||
|         .chain(user.groups.iter().cloned()) | ||||
|         .collect::<Vec<String>>(); | ||||
|  | ||||
|     validate_ownership_by_prefixes(name, &prefixes) | ||||
| } | ||||
|  | ||||
| /// Core logic for validating the ownership of a database name. | ||||
| /// This function checks if the given name matches any of the given prefixes. | ||||
| /// These prefixes will in most cases be the user's unix username and any | ||||
| /// unix groups the user is a member of. | ||||
| pub fn validate_ownership_by_prefixes( | ||||
|     name: &str, | ||||
|     prefixes: &[String], | ||||
| ) -> Result<(), OwnerValidationError> { | ||||
|     if name.is_empty() { | ||||
|         return Err(OwnerValidationError::StringEmpty); | ||||
|     } | ||||
|  | ||||
|     if name.starts_with('_') { | ||||
|         return Err(OwnerValidationError::MissingPrefix); | ||||
|     } | ||||
|  | ||||
|     let (prefix, _) = match name.split_once('_') { | ||||
|         Some(pair) => pair, | ||||
|         None => return Err(OwnerValidationError::MissingPostfix), | ||||
|     }; | ||||
|  | ||||
|     if !prefixes.iter().any(|g| g == prefix) { | ||||
|         return Err(OwnerValidationError::NoMatch); | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| pub fn quote_literal(s: &str) -> String { | ||||
|     format!("'{}'", s.replace('\'', r"\'")) | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| pub fn quote_identifier(s: &str) -> String { | ||||
|     format!("`{}`", s.replace('`', r"\`")) | ||||
| } | ||||
|  | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|     use super::*; | ||||
|     #[test] | ||||
|     fn test_quote_literal() { | ||||
|         let payload = "' OR 1=1 --"; | ||||
|         assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_quote_identifier() { | ||||
|         let payload = "` OR 1=1 --"; | ||||
|         assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_validate_name() { | ||||
|         assert_eq!(validate_name(""), Err(NameValidationError::EmptyString)); | ||||
|         assert_eq!(validate_name("abcdefghijklmnopqrstuvwxyz"), Ok(())); | ||||
|         assert_eq!(validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), Ok(())); | ||||
|         assert_eq!(validate_name("0123456789_-"), Ok(())); | ||||
|  | ||||
|         for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() { | ||||
|             assert_eq!( | ||||
|                 validate_name(&c.to_string()), | ||||
|                 Err(NameValidationError::InvalidCharacters) | ||||
|             ); | ||||
|         } | ||||
|  | ||||
|         assert_eq!(validate_name(&"a".repeat(MAX_NAME_LENGTH)), Ok(())); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_name(&"a".repeat(MAX_NAME_LENGTH + 1)), | ||||
|             Err(NameValidationError::TooLong) | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_validate_owner_by_prefixes() { | ||||
|         let prefixes = vec!["user".to_string(), "group".to_string()]; | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("", &prefixes), | ||||
|             Err(OwnerValidationError::StringEmpty) | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("user", &prefixes), | ||||
|             Err(OwnerValidationError::MissingPostfix) | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("something", &prefixes), | ||||
|             Err(OwnerValidationError::MissingPostfix) | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("user-testdb", &prefixes), | ||||
|             Err(OwnerValidationError::MissingPostfix) | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("_testdb", &prefixes), | ||||
|             Err(OwnerValidationError::MissingPrefix) | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("user_testdb", &prefixes), | ||||
|             Ok(()) | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("group_testdb", &prefixes), | ||||
|             Ok(()) | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("group_test_db", &prefixes), | ||||
|             Ok(()) | ||||
|         ); | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("group_test-db", &prefixes), | ||||
|             Ok(()) | ||||
|         ); | ||||
|  | ||||
|         assert_eq!( | ||||
|             validate_ownership_by_prefixes("nonexistent_testdb", &prefixes), | ||||
|             Err(OwnerValidationError::NoMatch) | ||||
|         ); | ||||
|     } | ||||
| } | ||||
							
								
								
									
										229
									
								
								src/server/server_loop.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										229
									
								
								src/server/server_loop.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,229 @@ | ||||
| use std::{collections::BTreeSet, fs, path::PathBuf}; | ||||
|  | ||||
| use anyhow::Context; | ||||
|  | ||||
| use futures_util::{SinkExt, StreamExt}; | ||||
| use tokio::io::AsyncWriteExt; | ||||
| use tokio::net::{UnixListener, UnixStream}; | ||||
|  | ||||
| use sqlx::prelude::*; | ||||
| use sqlx::MySqlConnection; | ||||
|  | ||||
| use crate::{ | ||||
|     core::{ | ||||
|         bootstrap::authenticated_unix_socket, | ||||
|         common::{UnixUser, DEFAULT_SOCKET_PATH}, | ||||
|         protocol::request_response::{ | ||||
|             create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream, | ||||
|         }, | ||||
|     }, | ||||
|     server::{ | ||||
|         config::{create_mysql_connection_from_config, ServerConfig}, | ||||
|         sql::{ | ||||
|             database_operations::{create_databases, drop_databases, list_databases_for_user}, | ||||
|             database_privilege_operations::{ | ||||
|                 apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, | ||||
|             }, | ||||
|             user_operations::{ | ||||
|                 create_database_users, drop_database_users, list_all_database_users_for_unix_user, | ||||
|                 list_database_users, lock_database_users, set_password_for_database_user, | ||||
|                 unlock_database_users, | ||||
|             }, | ||||
|         }, | ||||
|     }, | ||||
| }; | ||||
|  | ||||
| // TODO: consider using a connection pool | ||||
|  | ||||
| // TODO: use tracing for login, so we can scope the log messages per incoming connection | ||||
|  | ||||
| pub async fn listen_for_incoming_connections( | ||||
|     socket_path: Option<PathBuf>, | ||||
|     config: ServerConfig, | ||||
|     // db_connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH)); | ||||
|  | ||||
|     let parent_directory = socket_path.parent().unwrap(); | ||||
|     if !parent_directory.exists() { | ||||
|         println!("Creating directory {:?}", parent_directory); | ||||
|         fs::create_dir_all(parent_directory)?; | ||||
|     } | ||||
|  | ||||
|     println!("Listening on {:?}", socket_path); | ||||
|     match fs::remove_file(socket_path.as_path()) { | ||||
|         Ok(_) => {} | ||||
|         Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} | ||||
|         Err(e) => return Err(e.into()), | ||||
|     } | ||||
|  | ||||
|     let listener = UnixListener::bind(socket_path)?; | ||||
|  | ||||
|     while let Ok((mut conn, _addr)) = listener.accept().await { | ||||
|         let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await { | ||||
|             Ok(uid) => uid, | ||||
|             Err(e) => { | ||||
|                 eprintln!("Failed to authenticate client: {}", e); | ||||
|                 conn.shutdown().await?; | ||||
|                 continue; | ||||
|             } | ||||
|         }; | ||||
|         let unix_user = match UnixUser::from_uid(uid.into()) { | ||||
|             Ok(user) => user, | ||||
|             Err(e) => { | ||||
|                 eprintln!("Failed to get UnixUser from uid: {}", e); | ||||
|                 conn.shutdown().await?; | ||||
|                 continue; | ||||
|             } | ||||
|         }; | ||||
|         match handle_requests_for_single_session(conn, &unix_user, &config).await { | ||||
|             Ok(_) => {} | ||||
|             Err(e) => { | ||||
|                 eprintln!("Failed to run server: {}", e); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|  | ||||
| pub async fn handle_requests_for_single_session( | ||||
|     socket: UnixStream, | ||||
|     unix_user: &UnixUser, | ||||
|     config: &ServerConfig, | ||||
| ) -> anyhow::Result<()> { | ||||
|     let message_stream = create_server_to_client_message_stream(socket); | ||||
|     let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; | ||||
|  | ||||
|     let result = handle_requests_for_single_session_with_db_connection( | ||||
|         message_stream, | ||||
|         unix_user, | ||||
|         &mut db_connection, | ||||
|     ) | ||||
|     .await; | ||||
|  | ||||
|     if let Err(e) = db_connection | ||||
|         .close() | ||||
|         .await | ||||
|         .context("Failed to close connection properly") | ||||
|     { | ||||
|         eprintln!("{}", e); | ||||
|         eprintln!("Ignoring..."); | ||||
|     } | ||||
|  | ||||
|     result | ||||
| } | ||||
|  | ||||
| // TODO: ensure proper db_connection hygiene for functions that invoke | ||||
| //       this function | ||||
|  | ||||
| pub async fn handle_requests_for_single_session_with_db_connection( | ||||
|     mut stream: ServerToClientMessageStream, | ||||
|     unix_user: &UnixUser, | ||||
|     db_connection: &mut MySqlConnection, | ||||
| ) -> anyhow::Result<()> { | ||||
|     loop { | ||||
|         // TODO: better error handling | ||||
|         let request = match stream.next().await { | ||||
|             Some(Ok(request)) => request, | ||||
|             Some(Err(e)) => return Err(e.into()), | ||||
|             None => { | ||||
|                 log::warn!("Client disconnected without sending an exit message"); | ||||
|                 break; | ||||
|             } | ||||
|         }; | ||||
|  | ||||
|         match request { | ||||
|             Request::CreateDatabases(databases_names) => { | ||||
|                 let result = create_databases(databases_names, unix_user, db_connection).await; | ||||
|                 stream.send(Response::CreateDatabases(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::DropDatabases(databases_names) => { | ||||
|                 let result = drop_databases(databases_names, unix_user, db_connection).await; | ||||
|                 stream.send(Response::DropDatabases(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::ListDatabases => { | ||||
|                 let result = list_databases_for_user(unix_user, db_connection).await; | ||||
|                 stream.send(Response::ListAllDatabases(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::ListPrivileges(database_names) => { | ||||
|                 let response = match database_names { | ||||
|                     Some(database_names) => { | ||||
|                         let privilege_data = | ||||
|                             get_databases_privilege_data(database_names, unix_user, db_connection) | ||||
|                                 .await; | ||||
|                         Response::ListPrivileges(privilege_data) | ||||
|                     } | ||||
|                     None => { | ||||
|                         let privilege_data = | ||||
|                             get_all_database_privileges(unix_user, db_connection).await; | ||||
|                         Response::ListAllPrivileges(privilege_data) | ||||
|                     } | ||||
|                 }; | ||||
|  | ||||
|                 stream.send(response).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::ModifyPrivileges(database_privilege_diffs) => { | ||||
|                 let result = apply_privilege_diffs( | ||||
|                     BTreeSet::from_iter(database_privilege_diffs), | ||||
|                     unix_user, | ||||
|                     db_connection, | ||||
|                 ) | ||||
|                 .await; | ||||
|                 stream.send(Response::ModifyPrivileges(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::CreateUsers(db_users) => { | ||||
|                 let result = create_database_users(db_users, unix_user, db_connection).await; | ||||
|                 stream.send(Response::CreateUsers(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::DropUsers(db_users) => { | ||||
|                 let result = drop_database_users(db_users, unix_user, db_connection).await; | ||||
|                 stream.send(Response::DropUsers(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::PasswdUser(db_user, password) => { | ||||
|                 let result = | ||||
|                     set_password_for_database_user(&db_user, &password, unix_user, db_connection) | ||||
|                         .await; | ||||
|                 stream.send(Response::PasswdUser(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::ListUsers(db_users) => { | ||||
|                 let response = match db_users { | ||||
|                     Some(db_users) => { | ||||
|                         let result = list_database_users(db_users, unix_user, db_connection).await; | ||||
|                         Response::ListUsers(result) | ||||
|                     } | ||||
|                     None => { | ||||
|                         let result = | ||||
|                             list_all_database_users_for_unix_user(unix_user, db_connection).await; | ||||
|                         Response::ListAllUsers(result) | ||||
|                     } | ||||
|                 }; | ||||
|                 stream.send(response).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::LockUsers(db_users) => { | ||||
|                 let result = lock_database_users(db_users, unix_user, db_connection).await; | ||||
|                 stream.send(Response::LockUsers(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::UnlockUsers(db_users) => { | ||||
|                 let result = unlock_database_users(db_users, unix_user, db_connection).await; | ||||
|                 stream.send(Response::UnlockUsers(result)).await?; | ||||
|                 stream.flush().await?; | ||||
|             } | ||||
|             Request::Exit => { | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
							
								
								
									
										3
									
								
								src/server/sql.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								src/server/sql.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| pub mod database_operations; | ||||
| pub mod database_privilege_operations; | ||||
| pub mod user_operations; | ||||
							
								
								
									
										165
									
								
								src/server/sql/database_operations.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								src/server/sql/database_operations.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,165 @@ | ||||
| use crate::{ | ||||
|     core::{ | ||||
|         common::UnixUser, | ||||
|         protocol::{ | ||||
|             CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput, | ||||
|             ListDatabasesError, | ||||
|         }, | ||||
|     }, | ||||
|     server::{ | ||||
|         common::create_user_group_matching_regex, | ||||
|         input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user}, | ||||
|     }, | ||||
| }; | ||||
|  | ||||
| use sqlx::prelude::*; | ||||
|  | ||||
| use sqlx::MySqlConnection; | ||||
| use std::collections::BTreeMap; | ||||
|  | ||||
| // NOTE: this function is unsafe because it does no input validation. | ||||
| pub(super) async fn unsafe_database_exists( | ||||
|     database_name: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<bool, sqlx::Error> { | ||||
|     let result = | ||||
|         sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") | ||||
|             .bind(database_name) | ||||
|             .fetch_optional(connection) | ||||
|             .await?; | ||||
|  | ||||
|     Ok(result.is_some()) | ||||
| } | ||||
|  | ||||
| pub async fn create_databases( | ||||
|     database_names: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> CreateDatabasesOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for database_name in database_names { | ||||
|         if let Err(err) = validate_name(&database_name) { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(CreateDatabaseError::SanitizationError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(CreateDatabaseError::OwnershipError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         match unsafe_database_exists(&database_name, &mut *connection).await { | ||||
|             Ok(true) => { | ||||
|                 results.insert( | ||||
|                     database_name.clone(), | ||||
|                     Err(CreateDatabaseError::DatabaseAlreadyExists), | ||||
|                 ); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert( | ||||
|                     database_name.clone(), | ||||
|                     Err(CreateDatabaseError::MySqlError(err.to_string())), | ||||
|                 ); | ||||
|                 continue; | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
|  | ||||
|         let result = | ||||
|             sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str()) | ||||
|                 .execute(&mut *connection) | ||||
|                 .await | ||||
|                 .map(|_| ()) | ||||
|                 .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); | ||||
|  | ||||
|         results.insert(database_name, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| pub async fn drop_databases( | ||||
|     database_names: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> DropDatabasesOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for database_name in database_names { | ||||
|         if let Err(err) = validate_name(&database_name) { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(DropDatabaseError::SanitizationError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(DropDatabaseError::OwnershipError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         match unsafe_database_exists(&database_name, &mut *connection).await { | ||||
|             Ok(false) => { | ||||
|                 results.insert( | ||||
|                     database_name.clone(), | ||||
|                     Err(DropDatabaseError::DatabaseDoesNotExist), | ||||
|                 ); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert( | ||||
|                     database_name.clone(), | ||||
|                     Err(DropDatabaseError::MySqlError(err.to_string())), | ||||
|                 ); | ||||
|                 continue; | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
|  | ||||
|         let result = | ||||
|             sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str()) | ||||
|                 .execute(&mut *connection) | ||||
|                 .await | ||||
|                 .map(|_| ()) | ||||
|                 .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); | ||||
|  | ||||
|         results.insert(database_name, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| pub async fn list_databases_for_user( | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<Vec<String>, ListDatabasesError> { | ||||
|     sqlx::query( | ||||
|         r#" | ||||
|           SELECT `SCHEMA_NAME` AS `database` | ||||
|           FROM `information_schema`.`SCHEMATA` | ||||
|           WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') | ||||
|             AND `SCHEMA_NAME` REGEXP ? | ||||
|         "#, | ||||
|     ) | ||||
|     .bind(create_user_group_matching_regex(unix_user)) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
|     .and_then(|rows| { | ||||
|         rows.into_iter() | ||||
|             .map(|row| row.try_get::<String, _>("database")) | ||||
|             .collect::<Result<Vec<String>, sqlx::Error>>() | ||||
|     }) | ||||
|     .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) | ||||
| } | ||||
							
								
								
									
										452
									
								
								src/server/sql/database_privilege_operations.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										452
									
								
								src/server/sql/database_privilege_operations.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,452 @@ | ||||
| // TODO: fix comment | ||||
| //! Database privilege operations | ||||
| //! | ||||
| //! This module contains functions for querying, modifying, | ||||
| //! displaying and comparing database privileges. | ||||
| //! | ||||
| //! A lot of the complexity comes from two core components: | ||||
| //! | ||||
| //! - The privilege editor that needs to be able to print | ||||
| //!   an editable table of privileges and reparse the content | ||||
| //!   after the user has made manual changes. | ||||
| //! | ||||
| //! - The comparison functionality that tells the user what | ||||
| //!   changes will be made when applying a set of changes | ||||
| //!   to the list of database privileges. | ||||
|  | ||||
| use std::collections::{BTreeMap, BTreeSet}; | ||||
|  | ||||
| use indoc::indoc; | ||||
| use itertools::Itertools; | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; | ||||
|  | ||||
| use crate::{ | ||||
|     core::{ | ||||
|         common::{rev_yn, yn, UnixUser}, | ||||
|         database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff}, | ||||
|         protocol::{ | ||||
|             DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError, | ||||
|             GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError, | ||||
|             ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput, | ||||
|         }, | ||||
|     }, | ||||
|     server::{ | ||||
|         common::create_user_group_matching_regex, | ||||
|         input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user}, | ||||
|         sql::database_operations::unsafe_database_exists, | ||||
|     }, | ||||
| }; | ||||
|  | ||||
| /// This is the list of fields that are used to fetch the db + user + privileges | ||||
| /// from the `db` table in the database. If you need to add or remove privilege | ||||
| /// fields, this is a good place to start. | ||||
| pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ | ||||
|     "db", | ||||
|     "user", | ||||
|     "select_priv", | ||||
|     "insert_priv", | ||||
|     "update_priv", | ||||
|     "delete_priv", | ||||
|     "create_priv", | ||||
|     "drop_priv", | ||||
|     "alter_priv", | ||||
|     "index_priv", | ||||
|     "create_tmp_table_priv", | ||||
|     "lock_tables_priv", | ||||
|     "references_priv", | ||||
| ]; | ||||
|  | ||||
| // NOTE: ord is needed for BTreeSet to accept the type, but it | ||||
| //       doesn't have any natural implementation semantics. | ||||
|  | ||||
| /// This struct represents the set of privileges for a single user on a single database. | ||||
| #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] | ||||
| pub struct DatabasePrivilegeRow { | ||||
|     pub db: String, | ||||
|     pub user: String, | ||||
|     pub select_priv: bool, | ||||
|     pub insert_priv: bool, | ||||
|     pub update_priv: bool, | ||||
|     pub delete_priv: bool, | ||||
|     pub create_priv: bool, | ||||
|     pub drop_priv: bool, | ||||
|     pub alter_priv: bool, | ||||
|     pub index_priv: bool, | ||||
|     pub create_tmp_table_priv: bool, | ||||
|     pub lock_tables_priv: bool, | ||||
|     pub references_priv: bool, | ||||
| } | ||||
|  | ||||
| impl DatabasePrivilegeRow { | ||||
|     pub fn get_privilege_by_name(&self, name: &str) -> bool { | ||||
|         match name { | ||||
|             "select_priv" => self.select_priv, | ||||
|             "insert_priv" => self.insert_priv, | ||||
|             "update_priv" => self.update_priv, | ||||
|             "delete_priv" => self.delete_priv, | ||||
|             "create_priv" => self.create_priv, | ||||
|             "drop_priv" => self.drop_priv, | ||||
|             "alter_priv" => self.alter_priv, | ||||
|             "index_priv" => self.index_priv, | ||||
|             "create_tmp_table_priv" => self.create_tmp_table_priv, | ||||
|             "lock_tables_priv" => self.lock_tables_priv, | ||||
|             "references_priv" => self.references_priv, | ||||
|             _ => false, | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[inline] | ||||
| fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> { | ||||
|     let field = DATABASE_PRIVILEGE_FIELDS[position]; | ||||
|     let value = row.try_get(position)?; | ||||
|     match rev_yn(value) { | ||||
|         Some(val) => Ok(val), | ||||
|         _ => { | ||||
|             log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value); | ||||
|             Ok(false) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { | ||||
|     fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> { | ||||
|         Ok(Self { | ||||
|             db: row.try_get("db")?, | ||||
|             user: row.try_get("user")?, | ||||
|             select_priv: get_mysql_row_priv_field(row, 2)?, | ||||
|             insert_priv: get_mysql_row_priv_field(row, 3)?, | ||||
|             update_priv: get_mysql_row_priv_field(row, 4)?, | ||||
|             delete_priv: get_mysql_row_priv_field(row, 5)?, | ||||
|             create_priv: get_mysql_row_priv_field(row, 6)?, | ||||
|             drop_priv: get_mysql_row_priv_field(row, 7)?, | ||||
|             alter_priv: get_mysql_row_priv_field(row, 8)?, | ||||
|             index_priv: get_mysql_row_priv_field(row, 9)?, | ||||
|             create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?, | ||||
|             lock_tables_priv: get_mysql_row_priv_field(row, 11)?, | ||||
|             references_priv: get_mysql_row_priv_field(row, 12)?, | ||||
|         }) | ||||
|     } | ||||
| } | ||||
|  | ||||
| // NOTE: this function is unsafe because it does no input validation. | ||||
| /// Get all users + privileges for a single database. | ||||
| async fn unsafe_get_database_privileges( | ||||
|     database_name: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> { | ||||
|     sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( | ||||
|         "SELECT {} FROM `db` WHERE `db` = ?", | ||||
|         DATABASE_PRIVILEGE_FIELDS | ||||
|             .iter() | ||||
|             .map(|field| quote_identifier(field)) | ||||
|             .join(","), | ||||
|     )) | ||||
|     .bind(database_name) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
| } | ||||
|  | ||||
| // NOTE: this function is unsafe because it does no input validation. | ||||
| /// Get all users + privileges for a single database-user pair. | ||||
| pub async fn unsafe_get_database_privileges_for_db_user_pair( | ||||
|     database_name: &str, | ||||
|     user_name: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { | ||||
|     sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( | ||||
|         "SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?", | ||||
|         DATABASE_PRIVILEGE_FIELDS | ||||
|             .iter() | ||||
|             .map(|field| quote_identifier(field)) | ||||
|             .join(","), | ||||
|     )) | ||||
|     .bind(database_name) | ||||
|     .bind(user_name) | ||||
|     .fetch_optional(connection) | ||||
|     .await | ||||
| } | ||||
|  | ||||
| pub async fn get_databases_privilege_data( | ||||
|     database_names: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> GetDatabasesPrivilegeData { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for database_name in database_names.iter() { | ||||
|         if let Err(err) = validate_name(database_name) { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(GetDatabasesPrivilegeDataError::SanitizationError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(GetDatabasesPrivilegeDataError::OwnershipError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if !unsafe_database_exists(database_name, connection) | ||||
|             .await | ||||
|             .unwrap() | ||||
|         { | ||||
|             results.insert( | ||||
|                 database_name.clone(), | ||||
|                 Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         let result = unsafe_get_database_privileges(database_name, connection) | ||||
|             .await | ||||
|             .map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string())); | ||||
|  | ||||
|         results.insert(database_name.clone(), result); | ||||
|     } | ||||
|  | ||||
|     debug_assert!(database_names.len() == results.len()); | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| /// Get all database + user + privileges pairs that are owned by the current user. | ||||
| pub async fn get_all_database_privileges( | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> GetAllDatabasesPrivilegeData { | ||||
|     sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( | ||||
|         indoc! {r#" | ||||
|           SELECT {} FROM `db` WHERE `db` IN | ||||
|           (SELECT DISTINCT `SCHEMA_NAME` AS `database` | ||||
|             FROM `information_schema`.`SCHEMATA` | ||||
|             WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') | ||||
|               AND `SCHEMA_NAME` REGEXP ?) | ||||
|         "#}, | ||||
|         DATABASE_PRIVILEGE_FIELDS | ||||
|             .iter() | ||||
|             .map(|field| quote_identifier(field)) | ||||
|             .join(","), | ||||
|     )) | ||||
|     .bind(create_user_group_matching_regex(unix_user)) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
|     .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())) | ||||
| } | ||||
|  | ||||
| async fn unsafe_apply_privilege_diff( | ||||
|     database_privilege_diff: &DatabasePrivilegesDiff, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<(), sqlx::Error> { | ||||
|     match database_privilege_diff { | ||||
|         DatabasePrivilegesDiff::New(p) => { | ||||
|             let tables = DATABASE_PRIVILEGE_FIELDS | ||||
|                 .iter() | ||||
|                 .map(|field| quote_identifier(field)) | ||||
|                 .join(","); | ||||
|  | ||||
|             let question_marks = std::iter::repeat("?") | ||||
|                 .take(DATABASE_PRIVILEGE_FIELDS.len()) | ||||
|                 .join(","); | ||||
|  | ||||
|             sqlx::query( | ||||
|                 format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), | ||||
|             ) | ||||
|             .bind(p.db.to_string()) | ||||
|             .bind(p.user.to_string()) | ||||
|             .bind(yn(p.select_priv)) | ||||
|             .bind(yn(p.insert_priv)) | ||||
|             .bind(yn(p.update_priv)) | ||||
|             .bind(yn(p.delete_priv)) | ||||
|             .bind(yn(p.create_priv)) | ||||
|             .bind(yn(p.drop_priv)) | ||||
|             .bind(yn(p.alter_priv)) | ||||
|             .bind(yn(p.index_priv)) | ||||
|             .bind(yn(p.create_tmp_table_priv)) | ||||
|             .bind(yn(p.lock_tables_priv)) | ||||
|             .bind(yn(p.references_priv)) | ||||
|             .execute(connection) | ||||
|             .await | ||||
|             .map(|_| ()) | ||||
|         } | ||||
|         DatabasePrivilegesDiff::Modified(p) => { | ||||
|             let changes = p | ||||
|                 .diff | ||||
|                 .iter() | ||||
|                 .map(|diff| match diff { | ||||
|                     DatabasePrivilegeChange::YesToNo(name) => { | ||||
|                         format!("{} = 'N'", quote_identifier(name)) | ||||
|                     } | ||||
|                     DatabasePrivilegeChange::NoToYes(name) => { | ||||
|                         format!("{} = 'Y'", quote_identifier(name)) | ||||
|                     } | ||||
|                 }) | ||||
|                 .join(","); | ||||
|  | ||||
|             sqlx::query( | ||||
|                 format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(), | ||||
|             ) | ||||
|             .bind(p.db.to_string()) | ||||
|             .bind(p.user.to_string()) | ||||
|             .execute(connection) | ||||
|             .await | ||||
|             .map(|_| ()) | ||||
|         } | ||||
|         DatabasePrivilegesDiff::Deleted(p) => { | ||||
|             sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") | ||||
|                 .bind(p.db.to_string()) | ||||
|                 .bind(p.user.to_string()) | ||||
|                 .execute(connection) | ||||
|                 .await | ||||
|                 .map(|_| ()) | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| async fn validate_diff( | ||||
|     diff: &DatabasePrivilegesDiff, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<(), ModifyDatabasePrivilegesError> { | ||||
|     let privilege_row = unsafe_get_database_privileges_for_db_user_pair( | ||||
|         diff.get_database_name(), | ||||
|         diff.get_user_name(), | ||||
|         connection, | ||||
|     ) | ||||
|     .await; | ||||
|  | ||||
|     let privilege_row = match privilege_row { | ||||
|         Ok(privilege_row) => privilege_row, | ||||
|         Err(e) => return Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())), | ||||
|     }; | ||||
|  | ||||
|     let result = match diff { | ||||
|         DatabasePrivilegesDiff::New(_) => { | ||||
|             if privilege_row.is_some() { | ||||
|                 Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( | ||||
|                     DiffDoesNotApplyError::RowAlreadyExists( | ||||
|                         diff.get_user_name().to_string(), | ||||
|                         diff.get_database_name().to_string(), | ||||
|                     ), | ||||
|                 )) | ||||
|             } else { | ||||
|                 Ok(()) | ||||
|             } | ||||
|         } | ||||
|         DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => { | ||||
|             Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( | ||||
|                 DiffDoesNotApplyError::RowDoesNotExist( | ||||
|                     diff.get_user_name().to_string(), | ||||
|                     diff.get_database_name().to_string(), | ||||
|                 ), | ||||
|             )) | ||||
|         } | ||||
|         DatabasePrivilegesDiff::Modified(row_diff) => { | ||||
|             let row = privilege_row.unwrap(); | ||||
|  | ||||
|             let error_exists = row_diff.diff.iter().any(|change| match change { | ||||
|                 DatabasePrivilegeChange::YesToNo(name) => !row.get_privilege_by_name(name), | ||||
|                 DatabasePrivilegeChange::NoToYes(name) => row.get_privilege_by_name(name), | ||||
|             }); | ||||
|  | ||||
|             if error_exists { | ||||
|                 Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( | ||||
|                     DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.clone(), row), | ||||
|                 )) | ||||
|             } else { | ||||
|                 Ok(()) | ||||
|             } | ||||
|         } | ||||
|         DatabasePrivilegesDiff::Deleted(_) => { | ||||
|             if privilege_row.is_none() { | ||||
|                 Err(ModifyDatabasePrivilegesError::DiffDoesNotApply( | ||||
|                     DiffDoesNotApplyError::RowDoesNotExist( | ||||
|                         diff.get_user_name().to_string(), | ||||
|                         diff.get_database_name().to_string(), | ||||
|                     ), | ||||
|                 )) | ||||
|             } else { | ||||
|                 Ok(()) | ||||
|             } | ||||
|         } | ||||
|     }; | ||||
|  | ||||
|     result | ||||
| } | ||||
|  | ||||
| /// Uses the result of [`diff_privileges`] to modify privileges in the database. | ||||
| pub async fn apply_privilege_diffs( | ||||
|     database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> ModifyDatabasePrivilegesOutput { | ||||
|     let mut results: BTreeMap<(String, String), _> = BTreeMap::new(); | ||||
|  | ||||
|     for diff in database_privilege_diffs { | ||||
|         let key = ( | ||||
|             diff.get_database_name().to_string(), | ||||
|             diff.get_user_name().to_string(), | ||||
|         ); | ||||
|         if let Err(err) = validate_name(diff.get_database_name()) { | ||||
|             results.insert( | ||||
|                 key, | ||||
|                 Err(ModifyDatabasePrivilegesError::DatabaseSanitizationError( | ||||
|                     err, | ||||
|                 )), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) { | ||||
|             results.insert( | ||||
|                 key, | ||||
|                 Err(ModifyDatabasePrivilegesError::DatabaseOwnershipError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_name(diff.get_user_name()) { | ||||
|             results.insert( | ||||
|                 key, | ||||
|                 Err(ModifyDatabasePrivilegesError::UserSanitizationError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) { | ||||
|             results.insert( | ||||
|                 key, | ||||
|                 Err(ModifyDatabasePrivilegesError::UserOwnershipError(err)), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if !unsafe_database_exists(diff.get_database_name(), connection) | ||||
|             .await | ||||
|             .unwrap() | ||||
|         { | ||||
|             results.insert( | ||||
|                 key, | ||||
|                 Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist), | ||||
|             ); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_diff(&diff, connection).await { | ||||
|             results.insert(key, Err(err)); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         let result = unsafe_apply_privilege_diff(&diff, connection) | ||||
|             .await | ||||
|             .map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string())); | ||||
|  | ||||
|         results.insert(key, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
							
								
								
									
										375
									
								
								src/server/sql/user_operations.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										375
									
								
								src/server/sql/user_operations.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,375 @@ | ||||
| use std::collections::BTreeMap; | ||||
|  | ||||
| use serde::{Deserialize, Serialize}; | ||||
|  | ||||
| use sqlx::prelude::*; | ||||
| use sqlx::MySqlConnection; | ||||
|  | ||||
| use crate::{ | ||||
|     core::{ | ||||
|         common::UnixUser, | ||||
|         protocol::{ | ||||
|             CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError, | ||||
|             ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput, | ||||
|             SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput, | ||||
|         }, | ||||
|     }, | ||||
|     server::{ | ||||
|         common::create_user_group_matching_regex, | ||||
|         input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user}, | ||||
|     }, | ||||
| }; | ||||
|  | ||||
| // NOTE: this function is unsafe because it does no input validation. | ||||
| async fn unsafe_user_exists( | ||||
|     db_user: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<bool, sqlx::Error> { | ||||
|     sqlx::query( | ||||
|         r#" | ||||
|           SELECT EXISTS( | ||||
|             SELECT 1 | ||||
|             FROM `mysql`.`user` | ||||
|             WHERE `User` = ? | ||||
|           ) | ||||
|         "#, | ||||
|     ) | ||||
|     .bind(db_user) | ||||
|     .fetch_one(connection) | ||||
|     .await | ||||
|     .map(|row| row.get::<bool, _>(0)) | ||||
| } | ||||
|  | ||||
| pub async fn create_database_users( | ||||
|     db_users: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> CreateUsersOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for db_user in db_users { | ||||
|         if let Err(err) = validate_name(&db_user) { | ||||
|             results.insert(db_user, Err(CreateUserError::SanitizationError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { | ||||
|             results.insert(db_user, Err(CreateUserError::OwnershipError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         match unsafe_user_exists(&db_user, &mut *connection).await { | ||||
|             Ok(true) => { | ||||
|                 results.insert(db_user, Err(CreateUserError::UserAlreadyExists)); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string()))); | ||||
|                 continue; | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
|  | ||||
|         let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str()) | ||||
|             .execute(&mut *connection) | ||||
|             .await | ||||
|             .map(|_| ()) | ||||
|             .map_err(|err| CreateUserError::MySqlError(err.to_string())); | ||||
|  | ||||
|         results.insert(db_user, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| pub async fn drop_database_users( | ||||
|     db_users: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> DropUsersOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for db_user in db_users { | ||||
|         if let Err(err) = validate_name(&db_user) { | ||||
|             results.insert(db_user, Err(DropUserError::SanitizationError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { | ||||
|             results.insert(db_user, Err(DropUserError::OwnershipError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         match unsafe_user_exists(&db_user, &mut *connection).await { | ||||
|             Ok(false) => { | ||||
|                 results.insert(db_user, Err(DropUserError::UserDoesNotExist)); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert(db_user, Err(DropUserError::MySqlError(err.to_string()))); | ||||
|                 continue; | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
|  | ||||
|         let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str()) | ||||
|             .execute(&mut *connection) | ||||
|             .await | ||||
|             .map(|_| ()) | ||||
|             .map_err(|err| DropUserError::MySqlError(err.to_string())); | ||||
|  | ||||
|         results.insert(db_user, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| pub async fn set_password_for_database_user( | ||||
|     db_user: &str, | ||||
|     password: &str, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> SetPasswordOutput { | ||||
|     if let Err(err) = validate_name(db_user) { | ||||
|         return Err(SetPasswordError::SanitizationError(err)); | ||||
|     } | ||||
|  | ||||
|     if let Err(err) = validate_ownership_by_unix_user(db_user, unix_user) { | ||||
|         return Err(SetPasswordError::OwnershipError(err)); | ||||
|     } | ||||
|  | ||||
|     match unsafe_user_exists(db_user, &mut *connection).await { | ||||
|         Ok(false) => return Err(SetPasswordError::UserDoesNotExist), | ||||
|         Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())), | ||||
|         _ => {} | ||||
|     } | ||||
|  | ||||
|     sqlx::query( | ||||
|         format!( | ||||
|             "ALTER USER {}@'%' IDENTIFIED BY {}", | ||||
|             quote_literal(db_user), | ||||
|             quote_literal(password).as_str() | ||||
|         ) | ||||
|         .as_str(), | ||||
|     ) | ||||
|     .execute(&mut *connection) | ||||
|     .await | ||||
|     .map(|_| ()) | ||||
|     .map_err(|err| SetPasswordError::MySqlError(err.to_string())) | ||||
| } | ||||
|  | ||||
| // NOTE: this function is unsafe because it does no input validation. | ||||
| async fn database_user_is_locked_unsafe( | ||||
|     db_user: &str, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> Result<bool, sqlx::Error> { | ||||
|     sqlx::query( | ||||
|         r#" | ||||
|           SELECT COALESCE( | ||||
|             JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), | ||||
|             'false' | ||||
|           ) != 'false' | ||||
|           FROM `mysql`.`global_priv` | ||||
|           WHERE `User` = ? | ||||
|           AND `Host` = '%' | ||||
|         "#, | ||||
|     ) | ||||
|     .bind(db_user) | ||||
|     .fetch_one(connection) | ||||
|     .await | ||||
|     .map(|row| row.get::<bool, _>(0)) | ||||
| } | ||||
|  | ||||
| pub async fn lock_database_users( | ||||
|     db_users: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> LockUsersOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for db_user in db_users { | ||||
|         if let Err(err) = validate_name(&db_user) { | ||||
|             results.insert(db_user, Err(LockUserError::SanitizationError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { | ||||
|             results.insert(db_user, Err(LockUserError::OwnershipError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         match unsafe_user_exists(&db_user, &mut *connection).await { | ||||
|             Ok(true) => {} | ||||
|             Ok(false) => { | ||||
|                 results.insert(db_user, Err(LockUserError::UserDoesNotExist)); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert(db_user, Err(LockUserError::MySqlError(err.to_string()))); | ||||
|                 continue; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         match database_user_is_locked_unsafe(&db_user, &mut *connection).await { | ||||
|             Ok(false) => {} | ||||
|             Ok(true) => { | ||||
|                 results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked)); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert(db_user, Err(LockUserError::MySqlError(err.to_string()))); | ||||
|                 continue; | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         let result = sqlx::query( | ||||
|             format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(), | ||||
|         ) | ||||
|         .execute(&mut *connection) | ||||
|         .await | ||||
|         .map(|_| ()) | ||||
|         .map_err(|err| LockUserError::MySqlError(err.to_string())); | ||||
|  | ||||
|         results.insert(db_user, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| pub async fn unlock_database_users( | ||||
|     db_users: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> UnlockUsersOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for db_user in db_users { | ||||
|         if let Err(err) = validate_name(&db_user) { | ||||
|             results.insert(db_user, Err(UnlockUserError::SanitizationError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { | ||||
|             results.insert(db_user, Err(UnlockUserError::OwnershipError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         match unsafe_user_exists(&db_user, &mut *connection).await { | ||||
|             Ok(false) => { | ||||
|                 results.insert(db_user, Err(UnlockUserError::UserDoesNotExist)); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string()))); | ||||
|                 continue; | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
|  | ||||
|         match database_user_is_locked_unsafe(&db_user, &mut *connection).await { | ||||
|             Ok(false) => { | ||||
|                 results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked)); | ||||
|                 continue; | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string()))); | ||||
|                 continue; | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
|  | ||||
|         let result = sqlx::query( | ||||
|             format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(), | ||||
|         ) | ||||
|         .execute(&mut *connection) | ||||
|         .await | ||||
|         .map(|_| ()) | ||||
|         .map_err(|err| UnlockUserError::MySqlError(err.to_string())); | ||||
|  | ||||
|         results.insert(db_user, result); | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| /// This struct contains information about a database user. | ||||
| /// This can be extended if we need more information in the future. | ||||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] | ||||
| pub struct DatabaseUser { | ||||
|     #[sqlx(rename = "User")] | ||||
|     pub user: String, | ||||
|  | ||||
|     #[allow(dead_code)] | ||||
|     #[serde(skip)] | ||||
|     #[sqlx(rename = "Host")] | ||||
|     pub host: String, | ||||
|  | ||||
|     #[sqlx(rename = "has_password")] | ||||
|     pub has_password: bool, | ||||
|  | ||||
|     #[sqlx(rename = "is_locked")] | ||||
|     pub is_locked: bool, | ||||
| } | ||||
|  | ||||
| const DB_USER_SELECT_STATEMENT: &str = r#" | ||||
| SELECT | ||||
|   `mysql`.`user`.`User`, | ||||
|   `mysql`.`user`.`Host`, | ||||
|   `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, | ||||
|   COALESCE( | ||||
|     JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), | ||||
|     'false' | ||||
|   ) != 'false' AS `is_locked` | ||||
| FROM `mysql`.`user` | ||||
| JOIN `mysql`.`global_priv` ON | ||||
|   `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` | ||||
|   AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` | ||||
| "#; | ||||
|  | ||||
| pub async fn list_database_users( | ||||
|     db_users: Vec<String>, | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> ListUsersOutput { | ||||
|     let mut results = BTreeMap::new(); | ||||
|  | ||||
|     for db_user in db_users { | ||||
|         if let Err(err) = validate_name(&db_user) { | ||||
|             results.insert(db_user, Err(ListUsersError::SanitizationError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { | ||||
|             results.insert(db_user, Err(ListUsersError::OwnershipError(err))); | ||||
|             continue; | ||||
|         } | ||||
|  | ||||
|         let result = sqlx::query_as::<_, DatabaseUser>( | ||||
|             &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), | ||||
|         ) | ||||
|         .bind(&db_user) | ||||
|         .fetch_optional(&mut *connection) | ||||
|         .await; | ||||
|  | ||||
|         match result { | ||||
|             Ok(Some(user)) => results.insert(db_user, Ok(user)), | ||||
|             Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)), | ||||
|             Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))), | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     results | ||||
| } | ||||
|  | ||||
| pub async fn list_all_database_users_for_unix_user( | ||||
|     unix_user: &UnixUser, | ||||
|     connection: &mut MySqlConnection, | ||||
| ) -> ListAllUsersOutput { | ||||
|     sqlx::query_as::<_, DatabaseUser>( | ||||
|         &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), | ||||
|     ) | ||||
|     .bind(create_user_group_matching_regex(unix_user)) | ||||
|     .fetch_all(connection) | ||||
|     .await | ||||
|     .map_err(|err| ListAllUsersError::MySqlError(err.to_string())) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user