diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs index a2baa1b..3d2f9ab 100644 --- a/src/cli/database_command.rs +++ b/src/cli/database_command.rs @@ -34,13 +34,15 @@ pub enum DatabaseCommand { #[command()] DropDb(DatabaseDropArgs), - /// List all databases you have access to - #[command()] - ListDb(DatabaseListArgs), - - /// List user privileges for one or more databases + /// Print information about one or more databases /// - /// If no database names are provided, it will show privileges for all databases you have access to. + /// If no database name is provided, all databases you have access will be shown. + #[command()] + ShowDb(DatabaseShowArgs), + + /// Print user privileges for one or more databases + /// + /// If no database names are provided, all databases you have access to will be shown. #[command()] ShowDbPrivs(DatabaseShowPrivsArgs), @@ -113,7 +115,11 @@ pub struct DatabaseDropArgs { } #[derive(Parser, Debug, Clone)] -pub struct DatabaseListArgs { +pub struct DatabaseShowArgs { + /// The name of the database(s) to show. + #[arg(num_args = 0..)] + name: Vec, + /// Whether to output the information in JSON format. #[arg(short, long)] json: bool, @@ -158,7 +164,7 @@ pub async fn handle_command( match command { DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await, DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await, - DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await, + DatabaseCommand::ShowDb(args) => show_databases(args, server_connection).await, DatabaseCommand::ShowDbPrivs(args) => { show_database_privileges(args, server_connection).await } @@ -214,35 +220,56 @@ async fn drop_databases( Ok(()) } -async fn list_databases( - args: DatabaseListArgs, +async fn show_databases( + args: DatabaseShowArgs, mut server_connection: ClientToServerMessageStream, ) -> anyhow::Result<()> { - let message = Request::ListDatabases; + let message = if args.name.is_empty() { + Request::ListDatabases(None) + } else { + Request::ListDatabases(Some(args.name.clone())) + }; + server_connection.send(message).await?; - let result = match server_connection.next().await { - Some(Ok(Response::ListAllDatabases(result))) => result, + let database_list = match server_connection.next().await { + Some(Ok(Response::ListDatabases(databases))) => databases + .into_iter() + .filter_map(|(database_name, result)| match result { + Ok(database_row) => Some(database_row), + Err(err) => { + eprintln!("{}", err.to_error_message(&database_name)); + eprintln!("Skipping..."); + println!(); + None + } + }) + .collect::>(), + Some(Ok(Response::ListAllDatabases(database_list))) => match database_list { + Ok(list) => list, + Err(err) => { + server_connection.send(Request::Exit).await?; + return Err( + anyhow::anyhow!(err.to_error_message()).context("Failed to list databases") + ); + } + }, response => return erroneous_server_response(response), }; server_connection.send(Request::Exit).await?; - let database_list = match result { - Ok(list) => list, - Err(err) => { - return Err(anyhow::anyhow!(err.to_error_message()).context("Failed to list databases")) - } - }; - if args.json { println!("{}", serde_json::to_string_pretty(&database_list)?); } else if database_list.is_empty() { println!("No databases to show."); } else { + let mut table = Table::new(); + table.add_row(Row::new(vec![Cell::new("Database")])); for db in database_list { - println!("{}", db); + table.add_row(row![db.database]); } + table.printstd(); } Ok(()) diff --git a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs index 0e881c1..c028791 100644 --- a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs +++ b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs @@ -269,7 +269,7 @@ async fn show_databases( .collect(); let message = if database_names.is_empty() { - let message = Request::ListDatabases; + let message = Request::ListDatabases(None); server_connection.send(message).await?; let response = server_connection.next().await; let databases = match response { @@ -277,7 +277,9 @@ async fn show_databases( response => return erroneous_server_response(response), }; - Request::ListPrivileges(Some(databases)) + let database_names = databases.into_iter().map(|db| db.database).collect(); + + Request::ListPrivileges(Some(database_names)) } else { Request::ListPrivileges(Some(database_names)) }; diff --git a/src/cli/user_command.rs b/src/cli/user_command.rs index c49bff1..a99e596 100644 --- a/src/cli/user_command.rs +++ b/src/cli/user_command.rs @@ -33,7 +33,7 @@ pub enum UserCommand { #[command()] PasswdUser(UserPasswdArgs), - /// Give information about one or more users + /// Print information about one or more users /// /// If no username is provided, all users you have access will be shown. #[command()] diff --git a/src/core/protocol/request_response.rs b/src/core/protocol/request_response.rs index 94725cb..2f194af 100644 --- a/src/core/protocol/request_response.rs +++ b/src/core/protocol/request_response.rs @@ -36,7 +36,7 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer pub enum Request { CreateDatabases(Vec), DropDatabases(Vec), - ListDatabases, + ListDatabases(Option>), ListPrivileges(Option>), ModifyPrivileges(BTreeSet), @@ -59,6 +59,7 @@ pub enum Response { // Specific data for specific commands CreateDatabases(CreateDatabasesOutput), DropDatabases(DropDatabasesOutput), + ListDatabases(ListDatabasesOutput), ListAllDatabases(ListAllDatabasesOutput), ListPrivileges(GetDatabasesPrivilegeData), ListAllPrivileges(GetAllDatabasesPrivilegeData), diff --git a/src/core/protocol/server_responses.rs b/src/core/protocol/server_responses.rs index e883ad4..4bb857d 100644 --- a/src/core/protocol/server_responses.rs +++ b/src/core/protocol/server_responses.rs @@ -7,7 +7,8 @@ use serde::{Deserialize, Serialize}; use crate::{ core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff}, server::sql::{ - database_privilege_operations::DatabasePrivilegeRow, user_operations::DatabaseUser, + database_operations::DatabaseRow, database_privilege_operations::DatabasePrivilegeRow, + user_operations::DatabaseUser, }, }; @@ -202,16 +203,44 @@ impl DropDatabaseError { } } -pub type ListAllDatabasesOutput = Result, ListDatabasesError>; +pub type ListDatabasesOutput = BTreeMap>; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum ListDatabasesError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseDoesNotExist, MySqlError(String), } impl ListDatabasesError { + pub fn to_error_message(&self, database_name: &str) -> String { + match self { + ListDatabasesError::SanitizationError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + ListDatabasesError::OwnershipError(err) => { + err.to_error_message(database_name, DbOrUser::Database) + } + ListDatabasesError::DatabaseDoesNotExist => { + format!("Database '{}' does not exist.", database_name) + } + ListDatabasesError::MySqlError(err) => { + format!("MySQL error: {}", err) + } + } + } +} + +pub type ListAllDatabasesOutput = Result, ListAllDatabasesError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListAllDatabasesError { + MySqlError(String), +} + +impl ListAllDatabasesError { pub fn to_error_message(&self) -> String { match self { - ListDatabasesError::MySqlError(err) => format!("MySQL error: {}", err), + ListAllDatabasesError::MySqlError(err) => format!("MySQL error: {}", err), } } } diff --git a/src/server/common.rs b/src/server/common.rs index 454dcbe..0ddf1e5 100644 --- a/src/server/common.rs +++ b/src/server/common.rs @@ -48,4 +48,4 @@ mod tests { assert!(!re.is_match("user")); assert!(!re.is_match("usersomething")); } -} \ No newline at end of file +} diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs index 1fe6f30..f1dcf1a 100644 --- a/src/server/server_loop.rs +++ b/src/server/server_loop.rs @@ -7,6 +7,7 @@ use tokio::net::{UnixListener, UnixStream}; use sqlx::prelude::*; use sqlx::MySqlConnection; +use crate::server::sql::database_operations::list_databases; use crate::{ core::{ common::{UnixUser, DEFAULT_SOCKET_PATH}, @@ -17,7 +18,7 @@ use crate::{ server::{ config::{create_mysql_connection_from_config, ServerConfig}, sql::{ - database_operations::{create_databases, drop_databases, list_databases_for_user}, + database_operations::{create_databases, drop_databases, list_all_databases_for_user}, database_privilege_operations::{ apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, }, @@ -183,9 +184,18 @@ pub async fn handle_requests_for_single_session_with_db_connection( let result = drop_databases(databases_names, unix_user, db_connection).await; stream.send(Response::DropDatabases(result)).await?; } - Request::ListDatabases => { - let result = list_databases_for_user(unix_user, db_connection).await; - stream.send(Response::ListAllDatabases(result)).await?; + Request::ListDatabases(database_names) => { + let response = match database_names { + Some(database_names) => { + let result = list_databases(database_names, unix_user, db_connection).await; + Response::ListDatabases(result) + } + None => { + let result = list_all_databases_for_user(unix_user, db_connection).await; + Response::ListAllDatabases(result) + } + }; + stream.send(response).await?; } Request::ListPrivileges(database_names) => { let response = match database_names { diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index 2af32f5..3c87725 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -1,9 +1,16 @@ +use std::collections::BTreeMap; + +use sqlx::prelude::*; +use sqlx::MySqlConnection; + +use serde::{Deserialize, Serialize}; + use crate::{ core::{ common::UnixUser, protocol::{ CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput, - ListDatabasesError, + ListAllDatabasesError, ListAllDatabasesOutput, ListDatabasesError, ListDatabasesOutput, }, }, server::{ @@ -12,11 +19,6 @@ use crate::{ }, }; -use sqlx::prelude::*; - -use sqlx::MySqlConnection; -use std::collections::BTreeMap; - // NOTE: this function is unsafe because it does no input validation. pub(super) async fn unsafe_database_exists( database_name: &str, @@ -157,11 +159,67 @@ pub async fn drop_databases( results } -pub async fn list_databases_for_user( +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct DatabaseRow { + pub database: String, +} + +pub async fn list_databases( + database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, -) -> Result, ListDatabasesError> { - let result = sqlx::query( +) -> ListDatabasesOutput { + let mut results = BTreeMap::new(); + + for database_name in database_names { + if let Err(err) = validate_name(&database_name) { + results.insert( + database_name.clone(), + Err(ListDatabasesError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { + results.insert( + database_name.clone(), + Err(ListDatabasesError::OwnershipError(err)), + ); + continue; + } + + let result = sqlx::query_as::<_, DatabaseRow>( + r#" + SELECT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` = ? + "#, + ) + .bind(&database_name) + .fetch_optional(&mut *connection) + .await + .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) + .and_then(|database| { + database + .map(Ok) + .unwrap_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist)) + }); + + if let Err(err) = &result { + log::error!("Failed to list database '{}': {:?}", &database_name, err); + } + + results.insert(database_name, result); + } + + results +} + +pub async fn list_all_databases_for_user( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> ListAllDatabasesOutput { + let result = sqlx::query_as::<_, DatabaseRow>( r#" SELECT `SCHEMA_NAME` AS `database` FROM `information_schema`.`SCHEMATA` @@ -172,12 +230,7 @@ pub async fn list_databases_for_user( .bind(create_user_group_matching_regex(unix_user)) .fetch_all(connection) .await - .and_then(|rows| { - rows.into_iter() - .map(|row| row.try_get::("database")) - .collect::, sqlx::Error>>() - }) - .map_err(|err| ListDatabasesError::MySqlError(err.to_string())); + .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string())); if let Err(err) = &result { log::error!(