Add large parts of the permission editor

This commit is contained in:
Oystein Kristoffer Tveit 2024-04-26 00:30:32 +02:00
parent 0837ac9fc7
commit b0bffc45ee
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
9 changed files with 546 additions and 128 deletions

1
Cargo.lock generated
View File

@ -892,6 +892,7 @@ dependencies = [
"log", "log",
"nix", "nix",
"prettytable", "prettytable",
"rand",
"ratatui", "ratatui",
"rpassword", "rpassword",
"serde", "serde",

View File

@ -13,6 +13,7 @@ itertools = "0.12.1"
log = "0.4.21" log = "0.4.21"
nix = { version = "0.28.0", features = ["user"] } nix = { version = "0.28.0", features = ["user"] }
prettytable = "0.10.0" prettytable = "0.10.0"
rand = "0.8.5"
ratatui = { version = "0.26.2", optional = true } ratatui = { version = "0.26.2", optional = true }
rpassword = "7.3.1" rpassword = "7.3.1"
serde = "1.0.198" serde = "1.0.198"
@ -32,4 +33,4 @@ path = "src/main.rs"
[profile.release] [profile.release]
strip = true strip = true
lto = true lto = true
codegen-units = 1 codegen-units = 1

View File

@ -1,9 +1,19 @@
use anyhow::Context; use anyhow::{anyhow, Context};
use clap::Parser; use clap::Parser;
use indoc::indoc;
use itertools::Itertools;
use prettytable::{Cell, Row, Table}; use prettytable::{Cell, Row, Table};
use rand::prelude::*;
use sqlx::{Connection, MySqlConnection}; use sqlx::{Connection, MySqlConnection};
use crate::core::{self, database_operations::DatabasePrivileges}; use crate::core::{
self,
common::get_current_unix_user,
database_operations::{
apply_permission_diffs, db_priv_field_human_readable_name, diff_permissions, yn,
DatabasePrivileges, DATABASE_PRIVILEGE_FIELDS,
},
};
#[derive(Parser)] #[derive(Parser)]
pub struct DatabaseArgs { pub struct DatabaseArgs {
@ -36,32 +46,31 @@ enum DatabaseCommand {
/// Change permissions for the DATABASE(S). Run `edit-perm --help` for more information. /// Change permissions for the DATABASE(S). Run `edit-perm --help` for more information.
/// ///
/// TODO: fix this help message.
///
/// This command has two modes of operation: /// This command has two modes of operation:
/// 1. Interactive mode: If the `-t` flag is used, the user will be prompted to edit the permissions using a text editor. /// 1. Interactive mode: If nothing else is specified, the user will be prompted to edit the permissions using a text editor.
/// 2. Non-interactive mode: If the `-t` flag is not used, the user can specify the permissions to change using the `-p` flag.
/// ///
/// In non-interactive mode, the `-p` flag should be followed by strings, each representing a single permission change. /// Follow the instructions inside the editor for more information.
/// ///
/// The permission arguments should be a string, formatted as `db:user:privileges` /// 2. Non-interactive mode: If the `-p` flag is specified, the user can write permissions using arguments.
/// where privs are a string of characters, each representing a single permissions,
/// with the exception of `A` which represents all permissions.
/// ///
/// The permission to character mapping is as follows: /// The permission arguments should be formatted as `<db>:<user>:<privileges>`
/// where the privileges are a string of characters, each representing a single permissions.
/// The character `A` is an exception, because it represents all permissions.
/// ///
/// - `s` - SELECT /// The permission to character mapping is as follows:
/// - `i` - INSERT ///
/// - `u` - UPDATE /// - `s` - SELECT
/// - `d` - DELETE /// - `i` - INSERT
/// - `c` - CREATE /// - `u` - UPDATE
/// - `D` - DROP /// - `d` - DELETE
/// - `a` - ALTER /// - `c` - CREATE
/// - `I` - INDEX /// - `D` - DROP
/// - `t` - CREATE TEMPORARY TABLES /// - `a` - ALTER
/// - `l` - LOCK TABLES /// - `I` - INDEX
/// - `r` - REFERENCES /// - `t` - CREATE TEMPORARY TABLES
/// - `A` - ALL PRIVILEGES /// - `l` - LOCK TABLES
/// - `r` - REFERENCES
/// - `A` - ALL PRIVILEGES
/// ///
#[command(display_name = "edit-perm", alias = "e", verbatim_doc_comment)] #[command(display_name = "edit-perm", alias = "e", verbatim_doc_comment)]
EditPerm(DatabaseEditPermArgs), EditPerm(DatabaseEditPermArgs),
@ -111,10 +120,6 @@ struct DatabaseEditPermArgs {
#[arg(short, long)] #[arg(short, long)]
json: bool, json: bool,
/// Whether to edit the permissions using a text editor.
#[arg(short, long)]
text: bool,
/// Specify the text editor to use for editing permissions. /// Specify the text editor to use for editing permissions.
#[arg(short, long)] #[arg(short, long)]
editor: Option<String>, editor: Option<String>,
@ -184,9 +189,9 @@ async fn list_databases(args: DatabaseListArgs, conn: &mut MySqlConnection) -> a
if args.json { if args.json {
println!("{}", serde_json::to_string_pretty(&databases)?); println!("{}", serde_json::to_string_pretty(&databases)?);
} else { } else {
for db in databases { for db in databases {
println!("{}", db); println!("{}", db);
} }
} }
Ok(()) Ok(())
@ -223,9 +228,10 @@ async fn show_databases(
} else { } else {
let mut table = Table::new(); let mut table = Table::new();
table.add_row(Row::new( table.add_row(Row::new(
core::database_operations::HUMAN_READABLE_DATABASE_PRIVILEGE_NAMES DATABASE_PRIVILEGE_FIELDS
.iter() .into_iter()
.map(|(name, _)| Cell::new(name)) .map(db_priv_field_human_readable_name)
.map(|name| Cell::new(&name))
.collect(), .collect(),
)); ));
@ -266,44 +272,44 @@ fn parse_permission_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivilege
let mut result = DatabasePrivileges { let mut result = DatabasePrivileges {
db, db,
user, user,
select_priv: "N".to_string(), select_priv: false,
insert_priv: "N".to_string(), insert_priv: false,
update_priv: "N".to_string(), update_priv: false,
delete_priv: "N".to_string(), delete_priv: false,
create_priv: "N".to_string(), create_priv: false,
drop_priv: "N".to_string(), drop_priv: false,
alter_priv: "N".to_string(), alter_priv: false,
index_priv: "N".to_string(), index_priv: false,
create_tmp_table_priv: "N".to_string(), create_tmp_table_priv: false,
lock_tables_priv: "N".to_string(), lock_tables_priv: false,
references_priv: "N".to_string(), references_priv: false,
}; };
for char in privs.chars() { for char in privs.chars() {
match char { match char {
's' => result.select_priv = "Y".to_string(), 's' => result.select_priv = true,
'i' => result.insert_priv = "Y".to_string(), 'i' => result.insert_priv = true,
'u' => result.update_priv = "Y".to_string(), 'u' => result.update_priv = true,
'd' => result.delete_priv = "Y".to_string(), 'd' => result.delete_priv = true,
'c' => result.create_priv = "Y".to_string(), 'c' => result.create_priv = true,
'D' => result.drop_priv = "Y".to_string(), 'D' => result.drop_priv = true,
'a' => result.alter_priv = "Y".to_string(), 'a' => result.alter_priv = true,
'I' => result.index_priv = "Y".to_string(), 'I' => result.index_priv = true,
't' => result.create_tmp_table_priv = "Y".to_string(), 't' => result.create_tmp_table_priv = true,
'l' => result.lock_tables_priv = "Y".to_string(), 'l' => result.lock_tables_priv = true,
'r' => result.references_priv = "Y".to_string(), 'r' => result.references_priv = true,
'A' => { 'A' => {
result.select_priv = "Y".to_string(); result.select_priv = true;
result.insert_priv = "Y".to_string(); result.insert_priv = true;
result.update_priv = "Y".to_string(); result.update_priv = true;
result.delete_priv = "Y".to_string(); result.delete_priv = true;
result.create_priv = "Y".to_string(); result.create_priv = true;
result.drop_priv = "Y".to_string(); result.drop_priv = true;
result.alter_priv = "Y".to_string(); result.alter_priv = true;
result.index_priv = "Y".to_string(); result.index_priv = true;
result.create_tmp_table_priv = "Y".to_string(); result.create_tmp_table_priv = true;
result.lock_tables_priv = "Y".to_string(); result.lock_tables_priv = true;
result.references_priv = "Y".to_string(); result.references_priv = true;
} }
_ => anyhow::bail!("Invalid permission character: {}", char), _ => anyhow::bail!("Invalid permission character: {}", char),
} }
@ -312,18 +318,88 @@ fn parse_permission_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivilege
Ok(result) Ok(result)
} }
fn parse_permission(yn: &str) -> anyhow::Result<bool> {
match yn.to_ascii_lowercase().as_str() {
"y" => Ok(true),
"n" => Ok(false),
_ => Err(anyhow!("Expected Y or N, found {}", yn)),
}
}
fn parse_permission_data_from_editor(content: String) -> anyhow::Result<Vec<DatabasePrivileges>> {
content
.trim()
.split('\n')
.map(|line| line.trim())
.filter(|line| !(line.starts_with('#') || line.starts_with("//") || line == &""))
.skip(1)
.map(|line| {
let line_parts: Vec<&str> = line.trim().split_ascii_whitespace().collect();
if line_parts.len() != DATABASE_PRIVILEGE_FIELDS.len() {
anyhow::bail!("")
}
Ok(DatabasePrivileges {
db: (*line_parts.get(0).unwrap()).to_owned(),
user: (*line_parts.get(1).unwrap()).to_owned(),
select_priv: parse_permission(*line_parts.get(2).unwrap())
.context("Could not parse SELECT privilege")?,
insert_priv: parse_permission(*line_parts.get(3).unwrap())
.context("Could not parse INSERT privilege")?,
update_priv: parse_permission(*line_parts.get(4).unwrap())
.context("Could not parse UPDATE privilege")?,
delete_priv: parse_permission(*line_parts.get(5).unwrap())
.context("Could not parse DELETE privilege")?,
create_priv: parse_permission(*line_parts.get(6).unwrap())
.context("Could not parse CREATE privilege")?,
drop_priv: parse_permission(*line_parts.get(7).unwrap())
.context("Could not parse DROP privilege")?,
alter_priv: parse_permission(*line_parts.get(8).unwrap())
.context("Could not parse ALTER privilege")?,
index_priv: parse_permission(*line_parts.get(9).unwrap())
.context("Could not parse INDEX privilege")?,
create_tmp_table_priv: parse_permission(*line_parts.get(10).unwrap())
.context("Could not parse CREATE TEMPORARY TABLE privilege")?,
lock_tables_priv: parse_permission(*line_parts.get(11).unwrap())
.context("Could not parse LOCK TABLES privilege")?,
references_priv: parse_permission(*line_parts.get(12).unwrap())
.context("Could not parse REFERENCES privilege")?,
})
})
.collect::<anyhow::Result<Vec<DatabasePrivileges>>>()
}
fn display_permissions_as_editor_line(privs: &DatabasePrivileges) -> String {
vec![
privs.db.as_str(),
privs.user.as_str(),
yn(privs.select_priv),
yn(privs.insert_priv),
yn(privs.update_priv),
yn(privs.delete_priv),
yn(privs.create_priv),
yn(privs.drop_priv),
yn(privs.alter_priv),
yn(privs.index_priv),
yn(privs.create_tmp_table_priv),
yn(privs.lock_tables_priv),
yn(privs.references_priv),
]
.join("\t")
}
async fn edit_permissions( async fn edit_permissions(
args: DatabaseEditPermArgs, args: DatabaseEditPermArgs,
conn: &mut MySqlConnection, conn: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let _data = if let Some(name) = &args.name { let permission_data = if let Some(name) = &args.name {
core::database_operations::get_database_privileges(name, conn).await? core::database_operations::get_database_privileges(name, conn).await?
} else { } else {
core::database_operations::get_all_database_privileges(conn).await? core::database_operations::get_all_database_privileges(conn).await?
}; };
if !args.text { let permissions_to_change = if !args.perm.is_empty() {
let permissions_to_change: Vec<DatabasePrivileges> = if let Some(name) = args.name { if let Some(name) = args.name {
args.perm args.perm
.iter() .iter()
.map(|perm| { .map(|perm| {
@ -339,15 +415,61 @@ async fn edit_permissions(
.context(format!("Failed parsing database permissions: `{}`", &perm)) .context(format!("Failed parsing database permissions: `{}`", &perm))
}) })
.collect::<anyhow::Result<Vec<DatabasePrivileges>>>()? .collect::<anyhow::Result<Vec<DatabasePrivileges>>>()?
}
} else {
let comment = indoc! {r#"
# Welcome to the permission editor.
# To add permissions
"#};
let header = DATABASE_PRIVILEGE_FIELDS
.map(db_priv_field_human_readable_name)
.join("\t");
let example_line = {
let unix_user = get_current_unix_user()?;
let mut rng = thread_rng();
let random_yes_nos = (0..(DATABASE_PRIVILEGE_FIELDS.len() - 2))
.map(|_| ['Y', 'N'].choose(&mut rng).unwrap())
.join("\t");
format!(
"# {}_db\t{}_user\t{}",
unix_user.name, unix_user.name, random_yes_nos
)
}; };
println!("{:#?}", permissions_to_change); let result = edit::edit_with_builder(
} else { format!(
// TODO: debug assert that -p is not used with -t "{}\n{}\n{}",
comment,
header,
if permission_data.is_empty() {
example_line
} else {
permission_data
.iter()
.map(display_permissions_as_editor_line)
.join("\n")
}
),
edit::Builder::new()
.prefix("database-permissions")
.suffix(".tsv")
.rand_bytes(10),
)?;
parse_permission_data_from_editor(result)
.context("Could not parse permission data from editor")?
};
let diffs = diff_permissions(permission_data, &permissions_to_change).await;
if diffs.is_empty() {
println!("No changes to make.");
return Ok(());
} }
// TODO: find the difference between the two vectors, and ask for confirmation before applying the changes. // TODO: Add confirmation prompt.
// TODO: apply the changes to the database. apply_permission_diffs(diffs, conn).await?;
unimplemented!();
Ok(())
} }

View File

@ -77,9 +77,7 @@ async fn create_users(args: UserCreateArgs, conn: &mut MySqlConnection) -> anyho
} }
for username in args.username { for username in args.username {
if let Err(e) = if let Err(e) = crate::core::user_operations::create_database_user(&username, conn).await {
crate::core::user_operations::create_database_user(&username, conn).await
{
eprintln!("{}", e); eprintln!("{}", e);
eprintln!("Skipping..."); eprintln!("Skipping...");
} }
@ -94,9 +92,7 @@ async fn drop_users(args: UserDeleteArgs, conn: &mut MySqlConnection) -> anyhow:
} }
for username in args.username { for username in args.username {
if let Err(e) = if let Err(e) = crate::core::user_operations::delete_database_user(&username, conn).await {
crate::core::user_operations::delete_database_user(&username, conn).await
{
eprintln!("{}", e); eprintln!("{}", e);
eprintln!("Skipping..."); eprintln!("Skipping...");
} }
@ -132,12 +128,8 @@ async fn change_password_for_user(
pass1 pass1
}; };
crate::core::user_operations::set_password_for_database_user( crate::core::user_operations::set_password_for_database_user(&args.username, &password, conn)
&args.username, .await?;
&password,
conn,
)
.await?;
Ok(()) Ok(())
} }
@ -157,8 +149,7 @@ async fn show_users(args: UserShowArgs, conn: &mut MySqlConnection) -> anyhow::R
} }
let user = let user =
crate::core::user_operations::get_database_user_for_user(&username, conn) crate::core::user_operations::get_database_user_for_user(&username, conn).await?;
.await?;
if let Some(user) = user { if let Some(user) = user {
result.push(user); result.push(user);
} else { } else {

View File

@ -88,4 +88,4 @@ pub fn quote_literal(s: &str) -> String {
pub fn quote_identifier(s: &str) -> String { pub fn quote_identifier(s: &str) -> String {
format!("`{}`", s.replace('`', r"\`")) format!("`{}`", s.replace('`', r"\`"))
} }

View File

@ -119,10 +119,10 @@ pub async fn mysql_connection_from_config(config: Config) -> anyhow::Result<MySq
.port(config.mysql.port.unwrap_or(3306)) .port(config.mysql.port.unwrap_or(3306))
.database("mysql") .database("mysql")
.connect(), .connect(),
).await { )
.await
{
Ok(conn) => conn.context("Failed to connect to MySQL"), Ok(conn) => conn.context("Failed to connect to MySQL"),
Err(_) => { Err(_) => Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to MySQL"),
Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to MySQL")
}
} }
} }

View File

@ -1,9 +1,11 @@
use std::collections::HashMap;
use anyhow::Context; use anyhow::Context;
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use nix::unistd::User; use nix::unistd::User;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{prelude::*, MySqlConnection}; use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
use super::common::{ use super::common::{
get_current_unix_user, get_unix_groups, quote_identifier, validate_prefix_for_user, get_current_unix_user, get_unix_groups, quote_identifier, validate_prefix_for_user,
@ -82,38 +84,137 @@ pub async fn get_database_list(conn: &mut MySqlConnection) -> anyhow::Result<Vec
Ok(databases.into_iter().map(|d| d.database).collect()) Ok(databases.into_iter().map(|d| d.database).collect())
} }
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)] pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db",
"user",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name {
"db" => "Database".to_owned(),
"user" => "User".to_owned(),
"select_priv" => "Select".to_owned(),
"insert_priv" => "Insert".to_owned(),
"update_priv" => "Update".to_owned(),
"delete_priv" => "Delete".to_owned(),
"create_priv" => "Create".to_owned(),
"drop_priv" => "Drop".to_owned(),
"alter_priv" => "Alter".to_owned(),
"index_priv" => "Index".to_owned(),
"create_tmp_table_priv" => "Temp".to_owned(),
"lock_tables_priv" => "Lock".to_owned(),
"references_priv" => "References".to_owned(),
_ => format!("Unknown({})", name),
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabasePrivileges { pub struct DatabasePrivileges {
pub db: String, pub db: String,
pub user: String, pub user: String,
pub select_priv: String, pub select_priv: bool,
pub insert_priv: String, pub insert_priv: bool,
pub update_priv: String, pub update_priv: bool,
pub delete_priv: String, pub delete_priv: bool,
pub create_priv: String, pub create_priv: bool,
pub drop_priv: String, pub drop_priv: bool,
pub alter_priv: String, pub alter_priv: bool,
pub index_priv: String, pub index_priv: bool,
pub create_tmp_table_priv: String, pub create_tmp_table_priv: bool,
pub lock_tables_priv: String, pub lock_tables_priv: bool,
pub references_priv: String, pub references_priv: bool,
} }
pub const HUMAN_READABLE_DATABASE_PRIVILEGE_NAMES: [(&str, &str); 13] = [ impl DatabasePrivileges {
("Database", "db"), pub fn get_privilege_by_name(&self, name: &str) -> bool {
("User", "user"), match name {
("Select", "select_priv"), "select_priv" => self.select_priv,
("Insert", "insert_priv"), "insert_priv" => self.insert_priv,
("Update", "update_priv"), "update_priv" => self.update_priv,
("Delete", "delete_priv"), "delete_priv" => self.delete_priv,
("Create", "create_priv"), "create_priv" => self.create_priv,
("Drop", "drop_priv"), "drop_priv" => self.drop_priv,
("Alter", "alter_priv"), "alter_priv" => self.alter_priv,
("Index", "index_priv"), "index_priv" => self.index_priv,
("Temp", "create_tmp_table_priv"), "create_tmp_table_priv" => self.create_tmp_table_priv,
("Lock", "lock_tables_priv"), "lock_tables_priv" => self.lock_tables_priv,
("References", "references_priv"), "references_priv" => self.references_priv,
]; _ => false,
}
}
pub fn diff(&self, other: &DatabasePrivileges) -> DatabasePrivilegeDiffList {
debug_assert!(self.db == other.db && self.user == other.user);
DatabasePrivilegeDiffList {
db: self.db.clone(),
user: self.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.map(|field| {
diff_single_priv(
self.get_privilege_by_name(field),
other.get_privilege_by_name(field),
field,
)
})
.flatten()
.collect(),
}
}
}
#[inline]
pub(crate) fn yn(b: bool) -> &'static str {
if b {
"Y"
} else {
"N"
}
}
#[inline]
pub(crate) fn rev_yn(s: &str) -> bool {
match s.to_lowercase().as_str() {
"y" => true,
"n" => false,
_ => {
log::warn!("Invalid value for privilege: {}", s);
false
}
}
}
impl FromRow<'_, MySqlRow> for DatabasePrivileges {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
db: row.try_get("db")?,
user: row.try_get("user")?,
select_priv: row.try_get("select_priv").map(rev_yn)?,
insert_priv: row.try_get("insert_priv").map(rev_yn)?,
update_priv: row.try_get("update_priv").map(rev_yn)?,
delete_priv: row.try_get("delete_priv").map(rev_yn)?,
create_priv: row.try_get("create_priv").map(rev_yn)?,
drop_priv: row.try_get("drop_priv").map(rev_yn)?,
alter_priv: row.try_get("alter_priv").map(rev_yn)?,
index_priv: row.try_get("index_priv").map(rev_yn)?,
create_tmp_table_priv: row.try_get("create_tmp_table_priv").map(rev_yn)?,
lock_tables_priv: row.try_get("lock_tables_priv").map(rev_yn)?,
references_priv: row.try_get("references_priv").map(rev_yn)?,
})
}
}
pub async fn get_database_privileges( pub async fn get_database_privileges(
database_name: &str, database_name: &str,
@ -124,9 +225,9 @@ pub async fn get_database_privileges(
let result = sqlx::query_as::<_, DatabasePrivileges>(&format!( let result = sqlx::query_as::<_, DatabasePrivileges>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?", "SELECT {} FROM `db` WHERE `db` = ?",
HUMAN_READABLE_DATABASE_PRIVILEGE_NAMES DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|(_, prop)| quote_identifier(prop)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)) ))
.bind(database_name) .bind(database_name)
@ -154,9 +255,9 @@ pub async fn get_all_database_privileges(
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?) AND `SCHEMA_NAME` REGEXP ?)
"#}, "#},
HUMAN_READABLE_DATABASE_PRIVILEGE_NAMES DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|(_, prop)| format!("`{}`", prop)) .map(|field| format!("`{field}`"))
.join(","), .join(","),
)) ))
.bind(format!( .bind(format!(
@ -167,9 +268,141 @@ pub async fn get_all_database_privileges(
.fetch_all(conn) .fetch_all(conn)
.await .await
.context("Failed to show databases")?; .context("Failed to show databases")?;
Ok(result) Ok(result)
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DatabasePrivilegeDiffList {
pub db: String,
pub user: String,
pub diff: Vec<DatabasePrivilegeDiff>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabasePrivilegeDiff {
YesToNo(String),
NoToYes(String),
}
fn diff_single_priv(p1: bool, p2: bool, name: &str) -> Option<DatabasePrivilegeDiff> {
match (p1, p2) {
(true, false) => Some(DatabasePrivilegeDiff::YesToNo(name.to_owned())),
(false, true) => Some(DatabasePrivilegeDiff::NoToYes(name.to_owned())),
_ => None,
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DatabasePrivilegesDiff {
New(DatabasePrivileges),
Modified(DatabasePrivilegeDiffList),
Deleted(DatabasePrivileges),
}
pub async fn diff_permissions(
from: Vec<DatabasePrivileges>,
to: &[DatabasePrivileges],
) -> Vec<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let to_lookup_table: HashMap<(String, String), DatabasePrivileges> = HashMap::from_iter(
to.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let mut result = vec![];
for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p);
if !diff.diff.is_empty() {
result.push(DatabasePrivilegesDiff::Modified(diff));
}
} else {
result.push(DatabasePrivilegesDiff::New(p.clone()));
}
}
for p in from {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
result.push(DatabasePrivilegesDiff::Deleted(p));
}
}
result
}
pub async fn apply_permission_diffs(
diffs: Vec<DatabasePrivilegesDiff>,
conn: &mut MySqlConnection,
) -> anyhow::Result<()> {
for diff in diffs {
match diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(",");
let question_marks = std::iter::repeat("?")
.take(DATABASE_PRIVILEGE_FIELDS.len())
.join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db)
.bind(p.user)
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(&mut *conn)
.await?;
}
DatabasePrivilegesDiff::Modified(p) => {
let tables = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeDiff::YesToNo(name) => format!("`{}` = 'N'", name),
DatabasePrivilegeDiff::NoToYes(name) => format!("`{}` = 'Y'", name),
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
)
.bind(p.db)
.bind(p.user)
.execute(&mut *conn)
.await?;
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db)
.bind(p.user)
.execute(&mut *conn)
.await?;
}
}
}
Ok(())
}
/// NOTE: It is very critical that this function validates the database name /// NOTE: It is very critical that this function validates the database name
/// properly. MySQL does not seem to allow for prepared statements, binding /// properly. MySQL does not seem to allow for prepared statements, binding
/// the database name as a parameter to the query. This means that we have /// the database name as a parameter to the query. This means that we have
@ -199,3 +432,72 @@ pub fn validate_ownership_of_database_name(name: &str, user: &User) -> anyhow::R
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diff_single_priv() {
assert_eq!(
diff_single_priv(true, false, "test"),
Some(DatabasePrivilegeDiff::YesToNo("test".to_owned()))
);
assert_eq!(
diff_single_priv(false, true, "test"),
Some(DatabasePrivilegeDiff::NoToYes("test".to_owned()))
);
assert_eq!(diff_single_priv(true, true, "test"), None);
assert_eq!(diff_single_priv(false, false, "test"), None);
}
#[tokio::test]
async fn test_diff_permissions() {
let from = vec![DatabasePrivileges {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: true,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
}];
let to = vec![DatabasePrivileges {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: false,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
}];
let diffs = diff_permissions(from, &to).await;
assert_eq!(
diffs,
vec![DatabasePrivilegesDiff::Modified(
DatabasePrivilegeDiffList {
db: "db".to_owned(),
user: "user".to_owned(),
diff: vec![DatabasePrivilegeDiff::YesToNo("select_priv".to_owned())],
}
)]
);
assert!(matches!(&diffs[0], DatabasePrivilegesDiff::Modified(_)));
}
}

View File

@ -52,4 +52,4 @@ async fn main() -> anyhow::Result<()> {
} }
Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await, Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await,
} }
} }

View File

@ -0,0 +1 @@