From 76134704f9bdd50649da9275ffe155b67d0dd19c Mon Sep 17 00:00:00 2001 From: h7x4 Date: Wed, 7 Aug 2024 20:50:39 +0200 Subject: [PATCH] treewide: move some code around, spring cleaning --- src/cli/database_command.rs | 23 +- .../mysql_dbadm.rs | 12 +- .../mysql_useradm.rs | 5 +- src/cli/user_command.rs | 22 +- src/core.rs | 1 + src/core/common.rs | 21 + src/core/database_operations.rs | 398 +----------------- src/core/database_privilege_operations.rs | 372 ++++++++++++++++ 8 files changed, 431 insertions(+), 423 deletions(-) create mode 100644 src/core/database_privilege_operations.rs diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs index ccad5cd..f7395ff 100644 --- a/src/cli/database_command.rs +++ b/src/cli/database_command.rs @@ -7,12 +7,9 @@ use prettytable::{Cell, Row, Table}; use sqlx::{Connection, MySqlConnection}; use crate::core::{ - self, - common::{close_database_connection, get_current_unix_user}, - database_operations::{ - apply_permission_diffs, db_priv_field_human_readable_name, diff_permissions, yn, - DatabasePrivileges, DATABASE_PRIVILEGE_FIELDS, - }, + common::{close_database_connection, get_current_unix_user, yn}, + database_operations::*, + database_privilege_operations::*, user_operations::user_exists, }; @@ -174,7 +171,7 @@ async fn create_databases( for name in args.name { // TODO: This can be optimized by fetching all the database privileges in one query. - if let Err(e) = core::database_operations::create_database(&name, conn).await { + if let Err(e) = create_database(&name, conn).await { eprintln!("Failed to create database '{}': {}", name, e); eprintln!("Skipping..."); } @@ -190,7 +187,7 @@ async fn drop_databases(args: DatabaseDropArgs, conn: &mut MySqlConnection) -> a for name in args.name { // TODO: This can be optimized by fetching all the database privileges in one query. - if let Err(e) = core::database_operations::drop_database(&name, conn).await { + if let Err(e) = drop_database(&name, conn).await { eprintln!("Failed to drop database '{}': {}", name, e); eprintln!("Skipping..."); } @@ -200,7 +197,7 @@ async fn drop_databases(args: DatabaseDropArgs, conn: &mut MySqlConnection) -> a } async fn list_databases(args: DatabaseListArgs, conn: &mut MySqlConnection) -> anyhow::Result<()> { - let databases = core::database_operations::get_database_list(conn).await?; + let databases = get_database_list(conn).await?; if databases.is_empty() { println!("No databases to show."); @@ -223,12 +220,12 @@ async fn show_databases( conn: &mut MySqlConnection, ) -> anyhow::Result<()> { let database_users_to_show = if args.name.is_empty() { - core::database_operations::get_all_database_privileges(conn).await? + get_all_database_privileges(conn).await? } else { // TODO: This can be optimized by fetching all the database privileges in one query. let mut result = Vec::with_capacity(args.name.len()); for name in args.name { - match core::database_operations::get_database_privileges(&name, conn).await { + match get_database_privileges(&name, conn).await { Ok(db) => result.extend(db), Err(e) => { eprintln!("Failed to show database '{}': {}", name, e); @@ -420,9 +417,9 @@ pub async fn edit_permissions( conn: &mut MySqlConnection, ) -> anyhow::Result<()> { let permission_data = if let Some(name) = &args.name { - core::database_operations::get_database_privileges(name, conn).await? + get_database_privileges(name, conn).await? } else { - core::database_operations::get_all_database_privileges(conn).await? + get_all_database_privileges(conn).await? }; let permissions_to_change = if !args.perm.is_empty() { diff --git a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs index 9a0081e..0136216 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs @@ -7,8 +7,10 @@ use crate::{ mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser}, }, core::{ + common::yn, config::{get_config, mysql_connection_from_config, GlobalConfigArgs}, - database_operations::{self, yn}, + database_operations::{create_database, drop_database, get_database_list}, + database_privilege_operations, }, }; @@ -129,20 +131,20 @@ pub async fn main() -> anyhow::Result<()> { Command::Create(args) => { let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; for name in filtered_names { - database_operations::create_database(&name, &mut connection).await?; + create_database(&name, &mut connection).await?; println!("Database {} created.", name); } } Command::Drop(args) => { let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; for name in filtered_names { - database_operations::drop_database(&name, &mut connection).await?; + drop_database(&name, &mut connection).await?; println!("Database {} dropped.", name); } } Command::Show(args) => { let names = if args.name.is_empty() { - database_operations::get_database_list(&mut connection).await? + get_database_list(&mut connection).await? } else { filter_db_or_user_names(args.name, DbOrUser::Database)? }; @@ -176,7 +178,7 @@ async fn show_db(name: &str, conn: &mut MySqlConnection) -> anyhow::Result<()> { // for non-existent databases will report with no users. // This function should *not* check for db existence, only // validate the names. - let permissions = database_operations::get_database_privileges(name, conn) + let permissions = database_privilege_operations::get_database_privileges(name, conn) .await .unwrap_or(vec![]); diff --git a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs index a4b1381..a44a00f 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs @@ -9,10 +9,7 @@ use crate::{ core::{ common::{close_database_connection, get_current_unix_user}, config::{get_config, mysql_connection_from_config, GlobalConfigArgs}, - user_operations::{ - create_database_user, delete_database_user, get_all_database_users_for_unix_user, - get_database_user_for_user, set_password_for_database_user, user_exists, - }, + user_operations::*, }, }; diff --git a/src/cli/user_command.rs b/src/cli/user_command.rs index 7f8ccd4..79c57b2 100644 --- a/src/cli/user_command.rs +++ b/src/cli/user_command.rs @@ -9,9 +9,9 @@ use serde_json::json; use sqlx::{Connection, MySqlConnection}; use crate::core::{ - common::close_database_connection, - database_operations::get_databases_where_user_has_privileges, - user_operations::validate_user_name, + common::{close_database_connection, get_current_unix_user}, + database_operations::*, + user_operations::*, }; #[derive(Parser)] @@ -99,7 +99,7 @@ async fn create_users(args: UserCreateArgs, conn: &mut MySqlConnection) -> anyho } for username in args.username { - if let Err(e) = crate::core::user_operations::create_database_user(&username, conn).await { + if let Err(e) = create_database_user(&username, conn).await { eprintln!("{}", e); eprintln!("Skipping...\n"); continue; @@ -135,7 +135,7 @@ async fn drop_users(args: UserDeleteArgs, conn: &mut MySqlConnection) -> anyhow: } for username in args.username { - if let Err(e) = crate::core::user_operations::delete_database_user(&username, conn).await { + if let Err(e) = delete_database_user(&username, conn).await { eprintln!("{}", e); eprintln!("Skipping..."); } @@ -160,7 +160,7 @@ async fn change_password_for_user( ) -> anyhow::Result<()> { // NOTE: although this also is checked in `set_password_for_database_user`, we check it here // to provide a more natural order of error messages. - let unix_user = crate::core::common::get_current_unix_user()?; + let unix_user = get_current_unix_user()?; validate_user_name(&args.username, &unix_user)?; let password = if let Some(password_file) = args.password_file { @@ -172,17 +172,16 @@ async fn change_password_for_user( read_password_from_stdin_with_double_check(&args.username)? }; - crate::core::user_operations::set_password_for_database_user(&args.username, &password, conn) - .await?; + set_password_for_database_user(&args.username, &password, conn).await?; Ok(()) } async fn show_users(args: UserShowArgs, conn: &mut MySqlConnection) -> anyhow::Result<()> { - let unix_user = crate::core::common::get_current_unix_user()?; + let unix_user = get_current_unix_user()?; let users = if args.username.is_empty() { - crate::core::user_operations::get_all_database_users_for_unix_user(&unix_user, conn).await? + get_all_database_users_for_unix_user(&unix_user, conn).await? } else { let mut result = vec![]; for username in args.username { @@ -192,8 +191,7 @@ async fn show_users(args: UserShowArgs, conn: &mut MySqlConnection) -> anyhow::R continue; } - let user = - crate::core::user_operations::get_database_user_for_user(&username, conn).await?; + let user = get_database_user_for_user(&username, conn).await?; if let Some(user) = user { result.push(user); } else { diff --git a/src/core.rs b/src/core.rs index b83db7b..aa51dca 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,4 +1,5 @@ pub mod common; pub mod config; pub mod database_operations; +pub mod database_privilege_operations; pub mod user_operations; diff --git a/src/core/common.rs b/src/core/common.rs index d74cb0c..bc4c8a8 100644 --- a/src/core/common.rs +++ b/src/core/common.rs @@ -161,3 +161,24 @@ pub fn quote_literal(s: &str) -> String { pub fn quote_identifier(s: &str) -> String { format!("`{}`", s.replace('`', r"\`")) } + +#[inline] +pub(crate) fn yn(b: bool) -> &'static str { + if b { + "Y" + } else { + "N" + } +} + +#[inline] +pub(crate) fn rev_yn(s: &str) -> bool { + match s.to_lowercase().as_str() { + "y" => true, + "n" => false, + _ => { + log::warn!("Invalid value for privilege: {}", s); + false + } + } +} diff --git a/src/core/database_operations.rs b/src/core/database_operations.rs index f0084b9..f44a5f2 100644 --- a/src/core/database_operations.rs +++ b/src/core/database_operations.rs @@ -1,15 +1,16 @@ -use std::collections::HashMap; - use anyhow::Context; -use indoc::{formatdoc, indoc}; +use indoc::formatdoc; use itertools::Itertools; use nix::unistd::User; use serde::{Deserialize, Serialize}; -use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; +use sqlx::{prelude::*, MySqlConnection}; -use super::common::{ - create_user_group_matching_regex, get_current_unix_user, quote_identifier, validate_name_token, - validate_ownership_by_user_prefix, +use crate::core::{ + common::{ + create_user_group_matching_regex, get_current_unix_user, quote_identifier, + validate_name_token, validate_ownership_by_user_prefix, + }, + database_privilege_operations::DATABASE_PRIVILEGE_FIELDS, }; pub async fn create_database(name: &str, conn: &mut MySqlConnection) -> anyhow::Result<()> { @@ -100,324 +101,12 @@ pub async fn get_databases_where_user_has_privileges( .fetch_all(conn) .await? .into_iter() - .map(|databases| { - databases.try_get::("database").unwrap() - }) + .map(|databases| databases.try_get::("database").unwrap()) .collect(); Ok(result) } -pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ - "db", - "user", - "select_priv", - "insert_priv", - "update_priv", - "delete_priv", - "create_priv", - "drop_priv", - "alter_priv", - "index_priv", - "create_tmp_table_priv", - "lock_tables_priv", - "references_priv", -]; - -pub fn db_priv_field_human_readable_name(name: &str) -> String { - match name { - "db" => "Database".to_owned(), - "user" => "User".to_owned(), - "select_priv" => "Select".to_owned(), - "insert_priv" => "Insert".to_owned(), - "update_priv" => "Update".to_owned(), - "delete_priv" => "Delete".to_owned(), - "create_priv" => "Create".to_owned(), - "drop_priv" => "Drop".to_owned(), - "alter_priv" => "Alter".to_owned(), - "index_priv" => "Index".to_owned(), - "create_tmp_table_priv" => "Temp".to_owned(), - "lock_tables_priv" => "Lock".to_owned(), - "references_priv" => "References".to_owned(), - _ => format!("Unknown({})", name), - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct DatabasePrivileges { - pub db: String, - pub user: String, - pub select_priv: bool, - pub insert_priv: bool, - pub update_priv: bool, - pub delete_priv: bool, - pub create_priv: bool, - pub drop_priv: bool, - pub alter_priv: bool, - pub index_priv: bool, - pub create_tmp_table_priv: bool, - pub lock_tables_priv: bool, - pub references_priv: bool, -} - -impl DatabasePrivileges { - pub fn get_privilege_by_name(&self, name: &str) -> bool { - match name { - "select_priv" => self.select_priv, - "insert_priv" => self.insert_priv, - "update_priv" => self.update_priv, - "delete_priv" => self.delete_priv, - "create_priv" => self.create_priv, - "drop_priv" => self.drop_priv, - "alter_priv" => self.alter_priv, - "index_priv" => self.index_priv, - "create_tmp_table_priv" => self.create_tmp_table_priv, - "lock_tables_priv" => self.lock_tables_priv, - "references_priv" => self.references_priv, - _ => false, - } - } - pub fn diff(&self, other: &DatabasePrivileges) -> DatabasePrivilegeDiffList { - debug_assert!(self.db == other.db && self.user == other.user); - - DatabasePrivilegeDiffList { - db: self.db.clone(), - user: self.user.clone(), - diff: DATABASE_PRIVILEGE_FIELDS - .into_iter() - .skip(2) - .filter_map(|field| { - diff_single_priv( - self.get_privilege_by_name(field), - other.get_privilege_by_name(field), - field, - ) - }) - .collect(), - } - } -} - -#[inline] -pub(crate) fn yn(b: bool) -> &'static str { - if b { - "Y" - } else { - "N" - } -} - -#[inline] -pub(crate) fn rev_yn(s: &str) -> bool { - match s.to_lowercase().as_str() { - "y" => true, - "n" => false, - _ => { - log::warn!("Invalid value for privilege: {}", s); - false - } - } -} - -impl FromRow<'_, MySqlRow> for DatabasePrivileges { - fn from_row(row: &MySqlRow) -> Result { - Ok(Self { - db: row.try_get("db")?, - user: row.try_get("user")?, - select_priv: row.try_get("select_priv").map(rev_yn)?, - insert_priv: row.try_get("insert_priv").map(rev_yn)?, - update_priv: row.try_get("update_priv").map(rev_yn)?, - delete_priv: row.try_get("delete_priv").map(rev_yn)?, - create_priv: row.try_get("create_priv").map(rev_yn)?, - drop_priv: row.try_get("drop_priv").map(rev_yn)?, - alter_priv: row.try_get("alter_priv").map(rev_yn)?, - index_priv: row.try_get("index_priv").map(rev_yn)?, - create_tmp_table_priv: row.try_get("create_tmp_table_priv").map(rev_yn)?, - lock_tables_priv: row.try_get("lock_tables_priv").map(rev_yn)?, - references_priv: row.try_get("references_priv").map(rev_yn)?, - }) - } -} - -pub async fn get_database_privileges( - database_name: &str, - conn: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - validate_database_name(database_name, &unix_user)?; - - let result = sqlx::query_as::<_, DatabasePrivileges>(&format!( - "SELECT {} FROM `db` WHERE `db` = ?", - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| quote_identifier(field)) - .join(","), - )) - .bind(database_name) - .fetch_all(conn) - .await - .context("Failed to show database")?; - - Ok(result) -} - -pub async fn get_all_database_privileges( - conn: &mut MySqlConnection, -) -> anyhow::Result> { - let unix_user = get_current_unix_user()?; - - let result = sqlx::query_as::<_, DatabasePrivileges>(&format!( - indoc! {r#" - SELECT {} FROM `db` WHERE `db` IN - (SELECT DISTINCT `SCHEMA_NAME` AS `database` - FROM `information_schema`.`SCHEMATA` - WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') - AND `SCHEMA_NAME` REGEXP ?) - "#}, - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","), - )) - .bind(create_user_group_matching_regex(&unix_user)) - .fetch_all(conn) - .await - .context("Failed to show databases")?; - - Ok(result) -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct DatabasePrivilegeDiffList { - pub db: String, - pub user: String, - pub diff: Vec, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum DatabasePrivilegeDiff { - YesToNo(String), - NoToYes(String), -} - -fn diff_single_priv(p1: bool, p2: bool, name: &str) -> Option { - match (p1, p2) { - (true, false) => Some(DatabasePrivilegeDiff::YesToNo(name.to_owned())), - (false, true) => Some(DatabasePrivilegeDiff::NoToYes(name.to_owned())), - _ => None, - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum DatabasePrivilegesDiff { - New(DatabasePrivileges), - Modified(DatabasePrivilegeDiffList), - Deleted(DatabasePrivileges), -} - -pub async fn diff_permissions( - from: Vec, - to: &[DatabasePrivileges], -) -> Vec { - let from_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter( - from.iter() - .cloned() - .map(|p| ((p.db.clone(), p.user.clone()), p)), - ); - - let to_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter( - to.iter() - .cloned() - .map(|p| ((p.db.clone(), p.user.clone()), p)), - ); - - let mut result = vec![]; - - for p in to { - if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { - let diff = old_p.diff(p); - if !diff.diff.is_empty() { - result.push(DatabasePrivilegesDiff::Modified(diff)); - } - } else { - result.push(DatabasePrivilegesDiff::New(p.clone())); - } - } - - for p in from { - if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { - result.push(DatabasePrivilegesDiff::Deleted(p)); - } - } - - result -} - -pub async fn apply_permission_diffs( - diffs: Vec, - conn: &mut MySqlConnection, -) -> anyhow::Result<()> { - for diff in diffs { - match diff { - DatabasePrivilegesDiff::New(p) => { - let tables = DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}`")) - .join(","); - - let question_marks = std::iter::repeat("?") - .take(DATABASE_PRIVILEGE_FIELDS.len()) - .join(","); - - sqlx::query( - format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), - ) - .bind(p.db) - .bind(p.user) - .bind(yn(p.select_priv)) - .bind(yn(p.insert_priv)) - .bind(yn(p.update_priv)) - .bind(yn(p.delete_priv)) - .bind(yn(p.create_priv)) - .bind(yn(p.drop_priv)) - .bind(yn(p.alter_priv)) - .bind(yn(p.index_priv)) - .bind(yn(p.create_tmp_table_priv)) - .bind(yn(p.lock_tables_priv)) - .bind(yn(p.references_priv)) - .execute(&mut *conn) - .await?; - } - DatabasePrivilegesDiff::Modified(p) => { - let tables = p - .diff - .iter() - .map(|diff| match diff { - DatabasePrivilegeDiff::YesToNo(name) => format!("`{}` = 'N'", name), - DatabasePrivilegeDiff::NoToYes(name) => format!("`{}` = 'Y'", name), - }) - .join(","); - - sqlx::query( - format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), - ) - .bind(p.db) - .bind(p.user) - .execute(&mut *conn) - .await?; - } - DatabasePrivilegesDiff::Deleted(p) => { - sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") - .bind(p.db) - .bind(p.user) - .execute(&mut *conn) - .await?; - } - } - } - Ok(()) -} - /// NOTE: It is very critical that this function validates the database name /// properly. MySQL does not seem to allow for prepared statements, binding /// the database name as a parameter to the query. This means that we have @@ -428,72 +117,3 @@ pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> { Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_diff_single_priv() { - assert_eq!( - diff_single_priv(true, false, "test"), - Some(DatabasePrivilegeDiff::YesToNo("test".to_owned())) - ); - assert_eq!( - diff_single_priv(false, true, "test"), - Some(DatabasePrivilegeDiff::NoToYes("test".to_owned())) - ); - assert_eq!(diff_single_priv(true, true, "test"), None); - assert_eq!(diff_single_priv(false, false, "test"), None); - } - - #[tokio::test] - async fn test_diff_permissions() { - let from = vec![DatabasePrivileges { - db: "db".to_owned(), - user: "user".to_owned(), - select_priv: true, - insert_priv: true, - update_priv: true, - delete_priv: true, - create_priv: true, - drop_priv: true, - alter_priv: true, - index_priv: true, - create_tmp_table_priv: true, - lock_tables_priv: true, - references_priv: true, - }]; - - let to = vec![DatabasePrivileges { - db: "db".to_owned(), - user: "user".to_owned(), - select_priv: false, - insert_priv: true, - update_priv: true, - delete_priv: true, - create_priv: true, - drop_priv: true, - alter_priv: true, - index_priv: true, - create_tmp_table_priv: true, - lock_tables_priv: true, - references_priv: true, - }]; - - let diffs = diff_permissions(from, &to).await; - - assert_eq!( - diffs, - vec![DatabasePrivilegesDiff::Modified( - DatabasePrivilegeDiffList { - db: "db".to_owned(), - user: "user".to_owned(), - diff: vec![DatabasePrivilegeDiff::YesToNo("select_priv".to_owned())], - } - )] - ); - - assert!(matches!(&diffs[0], DatabasePrivilegesDiff::Modified(_))); - } -} diff --git a/src/core/database_privilege_operations.rs b/src/core/database_privilege_operations.rs new file mode 100644 index 0000000..bbd9c51 --- /dev/null +++ b/src/core/database_privilege_operations.rs @@ -0,0 +1,372 @@ +use std::collections::HashMap; + +use anyhow::Context; +use indoc::indoc; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; + +use crate::core::{ + common::{ + create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, + }, + database_operations::validate_database_name, +}; + +pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ + "db", + "user", + "select_priv", + "insert_priv", + "update_priv", + "delete_priv", + "create_priv", + "drop_priv", + "alter_priv", + "index_priv", + "create_tmp_table_priv", + "lock_tables_priv", + "references_priv", +]; + +pub fn db_priv_field_human_readable_name(name: &str) -> String { + match name { + "db" => "Database".to_owned(), + "user" => "User".to_owned(), + "select_priv" => "Select".to_owned(), + "insert_priv" => "Insert".to_owned(), + "update_priv" => "Update".to_owned(), + "delete_priv" => "Delete".to_owned(), + "create_priv" => "Create".to_owned(), + "drop_priv" => "Drop".to_owned(), + "alter_priv" => "Alter".to_owned(), + "index_priv" => "Index".to_owned(), + "create_tmp_table_priv" => "Temp".to_owned(), + "lock_tables_priv" => "Lock".to_owned(), + "references_priv" => "References".to_owned(), + _ => format!("Unknown({})", name), + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DatabasePrivileges { + pub db: String, + pub user: String, + pub select_priv: bool, + pub insert_priv: bool, + pub update_priv: bool, + pub delete_priv: bool, + pub create_priv: bool, + pub drop_priv: bool, + pub alter_priv: bool, + pub index_priv: bool, + pub create_tmp_table_priv: bool, + pub lock_tables_priv: bool, + pub references_priv: bool, +} + +impl DatabasePrivileges { + pub fn get_privilege_by_name(&self, name: &str) -> bool { + match name { + "select_priv" => self.select_priv, + "insert_priv" => self.insert_priv, + "update_priv" => self.update_priv, + "delete_priv" => self.delete_priv, + "create_priv" => self.create_priv, + "drop_priv" => self.drop_priv, + "alter_priv" => self.alter_priv, + "index_priv" => self.index_priv, + "create_tmp_table_priv" => self.create_tmp_table_priv, + "lock_tables_priv" => self.lock_tables_priv, + "references_priv" => self.references_priv, + _ => false, + } + } + pub fn diff(&self, other: &DatabasePrivileges) -> DatabasePrivilegeDiffList { + debug_assert!(self.db == other.db && self.user == other.user); + + DatabasePrivilegeDiffList { + db: self.db.clone(), + user: self.user.clone(), + diff: DATABASE_PRIVILEGE_FIELDS + .into_iter() + .skip(2) + .filter_map(|field| { + diff_single_priv( + self.get_privilege_by_name(field), + other.get_privilege_by_name(field), + field, + ) + }) + .collect(), + } + } +} + +impl FromRow<'_, MySqlRow> for DatabasePrivileges { + fn from_row(row: &MySqlRow) -> Result { + Ok(Self { + db: row.try_get("db")?, + user: row.try_get("user")?, + select_priv: row.try_get("select_priv").map(rev_yn)?, + insert_priv: row.try_get("insert_priv").map(rev_yn)?, + update_priv: row.try_get("update_priv").map(rev_yn)?, + delete_priv: row.try_get("delete_priv").map(rev_yn)?, + create_priv: row.try_get("create_priv").map(rev_yn)?, + drop_priv: row.try_get("drop_priv").map(rev_yn)?, + alter_priv: row.try_get("alter_priv").map(rev_yn)?, + index_priv: row.try_get("index_priv").map(rev_yn)?, + create_tmp_table_priv: row.try_get("create_tmp_table_priv").map(rev_yn)?, + lock_tables_priv: row.try_get("lock_tables_priv").map(rev_yn)?, + references_priv: row.try_get("references_priv").map(rev_yn)?, + }) + } +} + +pub async fn get_database_privileges( + database_name: &str, + conn: &mut MySqlConnection, +) -> anyhow::Result> { + let unix_user = get_current_unix_user()?; + validate_database_name(database_name, &unix_user)?; + + let result = sqlx::query_as::<_, DatabasePrivileges>(&format!( + "SELECT {} FROM `db` WHERE `db` = ?", + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| quote_identifier(field)) + .join(","), + )) + .bind(database_name) + .fetch_all(conn) + .await + .context("Failed to show database")?; + + Ok(result) +} + +pub async fn get_all_database_privileges( + conn: &mut MySqlConnection, +) -> anyhow::Result> { + let unix_user = get_current_unix_user()?; + + let result = sqlx::query_as::<_, DatabasePrivileges>(&format!( + indoc! {r#" + SELECT {} FROM `db` WHERE `db` IN + (SELECT DISTINCT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') + AND `SCHEMA_NAME` REGEXP ?) + "#}, + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| format!("`{field}`")) + .join(","), + )) + .bind(create_user_group_matching_regex(&unix_user)) + .fetch_all(conn) + .await + .context("Failed to show databases")?; + + Ok(result) +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DatabasePrivilegeDiffList { + pub db: String, + pub user: String, + pub diff: Vec, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DatabasePrivilegeDiff { + YesToNo(String), + NoToYes(String), +} + +fn diff_single_priv(p1: bool, p2: bool, name: &str) -> Option { + match (p1, p2) { + (true, false) => Some(DatabasePrivilegeDiff::YesToNo(name.to_owned())), + (false, true) => Some(DatabasePrivilegeDiff::NoToYes(name.to_owned())), + _ => None, + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DatabasePrivilegesDiff { + New(DatabasePrivileges), + Modified(DatabasePrivilegeDiffList), + Deleted(DatabasePrivileges), +} + +pub async fn diff_permissions( + from: Vec, + to: &[DatabasePrivileges], +) -> Vec { + let from_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter( + from.iter() + .cloned() + .map(|p| ((p.db.clone(), p.user.clone()), p)), + ); + + let to_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter( + to.iter() + .cloned() + .map(|p| ((p.db.clone(), p.user.clone()), p)), + ); + + let mut result = vec![]; + + for p in to { + if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { + let diff = old_p.diff(p); + if !diff.diff.is_empty() { + result.push(DatabasePrivilegesDiff::Modified(diff)); + } + } else { + result.push(DatabasePrivilegesDiff::New(p.clone())); + } + } + + for p in from { + if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { + result.push(DatabasePrivilegesDiff::Deleted(p)); + } + } + + result +} + +pub async fn apply_permission_diffs( + diffs: Vec, + conn: &mut MySqlConnection, +) -> anyhow::Result<()> { + for diff in diffs { + match diff { + DatabasePrivilegesDiff::New(p) => { + let tables = DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| format!("`{field}`")) + .join(","); + + let question_marks = std::iter::repeat("?") + .take(DATABASE_PRIVILEGE_FIELDS.len()) + .join(","); + + sqlx::query( + format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), + ) + .bind(p.db) + .bind(p.user) + .bind(yn(p.select_priv)) + .bind(yn(p.insert_priv)) + .bind(yn(p.update_priv)) + .bind(yn(p.delete_priv)) + .bind(yn(p.create_priv)) + .bind(yn(p.drop_priv)) + .bind(yn(p.alter_priv)) + .bind(yn(p.index_priv)) + .bind(yn(p.create_tmp_table_priv)) + .bind(yn(p.lock_tables_priv)) + .bind(yn(p.references_priv)) + .execute(&mut *conn) + .await?; + } + DatabasePrivilegesDiff::Modified(p) => { + let tables = p + .diff + .iter() + .map(|diff| match diff { + DatabasePrivilegeDiff::YesToNo(name) => format!("`{}` = 'N'", name), + DatabasePrivilegeDiff::NoToYes(name) => format!("`{}` = 'Y'", name), + }) + .join(","); + + sqlx::query( + format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), + ) + .bind(p.db) + .bind(p.user) + .execute(&mut *conn) + .await?; + } + DatabasePrivilegesDiff::Deleted(p) => { + sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") + .bind(p.db) + .bind(p.user) + .execute(&mut *conn) + .await?; + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_diff_single_priv() { + assert_eq!( + diff_single_priv(true, false, "test"), + Some(DatabasePrivilegeDiff::YesToNo("test".to_owned())) + ); + assert_eq!( + diff_single_priv(false, true, "test"), + Some(DatabasePrivilegeDiff::NoToYes("test".to_owned())) + ); + assert_eq!(diff_single_priv(true, true, "test"), None); + assert_eq!(diff_single_priv(false, false, "test"), None); + } + + #[tokio::test] + async fn test_diff_permissions() { + let from = vec![DatabasePrivileges { + db: "db".to_owned(), + user: "user".to_owned(), + select_priv: true, + insert_priv: true, + update_priv: true, + delete_priv: true, + create_priv: true, + drop_priv: true, + alter_priv: true, + index_priv: true, + create_tmp_table_priv: true, + lock_tables_priv: true, + references_priv: true, + }]; + + let to = vec![DatabasePrivileges { + db: "db".to_owned(), + user: "user".to_owned(), + select_priv: false, + insert_priv: true, + update_priv: true, + delete_priv: true, + create_priv: true, + drop_priv: true, + alter_priv: true, + index_priv: true, + create_tmp_table_priv: true, + lock_tables_priv: true, + references_priv: true, + }]; + + let diffs = diff_permissions(from, &to).await; + + assert_eq!( + diffs, + vec![DatabasePrivilegesDiff::Modified( + DatabasePrivilegeDiffList { + db: "db".to_owned(), + user: "user".to_owned(), + diff: vec![DatabasePrivilegeDiff::YesToNo("select_priv".to_owned())], + } + )] + ); + + assert!(matches!(&diffs[0], DatabasePrivilegesDiff::Modified(_))); + } +}