diff --git a/Cargo.lock b/Cargo.lock index 230629b..6476623 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -418,6 +418,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", + "unicode-xid", +] + [[package]] name = "dialoguer" version = "0.11.0" @@ -993,6 +1014,7 @@ dependencies = [ "async-bincode", "bincode", "clap", + "derive_more", "dialoguer", "env_logger", "futures", @@ -2130,6 +2152,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "unicode_categories" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 77d154f..f0603d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ anyhow = "1.0.82" async-bincode = "0.7.2" bincode = "1.3.3" clap = { version = "4.5.4", features = ["derive"] } +derive_more = { version = "1.0.0", features = ["display", "error"] } dialoguer = "0.11.0" env_logger = "0.11.3" futures = "0.3.30" diff --git a/src/cli.rs b/src/cli.rs index 7c80cfc..112755e 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,3 +1,4 @@ pub mod database_command; -pub mod mysql_admutils_compatibility; +// pub mod mysql_admutils_compatibility; +mod common; pub mod user_command; diff --git a/src/cli/common.rs b/src/cli/common.rs new file mode 100644 index 0000000..91fdd18 --- /dev/null +++ b/src/cli/common.rs @@ -0,0 +1,44 @@ +use crate::server::Response; + +/// This enum is used to differentiate between database and user operations. +/// Their output are very similar, but there are slight differences in the words used. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum DbOrUser { + Database, + User, +} + +impl DbOrUser { + pub fn lowercased(&self) -> String { + match self { + DbOrUser::Database => "database".to_string(), + DbOrUser::User => "user".to_string(), + } + } + + pub fn capitalized(&self) -> String { + match self { + DbOrUser::Database => "Database".to_string(), + DbOrUser::User => "User".to_string(), + } + } +} + +pub fn erroneous_server_response( + response: Option>, +) -> anyhow::Result<()> { + match response { + Some(Ok(Response::Error(e))) => { + anyhow::bail!("Error from server: {}", e); + } + Some(Err(e)) => { + anyhow::bail!(e); + } + Some(response) => { + anyhow::bail!("Unexpected response from server: {:?}", response); + } + None => { + anyhow::bail!("No response from server"); + } + } +} diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs index b41ed8c..35568fa 100644 --- a/src/cli/database_command.rs +++ b/src/cli/database_command.rs @@ -1,17 +1,18 @@ -use anyhow::Context; use clap::Parser; -use dialoguer::{Confirm, Editor}; +use futures_util::{SinkExt, StreamExt}; use prettytable::{Cell, Row, Table}; -use sqlx::{Connection, MySqlConnection}; -use crate::core::{ - common::{close_database_connection, get_current_unix_user, yn, CommandStatus}, - database_operations::*, - database_privilege_operations::*, - user_operations::user_exists, +use crate::{ + core::{common::yn, database_privileges::db_priv_field_human_readable_name}, + server::{ + database_privilege_operations::DATABASE_PRIVILEGE_FIELDS, + protocol::ClientToServerMessageStream, Request, Response, + }, }; -#[derive(Parser)] +use super::common::erroneous_server_response; + +#[derive(Parser, Debug, Clone)] // #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))] pub enum DatabaseCommand { /// Create one or more databases @@ -86,28 +87,28 @@ pub enum DatabaseCommand { EditDbPrivs(DatabaseEditPrivsArgs), } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseCreateArgs { /// The name of the database(s) to create. #[arg(num_args = 1..)] name: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseDropArgs { /// The name of the database(s) to drop. #[arg(num_args = 1..)] name: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseListArgs { /// Whether to output the information in JSON format. #[arg(short, long)] json: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseShowPrivsArgs { /// The name of the database(s) to show. #[arg(num_args = 0..)] @@ -118,7 +119,7 @@ pub struct DatabaseShowPrivsArgs { json: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct DatabaseEditPrivsArgs { /// The name of the database to edit privileges for. pub name: Option, @@ -141,125 +142,156 @@ pub struct DatabaseEditPrivsArgs { pub async fn handle_command( command: DatabaseCommand, - mut connection: MySqlConnection, -) -> anyhow::Result { - let result = connection - .transaction(|txn| { - Box::pin(async move { - match command { - DatabaseCommand::CreateDb(args) => create_databases(args, txn).await, - DatabaseCommand::DropDb(args) => drop_databases(args, txn).await, - DatabaseCommand::ListDb(args) => list_databases(args, txn).await, - DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await, - DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await, - } - }) - }) - .await; - - close_database_connection(connection).await; - - result + server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + match command { + DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await, + DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await, + DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await, + DatabaseCommand::ShowDbPrivs(args) => { + show_database_privileges(args, server_connection).await + } + DatabaseCommand::EditDbPrivs(_args) => todo!(), + } } async fn create_databases( args: DatabaseCreateArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.name.is_empty() { anyhow::bail!("No database names provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::CreateDatabases(args.name.clone()); + server_connection.send(message).await?; - for name in args.name { - // TODO: This can be optimized by fetching all the database privileges in one query. - if let Err(e) = create_database(&name, connection).await { - eprintln!("Failed to create database '{}': {}", name, e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("Database '{}' created.", name); + let result = match server_connection.next().await { + Some(Ok(Response::CreateDatabases(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (database_name, result) in result.iter() { + match result { + Ok(_) => println!("Database '{}' created.", database_name), + Err(err) => { + eprintln!("{:?}", err); + eprintln!("Skipping..."); + } } } - Ok(result) + Ok(()) } async fn drop_databases( args: DatabaseDropArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.name.is_empty() { anyhow::bail!("No database names provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::DropDatabases(args.name.clone()); + server_connection.send(message).await?; - for name in args.name { - // TODO: This can be optimized by fetching all the database privileges in one query. - if let Err(e) = drop_database(&name, connection).await { - eprintln!("Failed to drop database '{}': {}", name, e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("Database '{}' dropped.", name); + let result = match server_connection.next().await { + Some(Ok(Response::DropDatabases(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (database_name, result) in result.iter() { + match result { + Ok(_) => println!("Database '{}' dropped.", database_name), + Err(err) => { + eprintln!("{:?}", err); + eprintln!("Skipping..."); + } } } - Ok(result) + Ok(()) } async fn list_databases( args: DatabaseListArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let databases = get_database_list(connection).await?; + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = Request::ListDatabases; + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::ListAllDatabases(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + let database_list = match result { + Ok(list) => list, + Err(err) => { + eprintln!("{:?}", err); + return Ok(()); + } + }; if args.json { - println!("{}", serde_json::to_string_pretty(&databases)?); - return Ok(CommandStatus::NoModificationsIntended); - } - - if databases.is_empty() { + println!("{}", serde_json::to_string_pretty(&database_list)?); + } else if database_list.is_empty() { println!("No databases to show."); } else { - for db in databases { + for db in database_list { println!("{}", db); } } - Ok(CommandStatus::NoModificationsIntended) + Ok(()) } async fn show_database_privileges( args: DatabaseShowPrivsArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let database_users_to_show = if args.name.is_empty() { - get_all_database_privileges(connection).await? + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = if args.name.is_empty() { + Request::ListPrivileges(None) } else { - // TODO: This can be optimized by fetching all the database privileges in one query. - let mut result = Vec::with_capacity(args.name.len()); - for name in args.name { - match get_database_privileges(&name, connection).await { - Ok(db) => result.extend(db), - Err(e) => { - eprintln!("Failed to show database '{}': {}", name, e); + Request::ListPrivileges(Some(args.name.clone())) + }; + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::ListPrivileges(databases))) => databases + .into_iter() + .filter_map(|(_db, result)| match result { + Ok(privileges) => Some(privileges), + Err(err) => { + eprintln!("{:?}", err); eprintln!("Skipping..."); + None } + }) + .flatten() + .collect::>(), + Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { + Ok(list) => list, + Err(err) => { + eprintln!("{:?}", err); + return Ok(()); } - } - result + }, + response => return erroneous_server_response(response), }; - if args.json { - println!("{}", serde_json::to_string_pretty(&database_users_to_show)?); - return Ok(CommandStatus::NoModificationsIntended); - } + server_connection.send(Request::Exit).await?; - if database_users_to_show.is_empty() { - println!("No database users to show."); + if args.json { + println!("{}", serde_json::to_string_pretty(&result)?); + } else if result.is_empty() { + println!("No database privileges to show."); } else { let mut table = Table::new(); table.add_row(Row::new( @@ -270,7 +302,7 @@ async fn show_database_privileges( .collect(), )); - for row in database_users_to_show { + for row in result { table.add_row(row![ row.db, row.user, @@ -290,101 +322,101 @@ async fn show_database_privileges( table.printstd(); } - Ok(CommandStatus::NoModificationsIntended) + Ok(()) } -pub async fn edit_privileges( - args: DatabaseEditPrivsArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let privilege_data = if let Some(name) = &args.name { - get_database_privileges(name, connection).await? - } else { - get_all_database_privileges(connection).await? - }; +// pub async fn edit_privileges( +// args: DatabaseEditPrivsArgs, +// connection: &mut MySqlConnection, +// ) -> anyhow::Result { +// let privilege_data = if let Some(name) = &args.name { +// get_database_privileges(name, connection).await? +// } else { +// get_all_database_privileges(connection).await? +// }; - // TODO: The data from args should not be absolute. - // In the current implementation, the user would need to - // provide all privileges for all users on all databases. - // The intended effect is to modify the privileges which have - // matching users and databases, as well as add any - // new db-user pairs. This makes it impossible to remove - // privileges, but that is an issue for another day. - let privileges_to_change = if !args.privs.is_empty() { - parse_privilege_tables_from_args(&args)? - } else { - edit_privileges_with_editor(&privilege_data)? - }; +// // TODO: The data from args should not be absolute. +// // In the current implementation, the user would need to +// // provide all privileges for all users on all databases. +// // The intended effect is to modify the privileges which have +// // matching users and databases, as well as add any +// // new db-user pairs. This makes it impossible to remove +// // privileges, but that is an issue for another day. +// let privileges_to_change = if !args.privs.is_empty() { +// parse_privilege_tables_from_args(&args)? +// } else { +// edit_privileges_with_editor(&privilege_data)? +// }; - for row in privileges_to_change.iter() { - if !user_exists(&row.user, connection).await? { - // TODO: allow user to return and correct their mistake - anyhow::bail!("User {} does not exist", row.user); - } - } +// for row in privileges_to_change.iter() { +// if !user_exists(&row.user, connection).await? { +// // TODO: allow user to return and correct their mistake +// anyhow::bail!("User {} does not exist", row.user); +// } +// } - let diffs = diff_privileges(&privilege_data, &privileges_to_change); +// let diffs = diff_privileges(&privilege_data, &privileges_to_change); - if diffs.is_empty() { - println!("No changes to make."); - return Ok(CommandStatus::NoModificationsNeeded); - } +// if diffs.is_empty() { +// println!("No changes to make."); +// return Ok(CommandStatus::NoModificationsNeeded); +// } - println!("The following changes will be made:\n"); - println!("{}", display_privilege_diffs(&diffs)); - if !args.yes - && !Confirm::new() - .with_prompt("Do you want to apply these changes?") - .default(false) - .show_default(true) - .interact()? - { - return Ok(CommandStatus::Cancelled); - } +// println!("The following changes will be made:\n"); +// println!("{}", display_privilege_diffs(&diffs)); +// if !args.yes +// && !Confirm::new() +// .with_prompt("Do you want to apply these changes?") +// .default(false) +// .show_default(true) +// .interact()? +// { +// return Ok(CommandStatus::Cancelled); +// } - apply_privilege_diffs(diffs, connection).await?; +// apply_privilege_diffs(diffs, connection).await?; - Ok(CommandStatus::SuccessfullyModified) -} +// Ok(CommandStatus::SuccessfullyModified) +// } -pub fn parse_privilege_tables_from_args( - args: &DatabaseEditPrivsArgs, -) -> anyhow::Result> { - debug_assert!(!args.privs.is_empty()); - let result = if let Some(name) = &args.name { - args.privs - .iter() - .map(|p| { - parse_privilege_table_cli_arg(&format!("{}:{}", name, &p)) - .context(format!("Failed parsing database privileges: `{}`", &p)) - }) - .collect::>>()? - } else { - args.privs - .iter() - .map(|p| { - parse_privilege_table_cli_arg(p) - .context(format!("Failed parsing database privileges: `{}`", &p)) - }) - .collect::>>()? - }; - Ok(result) -} +// pub fn parse_privilege_tables_from_args( +// args: &DatabaseEditPrivsArgs, +// ) -> anyhow::Result> { +// debug_assert!(!args.privs.is_empty()); +// let result = if let Some(name) = &args.name { +// args.privs +// .iter() +// .map(|p| { +// parse_privilege_table_cli_arg(&format!("{}:{}", name, &p)) +// .context(format!("Failed parsing database privileges: `{}`", &p)) +// }) +// .collect::>>()? +// } else { +// args.privs +// .iter() +// .map(|p| { +// parse_privilege_table_cli_arg(p) +// .context(format!("Failed parsing database privileges: `{}`", &p)) +// }) +// .collect::>>()? +// }; +// Ok(result) +// } -pub fn edit_privileges_with_editor( - privilege_data: &[DatabasePrivilegeRow], -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; +// pub fn edit_privileges_with_editor( +// privilege_data: &[DatabasePrivilegeRow], +// ) -> anyhow::Result> { +// let unix_user = get_current_unix_user()?; - let editor_content = - generate_editor_content_from_privilege_data(privilege_data, &unix_user.name); +// let editor_content = +// generate_editor_content_from_privilege_data(privilege_data, &unix_user.name); - // TODO: handle errors better here - let result = Editor::new() - .extension("tsv") - .edit(&editor_content)? - .unwrap(); +// // TODO: handle errors better here +// let result = Editor::new() +// .extension("tsv") +// .edit(&editor_content)? +// .unwrap(); - parse_privilege_data_from_editor_content(result) - .context("Could not parse privilege data from editor") -} +// parse_privilege_data_from_editor_content(result) +// .context("Could not parse privilege data from editor") +// } diff --git a/src/cli/user_command.rs b/src/cli/user_command.rs index fd79acc..b5cbe4f 100644 --- a/src/cli/user_command.rs +++ b/src/cli/user_command.rs @@ -1,27 +1,21 @@ -use std::collections::BTreeMap; -use std::vec; - use anyhow::Context; use clap::Parser; use dialoguer::{Confirm, Password}; -use prettytable::Table; -use serde_json::json; -use sqlx::{Connection, MySqlConnection}; +use futures_util::{SinkExt, StreamExt}; -use crate::core::{ - common::{close_database_connection, get_current_unix_user, CommandStatus}, - database_operations::*, - user_operations::*, -}; +use crate::server::protocol::ClientToServerMessageStream; +use crate::server::{Request, Response}; -#[derive(Parser)] +use super::common::erroneous_server_response; + +#[derive(Parser, Debug, Clone)] pub struct UserArgs { #[clap(subcommand)] subcmd: UserCommand, } #[allow(clippy::enum_variant_names)] -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub enum UserCommand { /// Create one or more users #[command()] @@ -50,7 +44,7 @@ pub enum UserCommand { UnlockUser(UserUnlockArgs), } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserCreateArgs { #[arg(num_args = 1..)] username: Vec, @@ -60,13 +54,13 @@ pub struct UserCreateArgs { no_password: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserDeleteArgs { #[arg(num_args = 1..)] username: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserPasswdArgs { username: String, @@ -74,7 +68,7 @@ pub struct UserPasswdArgs { password_file: Option, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserShowArgs { #[arg(num_args = 0..)] username: Vec, @@ -83,13 +77,13 @@ pub struct UserShowArgs { json: bool, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserLockArgs { #[arg(num_args = 1..)] username: Vec, } -#[derive(Parser)] +#[derive(Parser, Debug, Clone)] pub struct UserUnlockArgs { #[arg(num_args = 1..)] username: Vec, @@ -97,48 +91,47 @@ pub struct UserUnlockArgs { pub async fn handle_command( command: UserCommand, - mut connection: MySqlConnection, -) -> anyhow::Result { - let result = connection - .transaction(|txn| { - Box::pin(async move { - match command { - UserCommand::CreateUser(args) => create_users(args, txn).await, - UserCommand::DropUser(args) => drop_users(args, txn).await, - UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await, - UserCommand::ShowUser(args) => show_users(args, txn).await, - UserCommand::LockUser(args) => lock_users(args, txn).await, - UserCommand::UnlockUser(args) => unlock_users(args, txn).await, - } - }) - }) - .await; - - close_database_connection(connection).await; - - result + server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + match command { + UserCommand::CreateUser(args) => create_users(args, server_connection).await, + UserCommand::DropUser(args) => drop_users(args, server_connection).await, + UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await, + UserCommand::ShowUser(args) => show_users(args, server_connection).await, + UserCommand::LockUser(args) => lock_users(args, server_connection).await, + UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await, + } } async fn create_users( args: UserCreateArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::CreateUsers(args.username.clone()); + // TODO: better error handling + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } - for username in args.username { - if let Err(e) = create_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping...\n"); - result = CommandStatus::PartiallySuccessfullyModified; - continue; - } else { - println!("User '{}' created.", username); - } + let result = match server_connection.next().await { + Some(Ok(Response::CreateUsers(result))) => result, + response => return erroneous_server_response(response), + }; + let successfully_created_users = result + .iter() + .filter_map(|(username, result)| match result { + Ok(_) => Some(username), + Err(_) => None, + }) + .collect::>(); + + for username in successfully_created_users { if !args.no_password && Confirm::new() .with_prompt(format!( @@ -147,41 +140,65 @@ async fn create_users( )) .interact()? { - change_password_for_user( - UserPasswdArgs { - username, - password_file: None, + let password = read_password_from_stdin_with_double_check(username)?; + let message = Request::PasswdUser(username.clone(), password); + + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + match server_connection.next().await { + Some(Ok(Response::PasswdUser(result))) => match result { + Ok(_) => println!("Password set for user '{}'", username), + Err(e) => { + eprintln!("{:?}", e); + eprintln!("Skipping...\n"); + } }, - connection, - ) - .await?; + response => return erroneous_server_response(response), + } } - println!(); } - Ok(result) + + server_connection.send(Request::Exit).await?; + + Ok(()) } async fn drop_users( args: UserDeleteArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::DropUsers(args.username.clone()); - for username in args.username { - if let Err(e) = delete_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("User '{}' dropped.", username); + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + let result = match server_connection.next().await { + Some(Ok(Response::DropUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (username, result) in result.iter() { + match result { + Ok(_) => println!("User '{}' dropped.", username), + Err(err) => { + eprintln!("{:?}", err); + eprintln!("Skipping..."); + } } } - Ok(result) + Ok(()) } pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result { @@ -195,15 +212,10 @@ pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Res .map_err(Into::into) } -async fn change_password_for_user( +async fn passwd_user( args: UserPasswdArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - // NOTE: although this also is checked in `set_password_for_database_user`, we check it here - // to provide a more natural order of error messages. - let unix_user = get_current_unix_user()?; - validate_user_name(&args.username, &unix_user)?; - + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { let password = if let Some(password_file) = args.password_file { std::fs::read_to_string(password_file) .context("Failed to read password file")? @@ -213,129 +225,170 @@ async fn change_password_for_user( read_password_from_stdin_with_double_check(&args.username)? }; - set_password_for_database_user(&args.username, &password, connection).await?; + let message = Request::PasswdUser(args.username.clone(), password); - Ok(CommandStatus::SuccessfullyModified) + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + let result = match server_connection.next().await { + Some(Ok(Response::PasswdUser(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + match result { + Ok(_) => println!("Password set for user '{}'", args.username), + Err(e) => { + eprintln!("{:?}", e); + eprintln!("Skipping...\n"); + } + }; + + Ok(()) } async fn show_users( args: UserShowArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { - let unix_user = get_current_unix_user()?; - - let users = if args.username.is_empty() { - get_all_database_users_for_unix_user(&unix_user, connection).await? + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + let message = if args.username.is_empty() { + Request::ListUsers(None) } else { - let mut result = vec![]; - for username in args.username { - if let Err(e) = validate_user_name(&username, &unix_user) { - eprintln!("{}", e); - eprintln!("Skipping..."); - continue; - } - - let user = get_database_user_for_user(&username, connection).await?; - if let Some(user) = user { - result.push(user); - } else { - eprintln!("User not found: {}", username); - } - } - result + Request::ListUsers(Some(args.username.clone())) }; - let mut user_databases: BTreeMap> = BTreeMap::new(); - for user in users.iter() { - user_databases.insert( - user.user.clone(), - get_databases_where_user_has_privileges(&user.user, connection).await?, - ); + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); } - if args.json { - let users_json = users + let users = match server_connection.next().await { + Some(Ok(Response::ListUsers(users))) => users .into_iter() - .map(|user| { - json!({ - "user": user.user, - "has_password": user.has_password, - "is_locked": user.is_locked, - "databases": user_databases.get(&user.user).unwrap_or(&vec![]), - }) + .filter_map(|(_username, result)| match result { + Ok(user) => Some(user), + Err(err) => { + eprintln!("{:?}", err); + eprintln!("Skipping..."); + None + } }) - .collect::(); + .collect::>(), + Some(Ok(Response::ListAllUsers(users))) => match users { + Ok(users) => users, + Err(err) => { + eprintln!("{:?}", err); + // TODO: close connection + return Ok(()); + } + }, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + // TODO: print erroneous users + // for user in users.iter() + + // TODO: print databases where user has privileges + if args.json { println!( "{}", - serde_json::to_string_pretty(&users_json) - .context("Failed to serialize users to JSON")? + serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")? ); } else if users.is_empty() { - println!("No users found."); + println!("No users to show."); } else { - let mut table = Table::new(); + let mut table = prettytable::Table::new(); table.add_row(row![ "User", "Password is set", "Locked", - "Databases where user has privileges" + // "Databases where user has privileges" ]); for user in users { table.add_row(row![ user.user, user.has_password, user.is_locked, - user_databases.get(&user.user).unwrap_or(&vec![]).join("\n") + // user.databases.join("\n") ]); } table.printstd(); } - Ok(CommandStatus::NoModificationsIntended) + Ok(()) } async fn lock_users( args: UserLockArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::LockUsers(args.username.clone()); - for username in args.username { - if let Err(e) = lock_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("User '{}' locked.", username); + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + let result = match server_connection.next().await { + Some(Ok(Response::LockUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (username, result) in result.iter() { + match result { + Ok(_) => println!("User '{}' locked.", username), + Err(err) => { + eprintln!("{:?}", err); + eprintln!("Skipping..."); + } } } - Ok(result) + Ok(()) } async fn unlock_users( args: UserUnlockArgs, - connection: &mut MySqlConnection, -) -> anyhow::Result { + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { if args.username.is_empty() { anyhow::bail!("No usernames provided"); } - let mut result = CommandStatus::SuccessfullyModified; + let message = Request::UnlockUsers(args.username.clone()); - for username in args.username { - if let Err(e) = unlock_database_user(&username, connection).await { - eprintln!("{}", e); - eprintln!("Skipping..."); - result = CommandStatus::PartiallySuccessfullyModified; - } else { - println!("User '{}' unlocked.", username); + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + let result = match server_connection.next().await { + Some(Ok(Response::UnlockUsers(result))) => result, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + for (username, result) in result.iter() { + match result { + Ok(_) => println!("User '{}' unlocked.", username), + Err(err) => { + eprintln!("{:?}", err); + eprintln!("Skipping..."); + } } } - Ok(result) + Ok(()) } diff --git a/src/core.rs b/src/core.rs index aa51dca..6ba17f0 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,5 +1,2 @@ pub mod common; -pub mod config; -pub mod database_operations; -pub mod database_privilege_operations; -pub mod user_operations; +pub mod database_privileges; diff --git a/src/core/common.rs b/src/core/common.rs index 9beb916..8f40f5c 100644 --- a/src/core/common.rs +++ b/src/core/common.rs @@ -1,272 +1,3 @@ -use anyhow::Context; -use indoc::indoc; -use itertools::Itertools; -use nix::unistd::{getuid, Group, User}; -use sqlx::{Connection, MySqlConnection}; - -#[cfg(not(target_os = "macos"))] -use std::ffi::CString; - -/// Report the result status of a command. -/// This is used to display a status message to the user. -pub enum CommandStatus { - /// The command was successful, - /// and made modification to the database. - SuccessfullyModified, - - /// The command was mostly successful, - /// and modifications have been made to the database. - /// However, some of the requested modifications failed. - PartiallySuccessfullyModified, - - /// The command was successful, - /// but no modifications were needed. - NoModificationsNeeded, - - /// The command was successful, - /// and made no modification to the database. - NoModificationsIntended, - - /// The command was cancelled, either through a dialog or a signal. - /// No modifications have been made to the database. - Cancelled, -} - -pub fn get_current_unix_user() -> anyhow::Result { - User::from_uid(getuid()) - .context("Failed to look up your UNIX username") - .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))) -} - -#[cfg(target_os = "macos")] -pub fn get_unix_groups(_user: &User) -> anyhow::Result> { - // Return an empty list on macOS since there is no `getgrouplist` function - Ok(vec![]) -} - -#[cfg(not(target_os = "macos"))] -pub fn get_unix_groups(user: &User) -> anyhow::Result> { - let user_cstr = - CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?; - let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)? - .iter() - .filter_map(|gid| match Group::from_gid(*gid) { - Ok(Some(group)) => Some(group), - Ok(None) => None, - Err(e) => { - log::warn!( - "Failed to look up group with GID {}: {}\nIgnoring...", - gid, - e - ); - None - } - }) - .collect::>(); - - Ok(groups) -} - -/// This function creates a regex that matches items (users, databases) -/// that belong to the user or any of the user's groups. -pub fn create_user_group_matching_regex(user: &User) -> String { - let groups = get_unix_groups(user).unwrap_or_default(); - - if groups.is_empty() { - format!("{}(_.+)?", user.name) - } else { - format!( - "({}|{})(_.+)?", - user.name, - groups - .iter() - .map(|g| g.name.as_str()) - .collect::>() - .join("|") - ) - } -} - -/// This enum is used to differentiate between database and user operations. -/// Their output are very similar, but there are slight differences in the words used. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum DbOrUser { - Database, - User, -} - -impl DbOrUser { - pub fn lowercased(&self) -> String { - match self { - DbOrUser::Database => "database".to_string(), - DbOrUser::User => "user".to_string(), - } - } - - pub fn capitalized(&self) -> String { - match self { - DbOrUser::Database => "Database".to_string(), - DbOrUser::User => "User".to_string(), - } - } -} - -#[derive(Debug, PartialEq, Eq)] -pub enum NameValidationResult { - Valid, - EmptyString, - InvalidCharacters, - TooLong, -} - -pub fn validate_name(name: &str) -> NameValidationResult { - if name.is_empty() { - NameValidationResult::EmptyString - } else if name.len() > 64 { - NameValidationResult::TooLong - } else if !name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') - { - NameValidationResult::InvalidCharacters - } else { - NameValidationResult::Valid - } -} - -pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> { - match validate_name(name) { - NameValidationResult::Valid => Ok(()), - NameValidationResult::EmptyString => { - anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized()) - } - NameValidationResult::TooLong => anyhow::bail!( - "{} is too long. Maximum length is 64 characters.", - db_or_user.capitalized() - ), - NameValidationResult::InvalidCharacters => anyhow::bail!( - indoc! {r#" - Invalid characters in {} name: '{}' - - Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. - "#}, - db_or_user.lowercased(), - name - ), - } -} - -#[derive(Debug, PartialEq, Eq)] -pub enum OwnerValidationResult { - // The name is valid and matches one of the given prefixes - Match, - - // The name is valid, but none of the given prefixes matched the name - NoMatch, - - // The name is empty, which is invalid - StringEmpty, - - // The name is in the format "_", which is invalid - MissingPrefix, - - // The name is in the format "_", which is invalid - MissingPostfix, -} - -/// Core logic for validating the ownership of a database name. -/// This function checks if the given name matches any of the given prefixes. -/// These prefixes will in most cases be the user's unix username and any -/// unix groups the user is a member of. -pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult { - if name.is_empty() { - return OwnerValidationResult::StringEmpty; - } - - if name.starts_with('_') { - return OwnerValidationResult::MissingPrefix; - } - - let (prefix, _) = match name.split_once('_') { - Some(pair) => pair, - None => return OwnerValidationResult::MissingPostfix, - }; - - if prefixes.iter().any(|g| g == prefix) { - OwnerValidationResult::Match - } else { - OwnerValidationResult::NoMatch - } -} - -/// Validate the ownership of a database name or database user name. -/// This function takes the name of a database or user and a unix user, -/// for which it fetches the user's groups. It then checks if the name -/// is prefixed with the user's username or any of the user's groups. -pub fn validate_ownership_or_error<'a>( - name: &'a str, - user: &User, - db_or_user: DbOrUser, -) -> anyhow::Result<&'a str> { - let user_groups = get_unix_groups(user)?; - let prefixes = std::iter::once(user.name.clone()) - .chain(user_groups.iter().map(|g| g.name.clone())) - .collect::>(); - - match validate_ownership_by_prefixes(name, &prefixes) { - OwnerValidationResult::Match => Ok(name), - OwnerValidationResult::NoMatch => { - anyhow::bail!( - indoc! {r#" - Invalid {} name prefix: '{}' does not match your username or any of your groups. - Are you sure you are allowed to create {} names with this prefix? - - Allowed prefixes: - - {} - {} - "#}, - db_or_user.lowercased(), - name, - db_or_user.lowercased(), - user.name, - user_groups - .iter() - .filter(|g| g.name != user.name) - .map(|g| format!(" - {}", g.name)) - .sorted() - .join("\n"), - ); - } - _ => anyhow::bail!( - "'{}' is not a valid {} name.", - name, - db_or_user.lowercased() - ), - } -} - -/// Gracefully close a MySQL connection. -pub async fn close_database_connection(connection: MySqlConnection) { - if let Err(e) = connection - .close() - .await - .context("Failed to close connection properly") - { - eprintln!("{}", e); - eprintln!("Ignoring..."); - } -} - -#[inline] -pub fn quote_literal(s: &str) -> String { - format!("'{}'", s.replace('\'', r"\'")) -} - -#[inline] -pub fn quote_identifier(s: &str) -> String { - format!("`{}`", s.replace('`', r"\`")) -} - #[inline] pub(crate) fn yn(b: bool) -> &'static str { if b { @@ -303,94 +34,4 @@ mod test { assert_eq!(rev_yn("n"), Some(false)); assert_eq!(rev_yn("X"), None); } - - #[test] - fn test_quote_literal() { - let payload = "' OR 1=1 --"; - assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#); - } - - #[test] - fn test_quote_identifier() { - let payload = "` OR 1=1 --"; - assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#); - } - - #[test] - fn test_validate_name() { - assert_eq!(validate_name(""), NameValidationResult::EmptyString); - assert_eq!( - validate_name("abcdefghijklmnopqrstuvwxyz"), - NameValidationResult::Valid - ); - assert_eq!( - validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), - NameValidationResult::Valid - ); - assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid); - - for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() { - assert_eq!( - validate_name(&c.to_string()), - NameValidationResult::InvalidCharacters - ); - } - - assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid); - - assert_eq!( - validate_name(&"a".repeat(65)), - NameValidationResult::TooLong - ); - } - - #[test] - fn test_validate_owner_by_prefixes() { - let prefixes = vec!["user".to_string(), "group".to_string()]; - - assert_eq!( - validate_ownership_by_prefixes("", &prefixes), - OwnerValidationResult::StringEmpty - ); - - assert_eq!( - validate_ownership_by_prefixes("user", &prefixes), - OwnerValidationResult::MissingPostfix - ); - assert_eq!( - validate_ownership_by_prefixes("something", &prefixes), - OwnerValidationResult::MissingPostfix - ); - assert_eq!( - validate_ownership_by_prefixes("user-testdb", &prefixes), - OwnerValidationResult::MissingPostfix - ); - - assert_eq!( - validate_ownership_by_prefixes("_testdb", &prefixes), - OwnerValidationResult::MissingPrefix - ); - - assert_eq!( - validate_ownership_by_prefixes("user_testdb", &prefixes), - OwnerValidationResult::Match - ); - assert_eq!( - validate_ownership_by_prefixes("group_testdb", &prefixes), - OwnerValidationResult::Match - ); - assert_eq!( - validate_ownership_by_prefixes("group_test_db", &prefixes), - OwnerValidationResult::Match - ); - assert_eq!( - validate_ownership_by_prefixes("group_test-db", &prefixes), - OwnerValidationResult::Match - ); - - assert_eq!( - validate_ownership_by_prefixes("nonexistent_testdb", &prefixes), - OwnerValidationResult::NoMatch - ); - } } diff --git a/src/core/database_operations.rs b/src/core/database_operations.rs deleted file mode 100644 index 9cbc3c6..0000000 --- a/src/core/database_operations.rs +++ /dev/null @@ -1,120 +0,0 @@ -use anyhow::Context; -use indoc::formatdoc; -use itertools::Itertools; -use nix::unistd::User; -use sqlx::{prelude::*, MySqlConnection}; - -use crate::core::{ - common::{ - create_user_group_matching_regex, get_current_unix_user, quote_identifier, - validate_name_or_error, validate_ownership_or_error, DbOrUser, - }, - database_privilege_operations::DATABASE_PRIVILEGE_FIELDS, -}; - -pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { - let user = get_current_unix_user()?; - validate_database_name(name, &user)?; - - // NOTE: see the note about SQL injections in `validate_owner_of_database_name` - sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name))) - .execute(connection) - .await - .map_err(|e| { - if e.to_string().contains("database exists") { - anyhow::anyhow!("Database '{}' already exists", name) - } else { - e.into() - } - })?; - - Ok(()) -} - -pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> { - let user = get_current_unix_user()?; - validate_database_name(name, &user)?; - - // NOTE: see the note about SQL injections in `validate_owner_of_database_name` - sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name))) - .execute(connection) - .await - .map_err(|e| { - if e.to_string().contains("doesn't exist") { - anyhow::anyhow!("Database '{}' does not exist", name) - } else { - e.into() - } - })?; - - Ok(()) -} - -pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - - let databases: Vec = sqlx::query( - r#" - SELECT `SCHEMA_NAME` AS `database` - FROM `information_schema`.`SCHEMATA` - WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') - AND `SCHEMA_NAME` REGEXP ? - "#, - ) - .bind(create_user_group_matching_regex(&unix_user)) - .fetch_all(connection) - .await - .and_then(|row| { - row.into_iter() - .map(|row| row.try_get::("database")) - .collect::>() - }) - .context(format!( - "Failed to get databases for user '{}'", - unix_user.name - ))?; - - Ok(databases) -} - -pub async fn get_databases_where_user_has_privileges( - username: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let result = sqlx::query( - formatdoc!( - r#" - SELECT `db` AS `database` - FROM `db` - WHERE `user` = ? - AND ({}) - "#, - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{}` = 'Y'", field)) - .join(" OR "), - ) - .as_str(), - ) - .bind(username) - .fetch_all(connection) - .await? - .into_iter() - .map(|databases| databases.try_get::("database").unwrap()) - .collect(); - - Ok(result) -} - -/// NOTE: It is very critical that this function validates the database name -/// properly. MySQL does not seem to allow for prepared statements, binding -/// the database name as a parameter to the query. This means that we have -/// to validate the database name ourselves to prevent SQL injection. -pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> { - validate_name_or_error(name, DbOrUser::Database) - .context(format!("Invalid database name: '{}'", name))?; - validate_ownership_or_error(name, user, DbOrUser::Database) - .context(format!("Invalid database name: '{}'", name))?; - - Ok(()) -} diff --git a/src/core/database_privilege_operations.rs b/src/core/database_privileges.rs similarity index 68% rename from src/core/database_privilege_operations.rs rename to src/core/database_privileges.rs index d5cf850..d9d5111 100644 --- a/src/core/database_privilege_operations.rs +++ b/src/core/database_privileges.rs @@ -1,53 +1,14 @@ -//! Database privilege operations -//! -//! This module contains functions for querying, modifying, -//! displaying and comparing database privileges. -//! -//! A lot of the complexity comes from two core components: -//! -//! - The privilege editor that needs to be able to print -//! an editable table of privileges and reparse the content -//! after the user has made manual changes. -//! -//! - The comparison functionality that tells the user what -//! changes will be made when applying a set of changes -//! to the list of database privileges. - -use std::collections::{BTreeSet, HashMap}; - use anyhow::{anyhow, Context}; -use indoc::indoc; use itertools::Itertools; use prettytable::Table; use serde::{Deserialize, Serialize}; -use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; +use std::collections::{BTreeSet, HashMap}; -use crate::core::{ - common::{ - create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, - }, - database_operations::validate_database_name, +use super::common::{rev_yn, yn}; +use crate::server::database_privilege_operations::{ + DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS, }; -/// This is the list of fields that are used to fetch the db + user + privileges -/// from the `db` table in the database. If you need to add or remove privilege -/// fields, this is a good place to start. -pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ - "db", - "user", - "select_priv", - "insert_priv", - "update_priv", - "delete_priv", - "create_priv", - "drop_priv", - "alter_priv", - "index_priv", - "create_tmp_table_priv", - "lock_tables_priv", - "references_priv", -]; - pub fn db_priv_field_human_readable_name(name: &str) -> String { match name { "db" => "Database".to_owned(), @@ -67,162 +28,24 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String { } } -/// This struct represents the set of privileges for a single user on a single database. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] -pub struct DatabasePrivilegeRow { - pub db: String, - pub user: String, - pub select_priv: bool, - pub insert_priv: bool, - pub update_priv: bool, - pub delete_priv: bool, - pub create_priv: bool, - pub drop_priv: bool, - pub alter_priv: bool, - pub index_priv: bool, - pub create_tmp_table_priv: bool, - pub lock_tables_priv: bool, - pub references_priv: bool, -} +pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { + debug_assert!(row1.db == row2.db && row1.user == row2.user); -impl DatabasePrivilegeRow { - pub fn empty(db: &str, user: &str) -> Self { - Self { - db: db.to_owned(), - user: user.to_owned(), - select_priv: false, - insert_priv: false, - update_priv: false, - delete_priv: false, - create_priv: false, - drop_priv: false, - alter_priv: false, - index_priv: false, - create_tmp_table_priv: false, - lock_tables_priv: false, - references_priv: false, - } + DatabasePrivilegeRowDiff { + db: row1.db.clone(), + user: row1.user.clone(), + diff: DATABASE_PRIVILEGE_FIELDS + .into_iter() + .skip(2) + .filter_map(|field| { + DatabasePrivilegeChange::new( + row1.get_privilege_by_name(field), + row2.get_privilege_by_name(field), + field, + ) + }) + .collect(), } - - pub fn get_privilege_by_name(&self, name: &str) -> bool { - match name { - "select_priv" => self.select_priv, - "insert_priv" => self.insert_priv, - "update_priv" => self.update_priv, - "delete_priv" => self.delete_priv, - "create_priv" => self.create_priv, - "drop_priv" => self.drop_priv, - "alter_priv" => self.alter_priv, - "index_priv" => self.index_priv, - "create_tmp_table_priv" => self.create_tmp_table_priv, - "lock_tables_priv" => self.lock_tables_priv, - "references_priv" => self.references_priv, - _ => false, - } - } - - pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { - debug_assert!(self.db == other.db && self.user == other.user); - - DatabasePrivilegeRowDiff { - db: self.db.clone(), - user: self.user.clone(), - diff: DATABASE_PRIVILEGE_FIELDS - .into_iter() - .skip(2) - .filter_map(|field| { - DatabasePrivilegeChange::new( - self.get_privilege_by_name(field), - other.get_privilege_by_name(field), - field, - ) - }) - .collect(), - } - } -} - -#[inline] -fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result { - let field = DATABASE_PRIVILEGE_FIELDS[position]; - let value = row.try_get(position)?; - match rev_yn(value) { - Some(val) => Ok(val), - _ => { - log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value); - Ok(false) - } - } -} - -impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { - fn from_row(row: &MySqlRow) -> Result { - Ok(Self { - db: row.try_get("db")?, - user: row.try_get("user")?, - select_priv: get_mysql_row_priv_field(row, 2)?, - insert_priv: get_mysql_row_priv_field(row, 3)?, - update_priv: get_mysql_row_priv_field(row, 4)?, - delete_priv: get_mysql_row_priv_field(row, 5)?, - create_priv: get_mysql_row_priv_field(row, 6)?, - drop_priv: get_mysql_row_priv_field(row, 7)?, - alter_priv: get_mysql_row_priv_field(row, 8)?, - index_priv: get_mysql_row_priv_field(row, 9)?, - create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?, - lock_tables_priv: get_mysql_row_priv_field(row, 11)?, - references_priv: get_mysql_row_priv_field(row, 12)?, - }) - } -} - -/// Get all users + privileges for a single database. -pub async fn get_database_privileges( - database_name: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - validate_database_name(database_name, &unix_user)?; - - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( - "SELECT {} FROM `db` WHERE `db` = ?", - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| quote_identifier(field)) - .join(","), - )) - .bind(database_name) - .fetch_all(connection) - .await - .context("Failed to show database")?; - - Ok(result) -} - -/// Get all database + user + privileges pairs that are owned by the current user. -pub async fn get_all_database_privileges( - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( - indoc! {r#" - SELECT {} FROM `db` WHERE `db` IN - (SELECT DISTINCT `SCHEMA_NAME` AS `database` - FROM `information_schema`.`SCHEMATA` - WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') - AND `SCHEMA_NAME` REGEXP ?) - "#}, - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","), - )) - .bind(create_user_group_matching_regex(&unix_user)) - .fetch_all(connection) - .await - .context("Failed to show databases")?; - - Ok(result) } /*************************/ @@ -578,7 +401,7 @@ pub fn parse_privilege_data_from_editor_content( /// instances of privilege sets for a single user on a single database. /// /// The `User` and `Database` are the same for both instances. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub struct DatabasePrivilegeRowDiff { pub db: String, pub user: String, @@ -586,7 +409,7 @@ pub struct DatabasePrivilegeRowDiff { } /// This enum represents a change for a single privilege. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub enum DatabasePrivilegeChange { YesToNo(String), NoToYes(String), @@ -603,13 +426,31 @@ impl DatabasePrivilegeChange { } /// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] pub enum DatabasePrivilegesDiff { New(DatabasePrivilegeRow), Modified(DatabasePrivilegeRowDiff), Deleted(DatabasePrivilegeRow), } +impl DatabasePrivilegesDiff { + pub fn get_database_name(&self) -> &str { + match self { + DatabasePrivilegesDiff::New(p) => &p.db, + DatabasePrivilegesDiff::Modified(p) => &p.db, + DatabasePrivilegesDiff::Deleted(p) => &p.db, + } + } + + pub fn get_user_name(&self) -> &str { + match self { + DatabasePrivilegesDiff::New(p) => &p.user, + DatabasePrivilegesDiff::Modified(p) => &p.user, + DatabasePrivilegesDiff::Deleted(p) => &p.user, + } + } +} + /// This function calculates the differences between two sets of database privileges. /// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or /// apply a set of privilege modifications to the database. @@ -633,7 +474,7 @@ pub fn diff_privileges( for p in to { if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { - let diff = old_p.diff(p); + let diff = diff(old_p, p); if !diff.diff.is_empty() { result.insert(DatabasePrivilegesDiff::Modified(diff)); } @@ -651,72 +492,6 @@ pub fn diff_privileges( result } -/// Uses the result of [`diff_privileges`] to modify privileges in the database. -pub async fn apply_privilege_diffs( - diffs: BTreeSet, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - for diff in diffs { - match diff { - DatabasePrivilegesDiff::New(p) => { - let tables = DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","); - - let question_marks = std::iter::repeat("?") - .take(DATABASE_PRIVILEGE_FIELDS.len()) - .join(","); - - sqlx::query( - format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), - ) - .bind(p.db) - .bind(p.user) - .bind(yn(p.select_priv)) - .bind(yn(p.insert_priv)) - .bind(yn(p.update_priv)) - .bind(yn(p.delete_priv)) - .bind(yn(p.create_priv)) - .bind(yn(p.drop_priv)) - .bind(yn(p.alter_priv)) - .bind(yn(p.index_priv)) - .bind(yn(p.create_tmp_table_priv)) - .bind(yn(p.lock_tables_priv)) - .bind(yn(p.references_priv)) - .execute(&mut *connection) - .await?; - } - DatabasePrivilegesDiff::Modified(p) => { - let tables = p - .diff - .iter() - .map(|diff| match diff { - DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name), - DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name), - }) - .join(","); - - sqlx::query( - format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), - ) - .bind(p.db) - .bind(p.user) - .execute(&mut *connection) - .await?; - } - DatabasePrivilegesDiff::Deleted(p) => { - sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") - .bind(p.db) - .bind(p.user) - .execute(&mut *connection) - .await?; - } - } - } - Ok(()) -} - fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { diff.diff .iter() @@ -742,9 +517,10 @@ pub fn display_privilege_diffs(diffs: &BTreeSet) -> Stri p.db, p.user, "(New user)\n".to_string() - + &display_privilege_cell( - &DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p) - ) + + &display_privilege_cell(&diff( + &DatabasePrivilegeRow::empty(&p.db, &p.user), + p + )) ]); } DatabasePrivilegesDiff::Modified(p) => { @@ -755,9 +531,10 @@ pub fn display_privilege_diffs(diffs: &BTreeSet) -> Stri p.db, p.user, "(All privileges removed)\n".to_string() - + &display_privilege_cell( - &p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user)) - ) + + &display_privilege_cell(&diff( + p, + &DatabasePrivilegeRow::empty(&p.db, &p.user) + )) ]); } } diff --git a/src/core/user_operations.rs b/src/core/user_operations.rs deleted file mode 100644 index 5d3ac8a..0000000 --- a/src/core/user_operations.rs +++ /dev/null @@ -1,249 +0,0 @@ -use anyhow::Context; -use nix::unistd::User; -use serde::{Deserialize, Serialize}; -use sqlx::{prelude::*, MySqlConnection}; - -use crate::core::common::{ - create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error, - validate_ownership_or_error, DbOrUser, -}; - -pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - let user_exists = sqlx::query( - r#" - SELECT EXISTS( - SELECT 1 - FROM `mysql`.`user` - WHERE `User` = ? - ) - "#, - ) - .bind(db_user) - .fetch_one(connection) - .await? - .get::(0); - - Ok(user_exists) -} - -pub async fn create_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' already exists", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -pub async fn delete_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -pub async fn set_password_for_database_user( - db_user: &str, - password: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = crate::core::common::get_current_unix_user()?; - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query( - format!( - "ALTER USER {}@'%' IDENTIFIED BY {}", - quote_literal(db_user), - quote_literal(password).as_str() - ) - .as_str(), - ) - .execute(connection) - .await?; - - Ok(()) -} - -async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - let is_locked = sqlx::query( - r#" - SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true' - FROM `mysql`.`global_priv` - WHERE `User` = ? - AND `Host` = '%' - "#, - ) - .bind(db_user) - .fetch_one(connection) - .await? - .get::(0); - - Ok(is_locked) -} - -pub async fn lock_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - if user_is_locked(db_user, connection).await? { - anyhow::bail!("User '{}' is already locked", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -pub async fn unlock_database_user( - db_user: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - let unix_user = get_current_unix_user()?; - - validate_user_name(db_user, &unix_user)?; - - if !user_exists(db_user, connection).await? { - anyhow::bail!("User '{}' does not exist", db_user); - } - - if !user_is_locked(db_user, connection).await? { - anyhow::bail!("User '{}' is already unlocked", db_user); - } - - // NOTE: see the note about SQL injections in `validate_ownership_of_user_name` - sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str()) - .execute(connection) - .await?; - - Ok(()) -} - -/// This struct contains information about a database user. -/// This can be extended if we need more information in the future. -#[derive(Debug, Clone, FromRow, Serialize, Deserialize)] -pub struct DatabaseUser { - #[sqlx(rename = "User")] - pub user: String, - - #[allow(dead_code)] - #[serde(skip)] - #[sqlx(rename = "Host")] - pub host: String, - - #[sqlx(rename = "has_password")] - pub has_password: bool, - - #[sqlx(rename = "is_locked")] - pub is_locked: bool, -} - -const DB_USER_SELECT_STATEMENT: &str = r#" -SELECT - `mysql`.`user`.`User`, - `mysql`.`user`.`Host`, - `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, - COALESCE( - JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), - 'false' - ) != 'false' AS `is_locked` -FROM `mysql`.`user` -JOIN `mysql`.`global_priv` ON - `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` - AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` -"#; - -/// This function fetches all database users that have a prefix matching the -/// unix username and group names of the given unix user. -pub async fn get_all_database_users_for_unix_user( - unix_user: &User, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let users = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), - ) - .bind(create_user_group_matching_regex(unix_user)) - .fetch_all(connection) - .await?; - - Ok(users) -} - -/// This function fetches a database user if it exists. -pub async fn get_database_user_for_user( - username: &str, - connection: &mut MySqlConnection, -) -> anyhow::Result> { - let user = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), - ) - .bind(username) - .fetch_optional(connection) - .await?; - - Ok(user) -} - -/// NOTE: It is very critical that this function validates the database name -/// properly. MySQL does not seem to allow for prepared statements, binding -/// the database name as a parameter to the query. This means that we have -/// to validate the database name ourselves to prevent SQL injection. -pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> { - validate_name_or_error(name, DbOrUser::User) - .context(format!("Invalid username: '{}'", name))?; - validate_ownership_or_error(name, user, DbOrUser::User) - .context(format!("Invalid username: '{}'", name))?; - - Ok(()) -} diff --git a/src/main.rs b/src/main.rs index 2622c50..89ba01f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,20 @@ #[macro_use] extern crate prettytable; -use core::common::CommandStatus; -#[cfg(feature = "mysql-admutils-compatibility")] use std::path::PathBuf; -#[cfg(feature = "mysql-admutils-compatibility")] -use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; - use clap::Parser; +use server::bootstrap::bootstrap_server_connection_and_drop_privileges; + +// use core::common::CommandStatus; +// #[cfg(feature = "mysql-admutils-compatibility")] +// use std::path::PathBuf; + +// #[cfg(feature = "mysql-admutils-compatibility")] +// use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; + +mod server; -mod authenticated_unix_socket; mod cli; mod core; @@ -22,21 +26,22 @@ struct Args { #[command(subcommand)] command: Command, - #[command(flatten)] - config_overrides: core::config::GlobalConfigArgs, + /// Path to the socket of the server, if it already exists. + #[arg(long, value_name = "HOST", global = true, hide_short_help = true)] + server_address: Option, #[cfg(feature = "tui")] #[arg(short, long, alias = "tui", global = true)] interactive: bool, } -/// Database administration tool for non-admin users to manage their own MySQL databases and users. -/// -/// This tool allows you to manage users and databases in MySQL. -/// -/// You are only allowed to manage databases and users that are prefixed with -/// either your username, or a group that you are a member of. -#[derive(Parser)] +// Database administration tool for non-admin users to manage their own MySQL databases and users. +// +// This tool allows you to manage users and databases in MySQL. +// +// You are only allowed to manage databases and users that are prefixed with +// either your username, or a group that you are a member of. +#[derive(Parser, Debug, Clone)] #[command(version, about, disable_help_subcommand = true)] enum Command { #[command(flatten)] @@ -44,57 +49,44 @@ enum Command { #[command(flatten)] User(cli::user_command::UserCommand), + + #[command(hide = true)] + Server(server::command::ServerArgs), } #[tokio::main(flavor = "current_thread")] async fn main() -> anyhow::Result<()> { env_logger::init(); - #[cfg(feature = "mysql-admutils-compatibility")] - { - let argv0 = std::env::args().next().and_then(|s| { - PathBuf::from(s) - .file_name() - .map(|s| s.to_string_lossy().to_string()) - }); + // #[cfg(feature = "mysql-admutils-compatibility")] + // { + // let argv0 = std::env::args().next().and_then(|s| { + // PathBuf::from(s) + // .file_name() + // .map(|s| s.to_string_lossy().to_string()) + // }); - match argv0.as_deref() { - Some("mysql-dbadm") => return mysql_dbadm::main().await, - Some("mysql-useradm") => return mysql_useradm::main().await, - _ => { /* fall through */ } - } - } + // match argv0.as_deref() { + // Some("mysql-dbadm") => return mysql_dbadm::main().await, + // Some("mysql-useradm") => return mysql_useradm::main().await, + // _ => { /* fall through */ } + // } + // } let args: Args = Args::parse(); - let config = core::config::get_config(args.config_overrides)?; - let connection = core::config::create_mysql_connection_from_config(config.mysql).await?; + match args.command { + Command::Server(ref command) => server::command::handle_command(command.to_owned()).await?, + _ => { /* fall through */ } + } - let result = match args.command { - Command::Db(command) => cli::database_command::handle_command(command, connection).await, - Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await, - }; + let server_stream = + bootstrap_server_connection_and_drop_privileges(args.server_address).await?; - match result { - Ok(CommandStatus::SuccessfullyModified) => { - println!("Modifications committed successfully"); - Ok(()) + match args.command { + Command::User(user_args) => { + cli::user_command::handle_command(user_args, server_stream).await } - Ok(CommandStatus::PartiallySuccessfullyModified) => { - println!("Some modifications committed successfully"); - Ok(()) - } - Ok(CommandStatus::NoModificationsNeeded) => { - println!("No modifications made"); - Ok(()) - } - Ok(CommandStatus::NoModificationsIntended) => { - /* Don't report anything */ - Ok(()) - } - Ok(CommandStatus::Cancelled) => { - println!("Command cancelled successfully"); - Ok(()) - } - Err(e) => Err(e), + Command::Db(db_args) => cli::database_command::handle_command(db_args, server_stream).await, + Command::Server(_) => unreachable!(), } } diff --git a/src/server.rs b/src/server.rs index 2943954..a762b6d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,12 @@ +pub mod bootstrap; +pub mod command; mod common; +pub mod config; mod database_operations; +pub mod database_privilege_operations; mod entrypoint; mod input_sanitization; -mod protocol; +pub mod protocol; mod user_operations; + +pub use protocol::{Request, Response}; diff --git a/src/server/bootstrap.rs b/src/server/bootstrap.rs new file mode 100644 index 0000000..865aafa --- /dev/null +++ b/src/server/bootstrap.rs @@ -0,0 +1,43 @@ +use std::{fs, path::PathBuf}; + +use tokio::net::UnixStream; + +use crate::server::protocol::create_client_to_server_message_stream; + +use super::config::DEFAULT_SOCKET_PATH; +use super::protocol::ClientToServerMessageStream; + +pub mod authenticated_unix_socket; + +// TODO: allow overriding which way to connect to the server + +pub async fn bootstrap_server_connection_and_drop_privileges<'a>( + socket_path: Option, +) -> anyhow::Result { + // If socket path explicitly provided, use or error + if let Some(socket_path) = socket_path { + let mut socket = UnixStream::connect(socket_path).await?; + authenticated_unix_socket::client_authenticate(&mut socket, None).await?; + let message_stream = create_client_to_server_message_stream(socket); + return Ok(message_stream); + } + + // If not, check if the default socket path exists + if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { + let mut socket = UnixStream::connect(DEFAULT_SOCKET_PATH).await?; + authenticated_unix_socket::client_authenticate(&mut socket, None).await?; + let message_stream = create_client_to_server_message_stream(socket); + return Ok(message_stream); + } + + // If not, check if we are suid, guid and can read the config file + // if { + // if so, create anonymous socket pair, spawn server, and drop privileges + // } + + // If not, error + + // Upon non suid-guid invocation, we will need to authenticate the socket. + + todo!() +} diff --git a/src/authenticated_unix_socket.rs b/src/server/bootstrap/authenticated_unix_socket.rs similarity index 84% rename from src/authenticated_unix_socket.rs rename to src/server/bootstrap/authenticated_unix_socket.rs index 0092b5f..9b1a235 100644 --- a/src/authenticated_unix_socket.rs +++ b/src/server/bootstrap/authenticated_unix_socket.rs @@ -34,6 +34,7 @@ use std::os::unix::io::AsRawFd; use std::path::PathBuf; use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination}; +use derive_more::derive::{Display, Error}; use futures::{SinkExt, StreamExt}; use nix::{sys::stat, unistd::Uid}; use rand::distributions::Alphanumeric; @@ -52,7 +53,7 @@ pub enum ClientRequest { Cancel, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)] pub enum ServerResponse { Authenticated, ChallengeDidNotMatch, @@ -61,7 +62,7 @@ pub enum ServerResponse { // TODO: wrap more data into the errors -#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] +#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)] pub enum ServerError { InvalidRequest, UnableToReadPermissionsFromAuthSocket, @@ -72,7 +73,7 @@ pub enum ServerError { InvalidChallenge, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Display, Error)] pub enum ClientError { UnableToConnectToServer, UnableToOpenAuthSocket, @@ -86,7 +87,7 @@ pub enum ClientError { ServerError(ServerError), } -async fn create_auth_socket(socket_addr: &str) -> Result { +async fn create_auth_socket(socket_addr: &PathBuf) -> Result { let auth_socket = UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?; @@ -109,11 +110,13 @@ type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDest // TODO: add timeout +// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR + const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock"; + pub async fn client_authenticate( normal_socket: &mut UnixStream, - #[cfg(not(test))] auth_socket_dir: Option, - #[cfg(test)] auth_socket_file: Option, + auth_socket_dir: Option, ) -> Result<(), ClientError> { let random_prefix: String = rand::thread_rng() .sample_iter(&Alphanumeric) @@ -123,32 +126,16 @@ pub async fn client_authenticate( let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME); - #[cfg(not(test))] - let auth_socket_address = match auth_socket_dir { - Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(), - None => std::env::temp_dir() - .join(socket_name) - .to_str() - .unwrap() - .to_string(), - }; - - #[cfg(test)] - let auth_socket_address = match auth_socket_file { - Some(file) => file.to_str().unwrap().to_string(), - None => std::env::temp_dir() - .join(socket_name) - .to_str() - .unwrap() - .to_string(), - }; + let auth_socket_address = auth_socket_dir + .unwrap_or(std::env::temp_dir()) + .join(socket_name); client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await } async fn client_authenticate_with_auth_socket_address( normal_socket: &mut UnixStream, - auth_socket_address: &str, + auth_socket_address: &PathBuf, ) -> Result<(), ClientError> { let auth_socket = create_auth_socket(auth_socket_address).await?; @@ -164,7 +151,7 @@ async fn client_authenticate_with_auth_socket_address( async fn client_authenticate_with_auth_socket( normal_socket: &mut UnixStream, auth_socket: UnixListener, - auth_socket_address: &str, + auth_socket_address: &PathBuf, ) -> Result<(), ClientError> { let challenge = rand::random::(); let uid = nix::unistd::getuid(); @@ -199,7 +186,7 @@ async fn client_authenticate_with_auth_socket( let client_hello = ClientRequest::Initialize { uid: uid.into(), challenge, - auth_socket: auth_socket_address.to_string(), + auth_socket: auth_socket_address.to_str().unwrap().to_owned(), }; normal_socket @@ -239,9 +226,13 @@ macro_rules! report_server_error_and_return { }}; } -async fn server_authenticate( +pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result { + _server_authenticate(normal_socket, None).await +} + +pub async fn _server_authenticate( normal_socket: &mut UnixStream, - #[cfg(test)] unix_user_uid: Option, + unix_user_uid: Option, ) -> Result { let mut normal_socket: ServerToClientStream = AsyncBincodeStream::from(normal_socket).for_async(); @@ -256,22 +247,15 @@ async fn server_authenticate( _ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest), }; - #[cfg(test)] let auth_socket_uid = match unix_user_uid { Some(uid) => uid, - None => report_server_error_and_return!( - normal_socket, - ServerError::UnableToReadPermissionsFromAuthSocket - ), - }; - - #[cfg(not(test))] - let auth_socket_uid = match stat::stat(auth_socket.as_str()) { - Ok(stat) => stat.st_uid, - Err(_err) => report_server_error_and_return!( - normal_socket, - ServerError::UnableToReadPermissionsFromAuthSocket - ), + None => match stat::stat(auth_socket.as_str()) { + Ok(stat) => stat.st_uid, + Err(_err) => report_server_error_and_return!( + normal_socket, + ServerError::UnableToReadPermissionsFromAuthSocket + ), + }, }; if uid != auth_socket_uid { @@ -324,10 +308,7 @@ mod test { let client_handle = tokio::spawn(async move { client_authenticate(&mut client, None).await }); - let server_handle = tokio::spawn(async move { - let uid = nix::unistd::getuid().into(); - server_authenticate(&mut server, Some(uid)).await - }); + let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); client_handle.await.unwrap().unwrap(); server_handle.await.unwrap().unwrap(); @@ -340,15 +321,12 @@ mod test { let client_handle = tokio::spawn(async move { client_authenticate_with_auth_socket_address( &mut client, - "/tmp/test_auth_socket_does_not_exist.sock", + &PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"), ) .await }); - let server_handle = tokio::spawn(async move { - let uid = nix::unistd::getuid().into(); - server_authenticate(&mut server, Some(uid)).await - }); + let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); client_handle.await.unwrap().unwrap(); server_handle.await.unwrap().unwrap(); @@ -365,7 +343,7 @@ mod test { let server_handle = tokio::spawn(async move { let uid: u32 = nix::unistd::getuid().into(); - let err = server_authenticate(&mut server, Some(uid + 1)).await; + let err = _server_authenticate(&mut server, Some(uid + 1)).await; assert_eq!(err, Err(ServerError::UidMismatch)); }); @@ -379,13 +357,19 @@ mod test { let socket_path = std::env::temp_dir().join("socket_to_snoop.sock"); let socket_path_clone = socket_path.clone(); - let client_handle = - tokio::spawn( - async move { client_authenticate(&mut client, Some(socket_path_clone)).await }, - ); + let client_handle = tokio::spawn(async move { + client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await + }); - while !socket_path.exists() { - sleep(std::time::Duration::from_millis(10)).await; + for i in 0..100 { + if socket_path.exists() { + break; + } + sleep(Duration::from_millis(10)).await; + + if i == 99 { + panic!("Socket not created after 1 second, assuming test failure"); + } } let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap(); @@ -409,10 +393,7 @@ mod test { sleep(Duration::from_millis(10)).await; - let server_handle = tokio::spawn(async move { - let uid: u32 = nix::unistd::getuid().into(); - server_authenticate(&mut server, Some(uid)).await - }); + let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await }); client_handle.await.unwrap().unwrap(); server_handle.await.unwrap().unwrap(); diff --git a/src/server/command.rs b/src/server/command.rs new file mode 100644 index 0000000..dd49c27 --- /dev/null +++ b/src/server/command.rs @@ -0,0 +1,161 @@ +use std::fs; +use std::os::fd::FromRawFd; +use std::path::PathBuf; + +use anyhow::Context; +use clap::Parser; +use sqlx::prelude::*; +use sqlx::MySqlConnection; +use tokio::io::AsyncWriteExt; +use tokio::net::UnixListener; +use tokio::net::UnixStream; + +use crate::server::config::get_config; +use crate::server::config::DEFAULT_SOCKET_PATH; + +use super::{ + bootstrap::authenticated_unix_socket, + common::UnixUser, + config::{create_mysql_connection_from_config, ServerConfig, ServerConfigArgs}, + entrypoint::handle_request, + protocol::create_server_to_client_message_stream, +}; + +#[derive(Parser, Debug, Clone)] +pub struct ServerArgs { + #[command(subcommand)] + subcmd: ServerCommand, + + #[command(flatten)] + config_overrides: ServerConfigArgs, +} + +#[derive(Parser, Debug, Clone)] +pub enum ServerCommand { + #[command()] + Listen, + + #[command()] + SocketActivate, + + #[command()] + StickyBitActivate, +} + +pub async fn handle_command(args: ServerArgs) -> anyhow::Result<()> { + let config = get_config(args.config_overrides)?; + + let mut db_connection = create_mysql_connection_from_config(config.mysql.clone()).await?; + + let result = match args.subcmd { + ServerCommand::Listen => listen(config, &mut db_connection).await, + ServerCommand::SocketActivate => socket_activate(config, &mut db_connection).await, + ServerCommand::StickyBitActivate => sticky_bit_activate(config, &mut db_connection).await, + }; + + if let Err(e) = &result { + eprintln!("{}", e); + } + + close_database_connection(db_connection).await; + + result +} + +/// Gracefully close a MySQL connection. +async fn close_database_connection(connection: MySqlConnection) { + if let Err(e) = connection + .close() + .await + .context("Failed to close connection properly") + { + eprintln!("{}", e); + eprintln!("Ignoring..."); + } +} + +async fn listen(_config: ServerConfig, db_connection: &mut MySqlConnection) -> anyhow::Result<()> { + // let mut db_connection = create_mysql_connection_from_config(config.mysql).await?; + // TODO: fetch from arguments or config + + let default_config_path = PathBuf::from(DEFAULT_SOCKET_PATH); + let parent_directory = default_config_path.parent().unwrap(); + if !parent_directory.exists() { + println!("Creating directory {:?}", parent_directory); + fs::create_dir_all(parent_directory)?; + } + println!("Listening on {:?}", DEFAULT_SOCKET_PATH); + match fs::remove_file(DEFAULT_SOCKET_PATH) { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} + Err(e) => return Err(e.into()), + } + + let listener = UnixListener::bind(DEFAULT_SOCKET_PATH)?; + + while let Ok((mut conn, _addr)) = listener.accept().await { + let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await { + Ok(uid) => uid, + Err(e) => { + eprintln!("Failed to authenticate client: {}", e); + conn.shutdown().await?; + continue; + } + }; + let unix_user = match UnixUser::from_uid(uid.into()) { + Ok(user) => user, + Err(e) => { + eprintln!("Failed to get UnixUser from uid: {}", e); + conn.shutdown().await?; + continue; + } + }; + let stream = create_server_to_client_message_stream(conn); + match handle_request(stream, &unix_user, db_connection).await { + Ok(_) => {} + Err(e) => { + eprintln!("Failed to run server: {}", e); + } + } + } + + Ok(()) +} + +async fn get_socket_from_systemd() -> anyhow::Result { + let fd = std::env::var("LISTEN_FDS") + .context("LISTEN_FDS not set, not running under systemd?")? + .parse::() + .context("Failed to parse LISTEN_FDS")?; + + if fd != 1 { + return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd)); + } + + let std_unix_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(fd) }; + let socket = tokio::net::UnixStream::from_std(std_unix_stream)?; + Ok(socket) +} + +async fn socket_activate( + _config: ServerConfig, + db_connection: &mut MySqlConnection, +) -> anyhow::Result<()> { + // TODO: allow getting socket path from other socket activation sources + let mut conn = get_socket_from_systemd().await?; + let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?; + let unix_user = UnixUser::from_uid(uid.into())?; + let stream = create_server_to_client_message_stream(conn); + handle_request(stream, &unix_user, db_connection).await?; + + Ok(()) +} + +async fn sticky_bit_activate( + _config: ServerConfig, + _db_connection: &mut MySqlConnection, +) -> anyhow::Result<()> { + // Is this a type of socket activation? + println!("Activating sticky bit"); + Ok(()) +} diff --git a/src/server/common.rs b/src/server/common.rs index bb74cd4..c0a47af 100644 --- a/src/server/common.rs +++ b/src/server/common.rs @@ -1,6 +1,5 @@ use anyhow::Context; use nix::unistd::{Group as LibcGroup, User as LibcUser}; -use sqlx::{Connection, MySqlConnection}; #[cfg(not(target_os = "macos"))] use std::ffi::CString; @@ -30,12 +29,6 @@ pub enum CommandStatus { Cancelled, } -// pub fn get_current_unix_user() -> anyhow::Result { -// User::from_uid(getuid()) -// .context("Failed to look up your UNIX username") -// .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))) -// } - pub struct UnixUser { pub username: String, pub uid: u32, @@ -99,15 +92,3 @@ pub fn create_user_group_matching_regex(user: &UnixUser) -> String { format!("({}|{})(_.+)?", user.username, user.groups.join("|")) } } - -/// Gracefully close a MySQL connection. -pub async fn close_database_connection(connection: MySqlConnection) { - if let Err(e) = connection - .close() - .await - .context("Failed to close connection properly") - { - eprintln!("{}", e); - eprintln!("Ignoring..."); - } -} diff --git a/src/core/config.rs b/src/server/config.rs similarity index 72% rename from src/core/config.rs rename to src/server/config.rs index 83ada90..0e2d83e 100644 --- a/src/core/config.rs +++ b/src/server/config.rs @@ -5,11 +5,16 @@ use clap::Parser; use serde::{Deserialize, Serialize}; use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection}; +pub const DEFAULT_CONFIG_PATH: &str = "/etc/mysqladm/config.toml"; +pub const DEFAULT_SOCKET_PATH: &str = "/run/mysqladm/mysqladm.sock"; +pub const DEFAULT_PORT: u16 = 3306; +pub const DEFAULT_TIMEOUT: u64 = 2; + // NOTE: this might look empty now, and the extra wrapping for the mysql // config seems unnecessary, but it will be useful later when we // add more configuration options. #[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Config { +pub struct ServerConfig { pub mysql: MysqlConfig, } @@ -23,49 +28,45 @@ pub struct MysqlConfig { pub timeout: Option, } -const DEFAULT_PORT: u16 = 3306; -const DEFAULT_TIMEOUT: u64 = 2; - -#[derive(Parser)] -pub struct GlobalConfigArgs { +#[derive(Parser, Debug, Clone)] +pub struct ServerConfigArgs { /// Path to the configuration file. #[arg( short, long, value_name = "PATH", global = true, - hide_short_help = true, - default_value = "/etc/mysqladm/config.toml" + default_value = DEFAULT_CONFIG_PATH, )] config_file: String, /// Hostname of the MySQL server. - #[arg(long, value_name = "HOST", global = true, hide_short_help = true)] + #[arg(long, value_name = "HOST", global = true)] mysql_host: Option, /// Port of the MySQL server. - #[arg(long, value_name = "PORT", global = true, hide_short_help = true)] + #[arg(long, value_name = "PORT", global = true)] mysql_port: Option, /// Username to use for the MySQL connection. - #[arg(long, value_name = "USER", global = true, hide_short_help = true)] + #[arg(long, value_name = "USER", global = true)] mysql_user: Option, /// Path to a file containing the MySQL password. - #[arg(long, value_name = "PATH", global = true, hide_short_help = true)] + #[arg(long, value_name = "PATH", global = true)] mysql_password_file: Option, /// Seconds to wait for the MySQL connection to be established. - #[arg(long, value_name = "SECONDS", global = true, hide_short_help = true)] + #[arg(long, value_name = "SECONDS", global = true)] mysql_connect_timeout: Option, } /// Use the arguments and whichever configuration file which might or might not /// be found and default values to determine the configuration for the program. -pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result { +pub fn get_config(args: ServerConfigArgs) -> anyhow::Result { let config_path = PathBuf::from(args.config_file); - let config: Config = fs::read_to_string(&config_path) + let config: ServerConfig = fs::read_to_string(&config_path) .context(format!( "Failed to read config file from {:?}", &config_path @@ -86,16 +87,14 @@ pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result { mysql.password.to_owned() }; - let mysql_config = MysqlConfig { - host: args.mysql_host.unwrap_or(mysql.host.to_owned()), - port: args.mysql_port.or(mysql.port), - username: args.mysql_user.unwrap_or(mysql.username.to_owned()), - password, - timeout: args.mysql_connect_timeout.or(mysql.timeout), - }; - - Ok(Config { - mysql: mysql_config, + Ok(ServerConfig { + mysql: MysqlConfig { + host: args.mysql_host.unwrap_or(mysql.host.to_owned()), + port: args.mysql_port.or(mysql.port), + username: args.mysql_user.unwrap_or(mysql.username.to_owned()), + password, + timeout: args.mysql_connect_timeout.or(mysql.timeout), + }, }) } diff --git a/src/server/database_operations.rs b/src/server/database_operations.rs index ada4189..00b96b8 100644 --- a/src/server/database_operations.rs +++ b/src/server/database_operations.rs @@ -12,13 +12,13 @@ use std::collections::BTreeMap; use super::common::create_user_group_matching_regex; // NOTE: this function is unsafe because it does no input validation. -async fn unsafe_database_exists( - db_name: &str, +pub(super) async fn unsafe_database_exists( + database_name: &str, connection: &mut MySqlConnection, ) -> Result { let result = sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") - .bind(db_name) + .bind(database_name) .fetch_optional(connection) .await?; @@ -153,7 +153,7 @@ pub async fn drop_databases( results } -pub type ListDatabasesOutput = Result, ListDatabasesError>; +pub type ListAllDatabasesOutput = Result, ListDatabasesError>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ListDatabasesError { MySqlError(String), @@ -180,4 +180,4 @@ pub async fn list_databases_for_user( .collect::, sqlx::Error>>() }) .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) -} \ No newline at end of file +} diff --git a/src/server/database_privilege_operations.rs b/src/server/database_privilege_operations.rs index be21078..0c8a148 100644 --- a/src/server/database_privilege_operations.rs +++ b/src/server/database_privilege_operations.rs @@ -1,3 +1,4 @@ +// TODO: fix comment //! Database privilege operations //! //! This module contains functions for querying, modifying, @@ -13,21 +14,21 @@ //! changes will be made when applying a set of changes //! to the list of database privileges. -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeMap, BTreeSet}; -use anyhow::{anyhow, Context}; use indoc::indoc; use itertools::Itertools; -use prettytable::Table; use serde::{Deserialize, Serialize}; use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; -use crate::core::{ - common::{ - create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, - }, - database_operations::validate_database_name, +use super::common::{create_user_group_matching_regex, UnixUser}; +use super::database_operations::unsafe_database_exists; +use super::input_sanitization::{ + quote_identifier, validate_name, validate_ownership_by_unix_user, NameValidationError, + OwnerValidationError, }; +use crate::core::common::{rev_yn, yn}; +use crate::core::database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff}; /// This is the list of fields that are used to fetch the db + user + privileges /// from the `db` table in the database. If you need to add or remove privilege @@ -48,25 +49,6 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ "references_priv", ]; -pub fn db_priv_field_human_readable_name(name: &str) -> String { - match name { - "db" => "Database".to_owned(), - "user" => "User".to_owned(), - "select_priv" => "Select".to_owned(), - "insert_priv" => "Insert".to_owned(), - "update_priv" => "Update".to_owned(), - "delete_priv" => "Delete".to_owned(), - "create_priv" => "Create".to_owned(), - "drop_priv" => "Drop".to_owned(), - "alter_priv" => "Alter".to_owned(), - "index_priv" => "Index".to_owned(), - "create_tmp_table_priv" => "Temp".to_owned(), - "lock_tables_priv" => "Lock".to_owned(), - "references_priv" => "References".to_owned(), - _ => format!("Unknown({})", name), - } -} - // NOTE: ord is needed for BTreeSet to accept the type, but it // doesn't have any natural implementation semantics. @@ -123,26 +105,6 @@ impl DatabasePrivilegeRow { _ => false, } } - - pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { - debug_assert!(self.db == other.db && self.user == other.user); - - DatabasePrivilegeRowDiff { - db: self.db.clone(), - user: self.user.clone(), - diff: DATABASE_PRIVILEGE_FIELDS - .into_iter() - .skip(2) - .filter_map(|field| { - DatabasePrivilegeChange::new( - self.get_privilege_by_name(field), - other.get_privilege_by_name(field), - field, - ) - }) - .collect(), - } - } } #[inline] @@ -178,15 +140,13 @@ impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { } } +// NOTE: this function is unsafe because it does no input validation. /// Get all users + privileges for a single database. -pub async fn get_database_privileges( +async fn unsafe_get_database_privileges( database_name: &str, connection: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - validate_database_name(database_name, &unix_user)?; - - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( +) -> Result, sqlx::Error> { + sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( "SELECT {} FROM `db` WHERE `db` = ?", DATABASE_PRIVILEGE_FIELDS .iter() @@ -196,18 +156,82 @@ pub async fn get_database_privileges( .bind(database_name) .fetch_all(connection) .await - .context("Failed to show database")?; +} - Ok(result) +// TODO: merge all rows into a single collection. +// they already contain which database they belong to. +// no need to index by database name. + +pub type GetDatabasesPrivilegeData = + BTreeMap, GetDatabasesPrivilegeDataError>>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum GetDatabasesPrivilegeDataError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseDoesNotExist, + MySqlError(String), +} + +pub async fn get_databases_privilege_data( + database_names: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> GetDatabasesPrivilegeData { + let mut results = BTreeMap::new(); + + for database_name in database_names.iter() { + if let Err(err) = validate_name(database_name) { + results.insert( + database_name.clone(), + Err(GetDatabasesPrivilegeDataError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) { + results.insert( + database_name.clone(), + Err(GetDatabasesPrivilegeDataError::OwnershipError(err)), + ); + continue; + } + + if !unsafe_database_exists(database_name, connection) + .await + .unwrap() + { + results.insert( + database_name.clone(), + Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist), + ); + continue; + } + + let result = unsafe_get_database_privileges(database_name, connection) + .await + .map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string())); + + results.insert(database_name.clone(), result); + } + + debug_assert!(database_names.len() == results.len()); + + results +} + +pub type GetAllDatabasesPrivilegeData = + Result, GetAllDatabasesPrivilegeDataError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum GetAllDatabasesPrivilegeDataError { + MySqlError(String), } /// Get all database + user + privileges pairs that are owned by the current user. pub async fn get_all_database_privileges( + unix_user: &UnixUser, connection: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( +) -> GetAllDatabasesPrivilegeData { + sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( indoc! {r#" SELECT {} FROM `db` WHERE `db` IN (SELECT DISTINCT `SCHEMA_NAME` AS `database` @@ -217,672 +241,159 @@ pub async fn get_all_database_privileges( "#}, DATABASE_PRIVILEGE_FIELDS .iter() - .map(|field| format!("`{field}`")) + .map(|field| quote_identifier(field)) .join(","), )) - .bind(create_user_group_matching_regex(&unix_user)) + .bind(create_user_group_matching_regex(unix_user)) .fetch_all(connection) .await - .context("Failed to show databases")?; - - Ok(result) + .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())) } -/*************************/ -/* CLI INTERFACE PARSING */ -/*************************/ - -/// See documentation for [`DatabaseCommand::EditDbPrivs`]. -pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result { - let parts: Vec<&str> = arg.split(':').collect(); - if parts.len() != 3 { - anyhow::bail!("Invalid argument format. See `edit-db-privs --help` for more information."); - } - - let db = parts[0].to_string(); - let user = parts[1].to_string(); - let privs = parts[2].to_string(); - - let mut result = DatabasePrivilegeRow { - db, - user, - select_priv: false, - insert_priv: false, - update_priv: false, - delete_priv: false, - create_priv: false, - drop_priv: false, - alter_priv: false, - index_priv: false, - create_tmp_table_priv: false, - lock_tables_priv: false, - references_priv: false, - }; - - for char in privs.chars() { - match char { - 's' => result.select_priv = true, - 'i' => result.insert_priv = true, - 'u' => result.update_priv = true, - 'd' => result.delete_priv = true, - 'c' => result.create_priv = true, - 'D' => result.drop_priv = true, - 'a' => result.alter_priv = true, - 'I' => result.index_priv = true, - 't' => result.create_tmp_table_priv = true, - 'l' => result.lock_tables_priv = true, - 'r' => result.references_priv = true, - 'A' => { - result.select_priv = true; - result.insert_priv = true; - result.update_priv = true; - result.delete_priv = true; - result.create_priv = true; - result.drop_priv = true; - result.alter_priv = true; - result.index_priv = true; - result.create_tmp_table_priv = true; - result.lock_tables_priv = true; - result.references_priv = true; - } - _ => anyhow::bail!("Invalid privilege character: {}", char), - } - } - - Ok(result) -} - -/**********************************/ -/* EDITOR CONTENT DISPLAY/DISPLAY */ -/**********************************/ - -/// Generates a single row of the privileges table for the editor. -pub fn format_privileges_line_for_editor( - privs: &DatabasePrivilegeRow, - username_len: usize, - database_name_len: usize, -) -> String { - DATABASE_PRIVILEGE_FIELDS - .into_iter() - .map(|field| match field { - "db" => format!("{:width$}", privs.db, width = database_name_len), - "user" => format!("{:width$}", privs.user, width = username_len), - privilege => format!( - "{:width$}", - yn(privs.get_privilege_by_name(privilege)), - width = db_priv_field_human_readable_name(privilege).len() - ), - }) - .join(" ") - .trim() - .to_string() -} - -const EDITOR_COMMENT: &str = r#" -# Welcome to the privilege editor. -# Each line defines what privileges a single user has on a single database. -# The first two columns respectively represent the database name and the user, and the remaining columns are the privileges. -# If the user should have a certain privilege, write 'Y', otherwise write 'N'. -# -# Lines starting with '#' are comments and will be ignored. -"#; - -/// Generates the content for the privilege editor. -/// -/// The unix user is used in case there are no privileges to edit, -/// so that the user can see an example line based on their username. -pub fn generate_editor_content_from_privilege_data( - privilege_data: &[DatabasePrivilegeRow], - unix_user: &str, -) -> String { - let example_user = format!("{}_user", unix_user); - let example_db = format!("{}_db", unix_user); - - // NOTE: `.max()`` fails when the iterator is empty. - // In this case, we know that the only fields in the - // editor will be the example user and example db name. - // Hence, it's put as the fallback value, despite not really - // being a "fallback" in the normal sense. - let longest_username = privilege_data - .iter() - .map(|p| p.user.len()) - .max() - .unwrap_or(example_user.len()); - - let longest_database_name = privilege_data - .iter() - .map(|p| p.db.len()) - .max() - .unwrap_or(example_db.len()); - - let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS - .into_iter() - .map(db_priv_field_human_readable_name) - .collect(); - - // Pad the first two columns with spaces to align the privileges. - header[0] = format!("{:width$}", header[0], width = longest_database_name); - header[1] = format!("{:width$}", header[1], width = longest_username); - - let example_line = format_privileges_line_for_editor( - &DatabasePrivilegeRow { - db: example_db, - user: example_user, - select_priv: true, - insert_priv: true, - update_priv: true, - delete_priv: true, - create_priv: false, - drop_priv: false, - alter_priv: false, - index_priv: false, - create_tmp_table_priv: false, - lock_tables_priv: false, - references_priv: false, - }, - longest_username, - longest_database_name, - ); - - format!( - "{}\n{}\n{}", - EDITOR_COMMENT, - header.join(" "), - if privilege_data.is_empty() { - format!("# {}", example_line) - } else { - privilege_data +async fn unsafe_apply_privilege_diff( + database_privilege_diff: &DatabasePrivilegesDiff, + connection: &mut MySqlConnection, +) -> Result<(), sqlx::Error> { + match database_privilege_diff { + DatabasePrivilegesDiff::New(p) => { + let tables = DATABASE_PRIVILEGE_FIELDS .iter() - .map(|privs| { - format_privileges_line_for_editor( - privs, - longest_username, - longest_database_name, - ) + .map(|field| quote_identifier(field)) + .join(","); + + let question_marks = std::iter::repeat("?") + .take(DATABASE_PRIVILEGE_FIELDS.len()) + .join(","); + + sqlx::query( + format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), + ) + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .bind(yn(p.select_priv)) + .bind(yn(p.insert_priv)) + .bind(yn(p.update_priv)) + .bind(yn(p.delete_priv)) + .bind(yn(p.create_priv)) + .bind(yn(p.drop_priv)) + .bind(yn(p.alter_priv)) + .bind(yn(p.index_priv)) + .bind(yn(p.create_tmp_table_priv)) + .bind(yn(p.lock_tables_priv)) + .bind(yn(p.references_priv)) + .execute(connection) + .await + .map(|_| ()) + } + DatabasePrivilegesDiff::Modified(p) => { + let changes = p + .diff + .iter() + .map(|diff| match diff { + DatabasePrivilegeChange::YesToNo(name) => { + format!("{} = 'N'", quote_identifier(name)) + } + DatabasePrivilegeChange::NoToYes(name) => { + format!("{} = 'Y'", quote_identifier(name)) + } }) - .join("\n") + .join(","); + + sqlx::query( + format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(), + ) + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .execute(connection) + .await + .map(|_| ()) } - ) -} - -#[derive(Debug)] -enum PrivilegeRowParseResult { - PrivilegeRow(DatabasePrivilegeRow), - ParserError(anyhow::Error), - TooFewFields(usize), - TooManyFields(usize), - Header, - Comment, - Empty, -} - -#[inline] -fn parse_privilege_cell_from_editor(yn: &str, name: &str) -> anyhow::Result { - rev_yn(yn) - .ok_or_else(|| anyhow!("Expected Y or N, found {}", yn)) - .context(format!("Could not parse {} privilege", name)) -} - -#[inline] -fn editor_row_is_header(row: &str) -> bool { - row.split_ascii_whitespace() - .zip(DATABASE_PRIVILEGE_FIELDS.iter()) - .map(|(field, priv_name)| (field, db_priv_field_human_readable_name(priv_name))) - .all(|(field, header_field)| field == header_field) -} - -/// Parse a single row of the privileges table from the editor. -fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult { - if row.starts_with('#') || row.starts_with("//") { - return PrivilegeRowParseResult::Comment; - } - - if row.trim().is_empty() { - return PrivilegeRowParseResult::Empty; - } - - let parts: Vec<&str> = row.trim().split_ascii_whitespace().collect(); - - match parts.len() { - n if (n < DATABASE_PRIVILEGE_FIELDS.len()) => { - return PrivilegeRowParseResult::TooFewFields(n) - } - n if (n > DATABASE_PRIVILEGE_FIELDS.len()) => { - return PrivilegeRowParseResult::TooManyFields(n) - } - _ => {} - } - - if editor_row_is_header(row) { - return PrivilegeRowParseResult::Header; - } - - let row = DatabasePrivilegeRow { - db: (*parts.first().unwrap()).to_owned(), - user: (*parts.get(1).unwrap()).to_owned(), - select_priv: match parse_privilege_cell_from_editor( - parts.get(2).unwrap(), - DATABASE_PRIVILEGE_FIELDS[2], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - insert_priv: match parse_privilege_cell_from_editor( - parts.get(3).unwrap(), - DATABASE_PRIVILEGE_FIELDS[3], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - update_priv: match parse_privilege_cell_from_editor( - parts.get(4).unwrap(), - DATABASE_PRIVILEGE_FIELDS[4], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - delete_priv: match parse_privilege_cell_from_editor( - parts.get(5).unwrap(), - DATABASE_PRIVILEGE_FIELDS[5], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - create_priv: match parse_privilege_cell_from_editor( - parts.get(6).unwrap(), - DATABASE_PRIVILEGE_FIELDS[6], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - drop_priv: match parse_privilege_cell_from_editor( - parts.get(7).unwrap(), - DATABASE_PRIVILEGE_FIELDS[7], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - alter_priv: match parse_privilege_cell_from_editor( - parts.get(8).unwrap(), - DATABASE_PRIVILEGE_FIELDS[8], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - index_priv: match parse_privilege_cell_from_editor( - parts.get(9).unwrap(), - DATABASE_PRIVILEGE_FIELDS[9], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - create_tmp_table_priv: match parse_privilege_cell_from_editor( - parts.get(10).unwrap(), - DATABASE_PRIVILEGE_FIELDS[10], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - lock_tables_priv: match parse_privilege_cell_from_editor( - parts.get(11).unwrap(), - DATABASE_PRIVILEGE_FIELDS[11], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - references_priv: match parse_privilege_cell_from_editor( - parts.get(12).unwrap(), - DATABASE_PRIVILEGE_FIELDS[12], - ) { - Ok(p) => p, - Err(e) => return PrivilegeRowParseResult::ParserError(e), - }, - }; - - PrivilegeRowParseResult::PrivilegeRow(row) -} - -// TODO: return better errors - -pub fn parse_privilege_data_from_editor_content( - content: String, -) -> anyhow::Result> { - content - .trim() - .split('\n') - .map(|line| line.trim()) - .map(parse_privilege_row_from_editor) - .map(|result| match result { - PrivilegeRowParseResult::PrivilegeRow(row) => Ok(Some(row)), - PrivilegeRowParseResult::ParserError(e) => Err(e), - PrivilegeRowParseResult::TooFewFields(n) => Err(anyhow!( - "Too few fields in line. Expected to find {} fields, found {}", - DATABASE_PRIVILEGE_FIELDS.len(), - n - )), - PrivilegeRowParseResult::TooManyFields(n) => Err(anyhow!( - "Too many fields in line. Expected to find {} fields, found {}", - DATABASE_PRIVILEGE_FIELDS.len(), - n - )), - PrivilegeRowParseResult::Header => Ok(None), - PrivilegeRowParseResult::Comment => Ok(None), - PrivilegeRowParseResult::Empty => Ok(None), - }) - .filter_map(|result| result.transpose()) - .collect::>>() -} - -/*****************************/ -/* CALCULATE PRIVILEGE DIFFS */ -/*****************************/ - -/// This struct represents encapsulates the differences between two -/// instances of privilege sets for a single user on a single database. -/// -/// The `User` and `Database` are the same for both instances. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] -pub struct DatabasePrivilegeRowDiff { - pub db: String, - pub user: String, - pub diff: BTreeSet, -} - -/// This enum represents a change for a single privilege. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] -pub enum DatabasePrivilegeChange { - YesToNo(String), - NoToYes(String), -} - -impl DatabasePrivilegeChange { - pub fn new(p1: bool, p2: bool, name: &str) -> Option { - match (p1, p2) { - (true, false) => Some(DatabasePrivilegeChange::YesToNo(name.to_owned())), - (false, true) => Some(DatabasePrivilegeChange::NoToYes(name.to_owned())), - _ => None, + DatabasePrivilegesDiff::Deleted(p) => { + sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .execute(connection) + .await + .map(|_| ()) } } } -/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] -pub enum DatabasePrivilegesDiff { - New(DatabasePrivilegeRow), - Modified(DatabasePrivilegeRowDiff), - Deleted(DatabasePrivilegeRow), -} - -/// This function calculates the differences between two sets of database privileges. -/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or -/// apply a set of privilege modifications to the database. -pub fn diff_privileges( - from: &[DatabasePrivilegeRow], - to: &[DatabasePrivilegeRow], -) -> BTreeSet { - let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( - from.iter() - .cloned() - .map(|p| ((p.db.clone(), p.user.clone()), p)), - ); - - let to_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( - to.iter() - .cloned() - .map(|p| ((p.db.clone(), p.user.clone()), p)), - ); - - let mut result = BTreeSet::new(); - - for p in to { - if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { - let diff = old_p.diff(p); - if !diff.diff.is_empty() { - result.insert(DatabasePrivilegesDiff::Modified(diff)); - } - } else { - result.insert(DatabasePrivilegesDiff::New(p.clone())); - } - } - - for p in from { - if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { - result.insert(DatabasePrivilegesDiff::Deleted(p.clone())); - } - } - - result +pub type ModifyDatabasePrivilegesOutput = + BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ModifyDatabasePrivilegesError { + DatabaseSanitizationError(NameValidationError), + DatabaseOwnershipError(OwnerValidationError), + UserSanitizationError(NameValidationError), + UserOwnershipError(OwnerValidationError), + DatabaseDoesNotExist, + DiffDoesNotApply(DatabasePrivilegeChange), + MySqlError(String), } /// Uses the result of [`diff_privileges`] to modify privileges in the database. pub async fn apply_privilege_diffs( - diffs: BTreeSet, + database_privilege_diffs: BTreeSet, + unix_user: &UnixUser, connection: &mut MySqlConnection, -) -> anyhow::Result<()> { - for diff in diffs { - match diff { - DatabasePrivilegesDiff::New(p) => { - let tables = DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","); +) -> ModifyDatabasePrivilegesOutput { + let mut results: BTreeMap = BTreeMap::new(); - let question_marks = std::iter::repeat("?") - .take(DATABASE_PRIVILEGE_FIELDS.len()) - .join(","); - - sqlx::query( - format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), - ) - .bind(p.db) - .bind(p.user) - .bind(yn(p.select_priv)) - .bind(yn(p.insert_priv)) - .bind(yn(p.update_priv)) - .bind(yn(p.delete_priv)) - .bind(yn(p.create_priv)) - .bind(yn(p.drop_priv)) - .bind(yn(p.alter_priv)) - .bind(yn(p.index_priv)) - .bind(yn(p.create_tmp_table_priv)) - .bind(yn(p.lock_tables_priv)) - .bind(yn(p.references_priv)) - .execute(&mut *connection) - .await?; - } - DatabasePrivilegesDiff::Modified(p) => { - let tables = p - .diff - .iter() - .map(|diff| match diff { - DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name), - DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name), - }) - .join(","); - - sqlx::query( - format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), - ) - .bind(p.db) - .bind(p.user) - .execute(&mut *connection) - .await?; - } - DatabasePrivilegesDiff::Deleted(p) => { - sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") - .bind(p.db) - .bind(p.user) - .execute(&mut *connection) - .await?; - } + for diff in database_privilege_diffs { + if let Err(err) = validate_name(diff.get_database_name()) { + results.insert( + diff.get_database_name().to_string(), + Err(ModifyDatabasePrivilegesError::DatabaseSanitizationError( + err, + )), + ); + continue; } - } - Ok(()) -} -fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { - diff.diff - .iter() - .map(|change| match change { - DatabasePrivilegeChange::YesToNo(name) => { - format!("{}: Y -> N", db_priv_field_human_readable_name(name)) - } - DatabasePrivilegeChange::NoToYes(name) => { - format!("{}: N -> Y", db_priv_field_human_readable_name(name)) - } - }) - .join("\n") -} - -/// Displays the difference between two sets of database privileges. -pub fn display_privilege_diffs(diffs: &BTreeSet) -> String { - let mut table = Table::new(); - table.set_titles(row!["Database", "User", "Privilege diff",]); - for row in diffs { - match row { - DatabasePrivilegesDiff::New(p) => { - table.add_row(row![ - p.db, - p.user, - "(New user)\n".to_string() - + &display_privilege_cell( - &DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p) - ) - ]); - } - DatabasePrivilegesDiff::Modified(p) => { - table.add_row(row![p.db, p.user, display_privilege_cell(p),]); - } - DatabasePrivilegesDiff::Deleted(p) => { - table.add_row(row![ - p.db, - p.user, - "(All privileges removed)\n".to_string() - + &display_privilege_cell( - &p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user)) - ) - ]); - } + if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) { + results.insert( + diff.get_database_name().to_string(), + Err(ModifyDatabasePrivilegesError::DatabaseOwnershipError(err)), + ); + continue; } + + if let Err(err) = validate_name(diff.get_user_name()) { + results.insert( + diff.get_database_name().to_string(), + Err(ModifyDatabasePrivilegesError::UserSanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) { + results.insert( + diff.get_database_name().to_string(), + Err(ModifyDatabasePrivilegesError::UserOwnershipError(err)), + ); + continue; + } + + if !unsafe_database_exists(diff.get_database_name(), connection) + .await + .unwrap() + { + results.insert( + diff.get_database_name().to_string(), + Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist), + ); + continue; + } + + // TODO: validate that the diff actually applies to the database + + let result = unsafe_apply_privilege_diff(&diff, connection) + .await + .map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string())); + + results.insert(diff.get_database_name().to_string(), result); } - table.to_string() -} - -/*********/ -/* TESTS */ -/*********/ - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_database_privilege_change_creation() { - assert_eq!( - DatabasePrivilegeChange::new(true, false, "test"), - Some(DatabasePrivilegeChange::YesToNo("test".to_owned())) - ); - assert_eq!( - DatabasePrivilegeChange::new(false, true, "test"), - Some(DatabasePrivilegeChange::NoToYes("test".to_owned())) - ); - assert_eq!(DatabasePrivilegeChange::new(true, true, "test"), None); - assert_eq!(DatabasePrivilegeChange::new(false, false, "test"), None); - } - - #[test] - fn test_diff_privileges() { - let row_to_be_modified = DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), - select_priv: true, - insert_priv: true, - update_priv: true, - delete_priv: true, - create_priv: true, - drop_priv: true, - alter_priv: true, - index_priv: false, - create_tmp_table_priv: true, - lock_tables_priv: true, - references_priv: false, - }; - - let mut row_to_be_deleted = row_to_be_modified.clone(); - "user2".clone_into(&mut row_to_be_deleted.user); - - let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()]; - - let mut modified_row = row_to_be_modified.clone(); - modified_row.select_priv = false; - modified_row.insert_priv = false; - modified_row.index_priv = true; - - let mut new_row = row_to_be_modified.clone(); - "user3".clone_into(&mut new_row.user); - - let to = vec![modified_row.clone(), new_row.clone()]; - - let diffs = diff_privileges(&from, &to); - - assert_eq!( - diffs, - BTreeSet::from_iter(vec![ - DatabasePrivilegesDiff::Deleted(row_to_be_deleted), - DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff { - db: "db".to_owned(), - user: "user".to_owned(), - diff: BTreeSet::from_iter(vec![ - DatabasePrivilegeChange::YesToNo("select_priv".to_owned()), - DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()), - DatabasePrivilegeChange::NoToYes("index_priv".to_owned()), - ]), - }), - DatabasePrivilegesDiff::New(new_row), - ]) - ); - } - - #[test] - fn ensure_generated_and_parsed_editor_content_is_equal() { - let permissions = vec![ - DatabasePrivilegeRow { - db: "db".to_owned(), - user: "user".to_owned(), - select_priv: true, - insert_priv: true, - update_priv: true, - delete_priv: true, - create_priv: true, - drop_priv: true, - alter_priv: true, - index_priv: true, - create_tmp_table_priv: true, - lock_tables_priv: true, - references_priv: true, - }, - DatabasePrivilegeRow { - db: "db2".to_owned(), - user: "user2".to_owned(), - select_priv: false, - insert_priv: false, - update_priv: false, - delete_priv: false, - create_priv: false, - drop_priv: false, - alter_priv: false, - index_priv: false, - create_tmp_table_priv: false, - lock_tables_priv: false, - references_priv: false, - }, - ]; - - let content = generate_editor_content_from_privilege_data(&permissions, "user"); - - let parsed_permissions = parse_privilege_data_from_editor_content(content).unwrap(); - - assert_eq!(permissions, parsed_permissions); - } + results } diff --git a/src/server/entrypoint.rs b/src/server/entrypoint.rs index 309e768..5f48374 100644 --- a/src/server/entrypoint.rs +++ b/src/server/entrypoint.rs @@ -1,87 +1,126 @@ use futures_util::{SinkExt, StreamExt}; use sqlx::MySqlConnection; -use tokio::net::UnixStream; -use tokio_serde::{formats::Bincode, Framed as SerdeFramed}; -use tokio_util::codec::{Framed, LengthDelimitedCodec}; - -// use crate::server:: - -use crate::server::protocol::{Request, Response}; +use std::collections::BTreeSet; use super::{ - common::UnixUser, database_operations::{create_databases, drop_databases}, user_operations::{create_database_users, drop_database_users, set_password_for_database_user} + common::UnixUser, + database_operations::{create_databases, drop_databases, list_databases_for_user}, + database_privilege_operations::{ + apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, + }, + protocol::{Request, Response, ServerToClientMessageStream}, + user_operations::{ + create_database_users, drop_database_users, list_all_database_users_for_unix_user, + list_database_users, lock_database_users, set_password_for_database_user, + unlock_database_users, + }, }; -pub type ClientToServerMessageStream<'a> = SerdeFramed< - Framed<&'a mut UnixStream, LengthDelimitedCodec>, - Request, - Response, - Bincode, ->; - -pub async fn run_server( - socket: &mut UnixStream, +pub async fn handle_request( + mut stream: ServerToClientMessageStream, unix_user: &UnixUser, db_connection: &mut MySqlConnection, ) -> anyhow::Result<()> { - let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); - let mut stream: ClientToServerMessageStream = - tokio_serde::Framed::new(length_delimited, Bincode::default()); + loop { + // TODO: better error handling + let request = match stream.next().await { + Some(Ok(request)) => request, + Some(Err(e)) => return Err(e.into()), + None => { + log::warn!("Client disconnected without sending an exit message"); + break; + } + }; - // TODO: better error handling - let request = match stream.next().await { - Some(Ok(request)) => request, - Some(Err(e)) => return Err(e.into()), - None => return Err(anyhow::anyhow!("No request received")), - }; + match request { + Request::CreateDatabases(databases_names) => { + let result = create_databases(databases_names, unix_user, db_connection).await; + stream.send(Response::CreateDatabases(result)).await?; + stream.flush().await?; + } + Request::DropDatabases(databases_names) => { + let result = drop_databases(databases_names, unix_user, db_connection).await; + stream.send(Response::DropDatabases(result)).await?; + stream.flush().await?; + } + Request::ListDatabases => { + let result = list_databases_for_user(unix_user, db_connection).await; + stream.send(Response::ListAllDatabases(result)).await?; + stream.flush().await?; + } + Request::ListPrivileges(database_names) => { + let response = match database_names { + Some(database_names) => { + let privilege_data = + get_databases_privilege_data(database_names, unix_user, db_connection) + .await; + Response::ListPrivileges(privilege_data) + } + None => { + let privilege_data = + get_all_database_privileges(unix_user, db_connection).await; + Response::ListAllPrivileges(privilege_data) + } + }; - match request { - Request::CreateDatabases(databases) => { - let result = create_databases(databases, unix_user, db_connection).await; - stream.send(Response::CreateDatabases(result)).await?; - stream.flush().await?; - } - Request::DropDatabases(databases) => { - let result = drop_databases(databases, unix_user, db_connection).await; - stream.send(Response::DropDatabases(result)).await?; - stream.flush().await?; - } - Request::ListDatabases => { - println!("Listing databases"); - // let result = list_databases(unix_user, db_connection).await; - // stream.send(Response::ListDatabases(result)).await?; - // stream.flush().await?; - } - Request::ListPrivileges(users) => { - println!("Listing privileges for users: {:?}", users); - } - Request::ModifyPrivileges(()) => { - println!("Modifying privileges"); - } - Request::CreateUsers(db_users) => { - let result = create_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::CreateUsers(result)).await?; - stream.flush().await?; - } - Request::DropUsers(db_users) => { - let result = drop_database_users(db_users, unix_user, db_connection).await; - stream.send(Response::DropUsers(result)).await?; - stream.flush().await?; - } - Request::PasswdUser(db_user, password) => { - let result = - set_password_for_database_user(&db_user, &password, unix_user, db_connection).await; - stream.send(Response::PasswdUser(result)).await?; - stream.flush().await?; - } - Request::ListUsers(db_users) => { - println!("Listing users: {:?}", db_users); - } - Request::LockUsers(db_users) => { - println!("Locking users: {:?}", db_users); - } - Request::UnlockUsers(db_users) => { - println!("Unlocking users: {:?}", db_users); + stream.send(response).await?; + stream.flush().await?; + } + Request::ModifyPrivileges(database_privilege_diffs) => { + let result = apply_privilege_diffs( + BTreeSet::from_iter(database_privilege_diffs), + unix_user, + db_connection, + ) + .await; + stream.send(Response::ModifyPrivileges(result)).await?; + stream.flush().await?; + } + Request::CreateUsers(db_users) => { + let result = create_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::CreateUsers(result)).await?; + stream.flush().await?; + } + Request::DropUsers(db_users) => { + let result = drop_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::DropUsers(result)).await?; + stream.flush().await?; + } + Request::PasswdUser(db_user, password) => { + let result = + set_password_for_database_user(&db_user, &password, unix_user, db_connection) + .await; + stream.send(Response::PasswdUser(result)).await?; + stream.flush().await?; + } + Request::ListUsers(db_users) => { + let response = match db_users { + Some(db_users) => { + let result = list_database_users(db_users, unix_user, db_connection).await; + Response::ListUsers(result) + } + None => { + let result = + list_all_database_users_for_unix_user(unix_user, db_connection).await; + Response::ListAllUsers(result) + } + }; + stream.send(response).await?; + stream.flush().await?; + } + Request::LockUsers(db_users) => { + let result = lock_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::LockUsers(result)).await?; + stream.flush().await?; + } + Request::UnlockUsers(db_users) => { + let result = unlock_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::UnlockUsers(result)).await?; + stream.flush().await?; + } + Request::Exit => { + break; + } } } diff --git a/src/server/protocol.rs b/src/server/protocol.rs index f0faeb0..5e28915 100644 --- a/src/server/protocol.rs +++ b/src/server/protocol.rs @@ -1,8 +1,44 @@ use serde::{Deserialize, Serialize}; +use tokio::net::UnixStream; +use tokio_serde::{formats::Bincode, Framed as SerdeFramed}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; -use super::{database_operations::{CreateDatabasesOutput, DropDatabasesOutput}, user_operations::{ - CreateUsersOutput, DropUsersOutput, LockUsersOutput, SetPasswordOutput, UnlockUsersOutput, -}}; +use crate::core::database_privileges::DatabasePrivilegesDiff; + +use super::{ + database_operations::{CreateDatabasesOutput, DropDatabasesOutput, ListAllDatabasesOutput}, + database_privilege_operations::{ + GetAllDatabasesPrivilegeData, GetDatabasesPrivilegeData, ModifyDatabasePrivilegesOutput, + }, + user_operations::{ + CreateUsersOutput, DropUsersOutput, ListAllUsersOutput, ListUsersOutput, LockUsersOutput, + SetPasswordOutput, UnlockUsersOutput, + }, +}; + +pub type ServerToClientMessageStream = SerdeFramed< + Framed, + Request, + Response, + Bincode, +>; + +pub type ClientToServerMessageStream = SerdeFramed< + Framed, + Response, + Request, + Bincode, +>; + +pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream { + let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); + tokio_serde::Framed::new(length_delimited, Bincode::default()) +} + +pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream { + let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); + tokio_serde::Framed::new(length_delimited, Bincode::default()) +} #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -10,8 +46,8 @@ pub enum Request { CreateDatabases(Vec), DropDatabases(Vec), ListDatabases, - ListPrivileges(Vec), - ModifyPrivileges(()), // what data should be sent with this command? Who should calculate the diff? + ListPrivileges(Option>), + ModifyPrivileges(Vec), CreateUsers(Vec), DropUsers(Vec), @@ -19,6 +55,9 @@ pub enum Request { ListUsers(Option>), LockUsers(Vec), UnlockUsers(Vec), + + // Commit, + Exit, } // TODO: include a generic "message" that will display a message to the user? @@ -29,12 +68,16 @@ pub enum Response { // Specific data for specific commands CreateDatabases(CreateDatabasesOutput), DropDatabases(DropDatabasesOutput), - // ListDatabases(ListDatabasesOutput), - // ListPrivileges(ListPrivilegesOutput), + ListAllDatabases(ListAllDatabasesOutput), + ListPrivileges(GetDatabasesPrivilegeData), + ListAllPrivileges(GetAllDatabasesPrivilegeData), + ModifyPrivileges(ModifyDatabasePrivilegesOutput), + CreateUsers(CreateUsersOutput), DropUsers(DropUsersOutput), PasswdUser(SetPasswordOutput), - ListUsers(()), // what data should be sent with this response? + ListUsers(ListUsersOutput), + ListAllUsers(ListAllUsersOutput), LockUsers(LockUsersOutput), UnlockUsers(UnlockUsersOutput), diff --git a/src/server/user_operations.rs b/src/server/user_operations.rs index 51abf6c..32b5328 100644 --- a/src/server/user_operations.rs +++ b/src/server/user_operations.rs @@ -272,7 +272,7 @@ pub enum UnlockUserError { MySqlError(String), } -pub async fn unlock_database_user( +pub async fn unlock_database_users( db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, @@ -362,34 +362,65 @@ JOIN `mysql`.`global_priv` ON AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` "#; -pub async fn get_all_database_users_for_unix_user( +pub type ListUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListUsersError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + MySqlError(String), +} + +pub async fn list_database_users( + db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, -) -> Result, sqlx::Error> { +) -> ListUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(ListUsersError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(ListUsersError::OwnershipError(err))); + continue; + } + + let result = sqlx::query_as::<_, DatabaseUser>( + &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), + ) + .bind(&db_user) + .fetch_optional(&mut *connection) + .await; + + match result { + Ok(Some(user)) => results.insert(db_user, Ok(user)), + Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)), + Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))), + }; + } + + results +} + +pub type ListAllUsersOutput = Result, ListAllUsersError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListAllUsersError { + MySqlError(String), +} + +pub async fn list_all_database_users_for_unix_user( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> ListAllUsersOutput { sqlx::query_as::<_, DatabaseUser>( &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), ) .bind(create_user_group_matching_regex(unix_user)) .fetch_all(connection) .await + .map_err(|err| ListAllUsersError::MySqlError(err.to_string())) } - -// /// This function fetches a database user if it exists. -// pub async fn get_database_user_for_user( -// username: &str, -// connection: &mut MySqlConnection, -// ) -> anyhow::Result> { -// let user = sqlx::query_as::<_, DatabaseUser>( -// &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), -// ) -// .bind(username) -// .fetch_optional(connection) -// .await?; - -// Ok(user) -// } - -// /// NOTE: It is very critical that this function validates the database name -// /// properly. MySQL does not seem to allow for prepared statements, binding -// /// the database name as a parameter to the query. This means that we have -// /// to validate the database name ourselves to prevent SQL injection.