Refactor privilege handling
All checks were successful
Build / check (push) Successful in 2m41s
Build / build (push) Successful in 3m5s
Build / docs (push) Successful in 5m37s

This commit is contained in:
2025-11-14 00:49:29 +09:00
parent 7760b001d8
commit 03a761a0ff
12 changed files with 1567 additions and 886 deletions

View File

@@ -18,13 +18,15 @@ use std::collections::{BTreeMap, BTreeSet};
use indoc::indoc;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sqlx::{MySqlConnection, mysql::MySqlRow, prelude::*};
use crate::{
core::{
common::{UnixUser, rev_yn, yn},
database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff},
database_privileges::{
DATABASE_PRIVILEGE_FIELDS, DatabasePrivilegeChange, DatabasePrivilegeRow,
DatabasePrivilegesDiff,
},
protocol::{
DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError,
GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError,
@@ -39,65 +41,6 @@ use crate::{
},
};
/// This is the list of fields that are used to fetch the db + user + privileges
/// from the `db` table in the database. If you need to add or remove privilege
/// fields, this is a good place to start.
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"Db",
"User",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
// NOTE: ord is needed for BTreeSet to accept the type, but it
// doesn't have any natural implementation semantics.
/// This struct represents the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow {
pub db: MySQLDatabase,
pub user: MySQLUser,
pub select_priv: bool,
pub insert_priv: bool,
pub update_priv: bool,
pub delete_priv: bool,
pub create_priv: bool,
pub drop_priv: bool,
pub alter_priv: bool,
pub index_priv: bool,
pub create_tmp_table_priv: bool,
pub lock_tables_priv: bool,
pub references_priv: bool,
}
impl DatabasePrivilegeRow {
pub fn get_privilege_by_name(&self, name: &str) -> bool {
match name {
"select_priv" => self.select_priv,
"insert_priv" => self.insert_priv,
"update_priv" => self.update_priv,
"delete_priv" => self.delete_priv,
"create_priv" => self.create_priv,
"drop_priv" => self.drop_priv,
"alter_priv" => self.alter_priv,
"index_priv" => self.index_priv,
"create_tmp_table_priv" => self.create_tmp_table_priv,
"lock_tables_priv" => self.lock_tables_priv,
"references_priv" => self.references_priv,
_ => false,
}
}
}
// TODO: get by name instead of row tuple position
#[inline]
@@ -304,22 +247,39 @@ async fn unsafe_apply_privilege_diff(
.map(|_| ())
}
DatabasePrivilegesDiff::Modified(p) => {
let changes = p
.diff
let changes = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|diff| match diff {
DatabasePrivilegeChange::YesToNo(name) => {
format!("{} = 'N'", quote_identifier(name))
}
DatabasePrivilegeChange::NoToYes(name) => {
format!("{} = 'Y'", quote_identifier(name))
}
.skip(2) // Skip Db and User fields
.map(|field| {
format!(
"{} = COALESCE(?, {})",
quote_identifier(field),
quote_identifier(field)
)
})
.join(",");
fn change_to_yn(change: DatabasePrivilegeChange) -> &'static str {
match change {
DatabasePrivilegeChange::YesToNo => "N",
DatabasePrivilegeChange::NoToYes => "Y",
}
}
sqlx::query(
format!("UPDATE `db` SET {} WHERE `Db` = ? AND `User` = ?", changes).as_str(),
)
.bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn))
.bind(p.delete_priv.map(change_to_yn))
.bind(p.create_priv.map(change_to_yn))
.bind(p.drop_priv.map(change_to_yn))
.bind(p.alter_priv.map(change_to_yn))
.bind(p.index_priv.map(change_to_yn))
.bind(p.create_tmp_table_priv.map(change_to_yn))
.bind(p.lock_tables_priv.map(change_to_yn))
.bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
@@ -334,6 +294,7 @@ async fn unsafe_apply_privilege_diff(
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Noop { .. } => Ok(()),
};
if let Err(e) = &result {
@@ -359,7 +320,7 @@ async fn validate_diff(
Err(e) => return Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
};
let result = match diff {
match diff {
DatabasePrivilegesDiff::New(_) => {
if privilege_row.is_some() {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
@@ -383,10 +344,20 @@ async fn validate_diff(
DatabasePrivilegesDiff::Modified(row_diff) => {
let row = privilege_row.unwrap();
let error_exists = row_diff.diff.iter().any(|change| match change {
DatabasePrivilegeChange::YesToNo(name) => !row.get_privilege_by_name(name),
DatabasePrivilegeChange::NoToYes(name) => row.get_privilege_by_name(name),
});
let error_exists = DATABASE_PRIVILEGE_FIELDS
.iter()
.skip(2) // Skip Db and User fields
.any(
|field| match row_diff.get_privilege_change_by_name(field).unwrap() {
Some(DatabasePrivilegeChange::YesToNo) => {
!row.get_privilege_by_name(field).unwrap()
}
Some(DatabasePrivilegeChange::NoToYes) => {
row.get_privilege_by_name(field).unwrap()
}
None => false,
},
);
if error_exists {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
@@ -408,9 +379,13 @@ async fn validate_diff(
Ok(())
}
}
};
result
DatabasePrivilegesDiff::Noop { .. } => {
log::warn!(
"Server got sent a noop database privilege diff to validate, is the client buggy?"
);
Ok(())
}
}
}
/// Uses the result of [`diff_privileges`] to modify privileges in the database.