diff --git a/src/client/commands.rs b/src/client/commands.rs index 8bc3d58..77b9f8e 100644 --- a/src/client/commands.rs +++ b/src/client/commands.rs @@ -1,3 +1,4 @@ +mod check_auth; mod create_db; mod create_user; mod drop_db; @@ -10,6 +11,7 @@ mod show_privs; mod show_user; mod unlock_user; +pub use check_auth::*; pub use create_db::*; pub use create_user::*; pub use drop_db::*; @@ -28,6 +30,10 @@ use crate::core::protocol::{ClientToServerMessageStream, Response}; #[derive(Parser, Debug, Clone)] pub enum ClientCommand { + /// Check whether you are authorized to manage the specified databases or users. + #[command()] + CheckAuth(CheckAuthArgs), + /// Create one or more databases #[command()] CreateDb(CreateDbArgs), @@ -141,6 +147,7 @@ pub async fn handle_command( server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { match command { + ClientCommand::CheckAuth(args) => check_authorization(args, server_connection).await, ClientCommand::CreateDb(args) => create_databases(args, server_connection).await, ClientCommand::DropDb(args) => drop_databases(args, server_connection).await, ClientCommand::ShowDb(args) => show_databases(args, server_connection).await, diff --git a/src/client/commands/check_auth.rs b/src/client/commands/check_auth.rs new file mode 100644 index 0000000..b5a7f3e --- /dev/null +++ b/src/client/commands/check_auth.rs @@ -0,0 +1,67 @@ +use crate::{ + client::commands::erroneous_server_response, + core::{ + protocol::{ + ClientToServerMessageStream, Request, Response, + print_check_authorization_output_status, print_check_authorization_output_status_json, + }, + types::DbOrUser, + }, +}; +use clap::Parser; +use futures_util::SinkExt; +use tokio_stream::StreamExt; + +#[derive(Parser, Debug, Clone)] +pub struct CheckAuthArgs { + /// The name of the database(s) or user(s) to check authorization for + #[arg(num_args = 1..)] + name: Vec, + + /// Assume the names are users, not databases + #[arg(short, long)] + users: bool, + + /// Print the information as JSON + #[arg(short, long)] + json: bool, +} + +pub async fn check_authorization( + args: CheckAuthArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + if args.name.is_empty() { + anyhow::bail!("No database/user names provided"); + } + + let payload = args + .name + .into_iter() + .map(|name| { + if args.users { + DbOrUser::User(name.into()) + } else { + DbOrUser::Database(name.into()) + } + }) + .collect::>(); + + let message = Request::CheckAuthorization(payload); + server_connection.send(message).await?; + + let result = match server_connection.next().await { + Some(Ok(Response::CheckAuthorization(response))) => response, + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + if args.json { + print_check_authorization_output_status_json(&result); + } else { + print_check_authorization_output_status(&result); + } + + Ok(()) +} diff --git a/src/core/protocol/commands.rs b/src/core/protocol/commands.rs index 880938e..719eaa9 100644 --- a/src/core/protocol/commands.rs +++ b/src/core/protocol/commands.rs @@ -1,3 +1,4 @@ +mod check_authorization; mod create_databases; mod create_users; mod drop_databases; @@ -13,6 +14,7 @@ mod modify_privileges; mod passwd_user; mod unlock_users; +pub use check_authorization::*; pub use create_databases::*; pub use create_users::*; pub use drop_databases::*; @@ -60,6 +62,8 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Request { + CheckAuthorization(CheckAuthorizationRequest), + CreateDatabases(CreateDatabasesRequest), DropDatabases(DropDatabasesRequest), ListDatabases(ListDatabasesRequest), @@ -82,6 +86,8 @@ pub enum Request { #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Response { + CheckAuthorization(CheckAuthorizationResponse), + // Specific data for specific commands CreateDatabases(CreateDatabasesResponse), DropDatabases(DropDatabasesResponse), diff --git a/src/core/protocol/commands/check_authorization.rs b/src/core/protocol/commands/check_authorization.rs new file mode 100644 index 0000000..61aacd5 --- /dev/null +++ b/src/core/protocol/commands/check_authorization.rs @@ -0,0 +1,80 @@ +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use crate::core::{ + protocol::request_validation::{NameValidationError, OwnerValidationError}, + types::DbOrUser, +}; + +pub type CheckAuthorizationRequest = Vec; + +pub type CheckAuthorizationResponse = BTreeMap>; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum CheckAuthorizationError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + // AuthorizationHandlerError(String), +} + +pub fn print_check_authorization_output_status(output: &CheckAuthorizationResponse) { + for (db_or_user, result) in output { + match result { + Ok(()) => { + println!("'{}': OK", db_or_user.name()); + } + Err(err) => { + println!( + "'{}': {}", + db_or_user.name(), + err.to_error_message(db_or_user) + ); + } + } + } +} + +pub fn print_check_authorization_output_status_json(output: &CheckAuthorizationResponse) { + let value = output + .iter() + .map(|(db_or_user, result)| match result { + Ok(()) => ( + db_or_user.name().to_string(), + json!({ "status": "success" }), + ), + Err(err) => ( + db_or_user.name().to_string(), + json!({ + "status": "error", + "error": err.to_error_message(db_or_user), + }), + ), + }) + .collect::>(); + println!( + "{}", + serde_json::to_string_pretty(&value) + .unwrap_or("Failed to serialize result to JSON".to_string()) + ); +} + +impl CheckAuthorizationError { + pub fn to_error_message(&self, db_or_user: &DbOrUser) -> String { + match self { + CheckAuthorizationError::SanitizationError(err) => { + err.to_error_message(db_or_user.clone()) + } + CheckAuthorizationError::OwnershipError(err) => { + err.to_error_message(db_or_user.clone()) + } // CheckAuthorizationError::AuthorizationHandlerError(msg) => { + // format!( + // "Authorization handler error for '{}': {}", + // db_or_user.name(), + // msg + // ) + // } + } + } +} diff --git a/src/core/types.rs b/src/core/types.rs index 13ab86e..d1308e3 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -92,8 +92,7 @@ impl From for MySQLDatabase { } } -/// This enum is used to differentiate between database and user operations. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum DbOrUser { Database(MySQLDatabase), User(MySQLUser), diff --git a/src/server.rs b/src/server.rs index b357745..7470b45 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,4 @@ +mod authorization; pub mod command; mod common; pub mod config; diff --git a/src/server/authorization.rs b/src/server/authorization.rs new file mode 100644 index 0000000..a32baa3 --- /dev/null +++ b/src/server/authorization.rs @@ -0,0 +1,33 @@ +use crate::{ + core::{common::UnixUser, protocol::CheckAuthorizationError, types::DbOrUser}, + server::input_sanitization::{validate_name, validate_ownership_by_unix_user}, +}; + +pub async fn check_authorization( + dbs_or_users: Vec, + unix_user: &UnixUser, +) -> std::collections::BTreeMap> { + let mut results = std::collections::BTreeMap::new(); + + for db_or_user in dbs_or_users { + if let Err(err) = validate_name(db_or_user.name()) { + results.insert( + db_or_user.clone(), + Err(CheckAuthorizationError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(db_or_user.name(), unix_user) { + results.insert( + db_or_user.clone(), + Err(CheckAuthorizationError::OwnershipError(err)), + ); + continue; + } + + results.insert(db_or_user.clone(), Ok(())); + } + + results +} diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index 1121067..370c5e1 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -13,17 +13,20 @@ use crate::{ create_server_to_client_message_stream, }, }, - server::sql::{ - database_operations::{ - create_databases, drop_databases, list_all_databases_for_user, list_databases, - }, - database_privilege_operations::{ - apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, - }, - user_operations::{ - create_database_users, drop_database_users, list_all_database_users_for_unix_user, - list_database_users, lock_database_users, set_password_for_database_user, - unlock_database_users, + server::{ + authorization::check_authorization, + sql::{ + database_operations::{ + create_databases, drop_databases, list_all_databases_for_user, list_databases, + }, + database_privilege_operations::{ + apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, + }, + user_operations::{ + create_database_users, drop_database_users, list_all_database_users_for_unix_user, + list_database_users, lock_database_users, set_password_for_database_user, + unlock_database_users, + }, }, }, }; @@ -119,6 +122,8 @@ async fn session_handler_with_db_connection( stream.send(Response::Ready).await?; loop { // TODO: better error handling + // TODO: timeout for receiving requests + // TODO: cancel on request by supervisor let request = match stream.next().await { Some(Ok(request)) => request, Some(Err(e)) => return Err(e.into()), @@ -138,6 +143,10 @@ async fn session_handler_with_db_connection( log::info!("Received request: {:#?}", request_to_display); let response = match request { + Request::CheckAuthorization(dbs_or_users) => { + let result = check_authorization(dbs_or_users, unix_user).await; + Response::CheckAuthorization(result) + } Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; Response::CreateDatabases(result)