diff --git a/src/client/commands/edit_privs.rs b/src/client/commands/edit_privs.rs index ee689aa..400ff93 100644 --- a/src/client/commands/edit_privs.rs +++ b/src/client/commands/edit_privs.rs @@ -305,7 +305,7 @@ pub async fn edit_database_privileges( print_modify_database_privileges_output_status(&result); - if result.iter().any(|(_, res)| { + if result.values().flatten().any(|(_, res)| { matches!( res, Err(ModifyDatabasePrivilegesError::UserValidationError( @@ -320,7 +320,7 @@ pub async fn edit_database_privileges( server_connection.send(Request::Exit).await?; - if result.values().any(std::result::Result::is_err) { + if result.values().flatten().any(|(_, res)| res.is_err()) { std::process::exit(1); } diff --git a/src/core/protocol/commands.rs b/src/core/protocol/commands.rs index 35ff9a1..b76da3f 100644 --- a/src/core/protocol/commands.rs +++ b/src/core/protocol/commands.rs @@ -324,9 +324,12 @@ impl Response { ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) } Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()), - Response::ModifyPrivileges(res) => { - ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) - } + Response::ModifyPrivileges(res) => ResponseOkStatus::from_counts( + res.len(), + res.values() + .map(|user_map| user_map.values().filter(|v| v.is_ok()).count()) + .sum(), + ), Response::CreateUsers(res) => { ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) diff --git a/src/core/protocol/commands/modify_privileges.rs b/src/core/protocol/commands/modify_privileges.rs index 35fbd11..4e7f3f1 100644 --- a/src/core/protocol/commands/modify_privileges.rs +++ b/src/core/protocol/commands/modify_privileges.rs @@ -12,7 +12,7 @@ use crate::core::{ pub type ModifyPrivilegesRequest = BTreeSet; pub type ModifyPrivilegesResponse = - BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>; + BTreeMap>>; #[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ModifyDatabasePrivilegesError { @@ -49,7 +49,11 @@ pub enum DiffDoesNotApplyError { } pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) { - for ((database_name, username), result) in output { + for ((database_name, username), result) in output.iter().flat_map(|(db, user_map)| { + user_map + .iter() + .map(move |(user, result)| ((db, user), result)) + }) { match result { Ok(()) => { println!( @@ -169,13 +173,16 @@ mod tests { #[test] fn test_serialize_deserialize_response() { - let response: ModifyPrivilegesResponse = BTreeMap::from([ - (("test_db".into(), "test_user".into()), Ok(())), - ( - ("test_db".into(), "invalid_user".into()), - Err(ModifyDatabasePrivilegesError::UserDoesNotExist), - ), - ]); + let response: ModifyPrivilegesResponse = BTreeMap::from([( + "test_db".into(), + BTreeMap::from([ + ("test_user".into(), Ok(())), + ( + "invalid_user".into(), + Err(ModifyDatabasePrivilegesError::UserDoesNotExist), + ), + ]), + )]); let json = serde_json::to_string_pretty(&response).unwrap(); println!("Serialized response:\n{}", json); diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index 0c4c8b7..7fb41f5 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -488,4 +488,13 @@ pub async fn apply_privilege_diffs( } results + .into_iter() + .map(|((k1, k2), v)| (k1, (k2, v))) + .into_group_map() + .into_iter() + .map(|(k1, pairs)| { + let inner = pairs.into_iter().collect::>(); + (k1, inner) + }) + .collect() }