From e56c41cee6a07db5d6a47f839904440f7a22235c Mon Sep 17 00:00:00 2001 From: h7x4 Date: Sun, 14 Dec 2025 01:58:48 +0900 Subject: [PATCH] {client,server}/edit-privs: check for user existence --- src/client/commands/edit_privs.rs | 88 ++++++++++++++++++- .../protocol/commands/modify_privileges.rs | 7 ++ .../sql/database_privilege_operations.rs | 12 +++ src/server/sql/user_operations.rs | 2 +- 4 files changed, 105 insertions(+), 4 deletions(-) diff --git a/src/client/commands/edit_privs.rs b/src/client/commands/edit_privs.rs index 29c213b..46d1020 100644 --- a/src/client/commands/edit_privs.rs +++ b/src/client/commands/edit_privs.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; use anyhow::Context; use clap::Parser; @@ -22,7 +22,7 @@ use crate::{ ClientToServerMessageStream, Request, Response, print_modify_database_privileges_output_status, }, - types::MySQLDatabase, + types::{MySQLDatabase, MySQLUser}, }, }; @@ -59,6 +59,64 @@ pub struct EditPrivsArgs { pub yes: bool, } +async fn users_exist( + server_connection: &mut ClientToServerMessageStream, + privilege_diff: &BTreeSet, +) -> anyhow::Result> { + let user_list = privilege_diff + .iter() + .map(|diff| diff.get_user_name().clone()) + .collect(); + + let message = Request::ListUsers(Some(user_list)); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::ListUsers(user_map))) => user_map, + response => { + erroneous_server_response(response)?; + // Unreachable, but needed to satisfy the type checker + BTreeMap::new() + } + }; + + let result = result + .into_iter() + .map(|(user, user_result)| (user, user_result.is_ok())) + .collect(); + + Ok(result) +} + +async fn databases_exist( + server_connection: &mut ClientToServerMessageStream, + privilege_diff: &BTreeSet, +) -> anyhow::Result> { + let database_list = privilege_diff + .iter() + .map(|diff| diff.get_database_name().clone()) + .collect(); + + let message = Request::ListDatabases(Some(database_list)); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::ListDatabases(database_map))) => database_map, + response => { + erroneous_server_response(response)?; + // Unreachable, but needed to satisfy the type checker + BTreeMap::new() + } + }; + + let result = result + .into_iter() + .map(|(database, db_result)| (database, db_result.is_ok())) + .collect(); + + Ok(result) +} + pub async fn edit_database_privileges( args: EditPrivsArgs, mut server_connection: ClientToServerMessageStream, @@ -100,7 +158,31 @@ pub async fn edit_database_privileges( edit_privileges_with_editor(&existing_privilege_rows, args.name.as_ref())?; diff_privileges(&existing_privilege_rows, &privileges_to_change) }; - let diffs = reduce_privilege_diffs(&existing_privilege_rows, diffs)?; + + let user_existence_map = users_exist(&mut server_connection, &diffs).await?; + let database_existence_map = databases_exist(&mut server_connection, &diffs).await?; + + let diffs = reduce_privilege_diffs(&existing_privilege_rows, diffs)? + .into_iter() + .filter(|diff| { + let database_name = diff.get_database_name(); + let username = diff.get_user_name(); + + if let Some(false) = database_existence_map.get(database_name) { + println!("Database '{}' does not exist.", database_name); + println!("Skipping..."); + return false; + } + + if let Some(false) = user_existence_map.get(username) { + println!("User '{}' does not exist.", username); + println!("Skipping..."); + return false; + } + + true + }) + .collect::>(); if diffs.is_empty() { println!("No changes to make."); diff --git a/src/core/protocol/commands/modify_privileges.rs b/src/core/protocol/commands/modify_privileges.rs index 2f5055f..7d9312a 100644 --- a/src/core/protocol/commands/modify_privileges.rs +++ b/src/core/protocol/commands/modify_privileges.rs @@ -20,6 +20,7 @@ pub enum ModifyDatabasePrivilegesError { UserSanitizationError(NameValidationError), UserOwnershipError(OwnerValidationError), DatabaseDoesNotExist, + UserDoesNotExist, DiffDoesNotApply(DiffDoesNotApplyError), MySqlError(String), } @@ -68,6 +69,9 @@ impl ModifyDatabasePrivilegesError { ModifyDatabasePrivilegesError::DatabaseDoesNotExist => { format!("Database '{}' does not exist.", database_name) } + ModifyDatabasePrivilegesError::UserDoesNotExist => { + format!("User '{}' does not exist.", username) + } ModifyDatabasePrivilegesError::DiffDoesNotApply(diff) => { format!( "Could not apply privilege change:\n{}", @@ -98,6 +102,9 @@ impl ModifyDatabasePrivilegesError { ModifyDatabasePrivilegesError::DatabaseDoesNotExist => { "database-does-not-exist".to_string() } + ModifyDatabasePrivilegesError::UserDoesNotExist => { + "user-does-not-exist".to_string() + } ModifyDatabasePrivilegesError::DiffDoesNotApply(err) => { format!("diff-does-not-apply/{}", err.error_type()) } diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index 80d5ac3..cc5b444 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -38,6 +38,7 @@ use crate::{ common::{create_user_group_matching_regex, try_get_with_binary_fallback}, input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user}, sql::database_operations::unsafe_database_exists, + sql::user_operations::unsafe_user_exists, }, }; @@ -446,6 +447,17 @@ pub async fn apply_privilege_diffs( continue; } + if !unsafe_user_exists(diff.get_user_name(), connection) + .await + .unwrap() + { + results.insert( + key, + Err(ModifyDatabasePrivilegesError::UserDoesNotExist), + ); + continue; + } + if let Err(err) = validate_diff(&diff, connection).await { results.insert(key, Err(err)); continue; diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 8262d9c..584f96a 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -26,7 +26,7 @@ use crate::{ }; // NOTE: this function is unsafe because it does no input validation. -async fn unsafe_user_exists( +pub(super) async fn unsafe_user_exists( db_user: &str, connection: &mut MySqlConnection, ) -> Result {