list-db -> show-db #79

Merged
oysteikt merged 1 commits from list-db-to-show-db into main 2024-08-19 18:59:09 +02:00
8 changed files with 170 additions and 48 deletions

View File

@ -34,13 +34,15 @@ pub enum DatabaseCommand {
#[command()] #[command()]
DropDb(DatabaseDropArgs), DropDb(DatabaseDropArgs),
/// List all databases you have access to /// Print information about one or more databases
#[command()]
ListDb(DatabaseListArgs),
/// List user privileges for one or more databases
/// ///
/// If no database names are provided, it will show privileges for all databases you have access to. /// If no database name is provided, all databases you have access will be shown.
#[command()]
ShowDb(DatabaseShowArgs),
/// Print user privileges for one or more databases
///
/// If no database names are provided, all databases you have access to will be shown.
#[command()] #[command()]
ShowDbPrivs(DatabaseShowPrivsArgs), ShowDbPrivs(DatabaseShowPrivsArgs),
@ -113,7 +115,11 @@ pub struct DatabaseDropArgs {
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
pub struct DatabaseListArgs { pub struct DatabaseShowArgs {
/// The name of the database(s) to show.
#[arg(num_args = 0..)]
name: Vec<String>,
/// Whether to output the information in JSON format. /// Whether to output the information in JSON format.
#[arg(short, long)] #[arg(short, long)]
json: bool, json: bool,
@ -158,7 +164,7 @@ pub async fn handle_command(
match command { match command {
DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await, DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await,
DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await, DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await,
DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await, DatabaseCommand::ShowDb(args) => show_databases(args, server_connection).await,
DatabaseCommand::ShowDbPrivs(args) => { DatabaseCommand::ShowDbPrivs(args) => {
show_database_privileges(args, server_connection).await show_database_privileges(args, server_connection).await
} }
@ -214,35 +220,56 @@ async fn drop_databases(
Ok(()) Ok(())
} }
async fn list_databases( async fn show_databases(
args: DatabaseListArgs, args: DatabaseShowArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let message = Request::ListDatabases; let message = if args.name.is_empty() {
Request::ListDatabases(None)
} else {
Request::ListDatabases(Some(args.name.clone()))
};
server_connection.send(message).await?; server_connection.send(message).await?;
let result = match server_connection.next().await { let database_list = match server_connection.next().await {
Some(Ok(Response::ListAllDatabases(result))) => result, Some(Ok(Response::ListDatabases(databases))) => databases
.into_iter()
.filter_map(|(database_name, result)| match result {
Ok(database_row) => Some(database_row),
Err(err) => {
eprintln!("{}", err.to_error_message(&database_name));
eprintln!("Skipping...");
println!();
None
}
})
.collect::<Vec<_>>(),
Some(Ok(Response::ListAllDatabases(database_list))) => match database_list {
Ok(list) => list,
Err(err) => {
server_connection.send(Request::Exit).await?;
return Err(
anyhow::anyhow!(err.to_error_message()).context("Failed to list databases")
);
}
},
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
let database_list = match result {
Ok(list) => list,
Err(err) => {
return Err(anyhow::anyhow!(err.to_error_message()).context("Failed to list databases"))
}
};
if args.json { if args.json {
println!("{}", serde_json::to_string_pretty(&database_list)?); println!("{}", serde_json::to_string_pretty(&database_list)?);
} else if database_list.is_empty() { } else if database_list.is_empty() {
println!("No databases to show."); println!("No databases to show.");
} else { } else {
let mut table = Table::new();
table.add_row(Row::new(vec![Cell::new("Database")]));
for db in database_list { for db in database_list {
println!("{}", db); table.add_row(row![db.database]);
} }
table.printstd();
} }
Ok(()) Ok(())

View File

@ -269,7 +269,7 @@ async fn show_databases(
.collect(); .collect();
let message = if database_names.is_empty() { let message = if database_names.is_empty() {
let message = Request::ListDatabases; let message = Request::ListDatabases(None);
server_connection.send(message).await?; server_connection.send(message).await?;
let response = server_connection.next().await; let response = server_connection.next().await;
let databases = match response { let databases = match response {
@ -277,7 +277,9 @@ async fn show_databases(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
Request::ListPrivileges(Some(databases)) let database_names = databases.into_iter().map(|db| db.database).collect();
Request::ListPrivileges(Some(database_names))
} else { } else {
Request::ListPrivileges(Some(database_names)) Request::ListPrivileges(Some(database_names))
}; };

View File

@ -33,7 +33,7 @@ pub enum UserCommand {
#[command()] #[command()]
PasswdUser(UserPasswdArgs), PasswdUser(UserPasswdArgs),
/// Give information about one or more users /// Print information about one or more users
/// ///
/// If no username is provided, all users you have access will be shown. /// If no username is provided, all users you have access will be shown.
#[command()] #[command()]

View File

@ -36,7 +36,7 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer
pub enum Request { pub enum Request {
CreateDatabases(Vec<String>), CreateDatabases(Vec<String>),
DropDatabases(Vec<String>), DropDatabases(Vec<String>),
ListDatabases, ListDatabases(Option<Vec<String>>),
ListPrivileges(Option<Vec<String>>), ListPrivileges(Option<Vec<String>>),
ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>), ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>),
@ -59,6 +59,7 @@ pub enum Response {
// Specific data for specific commands // Specific data for specific commands
CreateDatabases(CreateDatabasesOutput), CreateDatabases(CreateDatabasesOutput),
DropDatabases(DropDatabasesOutput), DropDatabases(DropDatabasesOutput),
ListDatabases(ListDatabasesOutput),
ListAllDatabases(ListAllDatabasesOutput), ListAllDatabases(ListAllDatabasesOutput),
ListPrivileges(GetDatabasesPrivilegeData), ListPrivileges(GetDatabasesPrivilegeData),
ListAllPrivileges(GetAllDatabasesPrivilegeData), ListAllPrivileges(GetAllDatabasesPrivilegeData),

View File

@ -7,7 +7,8 @@ use serde::{Deserialize, Serialize};
use crate::{ use crate::{
core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff}, core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff},
server::sql::{ server::sql::{
database_privilege_operations::DatabasePrivilegeRow, user_operations::DatabaseUser, database_operations::DatabaseRow, database_privilege_operations::DatabasePrivilegeRow,
user_operations::DatabaseUser,
}, },
}; };
@ -202,16 +203,44 @@ impl DropDatabaseError {
} }
} }
pub type ListAllDatabasesOutput = Result<Vec<String>, ListDatabasesError>; pub type ListDatabasesOutput = BTreeMap<String, Result<DatabaseRow, ListDatabasesError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListDatabasesError { pub enum ListDatabasesError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
DatabaseDoesNotExist,
MySqlError(String), MySqlError(String),
} }
impl ListDatabasesError { impl ListDatabasesError {
pub fn to_error_message(&self, database_name: &str) -> String {
match self {
ListDatabasesError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
ListDatabasesError::OwnershipError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
ListDatabasesError::DatabaseDoesNotExist => {
format!("Database '{}' does not exist.", database_name)
}
ListDatabasesError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type ListAllDatabasesOutput = Result<Vec<DatabaseRow>, ListAllDatabasesError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListAllDatabasesError {
MySqlError(String),
}
impl ListAllDatabasesError {
pub fn to_error_message(&self) -> String { pub fn to_error_message(&self) -> String {
match self { match self {
ListDatabasesError::MySqlError(err) => format!("MySQL error: {}", err), ListAllDatabasesError::MySqlError(err) => format!("MySQL error: {}", err),
} }
} }
} }

View File

@ -7,6 +7,7 @@ use tokio::net::{UnixListener, UnixStream};
use sqlx::prelude::*; use sqlx::prelude::*;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use crate::server::sql::database_operations::list_databases;
use crate::{ use crate::{
core::{ core::{
common::{UnixUser, DEFAULT_SOCKET_PATH}, common::{UnixUser, DEFAULT_SOCKET_PATH},
@ -17,7 +18,7 @@ use crate::{
server::{ server::{
config::{create_mysql_connection_from_config, ServerConfig}, config::{create_mysql_connection_from_config, ServerConfig},
sql::{ sql::{
database_operations::{create_databases, drop_databases, list_databases_for_user}, database_operations::{create_databases, drop_databases, list_all_databases_for_user},
database_privilege_operations::{ database_privilege_operations::{
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
}, },
@ -183,9 +184,18 @@ pub async fn handle_requests_for_single_session_with_db_connection(
let result = drop_databases(databases_names, unix_user, db_connection).await; let result = drop_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::DropDatabases(result)).await?; stream.send(Response::DropDatabases(result)).await?;
} }
Request::ListDatabases => { Request::ListDatabases(database_names) => {
let result = list_databases_for_user(unix_user, db_connection).await; let response = match database_names {
stream.send(Response::ListAllDatabases(result)).await?; Some(database_names) => {
let result = list_databases(database_names, unix_user, db_connection).await;
Response::ListDatabases(result)
}
None => {
let result = list_all_databases_for_user(unix_user, db_connection).await;
Response::ListAllDatabases(result)
}
};
stream.send(response).await?;
} }
Request::ListPrivileges(database_names) => { Request::ListPrivileges(database_names) => {
let response = match database_names { let response = match database_names {

View File

@ -1,9 +1,16 @@
use std::collections::BTreeMap;
use sqlx::prelude::*;
use sqlx::MySqlConnection;
use serde::{Deserialize, Serialize};
use crate::{ use crate::{
core::{ core::{
common::UnixUser, common::UnixUser,
protocol::{ protocol::{
CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput, CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput,
ListDatabasesError, ListAllDatabasesError, ListAllDatabasesOutput, ListDatabasesError, ListDatabasesOutput,
}, },
}, },
server::{ server::{
@ -12,11 +19,6 @@ use crate::{
}, },
}; };
use sqlx::prelude::*;
use sqlx::MySqlConnection;
use std::collections::BTreeMap;
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
pub(super) async fn unsafe_database_exists( pub(super) async fn unsafe_database_exists(
database_name: &str, database_name: &str,
@ -157,11 +159,67 @@ pub async fn drop_databases(
results results
} }
pub async fn list_databases_for_user( #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)]
pub struct DatabaseRow {
pub database: String,
}
pub async fn list_databases(
database_names: Vec<String>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Vec<String>, ListDatabasesError> { ) -> ListDatabasesOutput {
let result = sqlx::query( let mut results = BTreeMap::new();
for database_name in database_names {
if let Err(err) = validate_name(&database_name) {
results.insert(
database_name.clone(),
Err(ListDatabasesError::SanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
results.insert(
database_name.clone(),
Err(ListDatabasesError::OwnershipError(err)),
);
continue;
}
let result = sqlx::query_as::<_, DatabaseRow>(
r#"
SELECT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` = ?
"#,
)
.bind(&database_name)
.fetch_optional(&mut *connection)
.await
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
.and_then(|database| {
database
.map(Ok)
.unwrap_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist))
});
if let Err(err) = &result {
log::error!("Failed to list database '{}': {:?}", &database_name, err);
}
results.insert(database_name, result);
}
results
}
pub async fn list_all_databases_for_user(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> ListAllDatabasesOutput {
let result = sqlx::query_as::<_, DatabaseRow>(
r#" r#"
SELECT `SCHEMA_NAME` AS `database` SELECT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA` FROM `information_schema`.`SCHEMATA`
@ -172,12 +230,7 @@ pub async fn list_databases_for_user(
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection) .fetch_all(connection)
.await .await
.and_then(|rows| { .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
rows.into_iter()
.map(|row| row.try_get::<String, _>("database"))
.collect::<Result<Vec<String>, sqlx::Error>>()
})
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()));
if let Err(err) = &result { if let Err(err) = &result {
log::error!( log::error!(