treewide: move some code around, spring cleaning

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-07 20:50:39 +02:00
parent 833251a1a2
commit 76134704f9
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
8 changed files with 431 additions and 423 deletions

View File

@ -7,12 +7,9 @@ use prettytable::{Cell, Row, Table};
use sqlx::{Connection, MySqlConnection}; use sqlx::{Connection, MySqlConnection};
use crate::core::{ use crate::core::{
self, common::{close_database_connection, get_current_unix_user, yn},
common::{close_database_connection, get_current_unix_user}, database_operations::*,
database_operations::{ database_privilege_operations::*,
apply_permission_diffs, db_priv_field_human_readable_name, diff_permissions, yn,
DatabasePrivileges, DATABASE_PRIVILEGE_FIELDS,
},
user_operations::user_exists, user_operations::user_exists,
}; };
@ -174,7 +171,7 @@ async fn create_databases(
for name in args.name { for name in args.name {
// TODO: This can be optimized by fetching all the database privileges in one query. // TODO: This can be optimized by fetching all the database privileges in one query.
if let Err(e) = core::database_operations::create_database(&name, conn).await { if let Err(e) = create_database(&name, conn).await {
eprintln!("Failed to create database '{}': {}", name, e); eprintln!("Failed to create database '{}': {}", name, e);
eprintln!("Skipping..."); eprintln!("Skipping...");
} }
@ -190,7 +187,7 @@ async fn drop_databases(args: DatabaseDropArgs, conn: &mut MySqlConnection) -> a
for name in args.name { for name in args.name {
// TODO: This can be optimized by fetching all the database privileges in one query. // TODO: This can be optimized by fetching all the database privileges in one query.
if let Err(e) = core::database_operations::drop_database(&name, conn).await { if let Err(e) = drop_database(&name, conn).await {
eprintln!("Failed to drop database '{}': {}", name, e); eprintln!("Failed to drop database '{}': {}", name, e);
eprintln!("Skipping..."); eprintln!("Skipping...");
} }
@ -200,7 +197,7 @@ async fn drop_databases(args: DatabaseDropArgs, conn: &mut MySqlConnection) -> a
} }
async fn list_databases(args: DatabaseListArgs, conn: &mut MySqlConnection) -> anyhow::Result<()> { async fn list_databases(args: DatabaseListArgs, conn: &mut MySqlConnection) -> anyhow::Result<()> {
let databases = core::database_operations::get_database_list(conn).await?; let databases = get_database_list(conn).await?;
if databases.is_empty() { if databases.is_empty() {
println!("No databases to show."); println!("No databases to show.");
@ -223,12 +220,12 @@ async fn show_databases(
conn: &mut MySqlConnection, conn: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let database_users_to_show = if args.name.is_empty() { let database_users_to_show = if args.name.is_empty() {
core::database_operations::get_all_database_privileges(conn).await? get_all_database_privileges(conn).await?
} else { } else {
// TODO: This can be optimized by fetching all the database privileges in one query. // TODO: This can be optimized by fetching all the database privileges in one query.
let mut result = Vec::with_capacity(args.name.len()); let mut result = Vec::with_capacity(args.name.len());
for name in args.name { for name in args.name {
match core::database_operations::get_database_privileges(&name, conn).await { match get_database_privileges(&name, conn).await {
Ok(db) => result.extend(db), Ok(db) => result.extend(db),
Err(e) => { Err(e) => {
eprintln!("Failed to show database '{}': {}", name, e); eprintln!("Failed to show database '{}': {}", name, e);
@ -420,9 +417,9 @@ pub async fn edit_permissions(
conn: &mut MySqlConnection, conn: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let permission_data = if let Some(name) = &args.name { let permission_data = if let Some(name) = &args.name {
core::database_operations::get_database_privileges(name, conn).await? get_database_privileges(name, conn).await?
} else { } else {
core::database_operations::get_all_database_privileges(conn).await? get_all_database_privileges(conn).await?
}; };
let permissions_to_change = if !args.perm.is_empty() { let permissions_to_change = if !args.perm.is_empty() {

View File

@ -7,8 +7,10 @@ use crate::{
mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser}, mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
}, },
core::{ core::{
common::yn,
config::{get_config, mysql_connection_from_config, GlobalConfigArgs}, config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
database_operations::{self, yn}, database_operations::{create_database, drop_database, get_database_list},
database_privilege_operations,
}, },
}; };
@ -129,20 +131,20 @@ pub async fn main() -> anyhow::Result<()> {
Command::Create(args) => { Command::Create(args) => {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?;
for name in filtered_names { for name in filtered_names {
database_operations::create_database(&name, &mut connection).await?; create_database(&name, &mut connection).await?;
println!("Database {} created.", name); println!("Database {} created.", name);
} }
} }
Command::Drop(args) => { Command::Drop(args) => {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?; let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?;
for name in filtered_names { for name in filtered_names {
database_operations::drop_database(&name, &mut connection).await?; drop_database(&name, &mut connection).await?;
println!("Database {} dropped.", name); println!("Database {} dropped.", name);
} }
} }
Command::Show(args) => { Command::Show(args) => {
let names = if args.name.is_empty() { let names = if args.name.is_empty() {
database_operations::get_database_list(&mut connection).await? get_database_list(&mut connection).await?
} else { } else {
filter_db_or_user_names(args.name, DbOrUser::Database)? filter_db_or_user_names(args.name, DbOrUser::Database)?
}; };
@ -176,7 +178,7 @@ async fn show_db(name: &str, conn: &mut MySqlConnection) -> anyhow::Result<()> {
// for non-existent databases will report with no users. // for non-existent databases will report with no users.
// This function should *not* check for db existence, only // This function should *not* check for db existence, only
// validate the names. // validate the names.
let permissions = database_operations::get_database_privileges(name, conn) let permissions = database_privilege_operations::get_database_privileges(name, conn)
.await .await
.unwrap_or(vec![]); .unwrap_or(vec![]);

View File

@ -9,10 +9,7 @@ use crate::{
core::{ core::{
common::{close_database_connection, get_current_unix_user}, common::{close_database_connection, get_current_unix_user},
config::{get_config, mysql_connection_from_config, GlobalConfigArgs}, config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
user_operations::{ user_operations::*,
create_database_user, delete_database_user, get_all_database_users_for_unix_user,
get_database_user_for_user, set_password_for_database_user, user_exists,
},
}, },
}; };

View File

@ -9,9 +9,9 @@ use serde_json::json;
use sqlx::{Connection, MySqlConnection}; use sqlx::{Connection, MySqlConnection};
use crate::core::{ use crate::core::{
common::close_database_connection, common::{close_database_connection, get_current_unix_user},
database_operations::get_databases_where_user_has_privileges, database_operations::*,
user_operations::validate_user_name, user_operations::*,
}; };
#[derive(Parser)] #[derive(Parser)]
@ -99,7 +99,7 @@ async fn create_users(args: UserCreateArgs, conn: &mut MySqlConnection) -> anyho
} }
for username in args.username { for username in args.username {
if let Err(e) = crate::core::user_operations::create_database_user(&username, conn).await { if let Err(e) = create_database_user(&username, conn).await {
eprintln!("{}", e); eprintln!("{}", e);
eprintln!("Skipping...\n"); eprintln!("Skipping...\n");
continue; continue;
@ -135,7 +135,7 @@ async fn drop_users(args: UserDeleteArgs, conn: &mut MySqlConnection) -> anyhow:
} }
for username in args.username { for username in args.username {
if let Err(e) = crate::core::user_operations::delete_database_user(&username, conn).await { if let Err(e) = delete_database_user(&username, conn).await {
eprintln!("{}", e); eprintln!("{}", e);
eprintln!("Skipping..."); eprintln!("Skipping...");
} }
@ -160,7 +160,7 @@ async fn change_password_for_user(
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// NOTE: although this also is checked in `set_password_for_database_user`, we check it here // NOTE: although this also is checked in `set_password_for_database_user`, we check it here
// to provide a more natural order of error messages. // to provide a more natural order of error messages.
let unix_user = crate::core::common::get_current_unix_user()?; let unix_user = get_current_unix_user()?;
validate_user_name(&args.username, &unix_user)?; validate_user_name(&args.username, &unix_user)?;
let password = if let Some(password_file) = args.password_file { let password = if let Some(password_file) = args.password_file {
@ -172,17 +172,16 @@ async fn change_password_for_user(
read_password_from_stdin_with_double_check(&args.username)? read_password_from_stdin_with_double_check(&args.username)?
}; };
crate::core::user_operations::set_password_for_database_user(&args.username, &password, conn) set_password_for_database_user(&args.username, &password, conn).await?;
.await?;
Ok(()) Ok(())
} }
async fn show_users(args: UserShowArgs, conn: &mut MySqlConnection) -> anyhow::Result<()> { async fn show_users(args: UserShowArgs, conn: &mut MySqlConnection) -> anyhow::Result<()> {
let unix_user = crate::core::common::get_current_unix_user()?; let unix_user = get_current_unix_user()?;
let users = if args.username.is_empty() { let users = if args.username.is_empty() {
crate::core::user_operations::get_all_database_users_for_unix_user(&unix_user, conn).await? get_all_database_users_for_unix_user(&unix_user, conn).await?
} else { } else {
let mut result = vec![]; let mut result = vec![];
for username in args.username { for username in args.username {
@ -192,8 +191,7 @@ async fn show_users(args: UserShowArgs, conn: &mut MySqlConnection) -> anyhow::R
continue; continue;
} }
let user = let user = get_database_user_for_user(&username, conn).await?;
crate::core::user_operations::get_database_user_for_user(&username, conn).await?;
if let Some(user) = user { if let Some(user) = user {
result.push(user); result.push(user);
} else { } else {

View File

@ -1,4 +1,5 @@
pub mod common; pub mod common;
pub mod config; pub mod config;
pub mod database_operations; pub mod database_operations;
pub mod database_privilege_operations;
pub mod user_operations; pub mod user_operations;

View File

@ -161,3 +161,24 @@ pub fn quote_literal(s: &str) -> String {
pub fn quote_identifier(s: &str) -> String { pub fn quote_identifier(s: &str) -> String {
format!("`{}`", s.replace('`', r"\`")) format!("`{}`", s.replace('`', r"\`"))
} }
#[inline]
pub(crate) fn yn(b: bool) -> &'static str {
if b {
"Y"
} else {
"N"
}
}
#[inline]
pub(crate) fn rev_yn(s: &str) -> bool {
match s.to_lowercase().as_str() {
"y" => true,
"n" => false,
_ => {
log::warn!("Invalid value for privilege: {}", s);
false
}
}
}

View File

@ -1,15 +1,16 @@
use std::collections::HashMap;
use anyhow::Context; use anyhow::Context;
use indoc::{formatdoc, indoc}; use indoc::formatdoc;
use itertools::Itertools; use itertools::Itertools;
use nix::unistd::User; use nix::unistd::User;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; use sqlx::{prelude::*, MySqlConnection};
use super::common::{ use crate::core::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier, validate_name_token, common::{
validate_ownership_by_user_prefix, create_user_group_matching_regex, get_current_unix_user, quote_identifier,
validate_name_token, validate_ownership_by_user_prefix,
},
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
}; };
pub async fn create_database(name: &str, conn: &mut MySqlConnection) -> anyhow::Result<()> { pub async fn create_database(name: &str, conn: &mut MySqlConnection) -> anyhow::Result<()> {
@ -100,324 +101,12 @@ pub async fn get_databases_where_user_has_privileges(
.fetch_all(conn) .fetch_all(conn)
.await? .await?
.into_iter() .into_iter()
.map(|databases| { .map(|databases| databases.try_get::<String, _>("database").unwrap())
databases.try_get::<String, _>("database").unwrap()
})
.collect(); .collect();
Ok(result) Ok(result)
} }
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db",
"user",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name {
"db" => "Database".to_owned(),
"user" => "User".to_owned(),
"select_priv" => "Select".to_owned(),
"insert_priv" => "Insert".to_owned(),
"update_priv" => "Update".to_owned(),
"delete_priv" => "Delete".to_owned(),
"create_priv" => "Create".to_owned(),
"drop_priv" => "Drop".to_owned(),
"alter_priv" => "Alter".to_owned(),
"index_priv" => "Index".to_owned(),
"create_tmp_table_priv" => "Temp".to_owned(),
"lock_tables_priv" => "Lock".to_owned(),
"references_priv" => "References".to_owned(),
_ => format!("Unknown({})", name),
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabasePrivileges {
pub db: String,
pub user: String,
pub select_priv: bool,
pub insert_priv: bool,
pub update_priv: bool,
pub delete_priv: bool,
pub create_priv: bool,
pub drop_priv: bool,
pub alter_priv: bool,
pub index_priv: bool,
pub create_tmp_table_priv: bool,
pub lock_tables_priv: bool,
pub references_priv: bool,
}
impl DatabasePrivileges {
pub fn get_privilege_by_name(&self, name: &str) -> bool {
match name {
"select_priv" => self.select_priv,
"insert_priv" => self.insert_priv,
"update_priv" => self.update_priv,
"delete_priv" => self.delete_priv,
"create_priv" => self.create_priv,
"drop_priv" => self.drop_priv,
"alter_priv" => self.alter_priv,
"index_priv" => self.index_priv,
"create_tmp_table_priv" => self.create_tmp_table_priv,
"lock_tables_priv" => self.lock_tables_priv,
"references_priv" => self.references_priv,
_ => false,
}
}
pub fn diff(&self, other: &DatabasePrivileges) -> DatabasePrivilegeDiffList {
debug_assert!(self.db == other.db && self.user == other.user);
DatabasePrivilegeDiffList {
db: self.db.clone(),
user: self.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.filter_map(|field| {
diff_single_priv(
self.get_privilege_by_name(field),
other.get_privilege_by_name(field),
field,
)
})
.collect(),
}
}
}
#[inline]
pub(crate) fn yn(b: bool) -> &'static str {
if b {
"Y"
} else {
"N"
}
}
#[inline]
pub(crate) fn rev_yn(s: &str) -> bool {
match s.to_lowercase().as_str() {
"y" => true,
"n" => false,
_ => {
log::warn!("Invalid value for privilege: {}", s);
false
}
}
}
impl FromRow<'_, MySqlRow> for DatabasePrivileges {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
db: row.try_get("db")?,
user: row.try_get("user")?,
select_priv: row.try_get("select_priv").map(rev_yn)?,
insert_priv: row.try_get("insert_priv").map(rev_yn)?,
update_priv: row.try_get("update_priv").map(rev_yn)?,
delete_priv: row.try_get("delete_priv").map(rev_yn)?,
create_priv: row.try_get("create_priv").map(rev_yn)?,
drop_priv: row.try_get("drop_priv").map(rev_yn)?,
alter_priv: row.try_get("alter_priv").map(rev_yn)?,
index_priv: row.try_get("index_priv").map(rev_yn)?,
create_tmp_table_priv: row.try_get("create_tmp_table_priv").map(rev_yn)?,
lock_tables_priv: row.try_get("lock_tables_priv").map(rev_yn)?,
references_priv: row.try_get("references_priv").map(rev_yn)?,
})
}
}
pub async fn get_database_privileges(
database_name: &str,
conn: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivileges>> {
let unix_user = get_current_unix_user()?;
validate_database_name(database_name, &unix_user)?;
let result = sqlx::query_as::<_, DatabasePrivileges>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(database_name)
.fetch_all(conn)
.await
.context("Failed to show database")?;
Ok(result)
}
pub async fn get_all_database_privileges(
conn: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivileges>> {
let unix_user = get_current_unix_user()?;
let result = sqlx::query_as::<_, DatabasePrivileges>(&format!(
indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?)
"#},
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(","),
))
.bind(create_user_group_matching_regex(&unix_user))
.fetch_all(conn)
.await
.context("Failed to show databases")?;
Ok(result)
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabasePrivilegeDiffList {
pub db: String,
pub user: String,
pub diff: Vec<DatabasePrivilegeDiff>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabasePrivilegeDiff {
YesToNo(String),
NoToYes(String),
}
fn diff_single_priv(p1: bool, p2: bool, name: &str) -> Option<DatabasePrivilegeDiff> {
match (p1, p2) {
(true, false) => Some(DatabasePrivilegeDiff::YesToNo(name.to_owned())),
(false, true) => Some(DatabasePrivilegeDiff::NoToYes(name.to_owned())),
_ => None,
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabasePrivilegesDiff {
New(DatabasePrivileges),
Modified(DatabasePrivilegeDiffList),
Deleted(DatabasePrivileges),
}
pub async fn diff_permissions(
from: Vec<DatabasePrivileges>,
to: &[DatabasePrivileges],
) -> Vec<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let to_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter(
to.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let mut result = vec![];
for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p);
if !diff.diff.is_empty() {
result.push(DatabasePrivilegesDiff::Modified(diff));
}
} else {
result.push(DatabasePrivilegesDiff::New(p.clone()));
}
}
for p in from {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
result.push(DatabasePrivilegesDiff::Deleted(p));
}
}
result
}
pub async fn apply_permission_diffs(
diffs: Vec<DatabasePrivilegesDiff>,
conn: &mut MySqlConnection,
) -> anyhow::Result<()> {
for diff in diffs {
match diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(",");
let question_marks = std::iter::repeat("?")
.take(DATABASE_PRIVILEGE_FIELDS.len())
.join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db)
.bind(p.user)
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(&mut *conn)
.await?;
}
DatabasePrivilegesDiff::Modified(p) => {
let tables = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeDiff::YesToNo(name) => format!("`{}` = 'N'", name),
DatabasePrivilegeDiff::NoToYes(name) => format!("`{}` = 'Y'", name),
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
)
.bind(p.db)
.bind(p.user)
.execute(&mut *conn)
.await?;
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db)
.bind(p.user)
.execute(&mut *conn)
.await?;
}
}
}
Ok(())
}
/// NOTE: It is very critical that this function validates the database name /// NOTE: It is very critical that this function validates the database name
/// properly. MySQL does not seem to allow for prepared statements, binding /// properly. MySQL does not seem to allow for prepared statements, binding
/// the database name as a parameter to the query. This means that we have /// the database name as a parameter to the query. This means that we have
@ -428,72 +117,3 @@ pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diff_single_priv() {
assert_eq!(
diff_single_priv(true, false, "test"),
Some(DatabasePrivilegeDiff::YesToNo("test".to_owned()))
);
assert_eq!(
diff_single_priv(false, true, "test"),
Some(DatabasePrivilegeDiff::NoToYes("test".to_owned()))
);
assert_eq!(diff_single_priv(true, true, "test"), None);
assert_eq!(diff_single_priv(false, false, "test"), None);
}
#[tokio::test]
async fn test_diff_permissions() {
let from = vec![DatabasePrivileges {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: true,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
}];
let to = vec![DatabasePrivileges {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: false,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
}];
let diffs = diff_permissions(from, &to).await;
assert_eq!(
diffs,
vec![DatabasePrivilegesDiff::Modified(
DatabasePrivilegeDiffList {
db: "db".to_owned(),
user: "user".to_owned(),
diff: vec![DatabasePrivilegeDiff::YesToNo("select_priv".to_owned())],
}
)]
);
assert!(matches!(&diffs[0], DatabasePrivilegesDiff::Modified(_)));
}
}

View File

@ -0,0 +1,372 @@
use std::collections::HashMap;
use anyhow::Context;
use indoc::indoc;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
use crate::core::{
common::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn,
},
database_operations::validate_database_name,
};
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db",
"user",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name {
"db" => "Database".to_owned(),
"user" => "User".to_owned(),
"select_priv" => "Select".to_owned(),
"insert_priv" => "Insert".to_owned(),
"update_priv" => "Update".to_owned(),
"delete_priv" => "Delete".to_owned(),
"create_priv" => "Create".to_owned(),
"drop_priv" => "Drop".to_owned(),
"alter_priv" => "Alter".to_owned(),
"index_priv" => "Index".to_owned(),
"create_tmp_table_priv" => "Temp".to_owned(),
"lock_tables_priv" => "Lock".to_owned(),
"references_priv" => "References".to_owned(),
_ => format!("Unknown({})", name),
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabasePrivileges {
pub db: String,
pub user: String,
pub select_priv: bool,
pub insert_priv: bool,
pub update_priv: bool,
pub delete_priv: bool,
pub create_priv: bool,
pub drop_priv: bool,
pub alter_priv: bool,
pub index_priv: bool,
pub create_tmp_table_priv: bool,
pub lock_tables_priv: bool,
pub references_priv: bool,
}
impl DatabasePrivileges {
pub fn get_privilege_by_name(&self, name: &str) -> bool {
match name {
"select_priv" => self.select_priv,
"insert_priv" => self.insert_priv,
"update_priv" => self.update_priv,
"delete_priv" => self.delete_priv,
"create_priv" => self.create_priv,
"drop_priv" => self.drop_priv,
"alter_priv" => self.alter_priv,
"index_priv" => self.index_priv,
"create_tmp_table_priv" => self.create_tmp_table_priv,
"lock_tables_priv" => self.lock_tables_priv,
"references_priv" => self.references_priv,
_ => false,
}
}
pub fn diff(&self, other: &DatabasePrivileges) -> DatabasePrivilegeDiffList {
debug_assert!(self.db == other.db && self.user == other.user);
DatabasePrivilegeDiffList {
db: self.db.clone(),
user: self.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.filter_map(|field| {
diff_single_priv(
self.get_privilege_by_name(field),
other.get_privilege_by_name(field),
field,
)
})
.collect(),
}
}
}
impl FromRow<'_, MySqlRow> for DatabasePrivileges {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
db: row.try_get("db")?,
user: row.try_get("user")?,
select_priv: row.try_get("select_priv").map(rev_yn)?,
insert_priv: row.try_get("insert_priv").map(rev_yn)?,
update_priv: row.try_get("update_priv").map(rev_yn)?,
delete_priv: row.try_get("delete_priv").map(rev_yn)?,
create_priv: row.try_get("create_priv").map(rev_yn)?,
drop_priv: row.try_get("drop_priv").map(rev_yn)?,
alter_priv: row.try_get("alter_priv").map(rev_yn)?,
index_priv: row.try_get("index_priv").map(rev_yn)?,
create_tmp_table_priv: row.try_get("create_tmp_table_priv").map(rev_yn)?,
lock_tables_priv: row.try_get("lock_tables_priv").map(rev_yn)?,
references_priv: row.try_get("references_priv").map(rev_yn)?,
})
}
}
pub async fn get_database_privileges(
database_name: &str,
conn: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivileges>> {
let unix_user = get_current_unix_user()?;
validate_database_name(database_name, &unix_user)?;
let result = sqlx::query_as::<_, DatabasePrivileges>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(database_name)
.fetch_all(conn)
.await
.context("Failed to show database")?;
Ok(result)
}
pub async fn get_all_database_privileges(
conn: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivileges>> {
let unix_user = get_current_unix_user()?;
let result = sqlx::query_as::<_, DatabasePrivileges>(&format!(
indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?)
"#},
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(","),
))
.bind(create_user_group_matching_regex(&unix_user))
.fetch_all(conn)
.await
.context("Failed to show databases")?;
Ok(result)
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabasePrivilegeDiffList {
pub db: String,
pub user: String,
pub diff: Vec<DatabasePrivilegeDiff>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabasePrivilegeDiff {
YesToNo(String),
NoToYes(String),
}
fn diff_single_priv(p1: bool, p2: bool, name: &str) -> Option<DatabasePrivilegeDiff> {
match (p1, p2) {
(true, false) => Some(DatabasePrivilegeDiff::YesToNo(name.to_owned())),
(false, true) => Some(DatabasePrivilegeDiff::NoToYes(name.to_owned())),
_ => None,
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabasePrivilegesDiff {
New(DatabasePrivileges),
Modified(DatabasePrivilegeDiffList),
Deleted(DatabasePrivileges),
}
pub async fn diff_permissions(
from: Vec<DatabasePrivileges>,
to: &[DatabasePrivileges],
) -> Vec<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let to_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter(
to.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let mut result = vec![];
for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p);
if !diff.diff.is_empty() {
result.push(DatabasePrivilegesDiff::Modified(diff));
}
} else {
result.push(DatabasePrivilegesDiff::New(p.clone()));
}
}
for p in from {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
result.push(DatabasePrivilegesDiff::Deleted(p));
}
}
result
}
pub async fn apply_permission_diffs(
diffs: Vec<DatabasePrivilegesDiff>,
conn: &mut MySqlConnection,
) -> anyhow::Result<()> {
for diff in diffs {
match diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(",");
let question_marks = std::iter::repeat("?")
.take(DATABASE_PRIVILEGE_FIELDS.len())
.join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db)
.bind(p.user)
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(&mut *conn)
.await?;
}
DatabasePrivilegesDiff::Modified(p) => {
let tables = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeDiff::YesToNo(name) => format!("`{}` = 'N'", name),
DatabasePrivilegeDiff::NoToYes(name) => format!("`{}` = 'Y'", name),
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
)
.bind(p.db)
.bind(p.user)
.execute(&mut *conn)
.await?;
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db)
.bind(p.user)
.execute(&mut *conn)
.await?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diff_single_priv() {
assert_eq!(
diff_single_priv(true, false, "test"),
Some(DatabasePrivilegeDiff::YesToNo("test".to_owned()))
);
assert_eq!(
diff_single_priv(false, true, "test"),
Some(DatabasePrivilegeDiff::NoToYes("test".to_owned()))
);
assert_eq!(diff_single_priv(true, true, "test"), None);
assert_eq!(diff_single_priv(false, false, "test"), None);
}
#[tokio::test]
async fn test_diff_permissions() {
let from = vec![DatabasePrivileges {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: true,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
}];
let to = vec![DatabasePrivileges {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: false,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
}];
let diffs = diff_permissions(from, &to).await;
assert_eq!(
diffs,
vec![DatabasePrivilegesDiff::Modified(
DatabasePrivilegeDiffList {
db: "db".to_owned(),
user: "user".to_owned(),
diff: vec![DatabasePrivilegeDiff::YesToNo("select_priv".to_owned())],
}
)]
);
assert!(matches!(&diffs[0], DatabasePrivilegesDiff::Modified(_)));
}
}