Add command check-auth

This commit is contained in:
2025-11-29 19:25:33 +09:00
parent 03ddf0ac8a
commit 865b24884e
8 changed files with 215 additions and 13 deletions

View File

@@ -1,3 +1,4 @@
mod check_auth;
mod create_db;
mod create_user;
mod drop_db;
@@ -10,6 +11,7 @@ mod show_privs;
mod show_user;
mod unlock_user;
pub use check_auth::*;
pub use create_db::*;
pub use create_user::*;
pub use drop_db::*;
@@ -28,6 +30,10 @@ use crate::core::protocol::{ClientToServerMessageStream, Response};
#[derive(Parser, Debug, Clone)]
pub enum ClientCommand {
/// Check whether you are authorized to manage the specified databases or users.
#[command()]
CheckAuth(CheckAuthArgs),
/// Create one or more databases
#[command()]
CreateDb(CreateDbArgs),
@@ -141,6 +147,7 @@ pub async fn handle_command(
server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
match command {
ClientCommand::CheckAuth(args) => check_authorization(args, server_connection).await,
ClientCommand::CreateDb(args) => create_databases(args, server_connection).await,
ClientCommand::DropDb(args) => drop_databases(args, server_connection).await,
ClientCommand::ShowDb(args) => show_databases(args, server_connection).await,

View File

@@ -0,0 +1,67 @@
use crate::{
client::commands::erroneous_server_response,
core::{
protocol::{
ClientToServerMessageStream, Request, Response,
print_check_authorization_output_status, print_check_authorization_output_status_json,
},
types::DbOrUser,
},
};
use clap::Parser;
use futures_util::SinkExt;
use tokio_stream::StreamExt;
#[derive(Parser, Debug, Clone)]
pub struct CheckAuthArgs {
/// The name of the database(s) or user(s) to check authorization for
#[arg(num_args = 1..)]
name: Vec<String>,
/// Assume the names are users, not databases
#[arg(short, long)]
users: bool,
/// Print the information as JSON
#[arg(short, long)]
json: bool,
}
pub async fn check_authorization(
args: CheckAuthArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.name.is_empty() {
anyhow::bail!("No database/user names provided");
}
let payload = args
.name
.into_iter()
.map(|name| {
if args.users {
DbOrUser::User(name.into())
} else {
DbOrUser::Database(name.into())
}
})
.collect::<Vec<_>>();
let message = Request::CheckAuthorization(payload);
server_connection.send(message).await?;
let result = match server_connection.next().await {
Some(Ok(Response::CheckAuthorization(response))) => response,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_check_authorization_output_status_json(&result);
} else {
print_check_authorization_output_status(&result);
}
Ok(())
}

View File

@@ -1,3 +1,4 @@
mod check_authorization;
mod create_databases;
mod create_users;
mod drop_databases;
@@ -13,6 +14,7 @@ mod modify_privileges;
mod passwd_user;
mod unlock_users;
pub use check_authorization::*;
pub use create_databases::*;
pub use create_users::*;
pub use drop_databases::*;
@@ -60,6 +62,8 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Request {
CheckAuthorization(CheckAuthorizationRequest),
CreateDatabases(CreateDatabasesRequest),
DropDatabases(DropDatabasesRequest),
ListDatabases(ListDatabasesRequest),
@@ -82,6 +86,8 @@ pub enum Request {
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Response {
CheckAuthorization(CheckAuthorizationResponse),
// Specific data for specific commands
CreateDatabases(CreateDatabasesResponse),
DropDatabases(DropDatabasesResponse),

View File

@@ -0,0 +1,80 @@
use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::core::{
protocol::request_validation::{NameValidationError, OwnerValidationError},
types::DbOrUser,
};
pub type CheckAuthorizationRequest = Vec<DbOrUser>;
pub type CheckAuthorizationResponse = BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CheckAuthorizationError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
// AuthorizationHandlerError(String),
}
pub fn print_check_authorization_output_status(output: &CheckAuthorizationResponse) {
for (db_or_user, result) in output {
match result {
Ok(()) => {
println!("'{}': OK", db_or_user.name());
}
Err(err) => {
println!(
"'{}': {}",
db_or_user.name(),
err.to_error_message(db_or_user)
);
}
}
}
}
pub fn print_check_authorization_output_status_json(output: &CheckAuthorizationResponse) {
let value = output
.iter()
.map(|(db_or_user, result)| match result {
Ok(()) => (
db_or_user.name().to_string(),
json!({ "status": "success" }),
),
Err(err) => (
db_or_user.name().to_string(),
json!({
"status": "error",
"error": err.to_error_message(db_or_user),
}),
),
})
.collect::<serde_json::Map<_, _>>();
println!(
"{}",
serde_json::to_string_pretty(&value)
.unwrap_or("Failed to serialize result to JSON".to_string())
);
}
impl CheckAuthorizationError {
pub fn to_error_message(&self, db_or_user: &DbOrUser) -> String {
match self {
CheckAuthorizationError::SanitizationError(err) => {
err.to_error_message(db_or_user.clone())
}
CheckAuthorizationError::OwnershipError(err) => {
err.to_error_message(db_or_user.clone())
} // CheckAuthorizationError::AuthorizationHandlerError(msg) => {
// format!(
// "Authorization handler error for '{}': {}",
// db_or_user.name(),
// msg
// )
// }
}
}
}

View File

@@ -92,8 +92,7 @@ impl From<String> for MySQLDatabase {
}
}
/// This enum is used to differentiate between database and user operations.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum DbOrUser {
Database(MySQLDatabase),
User(MySQLUser),

View File

@@ -1,3 +1,4 @@
mod authorization;
pub mod command;
mod common;
pub mod config;

View File

@@ -0,0 +1,33 @@
use crate::{
core::{common::UnixUser, protocol::CheckAuthorizationError, types::DbOrUser},
server::input_sanitization::{validate_name, validate_ownership_by_unix_user},
};
pub async fn check_authorization(
dbs_or_users: Vec<DbOrUser>,
unix_user: &UnixUser,
) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> {
let mut results = std::collections::BTreeMap::new();
for db_or_user in dbs_or_users {
if let Err(err) = validate_name(db_or_user.name()) {
results.insert(
db_or_user.clone(),
Err(CheckAuthorizationError::SanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(db_or_user.name(), unix_user) {
results.insert(
db_or_user.clone(),
Err(CheckAuthorizationError::OwnershipError(err)),
);
continue;
}
results.insert(db_or_user.clone(), Ok(()));
}
results
}

View File

@@ -13,17 +13,20 @@ use crate::{
create_server_to_client_message_stream,
},
},
server::sql::{
database_operations::{
create_databases, drop_databases, list_all_databases_for_user, list_databases,
},
database_privilege_operations::{
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
},
user_operations::{
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
list_database_users, lock_database_users, set_password_for_database_user,
unlock_database_users,
server::{
authorization::check_authorization,
sql::{
database_operations::{
create_databases, drop_databases, list_all_databases_for_user, list_databases,
},
database_privilege_operations::{
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
},
user_operations::{
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
list_database_users, lock_database_users, set_password_for_database_user,
unlock_database_users,
},
},
},
};
@@ -119,6 +122,8 @@ async fn session_handler_with_db_connection(
stream.send(Response::Ready).await?;
loop {
// TODO: better error handling
// TODO: timeout for receiving requests
// TODO: cancel on request by supervisor
let request = match stream.next().await {
Some(Ok(request)) => request,
Some(Err(e)) => return Err(e.into()),
@@ -138,6 +143,10 @@ async fn session_handler_with_db_connection(
log::info!("Received request: {:#?}", request_to_display);
let response = match request {
Request::CheckAuthorization(dbs_or_users) => {
let result = check_authorization(dbs_or_users, unix_user).await;
Response::CheckAuthorization(result)
}
Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await;
Response::CreateDatabases(result)