Misc fixups to priv diff logic, add tests

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-08 21:02:25 +02:00
parent 8a91e9a3d0
commit 7ee60dacdc
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
2 changed files with 195 additions and 89 deletions

View File

@ -323,7 +323,7 @@ pub async fn edit_privileges(
} }
} }
let diffs = diff_privileges(privilege_data, &privileges_to_change); let diffs = diff_privileges(&privilege_data, &privileges_to_change);
if diffs.is_empty() { if diffs.is_empty() {
println!("No changes to make."); println!("No changes to make.");

View File

@ -13,7 +13,7 @@
//! changes will be made when applying a set of changes //! changes will be made when applying a set of changes
//! to the list of database privileges. //! to the list of database privileges.
use std::collections::HashMap; use std::collections::{BTreeSet, HashMap};
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use indoc::indoc; use indoc::indoc;
@ -67,7 +67,7 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String {
} }
/// This struct represents the set of privileges for a single user on a single database. /// This struct represents the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow { pub struct DatabasePrivilegeRow {
pub db: String, pub db: String,
pub user: String, pub user: String,
@ -124,11 +124,13 @@ impl DatabasePrivilegeRow {
} }
#[inline] #[inline]
fn get_row_priv_field(row: &MySqlRow, field: &str) -> Result<bool, sqlx::Error> { fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
match rev_yn(row.try_get(field)?) { let field = DATABASE_PRIVILEGE_FIELDS[position];
let value = row.try_get(position)?;
match rev_yn(value) {
Some(val) => Ok(val), Some(val) => Ok(val),
_ => { _ => {
log::warn!("Invalid value for privilege: {}", field); log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
Ok(false) Ok(false)
} }
} }
@ -139,21 +141,22 @@ impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
Ok(Self { Ok(Self {
db: row.try_get("db")?, db: row.try_get("db")?,
user: row.try_get("user")?, user: row.try_get("user")?,
select_priv: get_row_priv_field(row, "select_priv")?, select_priv: get_mysql_row_priv_field(row, 2)?,
insert_priv: get_row_priv_field(row, "insert_priv")?, insert_priv: get_mysql_row_priv_field(row, 3)?,
update_priv: get_row_priv_field(row, "update_priv")?, update_priv: get_mysql_row_priv_field(row, 4)?,
delete_priv: get_row_priv_field(row, "delete_priv")?, delete_priv: get_mysql_row_priv_field(row, 5)?,
create_priv: get_row_priv_field(row, "create_priv")?, create_priv: get_mysql_row_priv_field(row, 6)?,
drop_priv: get_row_priv_field(row, "drop_priv")?, drop_priv: get_mysql_row_priv_field(row, 7)?,
alter_priv: get_row_priv_field(row, "alter_priv")?, alter_priv: get_mysql_row_priv_field(row, 8)?,
index_priv: get_row_priv_field(row, "index_priv")?, index_priv: get_mysql_row_priv_field(row, 9)?,
create_tmp_table_priv: get_row_priv_field(row, "create_tmp_table_priv")?, create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
lock_tables_priv: get_row_priv_field(row, "lock_tables_priv")?, lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
references_priv: get_row_priv_field(row, "references_priv")?, references_priv: get_mysql_row_priv_field(row, 12)?,
}) })
} }
} }
/// Get all users + privileges for a single database.
pub async fn get_database_privileges( pub async fn get_database_privileges(
database_name: &str, database_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
@ -176,6 +179,7 @@ pub async fn get_database_privileges(
Ok(result) Ok(result)
} }
/// Get all database + user + privileges pairs that are owned by the current user.
pub async fn get_all_database_privileges( pub async fn get_all_database_privileges(
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
@ -206,7 +210,7 @@ pub async fn get_all_database_privileges(
/* CLI INTERFACE PARSING */ /* CLI INTERFACE PARSING */
/*************************/ /*************************/
/// See documentation for `DatabaseCommand::EditDbPrivs`. /// See documentation for [`DatabaseCommand::EditDbPrivs`].
pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivilegeRow> { pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivilegeRow> {
let parts: Vec<&str> = arg.split(':').collect(); let parts: Vec<&str> = arg.split(':').collect();
if parts.len() != 3 { if parts.len() != 3 {
@ -270,16 +274,121 @@ pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivil
/* EDITOR CONTENT PARSING/DISPLAY */ /* EDITOR CONTENT PARSING/DISPLAY */
/**********************************/ /**********************************/
// TODO: merge with `rev_yn` in `common.rs` #[inline]
fn parse_privilege(yn: &str, name: &str) -> anyhow::Result<bool> {
rev_yn(yn)
.ok_or_else(|| anyhow!("Expected Y or N, found {}", yn))
.context(format!("Could not parse {} privilege", name))
}
fn parse_privilege(yn: &str) -> anyhow::Result<bool> { #[derive(Debug)]
match yn.to_ascii_lowercase().as_str() { enum PrivilegeRowParseResult {
"y" => Ok(true), PrivilegeRow(DatabasePrivilegeRow),
"n" => Ok(false), ParserError(anyhow::Error),
_ => Err(anyhow!("Expected Y or N, found {}", yn)), TooFewFields(usize),
TooManyFields(usize),
Header,
Comment,
Empty,
} }
#[inline]
fn row_is_header(row: &str) -> bool {
row.split_ascii_whitespace()
.zip(DATABASE_PRIVILEGE_FIELDS.iter())
.map(|(field, priv_name)| (field, db_priv_field_human_readable_name(priv_name)))
.all(|(field, header_field)| field == header_field)
} }
/// Parse a single row of the privileges table from the editor.
fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult {
if row.starts_with('#') || row.starts_with("//") {
return PrivilegeRowParseResult::Comment;
}
if row.trim().is_empty() {
return PrivilegeRowParseResult::Empty;
}
let parts: Vec<&str> = row.trim().split_ascii_whitespace().collect();
match parts.len() {
n if (n < DATABASE_PRIVILEGE_FIELDS.len()) => {
return PrivilegeRowParseResult::TooFewFields(n)
}
n if (n > DATABASE_PRIVILEGE_FIELDS.len()) => {
return PrivilegeRowParseResult::TooManyFields(n)
}
_ => {}
}
if row_is_header(row) {
return PrivilegeRowParseResult::Header;
}
let row = DatabasePrivilegeRow {
db: (*parts.first().unwrap()).to_owned(),
user: (*parts.get(1).unwrap()).to_owned(),
select_priv: match parse_privilege(parts.get(2).unwrap(), DATABASE_PRIVILEGE_FIELDS[2]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
insert_priv: match parse_privilege(parts.get(3).unwrap(), DATABASE_PRIVILEGE_FIELDS[3]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
update_priv: match parse_privilege(parts.get(4).unwrap(), DATABASE_PRIVILEGE_FIELDS[4]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
delete_priv: match parse_privilege(parts.get(5).unwrap(), DATABASE_PRIVILEGE_FIELDS[5]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
create_priv: match parse_privilege(parts.get(6).unwrap(), DATABASE_PRIVILEGE_FIELDS[6]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
drop_priv: match parse_privilege(parts.get(7).unwrap(), DATABASE_PRIVILEGE_FIELDS[7]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
alter_priv: match parse_privilege(parts.get(8).unwrap(), DATABASE_PRIVILEGE_FIELDS[8]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
index_priv: match parse_privilege(parts.get(9).unwrap(), DATABASE_PRIVILEGE_FIELDS[9]) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
create_tmp_table_priv: match parse_privilege(
parts.get(10).unwrap(),
DATABASE_PRIVILEGE_FIELDS[10],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
lock_tables_priv: match parse_privilege(
parts.get(11).unwrap(),
DATABASE_PRIVILEGE_FIELDS[11],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
references_priv: match parse_privilege(
parts.get(12).unwrap(),
DATABASE_PRIVILEGE_FIELDS[12],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
};
PrivilegeRowParseResult::PrivilegeRow(row)
}
// TODO: return better errors
pub fn parse_privilege_data_from_editor_content( pub fn parse_privilege_data_from_editor_content(
content: String, content: String,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
@ -287,41 +396,25 @@ pub fn parse_privilege_data_from_editor_content(
.trim() .trim()
.split('\n') .split('\n')
.map(|line| line.trim()) .map(|line| line.trim())
.filter(|line| !(line.starts_with('#') || line.starts_with("//") || line == &"")) .map(parse_privilege_row_from_editor)
.skip(1) .map(|result| match result {
.map(|line| { PrivilegeRowParseResult::PrivilegeRow(row) => Ok(Some(row)),
let line_parts: Vec<&str> = line.trim().split_ascii_whitespace().collect(); PrivilegeRowParseResult::ParserError(e) => Err(e),
if line_parts.len() != DATABASE_PRIVILEGE_FIELDS.len() { PrivilegeRowParseResult::TooFewFields(n) => Err(anyhow!(
anyhow::bail!("") "Too few fields in line. Expected to find {} fields, found {}",
} DATABASE_PRIVILEGE_FIELDS.len(),
n
Ok(DatabasePrivilegeRow { )),
db: (*line_parts.first().unwrap()).to_owned(), PrivilegeRowParseResult::TooManyFields(n) => Err(anyhow!(
user: (*line_parts.get(1).unwrap()).to_owned(), "Too many fields in line. Expected to find {} fields, found {}",
select_priv: parse_privilege(line_parts.get(2).unwrap()) DATABASE_PRIVILEGE_FIELDS.len(),
.context("Could not parse SELECT privilege")?, n
insert_priv: parse_privilege(line_parts.get(3).unwrap()) )),
.context("Could not parse INSERT privilege")?, PrivilegeRowParseResult::Header => Ok(None),
update_priv: parse_privilege(line_parts.get(4).unwrap()) PrivilegeRowParseResult::Comment => Ok(None),
.context("Could not parse UPDATE privilege")?, PrivilegeRowParseResult::Empty => Ok(None),
delete_priv: parse_privilege(line_parts.get(5).unwrap())
.context("Could not parse DELETE privilege")?,
create_priv: parse_privilege(line_parts.get(6).unwrap())
.context("Could not parse CREATE privilege")?,
drop_priv: parse_privilege(line_parts.get(7).unwrap())
.context("Could not parse DROP privilege")?,
alter_priv: parse_privilege(line_parts.get(8).unwrap())
.context("Could not parse ALTER privilege")?,
index_priv: parse_privilege(line_parts.get(9).unwrap())
.context("Could not parse INDEX privilege")?,
create_tmp_table_priv: parse_privilege(line_parts.get(10).unwrap())
.context("Could not parse CREATE TEMPORARY TABLE privilege")?,
lock_tables_priv: parse_privilege(line_parts.get(11).unwrap())
.context("Could not parse LOCK TABLES privilege")?,
references_priv: parse_privilege(line_parts.get(12).unwrap())
.context("Could not parse REFERENCES privilege")?,
})
}) })
.filter_map(|result| result.transpose())
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>() .collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()
} }
@ -440,15 +533,15 @@ pub fn generate_editor_content_from_privilege_data(
/// instances of privilege sets for a single user on a single database. /// instances of privilege sets for a single user on a single database.
/// ///
/// The `User` and `Database` are the same for both instances. /// The `User` and `Database` are the same for both instances.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRowDiff { pub struct DatabasePrivilegeRowDiff {
pub db: String, pub db: String,
pub user: String, pub user: String,
pub diff: Vec<DatabasePrivilegeChange>, pub diff: BTreeSet<DatabasePrivilegeChange>,
} }
/// This enum represents a change for a single privilege. /// This enum represents a change for a single privilege.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
pub enum DatabasePrivilegeChange { pub enum DatabasePrivilegeChange {
YesToNo(String), YesToNo(String),
NoToYes(String), NoToYes(String),
@ -465,17 +558,18 @@ impl DatabasePrivilegeChange {
} }
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted. /// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
pub enum DatabasePrivilegesDiff { pub enum DatabasePrivilegesDiff {
New(DatabasePrivilegeRow), New(DatabasePrivilegeRow),
Modified(DatabasePrivilegeRowDiff), Modified(DatabasePrivilegeRowDiff),
Deleted(DatabasePrivilegeRow), Deleted(DatabasePrivilegeRow),
} }
/// T
pub fn diff_privileges( pub fn diff_privileges(
from: Vec<DatabasePrivilegeRow>, from: &[DatabasePrivilegeRow],
to: &[DatabasePrivilegeRow], to: &[DatabasePrivilegeRow],
) -> Vec<DatabasePrivilegesDiff> { ) -> BTreeSet<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter(
from.iter() from.iter()
.cloned() .cloned()
@ -488,22 +582,22 @@ pub fn diff_privileges(
.map(|p| ((p.db.clone(), p.user.clone()), p)), .map(|p| ((p.db.clone(), p.user.clone()), p)),
); );
let mut result = vec![]; let mut result = BTreeSet::new();
for p in to { for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p); let diff = old_p.diff(p);
if !diff.diff.is_empty() { if !diff.diff.is_empty() {
result.push(DatabasePrivilegesDiff::Modified(diff)); result.insert(DatabasePrivilegesDiff::Modified(diff));
} }
} else { } else {
result.push(DatabasePrivilegesDiff::New(p.clone())); result.insert(DatabasePrivilegesDiff::New(p.clone()));
} }
} }
for p in from { for p in from {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
result.push(DatabasePrivilegesDiff::Deleted(p)); result.insert(DatabasePrivilegesDiff::Deleted(p.clone()));
} }
} }
@ -512,7 +606,7 @@ pub fn diff_privileges(
/// Uses the resulting diffs to make modifications to the database. /// Uses the resulting diffs to make modifications to the database.
pub async fn apply_privilege_diffs( pub async fn apply_privilege_diffs(
diffs: Vec<DatabasePrivilegesDiff>, diffs: BTreeSet<DatabasePrivilegesDiff>,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
for diff in diffs { for diff in diffs {
@ -600,7 +694,7 @@ mod tests {
#[test] #[test]
fn test_diff_privileges() { fn test_diff_privileges() {
let from = vec![DatabasePrivilegeRow { let row_to_be_modified = DatabasePrivilegeRow {
db: "db".to_owned(), db: "db".to_owned(),
user: "user".to_owned(), user: "user".to_owned(),
select_priv: true, select_priv: true,
@ -614,29 +708,41 @@ mod tests {
create_tmp_table_priv: true, create_tmp_table_priv: true,
lock_tables_priv: true, lock_tables_priv: true,
references_priv: false, references_priv: false,
}]; };
let mut to = from.clone(); let mut row_to_be_deleted = row_to_be_modified.clone();
to[0].select_priv = false; "user2".clone_into(&mut row_to_be_deleted.user);
to[0].insert_priv = false;
to[0].index_priv = true;
let diffs = diff_privileges(from, &to); let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()];
let mut modified_row = row_to_be_modified.clone();
modified_row.select_priv = false;
modified_row.insert_priv = false;
modified_row.index_priv = true;
let mut new_row = row_to_be_modified.clone();
"user3".clone_into(&mut new_row.user);
let to = vec![modified_row.clone(), new_row.clone()];
let diffs = diff_privileges(&from, &to);
assert_eq!( assert_eq!(
diffs, diffs,
vec![DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff { BTreeSet::from_iter(vec![
DatabasePrivilegesDiff::Deleted(row_to_be_deleted),
DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "db".to_owned(), db: "db".to_owned(),
user: "user".to_owned(), user: "user".to_owned(),
diff: vec![ diff: BTreeSet::from_iter(vec![
DatabasePrivilegeChange::YesToNo("select_priv".to_owned()), DatabasePrivilegeChange::YesToNo("select_priv".to_owned()),
DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()), DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()),
DatabasePrivilegeChange::NoToYes("index_priv".to_owned()), DatabasePrivilegeChange::NoToYes("index_priv".to_owned()),
], ]),
})] }),
DatabasePrivilegesDiff::New(new_row),
])
); );
assert!(matches!(&diffs[0], DatabasePrivilegesDiff::Modified(_)));
} }
#[test] #[test]