diff --git a/src/core/protocol/commands.rs b/src/core/protocol/commands.rs index 5c0540f..e5bb77e 100644 --- a/src/core/protocol/commands.rs +++ b/src/core/protocol/commands.rs @@ -17,6 +17,8 @@ mod modify_privileges; mod passwd_user; mod unlock_users; +use std::collections::BTreeSet; + pub use check_authorization::*; pub use complete_database_name::*; pub use complete_user_name::*; @@ -41,6 +43,8 @@ use tokio::net::UnixStream; use tokio_serde::{Framed as SerdeFramed, formats::Bincode}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use crate::core::types::{MySQLDatabase, MySQLUser}; + pub type ServerToClientMessageStream = SerdeFramed< Framed, Request, @@ -104,6 +108,85 @@ pub enum Request { Exit, } +impl Request { + pub fn command_name(&self) -> &str { + match self { + Request::CheckAuthorization(_) => "check-authorization", + Request::ListValidNamePrefixes => "list-valid-name-prefixes", + Request::CompleteDatabaseName(_) => "complete-database-name", + Request::CompleteUserName(_) => "complete-user-name", + Request::CreateDatabases(_) => "create-databases", + Request::DropDatabases(_) => "drop-databases", + Request::ListDatabases(_) => "list-databases", + Request::ListPrivileges(_) => "list-privileges", + Request::ModifyPrivileges(_) => "modify-privileges", + Request::CreateUsers(_) => "create-users", + Request::DropUsers(_) => "drop-users", + Request::PasswdUser(_) => "passwd-user", + Request::ListUsers(_) => "list-users", + Request::LockUsers(_) => "lock-users", + Request::UnlockUsers(_) => "unlock-users", + Request::Exit => "exit", + } + } + + pub fn affected_users(&self) -> BTreeSet { + match self { + Request::CheckAuthorization(_) => Default::default(), + Request::ListValidNamePrefixes => Default::default(), + Request::CompleteDatabaseName(_) => Default::default(), + Request::CompleteUserName(_) => Default::default(), + Request::CreateDatabases(_) => Default::default(), + Request::DropDatabases(_) => Default::default(), + Request::ListDatabases(_) => Default::default(), + Request::ListPrivileges(_) => Default::default(), + Request::ModifyPrivileges(priv_diffs) => priv_diffs + .iter() + .map(|priv_diff| priv_diff.get_user_name().clone()) + .collect(), + Request::CreateUsers(users) => users.iter().cloned().collect(), + Request::DropUsers(users) => users.iter().cloned().collect(), + Request::PasswdUser(user_passwd_req) => { + let mut result = BTreeSet::new(); + result.insert(user_passwd_req.0.clone()); + result + } + Request::ListUsers(users) => users.clone().unwrap_or_default().into_iter().collect(), + Request::LockUsers(users) => users.iter().cloned().collect(), + Request::UnlockUsers(users) => users.iter().cloned().collect(), + Request::Exit => Default::default(), + } + } + + pub fn affected_databases(&self) -> BTreeSet { + match self { + Request::CheckAuthorization(_) => Default::default(), + Request::ListValidNamePrefixes => Default::default(), + Request::CompleteDatabaseName(_) => Default::default(), + Request::CompleteUserName(_) => Default::default(), + Request::CreateDatabases(databases) => databases.iter().cloned().collect(), + Request::DropDatabases(databases) => databases.iter().cloned().collect(), + Request::ListDatabases(databases) => { + databases.clone().unwrap_or_default().into_iter().collect() + } + Request::ListPrivileges(databases) => { + databases.clone().unwrap_or_default().into_iter().collect() + } + Request::ModifyPrivileges(priv_diffs) => priv_diffs + .iter() + .map(|priv_diff| priv_diff.get_database_name().clone()) + .collect(), + Request::CreateUsers(_) => Default::default(), + Request::DropUsers(_) => Default::default(), + Request::PasswdUser(_) => Default::default(), + Request::ListUsers(_) => Default::default(), + Request::LockUsers(_) => Default::default(), + Request::UnlockUsers(_) => Default::default(), + Request::Exit => Default::default(), + } + } +} + // TODO: include a generic "message" that will display a message to the user? #[non_exhaustive] diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index 892ec29..a31f739 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -2,6 +2,7 @@ use std::{collections::BTreeSet, sync::Arc}; use futures_util::{SinkExt, StreamExt}; use indoc::concatdoc; +use itertools::Itertools; use sqlx::{MySqlConnection, MySqlPool}; use tokio::{net::UnixStream, sync::RwLock}; use tracing::Instrument; @@ -173,242 +174,286 @@ async fn session_handler_with_db_connection( } }; - match &request { - Request::Exit => tracing::debug!("Received request: {:#?}", request), - Request::PasswdUser((db_user, _)) => tracing::info!( - "Received request: {:#?}", - Request::PasswdUser((db_user.to_owned(), "".to_string())) - ), - request => tracing::info!("Received request: {:#?}", request), + let request_span = tracing::info_span!("request", command = request.command_name()); + + if !handle_request( + request, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + &mut stream, + ) + .instrument(request_span) + .await? + { + break; } - - let response = match request { - Request::CheckAuthorization(dbs_or_users) => { - let result = check_authorization(dbs_or_users, unix_user, group_denylist).await; - Response::CheckAuthorization(result) - } - Request::ListValidNamePrefixes => { - let mut result = Vec::with_capacity(unix_user.groups.len() + 1); - result.push(unix_user.username.clone()); - - for group in get_user_filtered_groups(unix_user, group_denylist) { - result.push(group.clone()); - } - - Response::ListValidNamePrefixes(result) - } - Request::CompleteDatabaseName(partial_database_name) => { - // TODO: more correct validation here - if partial_database_name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') - { - let result = complete_database_name( - partial_database_name, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::CompleteDatabaseName(result) - } else { - Response::CompleteDatabaseName(vec![]) - } - } - Request::CompleteUserName(partial_user_name) => { - // TODO: more correct validation here - if partial_user_name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') - { - let result = complete_user_name( - partial_user_name, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::CompleteUserName(result) - } else { - Response::CompleteUserName(vec![]) - } - } - Request::CreateDatabases(databases_names) => { - let result = create_databases( - databases_names, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::CreateDatabases(result) - } - Request::DropDatabases(databases_names) => { - let result = drop_databases( - databases_names, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::DropDatabases(result) - } - Request::ListDatabases(database_names) => { - if let Some(database_names) = database_names { - let result = list_databases( - database_names, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ListDatabases(result) - } else { - let result = list_all_databases_for_user( - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ListAllDatabases(result) - } - } - Request::ListPrivileges(database_names) => { - if let Some(database_names) = database_names { - let privilege_data = get_databases_privilege_data( - database_names, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ListPrivileges(privilege_data) - } else { - let privilege_data = get_all_database_privileges( - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ListAllPrivileges(privilege_data) - } - } - Request::ModifyPrivileges(database_privilege_diffs) => { - let result = apply_privilege_diffs( - BTreeSet::from_iter(database_privilege_diffs), - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ModifyPrivileges(result) - } - Request::CreateUsers(db_users) => { - let result = create_database_users( - db_users, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::CreateUsers(result) - } - Request::DropUsers(db_users) => { - let result = drop_database_users( - db_users, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::DropUsers(result) - } - Request::PasswdUser((db_user, password)) => { - let result = set_password_for_database_user( - &db_user, - &password, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::SetUserPassword(result) - } - Request::ListUsers(db_users) => { - if let Some(db_users) = db_users { - let result = list_database_users( - db_users, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ListUsers(result) - } else { - let result = list_all_database_users_for_unix_user( - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::ListAllUsers(result) - } - } - Request::LockUsers(db_users) => { - let result = lock_database_users( - db_users, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::LockUsers(result) - } - Request::UnlockUsers(db_users) => { - let result = unlock_database_users( - db_users, - unix_user, - db_connection, - db_is_mariadb, - group_denylist, - ) - .await; - Response::UnlockUsers(result) - } - Request::Exit => { - break; - } - }; - - let response_to_display = match &response { - Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => { - &Response::SetUserPassword(Err(SetPasswordError::MySqlError( - "".to_string(), - ))) - } - response => response, - }; - tracing::debug!("Response: {:#?}", response_to_display); - - stream.send(response).await?; - stream.flush().await?; - tracing::debug!("Successfully processed request"); } Ok(()) } + +/// Handle a single request from a client. +/// +/// If the function returns `true`, the session should continue. +async fn handle_request( + request: Request, + unix_user: &UnixUser, + db_connection: &mut MySqlConnection, + db_is_mariadb: bool, + group_denylist: &GroupDenylist, + stream: &mut ServerToClientMessageStream, +) -> anyhow::Result { + match &request { + Request::Exit => tracing::debug!("Received request: {:#?}", request), + Request::PasswdUser((db_user, _)) => tracing::info!( + "Received request: {:#?}", + Request::PasswdUser((db_user.to_owned(), "".to_string())) + ), + request => tracing::info!("Received request: {:#?}", request), + } + + let affected_dbs = request.affected_databases(); + if !affected_dbs.is_empty() { + tracing::debug!( + "Affected databases: {}", + affected_dbs.into_iter().map(|db| db.to_string()).join(", ") + ); + } + + let affected_users = request.affected_users(); + if !affected_users.is_empty() { + tracing::debug!( + "Affected users: {}", + affected_users.into_iter().map(|u| u.to_string()).join(", "), + ); + } + + let response = match request { + Request::CheckAuthorization(dbs_or_users) => { + let result = check_authorization(dbs_or_users, unix_user, group_denylist).await; + Response::CheckAuthorization(result) + } + Request::ListValidNamePrefixes => { + let mut result = Vec::with_capacity(unix_user.groups.len() + 1); + result.push(unix_user.username.clone()); + + for group in get_user_filtered_groups(unix_user, group_denylist) { + result.push(group.clone()); + } + + Response::ListValidNamePrefixes(result) + } + Request::CompleteDatabaseName(partial_database_name) => { + // TODO: more correct validation here + if partial_database_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + let result = complete_database_name( + partial_database_name, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::CompleteDatabaseName(result) + } else { + Response::CompleteDatabaseName(vec![]) + } + } + Request::CompleteUserName(partial_user_name) => { + // TODO: more correct validation here + if partial_user_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + let result = complete_user_name( + partial_user_name, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::CompleteUserName(result) + } else { + Response::CompleteUserName(vec![]) + } + } + Request::CreateDatabases(databases_names) => { + let result = create_databases( + databases_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::CreateDatabases(result) + } + Request::DropDatabases(databases_names) => { + let result = drop_databases( + databases_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::DropDatabases(result) + } + Request::ListDatabases(database_names) => { + if let Some(database_names) = database_names { + let result = list_databases( + database_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ListDatabases(result) + } else { + let result = list_all_databases_for_user( + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ListAllDatabases(result) + } + } + Request::ListPrivileges(database_names) => { + if let Some(database_names) = database_names { + let privilege_data = get_databases_privilege_data( + database_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ListPrivileges(privilege_data) + } else { + let privilege_data = get_all_database_privileges( + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ListAllPrivileges(privilege_data) + } + } + Request::ModifyPrivileges(database_privilege_diffs) => { + let result = apply_privilege_diffs( + BTreeSet::from_iter(database_privilege_diffs), + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ModifyPrivileges(result) + } + Request::CreateUsers(db_users) => { + let result = create_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::CreateUsers(result) + } + Request::DropUsers(db_users) => { + let result = drop_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::DropUsers(result) + } + Request::PasswdUser((db_user, password)) => { + let result = set_password_for_database_user( + &db_user, + &password, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::SetUserPassword(result) + } + Request::ListUsers(db_users) => { + if let Some(db_users) = db_users { + let result = list_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ListUsers(result) + } else { + let result = list_all_database_users_for_unix_user( + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::ListAllUsers(result) + } + } + Request::LockUsers(db_users) => { + let result = lock_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::LockUsers(result) + } + Request::UnlockUsers(db_users) => { + let result = unlock_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; + Response::UnlockUsers(result) + } + Request::Exit => { + return Ok(false); + } + }; + + let response_to_display = match &response { + Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => { + &Response::SetUserPassword(Err(SetPasswordError::MySqlError("".to_string()))) + } + response => response, + }; + tracing::debug!("Response: {:#?}", response_to_display); + + stream.send(response).await?; + stream.flush().await?; + tracing::debug!("Successfully processed request"); + + Ok(true) +}