Rewrite entire codebase to split into client and server

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-10 02:16:38 +02:00
parent 20e60ca5c7
commit af86893acf
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
32 changed files with 3708 additions and 1599 deletions

99
Cargo.lock generated
View File

@ -418,6 +418,27 @@ dependencies = [
"zeroize",
]
[[package]]
name = "derive_more"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
"unicode-xid",
]
[[package]]
name = "dialoguer"
version = "0.11.0"
@ -470,6 +491,18 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "educe"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4bd92664bf78c4d3dba9b7cdafce6fa15b13ed3ed16175218196942e99168a8"
dependencies = [
"enum-ordinalize",
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "either"
version = "1.11.0"
@ -491,6 +524,26 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
[[package]]
name = "enum-ordinalize"
version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5"
dependencies = [
"enum-ordinalize-derive",
]
[[package]]
name = "enum-ordinalize-derive"
version = "4.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "env_filter"
version = "0.1.0"
@ -961,9 +1014,11 @@ dependencies = [
"async-bincode",
"bincode",
"clap",
"derive_more",
"dialoguer",
"env_logger",
"futures",
"futures-util",
"indoc",
"itertools",
"log",
@ -974,8 +1029,9 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
"thiserror",
"tokio",
"tokio-serde",
"tokio-stream",
"tokio-util",
"toml",
"uuid",
@ -1109,6 +1165,26 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "pin-project-lite"
version = "0.2.14"
@ -1931,6 +2007,21 @@ dependencies = [
"syn 2.0.60",
]
[[package]]
name = "tokio-serde"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "caf600e7036b17782571dd44fa0a5cea3c82f60db5137f774a325a76a0d6852b"
dependencies = [
"bincode",
"bytes",
"educe",
"futures-core",
"futures-sink",
"pin-project",
"serde",
]
[[package]]
name = "tokio-stream"
version = "0.1.15"
@ -2060,6 +2151,12 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
[[package]]
name = "unicode-xid"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "unicode_categories"
version = "0.1.1"

View File

@ -8,22 +8,25 @@ anyhow = "1.0.82"
async-bincode = "0.7.2"
bincode = "1.3.3"
clap = { version = "4.5.4", features = ["derive"] }
derive_more = { version = "1.0.0", features = ["display", "error"] }
dialoguer = "0.11.0"
env_logger = "0.11.3"
futures = "0.3.30"
futures-util = "0.3.30"
indoc = "2.0.5"
itertools = "0.12.1"
log = "0.4.21"
nix = { version = "0.28.0", features = ["fs", "user"] }
nix = { version = "0.28.0", features = ["fs", "process", "user"] }
prettytable = "0.10.0"
rand = "0.8.5"
ratatui = { version = "0.26.2", optional = true }
serde = "1.0.198"
serde_json = { version = "1.0.116", features = ["preserve_order"] }
sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] }
thiserror = "1.0.63"
tokio = { version = "1.37.0", features = ["rt", "macros"] }
tokio-util = "0.7.11"
tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.15"
tokio-util = { version = "0.7.11", features = ["codec"] }
toml = "0.8.12"
uuid = { version = "1.10.0", features = ["v4"] }

View File

@ -1,3 +1,6 @@
mod common;
pub mod database_command;
pub mod mysql_admutils_compatibility;
pub mod user_command;
#[cfg(feature = "mysql-admutils-compatibility")]
pub mod mysql_admutils_compatibility;

20
src/cli/common.rs Normal file
View File

@ -0,0 +1,20 @@
use crate::core::protocol::Response;
pub fn erroneous_server_response(
response: Option<Result<Response, std::io::Error>>,
) -> anyhow::Result<()> {
match response {
Some(Ok(Response::Error(e))) => {
anyhow::bail!("Server returned error: {}", e);
}
Some(Err(e)) => {
anyhow::bail!(e);
}
Some(response) => {
anyhow::bail!("Unexpected response from server: {:?}", response);
}
None => {
anyhow::bail!("No response from server");
}
}
}

View File

@ -1,17 +1,29 @@
use anyhow::Context;
use clap::Parser;
use dialoguer::{Confirm, Editor};
use futures_util::{SinkExt, StreamExt};
use nix::unistd::{getuid, User};
use prettytable::{Cell, Row, Table};
use sqlx::{Connection, MySqlConnection};
use crate::core::{
common::{close_database_connection, get_current_unix_user, yn, CommandStatus},
database_operations::*,
database_privilege_operations::*,
user_operations::user_exists,
use crate::{
cli::common::erroneous_server_response,
core::{
common::yn,
database_privileges::{
db_priv_field_human_readable_name, diff_privileges, display_privilege_diffs,
generate_editor_content_from_privilege_data, parse_privilege_data_from_editor_content,
parse_privilege_table_cli_arg,
},
protocol::{
print_create_databases_output_status, print_drop_databases_output_status,
print_modify_database_privileges_output_status, ClientToServerMessageStream, Request,
Response,
},
},
server::sql::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS},
};
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
// #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))]
pub enum DatabaseCommand {
/// Create one or more databases
@ -86,28 +98,28 @@ pub enum DatabaseCommand {
EditDbPrivs(DatabaseEditPrivsArgs),
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct DatabaseCreateArgs {
/// The name of the database(s) to create.
#[arg(num_args = 1..)]
name: Vec<String>,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct DatabaseDropArgs {
/// The name of the database(s) to drop.
#[arg(num_args = 1..)]
name: Vec<String>,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct DatabaseListArgs {
/// Whether to output the information in JSON format.
#[arg(short, long)]
json: bool,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct DatabaseShowPrivsArgs {
/// The name of the database(s) to show.
#[arg(num_args = 0..)]
@ -118,7 +130,7 @@ pub struct DatabaseShowPrivsArgs {
json: bool,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct DatabaseEditPrivsArgs {
/// The name of the database to edit privileges for.
pub name: Option<String>,
@ -141,125 +153,143 @@ pub struct DatabaseEditPrivsArgs {
pub async fn handle_command(
command: DatabaseCommand,
mut connection: MySqlConnection,
) -> anyhow::Result<CommandStatus> {
let result = connection
.transaction(|txn| {
Box::pin(async move {
match command {
DatabaseCommand::CreateDb(args) => create_databases(args, txn).await,
DatabaseCommand::DropDb(args) => drop_databases(args, txn).await,
DatabaseCommand::ListDb(args) => list_databases(args, txn).await,
DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await,
DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await,
}
})
})
.await;
close_database_connection(connection).await;
result
server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
match command {
DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await,
DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await,
DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await,
DatabaseCommand::ShowDbPrivs(args) => {
show_database_privileges(args, server_connection).await
}
DatabaseCommand::EditDbPrivs(args) => {
edit_database_privileges(args, server_connection).await
}
}
}
async fn create_databases(
args: DatabaseCreateArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.name.is_empty() {
anyhow::bail!("No database names provided");
}
let mut result = CommandStatus::SuccessfullyModified;
let message = Request::CreateDatabases(args.name.clone());
server_connection.send(message).await?;
for name in args.name {
// TODO: This can be optimized by fetching all the database privileges in one query.
if let Err(e) = create_database(&name, connection).await {
eprintln!("Failed to create database '{}': {}", name, e);
eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified;
} else {
println!("Database '{}' created.", name);
}
}
let result = match server_connection.next().await {
Some(Ok(Response::CreateDatabases(result))) => result,
response => return erroneous_server_response(response),
};
Ok(result)
server_connection.send(Request::Exit).await?;
print_create_databases_output_status(&result);
Ok(())
}
async fn drop_databases(
args: DatabaseDropArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.name.is_empty() {
anyhow::bail!("No database names provided");
}
let mut result = CommandStatus::SuccessfullyModified;
let message = Request::DropDatabases(args.name.clone());
server_connection.send(message).await?;
for name in args.name {
// TODO: This can be optimized by fetching all the database privileges in one query.
if let Err(e) = drop_database(&name, connection).await {
eprintln!("Failed to drop database '{}': {}", name, e);
eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified;
} else {
println!("Database '{}' dropped.", name);
}
}
let result = match server_connection.next().await {
Some(Ok(Response::DropDatabases(result))) => result,
response => return erroneous_server_response(response),
};
Ok(result)
server_connection.send(Request::Exit).await?;
print_drop_databases_output_status(&result);
Ok(())
}
async fn list_databases(
args: DatabaseListArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
let databases = get_database_list(connection).await?;
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let message = Request::ListDatabases;
server_connection.send(message).await?;
let result = match server_connection.next().await {
Some(Ok(Response::ListAllDatabases(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
let database_list = match result {
Ok(list) => list,
Err(err) => {
return Err(anyhow::anyhow!(err.to_error_message()).context("Failed to list databases"))
}
};
if args.json {
println!("{}", serde_json::to_string_pretty(&databases)?);
return Ok(CommandStatus::NoModificationsIntended);
}
if databases.is_empty() {
println!("{}", serde_json::to_string_pretty(&database_list)?);
} else if database_list.is_empty() {
println!("No databases to show.");
} else {
for db in databases {
for db in database_list {
println!("{}", db);
}
}
Ok(CommandStatus::NoModificationsIntended)
Ok(())
}
async fn show_database_privileges(
args: DatabaseShowPrivsArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
let database_users_to_show = if args.name.is_empty() {
get_all_database_privileges(connection).await?
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let message = if args.name.is_empty() {
Request::ListPrivileges(None)
} else {
// TODO: This can be optimized by fetching all the database privileges in one query.
let mut result = Vec::with_capacity(args.name.len());
for name in args.name {
match get_database_privileges(&name, connection).await {
Ok(db) => result.extend(db),
Err(e) => {
eprintln!("Failed to show database '{}': {}", name, e);
Request::ListPrivileges(Some(args.name.clone()))
};
server_connection.send(message).await?;
let privilege_data = match server_connection.next().await {
Some(Ok(Response::ListPrivileges(databases))) => databases
.into_iter()
.filter_map(|(database_name, result)| match result {
Ok(privileges) => Some(privileges),
Err(err) => {
eprintln!("{}", err.to_error_message(&database_name));
eprintln!("Skipping...");
println!();
None
}
})
.flatten()
.collect::<Vec<_>>(),
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
Ok(list) => list,
Err(err) => {
server_connection.send(Request::Exit).await?;
return Err(anyhow::anyhow!(err.to_error_message())
.context("Failed to list database privileges"));
}
}
result
},
response => return erroneous_server_response(response),
};
if args.json {
println!("{}", serde_json::to_string_pretty(&database_users_to_show)?);
return Ok(CommandStatus::NoModificationsIntended);
}
server_connection.send(Request::Exit).await?;
if database_users_to_show.is_empty() {
println!("No database users to show.");
if args.json {
println!("{}", serde_json::to_string_pretty(&privilege_data)?);
} else if privilege_data.is_empty() {
println!("No database privileges to show.");
} else {
let mut table = Table::new();
table.add_row(Row::new(
@ -270,7 +300,7 @@ async fn show_database_privileges(
.collect(),
));
for row in database_users_to_show {
for row in privilege_data {
table.add_row(row![
row.db,
row.user,
@ -290,17 +320,40 @@ async fn show_database_privileges(
table.printstd();
}
Ok(CommandStatus::NoModificationsIntended)
Ok(())
}
pub async fn edit_privileges(
pub async fn edit_database_privileges(
args: DatabaseEditPrivsArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
let privilege_data = if let Some(name) = &args.name {
get_database_privileges(name, connection).await?
} else {
get_all_database_privileges(connection).await?
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let message = Request::ListPrivileges(args.name.clone().map(|name| vec![name]));
server_connection.send(message).await?;
let privilege_data = match server_connection.next().await {
Some(Ok(Response::ListPrivileges(databases))) => databases
.into_iter()
.filter_map(|(database_name, result)| match result {
Ok(privileges) => Some(privileges),
Err(err) => {
eprintln!("{}", err.to_error_message(&database_name));
eprintln!("Skipping...");
println!();
None
}
})
.flatten()
.collect::<Vec<_>>(),
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
Ok(list) => list,
Err(err) => {
server_connection.send(Request::Exit).await?;
return Err(anyhow::anyhow!(err.to_error_message())
.context("Failed to list database privileges"));
}
},
response => return erroneous_server_response(response),
};
// TODO: The data from args should not be absolute.
@ -316,22 +369,16 @@ pub async fn edit_privileges(
edit_privileges_with_editor(&privilege_data)?
};
for row in privileges_to_change.iter() {
if !user_exists(&row.user, connection).await? {
// TODO: allow user to return and correct their mistake
anyhow::bail!("User {} does not exist", row.user);
}
}
let diffs = diff_privileges(&privilege_data, &privileges_to_change);
if diffs.is_empty() {
println!("No changes to make.");
return Ok(CommandStatus::NoModificationsNeeded);
return Ok(());
}
println!("The following changes will be made:\n");
println!("{}", display_privilege_diffs(&diffs));
if !args.yes
&& !Confirm::new()
.with_prompt("Do you want to apply these changes?")
@ -339,15 +386,27 @@ pub async fn edit_privileges(
.show_default(true)
.interact()?
{
return Ok(CommandStatus::Cancelled);
server_connection.send(Request::Exit).await?;
return Ok(());
}
apply_privilege_diffs(diffs, connection).await?;
let message = Request::ModifyPrivileges(diffs);
server_connection.send(message).await?;
Ok(CommandStatus::SuccessfullyModified)
let result = match server_connection.next().await {
Some(Ok(Response::ModifyPrivileges(result))) => result,
response => return erroneous_server_response(response),
};
// TODO: allow user to return and correct their mistake
print_modify_database_privileges_output_status(&result);
server_connection.send(Request::Exit).await?;
Ok(())
}
pub fn parse_privilege_tables_from_args(
fn parse_privilege_tables_from_args(
args: &DatabaseEditPrivsArgs,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
debug_assert!(!args.privs.is_empty());
@ -371,20 +430,22 @@ pub fn parse_privilege_tables_from_args(
Ok(result)
}
pub fn edit_privileges_with_editor(
fn edit_privileges_with_editor(
privilege_data: &[DatabasePrivilegeRow],
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = get_current_unix_user()?;
let unix_user = User::from_uid(getuid())
.context("Failed to look up your UNIX username")
.and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))?;
let editor_content =
generate_editor_content_from_privilege_data(privilege_data, &unix_user.name);
// TODO: handle errors better here
let result = Editor::new()
.extension("tsv")
.edit(&editor_content)?
.unwrap();
let result = Editor::new().extension("tsv").edit(&editor_content)?;
parse_privilege_data_from_editor_content(result)
.context("Could not parse privilege data from editor")
match result {
None => Ok(privilege_data.to_vec()),
Some(result) => parse_privilege_data_from_editor_content(result)
.context("Could not parse privilege data from editor"),
}
}

View File

@ -1,3 +1,4 @@
pub mod common;
mod error_messages;
pub mod mysql_dbadm;
pub mod mysql_useradm;

View File

@ -1,57 +1,4 @@
use crate::core::common::{
get_current_unix_user, validate_name_or_error, validate_ownership_or_error, DbOrUser,
};
/// In contrast to the new implementation which reports errors on any invalid name
/// for any reason, mysql-admutils would only log the error and skip that particular
/// name. This function replicates that behavior.
pub fn filter_db_or_user_names(
names: Vec<String>,
db_or_user: DbOrUser,
) -> anyhow::Result<Vec<String>> {
let unix_user = get_current_unix_user()?;
let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user {
DbOrUser::Database => "mysql-dbadm".to_string(),
DbOrUser::User => "mysql-useradm".to_string(),
});
let filtered_names = names
.into_iter()
// NOTE: The original implementation would only copy the first 32 characters
// of the argument into it's internal buffer. We replicate that behavior
// here.
.map(|name| name.chars().take(32).collect::<String>())
.filter(|name| {
if let Err(_err) = validate_ownership_or_error(name, &unix_user, db_or_user) {
println!(
"You are not in charge of mysql-{}: '{}'. Skipping.",
db_or_user.lowercased(),
name
);
return false;
}
true
})
.filter(|name| {
// NOTE: while this also checks for the length of the name,
// the name is already truncated to 32 characters. So
// if there is an error, it's guaranteed to be due to
// invalid characters.
if let Err(_err) = validate_name_or_error(name, db_or_user) {
println!(
concat!(
"{}: {} name '{}' contains invalid characters.\n",
"Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.",
),
argv0,
db_or_user.capitalized(),
name
);
return false;
}
true
})
.collect();
Ok(filtered_names)
#[inline]
pub fn trim_to_32_chars(name: &str) -> String {
name.chars().take(32).collect()
}

View File

@ -0,0 +1,176 @@
use crate::core::protocol::{
CreateDatabaseError, CreateUserError, DbOrUser, DropDatabaseError, DropUserError,
GetDatabasesPrivilegeDataError, ListUsersError,
};
pub fn name_validation_error_to_error_message(name: &str, db_or_user: DbOrUser) -> String {
let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user {
DbOrUser::Database => "mysql-dbadm".to_string(),
DbOrUser::User => "mysql-useradm".to_string(),
});
format!(
concat!(
"{}: {} name '{}' contains invalid characters.\n",
"Only A-Z, a-z, 0-9, _ (underscore) and - (dash) permitted. Skipping.",
),
argv0,
db_or_user.capitalized(),
name,
)
}
pub fn owner_validation_error_message(name: &str, db_or_user: DbOrUser) -> String {
format!(
"You are not in charge of mysql-{}: '{}'. Skipping.",
db_or_user.lowercased(),
name
)
}
pub fn handle_create_user_error(error: CreateUserError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-useradm".to_string());
match error {
CreateUserError::SanitizationError(_) => {
eprintln!(
"{}",
name_validation_error_to_error_message(name, DbOrUser::User)
);
}
CreateUserError::OwnershipError(_) => {
eprintln!("{}", owner_validation_error_message(name, DbOrUser::User));
}
CreateUserError::MySqlError(_) | CreateUserError::UserAlreadyExists => {
eprintln!("{}: Failed to create user '{}'.", argv0, name);
}
}
}
pub fn handle_drop_user_error(error: DropUserError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-useradm".to_string());
match error {
DropUserError::SanitizationError(_) => {
eprintln!(
"{}",
name_validation_error_to_error_message(name, DbOrUser::User)
);
}
DropUserError::OwnershipError(_) => {
eprintln!("{}", owner_validation_error_message(name, DbOrUser::User));
}
DropUserError::MySqlError(_) | DropUserError::UserDoesNotExist => {
eprintln!("{}: Failed to delete user '{}'.", argv0, name);
}
}
}
pub fn handle_list_users_error(error: ListUsersError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-useradm".to_string());
match error {
ListUsersError::SanitizationError(_) => {
eprintln!(
"{}",
name_validation_error_to_error_message(name, DbOrUser::User)
);
}
ListUsersError::OwnershipError(_) => {
eprintln!("{}", owner_validation_error_message(name, DbOrUser::User));
}
ListUsersError::UserDoesNotExist => {
eprintln!(
"{}: User '{}' does not exist. You must create it first.",
argv0, name,
);
}
ListUsersError::MySqlError(_) => {
eprintln!("{}: Failed to look up password for user '{}'", argv0, name);
}
}
}
// ----------------------------------------------------------------------------
pub fn handle_create_database_error(error: CreateDatabaseError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-dbadm".to_string());
match error {
CreateDatabaseError::SanitizationError(_) => {
eprintln!(
"{}",
name_validation_error_to_error_message(name, DbOrUser::Database)
);
}
CreateDatabaseError::OwnershipError(_) => {
eprintln!(
"{}",
owner_validation_error_message(name, DbOrUser::Database)
);
}
CreateDatabaseError::MySqlError(_) => {
eprintln!("{}: Cannot create database '{}'.", argv0, name);
}
CreateDatabaseError::DatabaseAlreadyExists => {
eprintln!("{}: Database '{}' already exists.", argv0, name);
}
}
}
pub fn handle_drop_database_error(error: DropDatabaseError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-dbadm".to_string());
match error {
DropDatabaseError::SanitizationError(_) => {
eprintln!(
"{}",
name_validation_error_to_error_message(name, DbOrUser::Database)
);
}
DropDatabaseError::OwnershipError(_) => {
eprintln!(
"{}",
owner_validation_error_message(name, DbOrUser::Database)
);
}
DropDatabaseError::MySqlError(_) => {
eprintln!("{}: Cannot drop database '{}'.", argv0, name);
}
DropDatabaseError::DatabaseDoesNotExist => {
eprintln!("{}: Database '{}' doesn't exist.", argv0, name);
}
}
}
pub fn format_show_database_error_message(
error: GetDatabasesPrivilegeDataError,
name: &str,
) -> String {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-dbadm".to_string());
match error {
GetDatabasesPrivilegeDataError::SanitizationError(_) => {
name_validation_error_to_error_message(name, DbOrUser::Database)
}
GetDatabasesPrivilegeDataError::OwnershipError(_) => {
owner_validation_error_message(name, DbOrUser::Database)
}
GetDatabasesPrivilegeDataError::MySqlError(err) => {
format!(
"{}: Failed to look up privileges for database '{}': {}",
argv0, name, err
)
}
GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => {
format!("{}: Database '{}' doesn't exist.", argv0, name)
}
}
}

View File

@ -1,14 +1,29 @@
use clap::Parser;
use sqlx::MySqlConnection;
use futures_util::{SinkExt, StreamExt};
use std::os::unix::net::UnixStream as StdUnixStream;
use std::path::PathBuf;
use tokio::net::UnixStream as TokioUnixStream;
use crate::{
cli::{database_command, mysql_admutils_compatibility::common::filter_db_or_user_names},
core::{
common::{yn, DbOrUser},
config::{create_mysql_connection_from_config, get_config, GlobalConfigArgs},
database_operations::{create_database, drop_database, get_database_list},
database_privilege_operations,
cli::{
common::erroneous_server_response,
database_command,
mysql_admutils_compatibility::{
common::trim_to_32_chars,
error_messages::{
format_show_database_error_message, handle_create_database_error,
handle_drop_database_error,
},
},
},
core::{
bootstrap::bootstrap_server_connection_and_drop_privileges,
protocol::{
create_client_to_server_message_stream, ClientToServerMessageStream,
GetDatabasesPrivilegeDataError, Request, Response,
},
},
server::sql::database_privilege_operations::DatabasePrivilegeRow,
};
const HELP_DB_PERM: &str = r#"
@ -39,8 +54,25 @@ pub struct Args {
#[command(subcommand)]
pub command: Option<Command>,
#[command(flatten)]
config_overrides: GlobalConfigArgs,
/// Path to the socket of the server, if it already exists.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true
)]
server_socket_path: Option<PathBuf>,
/// Config file to use for the server.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true
)]
config: Option<PathBuf>,
/// Print help for the 'editperm' subcommand.
#[arg(long, global = true)]
@ -76,7 +108,7 @@ pub enum Command {
/// to make changes to the permission table.
/// Run 'mysql-dbadm --help-editperm' for more
/// information.
EditPerm(EditPermArgs),
Editperm(EditPermArgs),
}
#[derive(Parser)]
@ -106,7 +138,7 @@ pub struct EditPermArgs {
pub database: String,
}
pub async fn main() -> anyhow::Result<()> {
pub fn main() -> anyhow::Result<()> {
let args: Args = Args::parse();
if args.help_editperm {
@ -114,6 +146,9 @@ pub async fn main() -> anyhow::Result<()> {
return Ok(());
}
let server_connection =
bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?;
let command = match args.command {
Some(command) => command,
None => {
@ -125,64 +160,164 @@ pub async fn main() -> anyhow::Result<()> {
}
};
let config = get_config(args.config_overrides)?;
let mut connection = create_mysql_connection_from_config(config.mysql).await?;
tokio_run_command(command, server_connection)?;
match command {
Command::Create(args) => {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?;
for name in filtered_names {
create_database(&name, &mut connection).await?;
println!("Database {} created.", name);
}
}
Command::Drop(args) => {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::Database)?;
for name in filtered_names {
drop_database(&name, &mut connection).await?;
println!("Database {} dropped.", name);
}
}
Command::Show(args) => {
let names = if args.name.is_empty() {
get_database_list(&mut connection).await?
} else {
filter_db_or_user_names(args.name, DbOrUser::Database)?
};
Ok(())
}
for name in names {
show_db(&name, &mut connection).await?;
}
}
Command::EditPerm(args) => {
// TODO: This does not accurately replicate the behavior of the old implementation.
// Hopefully, not many people rely on this in an automated fashion, as it
// is made to be interactive in nature. However, we should still try to
// replicate the old behavior as closely as possible.
let edit_privileges_args = database_command::DatabaseEditPrivsArgs {
name: Some(args.database),
privs: vec![],
json: false,
editor: None,
yes: false,
};
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let tokio_socket = TokioUnixStream::from_std(server_connection)?;
let message_stream = create_client_to_server_message_stream(tokio_socket);
match command {
Command::Create(args) => create_databases(args, message_stream).await,
Command::Drop(args) => drop_databases(args, message_stream).await,
Command::Show(args) => show_databases(args, message_stream).await,
Command::Editperm(args) => {
let edit_privileges_args = database_command::DatabaseEditPrivsArgs {
name: Some(args.database),
privs: vec![],
json: false,
// TODO: use this to mimic the old editor-finding logic
editor: None,
yes: false,
};
database_command::edit_privileges(edit_privileges_args, &mut connection).await?;
database_command::edit_database_privileges(edit_privileges_args, message_stream)
.await
}
}
})
}
async fn create_databases(
args: CreateArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let database_names = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::CreateDatabases(database_names);
server_connection.send(message).await?;
let result = match server_connection.next().await {
Some(Ok(Response::CreateDatabases(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
for (name, result) in result {
match result {
Ok(()) => println!("Database {} created.", name),
Err(err) => handle_create_database_error(err, &name),
}
}
Ok(())
}
async fn show_db(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
async fn drop_databases(
args: DatabaseDropArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let database_names = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::DropDatabases(database_names);
server_connection.send(message).await?;
let result = match server_connection.next().await {
Some(Ok(Response::DropDatabases(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
for (name, result) in result {
match result {
Ok(()) => println!("Database {} dropped.", name),
Err(err) => handle_drop_database_error(err, &name),
}
}
Ok(())
}
async fn show_databases(
args: DatabaseShowArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let database_names: Vec<String> = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = if database_names.is_empty() {
let message = Request::ListDatabases;
server_connection.send(message).await?;
let response = server_connection.next().await;
let databases = match response {
Some(Ok(Response::ListAllDatabases(databases))) => databases.unwrap_or(vec![]),
response => return erroneous_server_response(response),
};
Request::ListPrivileges(Some(databases))
} else {
Request::ListPrivileges(Some(database_names))
};
server_connection.send(message).await?;
let response = server_connection.next().await;
server_connection.send(Request::Exit).await?;
// NOTE: mysql-dbadm show has a quirk where valid database names
// for non-existent databases will report with no users.
// This function should *not* check for db existence, only
// validate the names.
let privileges = database_privilege_operations::get_database_privileges(name, connection)
.await
.unwrap_or(vec![]);
let results: Vec<Result<(String, Vec<DatabasePrivilegeRow>), String>> = match response {
Some(Ok(Response::ListPrivileges(result))) => result
.into_iter()
.map(|(name, rows)| match rows.map(|rows| (name.clone(), rows)) {
Ok(rows) => Ok(rows),
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist) => Ok((name, vec![])),
Err(err) => Err(format_show_database_error_message(err, &name)),
})
.collect(),
response => return erroneous_server_response(response),
};
results.into_iter().try_for_each(|result| match result {
Ok((name, rows)) => print_db_privs(&name, rows),
Err(err) => {
eprintln!("{}", err);
Ok(())
}
})?;
Ok(())
}
#[inline]
fn yn(value: bool) -> &'static str {
if value {
"Y"
} else {
"N"
}
}
fn print_db_privs(name: &str, rows: Vec<DatabasePrivilegeRow>) -> anyhow::Result<()> {
println!(
concat!(
"Database '{}':\n",
@ -191,10 +326,10 @@ async fn show_db(name: &str, connection: &mut MySqlConnection) -> anyhow::Result
),
name,
);
if privileges.is_empty() {
if rows.is_empty() {
println!("# (no permissions currently granted to any users)");
} else {
for privilege in privileges {
for privilege in rows {
println!(
" {:<16} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {:<7} {}",
privilege.user,

View File

@ -1,13 +1,28 @@
use clap::Parser;
use sqlx::MySqlConnection;
use futures_util::{SinkExt, StreamExt};
use std::path::PathBuf;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream;
use crate::{
cli::{mysql_admutils_compatibility::common::filter_db_or_user_names, user_command},
core::{
common::{close_database_connection, get_current_unix_user, DbOrUser},
config::{create_mysql_connection_from_config, get_config, GlobalConfigArgs},
user_operations::*,
cli::{
common::erroneous_server_response,
mysql_admutils_compatibility::{
common::trim_to_32_chars,
error_messages::{
handle_create_user_error, handle_drop_user_error, handle_list_users_error,
},
},
user_command::read_password_from_stdin_with_double_check,
},
core::{
bootstrap::bootstrap_server_connection_and_drop_privileges,
protocol::{
create_client_to_server_message_stream, ClientToServerMessageStream, Request, Response,
},
},
server::sql::user_operations::DatabaseUser,
};
#[derive(Parser)]
@ -15,8 +30,25 @@ pub struct Args {
#[command(subcommand)]
pub command: Option<Command>,
#[command(flatten)]
config_overrides: GlobalConfigArgs,
/// Path to the socket of the server, if it already exists.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true
)]
server_socket_path: Option<PathBuf>,
/// Config file to use for the server.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true
)]
config: Option<PathBuf>,
}
/// Create, delete or change password for the USER(s),
@ -69,7 +101,7 @@ pub struct ShowArgs {
name: Vec<String>,
}
pub async fn main() -> anyhow::Result<()> {
pub fn main() -> anyhow::Result<()> {
let args: Args = Args::parse();
let command = match args.command {
@ -85,78 +117,185 @@ pub async fn main() -> anyhow::Result<()> {
}
};
let config = get_config(args.config_overrides)?;
let mut connection = create_mysql_connection_from_config(config.mysql).await?;
let server_connection =
bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?;
match command {
Command::Create(args) => {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?;
for name in filtered_names {
create_database_user(&name, &mut connection).await?;
}
}
Command::Delete(args) => {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?;
for name in filtered_names {
delete_database_user(&name, &mut connection).await?;
}
}
Command::Passwd(args) => passwd(args, &mut connection).await?,
Command::Show(args) => show(args, &mut connection).await?,
}
close_database_connection(connection).await;
tokio_run_command(command, server_connection)?;
Ok(())
}
async fn passwd(args: PasswdArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> {
let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?;
// NOTE: this gets doubly checked during the call to `set_password_for_database_user`.
// This is moving the check before asking the user for the password,
// to avoid having them figure out that the user does not exist after they
// have entered the password twice.
let mut better_filtered_names = Vec::with_capacity(filtered_names.len());
for name in filtered_names.into_iter() {
if !user_exists(&name, connection).await? {
println!(
"{}: User '{}' does not exist. You must create it first.",
std::env::args()
.next()
.unwrap_or("mysql-useradm".to_string()),
name,
);
} else {
better_filtered_names.push(name);
}
}
for name in better_filtered_names {
let password = user_command::read_password_from_stdin_with_double_check(&name)?;
set_password_for_database_user(&name, &password, connection).await?;
println!("Password updated for user '{}'.", name);
}
Ok(())
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let tokio_socket = TokioUnixStream::from_std(server_connection)?;
let message_stream = create_client_to_server_message_stream(tokio_socket);
match command {
Command::Create(args) => create_user(args, message_stream).await,
Command::Delete(args) => drop_users(args, message_stream).await,
Command::Passwd(args) => passwd_users(args, message_stream).await,
Command::Show(args) => show_users(args, message_stream).await,
}
})
}
async fn show(args: ShowArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> {
let users = if args.name.is_empty() {
let unix_user = get_current_unix_user()?;
get_all_database_users_for_unix_user(&unix_user, connection).await?
} else {
let filtered_usernames = filter_db_or_user_names(args.name, DbOrUser::User)?;
let mut result = Vec::with_capacity(filtered_usernames.len());
for username in filtered_usernames.iter() {
// TODO: fetch all users in one query
if let Some(user) = get_database_user_for_user(username, connection).await? {
result.push(user)
}
}
result
async fn create_user(
args: CreateArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let usernames = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::CreateUsers(usernames);
server_connection.send(message).await?;
let result = match server_connection.next().await {
Some(Ok(Response::CreateUsers(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
for (name, result) in result {
match result {
Ok(()) => println!("User '{}' created.", name),
Err(err) => handle_create_user_error(err, &name),
}
}
Ok(())
}
async fn drop_users(
args: DeleteArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let usernames = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::DropUsers(usernames);
server_connection.send(message).await?;
let result = match server_connection.next().await {
Some(Ok(Response::DropUsers(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
for (name, result) in result {
match result {
Ok(()) => println!("User '{}' deleted.", name),
Err(err) => handle_drop_user_error(err, &name),
}
}
Ok(())
}
async fn passwd_users(
args: PasswdArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let usernames = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = Request::ListUsers(Some(usernames));
server_connection.send(message).await?;
let response = match server_connection.next().await {
Some(Ok(Response::ListUsers(result))) => result,
response => return erroneous_server_response(response),
};
let argv0 = std::env::args()
.next()
.unwrap_or("mysql-useradm".to_string());
let users = response
.into_iter()
.filter_map(|(name, result)| match result {
Ok(user) => Some(user),
Err(err) => {
handle_list_users_error(err, &name);
None
}
})
.collect::<Vec<_>>();
for user in users {
let password = read_password_from_stdin_with_double_check(&user.user)?;
let message = Request::PasswdUser(user.user.clone(), password);
server_connection.send(message).await?;
match server_connection.next().await {
Some(Ok(Response::PasswdUser(result))) => match result {
Ok(()) => println!("Password updated for user '{}'.", user.user),
Err(_) => eprintln!(
"{}: Failed to update password for user '{}'.",
argv0, user.user,
),
},
response => return erroneous_server_response(response),
}
}
server_connection.send(Request::Exit).await?;
Ok(())
}
async fn show_users(
args: ShowArgs,
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let usernames: Vec<_> = args
.name
.iter()
.map(|name| trim_to_32_chars(name))
.collect();
let message = if usernames.is_empty() {
Request::ListUsers(None)
} else {
Request::ListUsers(Some(usernames))
};
server_connection.send(message).await?;
let users: Vec<DatabaseUser> = match server_connection.next().await {
Some(Ok(Response::ListAllUsers(result))) => match result {
Ok(users) => users,
Err(err) => {
println!("Failed to list users: {:?}", err);
return Ok(());
}
},
Some(Ok(Response::ListUsers(result))) => result
.into_iter()
.filter_map(|(name, result)| match result {
Ok(user) => Some(user),
Err(err) => {
handle_list_users_error(err, &name);
None
}
})
.collect(),
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
for user in users {
if user.has_password {
println!("User '{}': password set.", user.user);

View File

@ -1,27 +1,24 @@
use std::collections::BTreeMap;
use std::vec;
use anyhow::Context;
use clap::Parser;
use dialoguer::{Confirm, Password};
use prettytable::Table;
use serde_json::json;
use sqlx::{Connection, MySqlConnection};
use futures_util::{SinkExt, StreamExt};
use crate::core::{
common::{close_database_connection, get_current_unix_user, CommandStatus},
database_operations::*,
user_operations::*,
use crate::core::protocol::{
print_create_users_output_status, print_drop_users_output_status,
print_lock_users_output_status, print_set_password_output_status,
print_unlock_users_output_status, ClientToServerMessageStream, Request, Response,
};
#[derive(Parser)]
use super::common::erroneous_server_response;
#[derive(Parser, Debug, Clone)]
pub struct UserArgs {
#[clap(subcommand)]
subcmd: UserCommand,
}
#[allow(clippy::enum_variant_names)]
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub enum UserCommand {
/// Create one or more users
#[command()]
@ -50,7 +47,7 @@ pub enum UserCommand {
UnlockUser(UserUnlockArgs),
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct UserCreateArgs {
#[arg(num_args = 1..)]
username: Vec<String>,
@ -60,13 +57,13 @@ pub struct UserCreateArgs {
no_password: bool,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct UserDeleteArgs {
#[arg(num_args = 1..)]
username: Vec<String>,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct UserPasswdArgs {
username: String,
@ -74,7 +71,7 @@ pub struct UserPasswdArgs {
password_file: Option<String>,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct UserShowArgs {
#[arg(num_args = 0..)]
username: Vec<String>,
@ -83,13 +80,13 @@ pub struct UserShowArgs {
json: bool,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct UserLockArgs {
#[arg(num_args = 1..)]
username: Vec<String>,
}
#[derive(Parser)]
#[derive(Parser, Debug, Clone)]
pub struct UserUnlockArgs {
#[arg(num_args = 1..)]
username: Vec<String>,
@ -97,48 +94,45 @@ pub struct UserUnlockArgs {
pub async fn handle_command(
command: UserCommand,
mut connection: MySqlConnection,
) -> anyhow::Result<CommandStatus> {
let result = connection
.transaction(|txn| {
Box::pin(async move {
match command {
UserCommand::CreateUser(args) => create_users(args, txn).await,
UserCommand::DropUser(args) => drop_users(args, txn).await,
UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await,
UserCommand::ShowUser(args) => show_users(args, txn).await,
UserCommand::LockUser(args) => lock_users(args, txn).await,
UserCommand::UnlockUser(args) => unlock_users(args, txn).await,
}
})
})
.await;
close_database_connection(connection).await;
result
server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
match command {
UserCommand::CreateUser(args) => create_users(args, server_connection).await,
UserCommand::DropUser(args) => drop_users(args, server_connection).await,
UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await,
UserCommand::ShowUser(args) => show_users(args, server_connection).await,
UserCommand::LockUser(args) => lock_users(args, server_connection).await,
UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await,
}
}
async fn create_users(
args: UserCreateArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.username.is_empty() {
anyhow::bail!("No usernames provided");
}
let mut result = CommandStatus::SuccessfullyModified;
let message = Request::CreateUsers(args.username.clone());
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
}
for username in args.username {
if let Err(e) = create_database_user(&username, connection).await {
eprintln!("{}", e);
eprintln!("Skipping...\n");
result = CommandStatus::PartiallySuccessfullyModified;
continue;
} else {
println!("User '{}' created.", username);
}
let result = match server_connection.next().await {
Some(Ok(Response::CreateUsers(result))) => result,
response => return erroneous_server_response(response),
};
print_create_users_output_status(&result);
let successfully_created_users = result
.iter()
.filter_map(|(username, result)| result.as_ref().ok().map(|_| username))
.collect::<Vec<_>>();
for username in successfully_created_users {
if !args.no_password
&& Confirm::new()
.with_prompt(format!(
@ -147,41 +141,55 @@ async fn create_users(
))
.interact()?
{
change_password_for_user(
UserPasswdArgs {
username,
password_file: None,
},
connection,
)
.await?;
let password = read_password_from_stdin_with_double_check(username)?;
let message = Request::PasswdUser(username.clone(), password);
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
match server_connection.next().await {
Some(Ok(Response::PasswdUser(result))) => {
print_set_password_output_status(&result, username)
}
response => return erroneous_server_response(response),
}
println!();
}
println!();
}
Ok(result)
server_connection.send(Request::Exit).await?;
Ok(())
}
async fn drop_users(
args: UserDeleteArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.username.is_empty() {
anyhow::bail!("No usernames provided");
}
let mut result = CommandStatus::SuccessfullyModified;
let message = Request::DropUsers(args.username.clone());
for username in args.username {
if let Err(e) = delete_database_user(&username, connection).await {
eprintln!("{}", e);
eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified;
} else {
println!("User '{}' dropped.", username);
}
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
Ok(result)
let result = match server_connection.next().await {
Some(Ok(Response::DropUsers(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
print_drop_users_output_status(&result);
Ok(())
}
pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> {
@ -195,15 +203,10 @@ pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Res
.map_err(Into::into)
}
async fn change_password_for_user(
async fn passwd_user(
args: UserPasswdArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
// NOTE: although this also is checked in `set_password_for_database_user`, we check it here
// to provide a more natural order of error messages.
let unix_user = get_current_unix_user()?;
validate_user_name(&args.username, &unix_user)?;
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let password = if let Some(password_file) = args.password_file {
std::fs::read_to_string(password_file)
.context("Failed to read password file")?
@ -213,129 +216,146 @@ async fn change_password_for_user(
read_password_from_stdin_with_double_check(&args.username)?
};
set_password_for_database_user(&args.username, &password, connection).await?;
let message = Request::PasswdUser(args.username.clone(), password);
Ok(CommandStatus::SuccessfullyModified)
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
let result = match server_connection.next().await {
Some(Ok(Response::PasswdUser(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
print_set_password_output_status(&result, &args.username);
Ok(())
}
async fn show_users(
args: UserShowArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
let unix_user = get_current_unix_user()?;
let users = if args.username.is_empty() {
get_all_database_users_for_unix_user(&unix_user, connection).await?
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
let message = if args.username.is_empty() {
Request::ListUsers(None)
} else {
let mut result = vec![];
for username in args.username {
if let Err(e) = validate_user_name(&username, &unix_user) {
eprintln!("{}", e);
eprintln!("Skipping...");
continue;
}
let user = get_database_user_for_user(&username, connection).await?;
if let Some(user) = user {
result.push(user);
} else {
eprintln!("User not found: {}", username);
}
}
result
Request::ListUsers(Some(args.username.clone()))
};
let mut user_databases: BTreeMap<String, Vec<String>> = BTreeMap::new();
for user in users.iter() {
user_databases.insert(
user.user.clone(),
get_databases_where_user_has_privileges(&user.user, connection).await?,
);
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
if args.json {
let users_json = users
let users = match server_connection.next().await {
Some(Ok(Response::ListUsers(users))) => users
.into_iter()
.map(|user| {
json!({
"user": user.user,
"has_password": user.has_password,
"is_locked": user.is_locked,
"databases": user_databases.get(&user.user).unwrap_or(&vec![]),
})
.filter_map(|(username, result)| match result {
Ok(user) => Some(user),
Err(err) => {
eprintln!("{}", err.to_error_message(&username));
eprintln!("Skipping...");
None
}
})
.collect::<serde_json::Value>();
.collect::<Vec<_>>(),
Some(Ok(Response::ListAllUsers(users))) => match users {
Ok(users) => users,
Err(err) => {
server_connection.send(Request::Exit).await?;
return Err(
anyhow::anyhow!(err.to_error_message()).context("Failed to list all users")
);
}
},
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
// TODO: print databases where user has privileges
if args.json {
println!(
"{}",
serde_json::to_string_pretty(&users_json)
.context("Failed to serialize users to JSON")?
serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")?
);
} else if users.is_empty() {
println!("No users found.");
println!("No users to show.");
} else {
let mut table = Table::new();
let mut table = prettytable::Table::new();
table.add_row(row![
"User",
"Password is set",
"Locked",
"Databases where user has privileges"
// "Databases where user has privileges"
]);
for user in users {
table.add_row(row![
user.user,
user.has_password,
user.is_locked,
user_databases.get(&user.user).unwrap_or(&vec![]).join("\n")
// user.databases.join("\n")
]);
}
table.printstd();
}
Ok(CommandStatus::NoModificationsIntended)
Ok(())
}
async fn lock_users(
args: UserLockArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.username.is_empty() {
anyhow::bail!("No usernames provided");
}
let mut result = CommandStatus::SuccessfullyModified;
let message = Request::LockUsers(args.username.clone());
for username in args.username {
if let Err(e) = lock_database_user(&username, connection).await {
eprintln!("{}", e);
eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified;
} else {
println!("User '{}' locked.", username);
}
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
Ok(result)
let result = match server_connection.next().await {
Some(Ok(Response::LockUsers(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
print_lock_users_output_status(&result);
Ok(())
}
async fn unlock_users(
args: UserUnlockArgs,
connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> {
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
if args.username.is_empty() {
anyhow::bail!("No usernames provided");
}
let mut result = CommandStatus::SuccessfullyModified;
let message = Request::UnlockUsers(args.username.clone());
for username in args.username {
if let Err(e) = unlock_database_user(&username, connection).await {
eprintln!("{}", e);
eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified;
} else {
println!("User '{}' unlocked.", username);
}
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
Ok(result)
let result = match server_connection.next().await {
Some(Ok(Response::UnlockUsers(result))) => result,
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
print_unlock_users_output_status(&result);
Ok(())
}

View File

@ -1,5 +1,4 @@
pub mod bootstrap;
pub mod common;
pub mod config;
pub mod database_operations;
pub mod database_privilege_operations;
pub mod user_operations;
pub mod database_privileges;
pub mod protocol;

177
src/core/bootstrap.rs Normal file
View File

@ -0,0 +1,177 @@
use std::{fs, path::PathBuf};
use anyhow::Context;
use nix::libc::{exit, EXIT_SUCCESS};
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream;
use crate::{
core::{
bootstrap::authenticated_unix_socket::client_authenticate,
common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH},
},
server::{config::read_config_form_path, server_loop::handle_requests_for_single_session},
};
pub mod authenticated_unix_socket;
// TODO: this function is security critical, it should be integration tested
// in isolation.
/// Drop privileges to the real user and group of the process.
/// If the process is not running with elevated privileges, this function
/// is a no-op.
pub fn drop_privs() -> anyhow::Result<()> {
log::debug!("Dropping privileges");
let real_uid = nix::unistd::getuid();
let real_gid = nix::unistd::getgid();
nix::unistd::setuid(real_uid).context("Failed to drop privileges")?;
nix::unistd::setgid(real_gid).context("Failed to drop privileges")?;
debug_assert_eq!(nix::unistd::getuid(), real_uid);
debug_assert_eq!(nix::unistd::getgid(), real_gid);
log::debug!("Privileges dropped successfully");
Ok(())
}
/// This function is used to bootstrap the connection to the server.
/// This can happen in two ways:
/// 1. If a socket path is provided, or exists in the default location,
/// the function will connect to the socket and authenticate with the
/// server to ensure that the server knows the uid of the client.
/// 2. If a config path is provided, or exists in the default location,
/// and the config is readable, the function will assume it is either
/// setuid or setgid, and will fork a child process to run the server
/// with the provided config. The server will exit silently by itself
/// when it is done, and this function will only return for the client
/// with the socket for the server.
/// If neither of these options are available, the function will fail.
pub fn bootstrap_server_connection_and_drop_privileges(
server_socket_path: Option<PathBuf>,
config_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> {
if server_socket_path.is_some() && config_path.is_some() {
anyhow::bail!("Cannot provide both a socket path and a config path");
}
log::debug!("Starting the server connection bootstrap process");
let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?;
drop_privs()?;
let result: anyhow::Result<StdUnixStream> = if do_authenticate {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let mut socket = TokioUnixStream::from_std(socket)?;
client_authenticate(&mut socket, None).await?;
Ok(socket.into_std()?)
})
} else {
Ok(socket)
};
result
}
/// Inner function for [`bootstrap_server_connection_and_drop_privileges`].
/// See that function for more information.
fn bootstrap_server_connection(
socket_path: Option<PathBuf>,
config_path: Option<PathBuf>,
) -> anyhow::Result<(StdUnixStream, bool)> {
// TODO: ensure this is both readable and writable
if let Some(socket_path) = socket_path {
log::debug!("Connecting to socket at {:?}", socket_path);
return match StdUnixStream::connect(socket_path) {
Ok(socket) => Ok((socket, true)),
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
},
};
}
if let Some(config_path) = config_path {
// ensure config exists and is readable
if fs::metadata(&config_path).is_err() {
return Err(anyhow::anyhow!("Config file not found or not readable"));
}
log::debug!("Starting server with config at {:?}", config_path);
return invoke_server_with_config(config_path).map(|socket| (socket, false));
}
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
Ok(socket) => Ok((socket, true)),
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
},
};
}
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
if fs::metadata(&config_path).is_ok() {
return invoke_server_with_config(config_path).map(|socket| (socket, false));
}
anyhow::bail!("No socket path or config path provided, and no default socket or config found");
}
// TODO: we should somehow ensure that the forked process is killed on completion,
// just in case the client does not behave properly.
/// Fork a child process to run the server with the provided config.
/// The server will exit silently by itself when it is done, and this function
/// will only return for the client with the socket for the server.
fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStream> {
let (server_socket, client_socket) = StdUnixStream::pair()?;
let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?;
match (unsafe { nix::unistd::fork() }).context("Failed to fork")? {
nix::unistd::ForkResult::Parent { child } => {
log::debug!("Forked child process with PID {}", child);
Ok(client_socket)
}
nix::unistd::ForkResult::Child => {
log::debug!("Running server in child process");
match run_forked_server(config_path, server_socket, unix_user) {
Err(e) => Err(e),
Ok(_) => unreachable!(),
}
}
}
}
/// Run the server in the forked child process.
/// This function will not return, but will exit the process with a success code.
fn run_forked_server(
config_path: PathBuf,
server_socket: StdUnixStream,
unix_user: UnixUser,
) -> anyhow::Result<()> {
let config = read_config_form_path(Some(config_path))?;
let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let socket = TokioUnixStream::from_std(server_socket)?;
handle_requests_for_single_session(socket, &unix_user, &config).await?;
Ok(())
});
result?;
unsafe {
exit(EXIT_SUCCESS);
}
}

View File

@ -30,10 +30,13 @@
//! Also note that it is essential that the client does not send any sensitive information
//! over it's authentication socket, since it is readable by any user on the system.
// TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream
use std::os::unix::io::AsRawFd;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
use derive_more::derive::{Display, Error};
use futures::{SinkExt, StreamExt};
use nix::{sys::stat, unistd::Uid};
use rand::distributions::Alphanumeric;
@ -52,7 +55,7 @@ pub enum ClientRequest {
Cancel,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)]
pub enum ServerResponse {
Authenticated,
ChallengeDidNotMatch,
@ -61,7 +64,7 @@ pub enum ServerResponse {
// TODO: wrap more data into the errors
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)]
pub enum ServerError {
InvalidRequest,
UnableToReadPermissionsFromAuthSocket,
@ -72,7 +75,7 @@ pub enum ServerError {
InvalidChallenge,
}
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Display, Error)]
pub enum ClientError {
UnableToConnectToServer,
UnableToOpenAuthSocket,
@ -80,13 +83,12 @@ pub enum ClientError {
AuthSocketClosedEarly,
UnableToCloseAuthSocket,
AuthenticationError,
InvalidServerResponse(ServerResponse),
UnableToParseServerResponse,
NoServerResponse,
ServerError(ServerError),
}
async fn create_auth_socket(socket_addr: &str) -> Result<UnixListener, ClientError> {
async fn create_auth_socket(socket_addr: &PathBuf) -> Result<UnixListener, ClientError> {
let auth_socket =
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
@ -109,11 +111,13 @@ type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDest
// TODO: add timeout
// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
pub async fn client_authenticate(
normal_socket: &mut UnixStream,
#[cfg(not(test))] auth_socket_dir: Option<PathBuf>,
#[cfg(test)] auth_socket_file: Option<PathBuf>,
auth_socket_dir: Option<PathBuf>,
) -> Result<(), ClientError> {
let random_prefix: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
@ -123,32 +127,16 @@ pub async fn client_authenticate(
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
#[cfg(not(test))]
let auth_socket_address = match auth_socket_dir {
Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(),
None => std::env::temp_dir()
.join(socket_name)
.to_str()
.unwrap()
.to_string(),
};
#[cfg(test)]
let auth_socket_address = match auth_socket_file {
Some(file) => file.to_str().unwrap().to_string(),
None => std::env::temp_dir()
.join(socket_name)
.to_str()
.unwrap()
.to_string(),
};
let auth_socket_address = auth_socket_dir
.unwrap_or(std::env::temp_dir())
.join(socket_name);
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
}
async fn client_authenticate_with_auth_socket_address(
normal_socket: &mut UnixStream,
auth_socket_address: &str,
auth_socket_address: &PathBuf,
) -> Result<(), ClientError> {
let auth_socket = create_auth_socket(auth_socket_address).await?;
@ -164,7 +152,7 @@ async fn client_authenticate_with_auth_socket_address(
async fn client_authenticate_with_auth_socket(
normal_socket: &mut UnixStream,
auth_socket: UnixListener,
auth_socket_address: &str,
auth_socket_address: &Path,
) -> Result<(), ClientError> {
let challenge = rand::random::<u64>();
let uid = nix::unistd::getuid();
@ -199,7 +187,10 @@ async fn client_authenticate_with_auth_socket(
let client_hello = ClientRequest::Initialize {
uid: uid.into(),
challenge,
auth_socket: auth_socket_address.to_string(),
auth_socket: auth_socket_address
.to_str()
.ok_or(ClientError::UnableToConfigureAuthSocket)?
.to_owned(),
};
normal_socket
@ -239,9 +230,13 @@ macro_rules! report_server_error_and_return {
}};
}
async fn server_authenticate(
pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result<Uid, ServerError> {
_server_authenticate(normal_socket, None).await
}
pub async fn _server_authenticate(
normal_socket: &mut UnixStream,
#[cfg(test)] unix_user_uid: Option<u32>,
unix_user_uid: Option<u32>,
) -> Result<Uid, ServerError> {
let mut normal_socket: ServerToClientStream =
AsyncBincodeStream::from(normal_socket).for_async();
@ -256,22 +251,15 @@ async fn server_authenticate(
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
};
#[cfg(test)]
let auth_socket_uid = match unix_user_uid {
Some(uid) => uid,
None => report_server_error_and_return!(
normal_socket,
ServerError::UnableToReadPermissionsFromAuthSocket
),
};
#[cfg(not(test))]
let auth_socket_uid = match stat::stat(auth_socket.as_str()) {
Ok(stat) => stat.st_uid,
Err(_err) => report_server_error_and_return!(
normal_socket,
ServerError::UnableToReadPermissionsFromAuthSocket
),
None => match stat::stat(auth_socket.as_str()) {
Ok(stat) => stat.st_uid,
Err(_err) => report_server_error_and_return!(
normal_socket,
ServerError::UnableToReadPermissionsFromAuthSocket
),
},
};
if uid != auth_socket_uid {
@ -324,10 +312,7 @@ mod test {
let client_handle =
tokio::spawn(async move { client_authenticate(&mut client, None).await });
let server_handle = tokio::spawn(async move {
let uid = nix::unistd::getuid().into();
server_authenticate(&mut server, Some(uid)).await
});
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
client_handle.await.unwrap().unwrap();
server_handle.await.unwrap().unwrap();
@ -340,15 +325,12 @@ mod test {
let client_handle = tokio::spawn(async move {
client_authenticate_with_auth_socket_address(
&mut client,
"/tmp/test_auth_socket_does_not_exist.sock",
&PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"),
)
.await
});
let server_handle = tokio::spawn(async move {
let uid = nix::unistd::getuid().into();
server_authenticate(&mut server, Some(uid)).await
});
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
client_handle.await.unwrap().unwrap();
server_handle.await.unwrap().unwrap();
@ -365,7 +347,7 @@ mod test {
let server_handle = tokio::spawn(async move {
let uid: u32 = nix::unistd::getuid().into();
let err = server_authenticate(&mut server, Some(uid + 1)).await;
let err = _server_authenticate(&mut server, Some(uid + 1)).await;
assert_eq!(err, Err(ServerError::UidMismatch));
});
@ -379,13 +361,19 @@ mod test {
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
let socket_path_clone = socket_path.clone();
let client_handle =
tokio::spawn(
async move { client_authenticate(&mut client, Some(socket_path_clone)).await },
);
let client_handle = tokio::spawn(async move {
client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await
});
while !socket_path.exists() {
sleep(std::time::Duration::from_millis(10)).await;
for i in 0..100 {
if socket_path.exists() {
break;
}
sleep(Duration::from_millis(10)).await;
if i == 99 {
panic!("Socket not created after 1 second, assuming test failure");
}
}
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
@ -409,10 +397,7 @@ mod test {
sleep(Duration::from_millis(10)).await;
let server_handle = tokio::spawn(async move {
let uid: u32 = nix::unistd::getuid().into();
server_authenticate(&mut server, Some(uid)).await
});
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
client_handle.await.unwrap().unwrap();
server_handle.await.unwrap().unwrap();

View File

@ -1,56 +1,32 @@
use anyhow::Context;
use indoc::indoc;
use itertools::Itertools;
use nix::unistd::{getuid, Group, User};
use sqlx::{Connection, MySqlConnection};
use nix::unistd::{Group as LibcGroup, User as LibcUser};
#[cfg(not(target_os = "macos"))]
use std::ffi::CString;
/// Report the result status of a command.
/// This is used to display a status message to the user.
pub enum CommandStatus {
/// The command was successful,
/// and made modification to the database.
SuccessfullyModified,
pub const DEFAULT_CONFIG_PATH: &str = "/etc/mysqladm/config.toml";
pub const DEFAULT_SOCKET_PATH: &str = "/run/mysqladm/mysqladm.sock";
/// The command was mostly successful,
/// and modifications have been made to the database.
/// However, some of the requested modifications failed.
PartiallySuccessfullyModified,
/// The command was successful,
/// but no modifications were needed.
NoModificationsNeeded,
/// The command was successful,
/// and made no modification to the database.
NoModificationsIntended,
/// The command was cancelled, either through a dialog or a signal.
/// No modifications have been made to the database.
Cancelled,
pub struct UnixUser {
pub username: String,
pub groups: Vec<String>,
}
pub fn get_current_unix_user() -> anyhow::Result<User> {
User::from_uid(getuid())
.context("Failed to look up your UNIX username")
.and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))
}
// TODO: these functions are somewhat critical, and should have integration tests
#[cfg(target_os = "macos")]
pub fn get_unix_groups(_user: &User) -> anyhow::Result<Vec<Group>> {
fn get_unix_groups(_user: &LibcUser) -> anyhow::Result<Vec<LibcGroup>> {
// Return an empty list on macOS since there is no `getgrouplist` function
Ok(vec![])
}
#[cfg(not(target_os = "macos"))]
pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> {
fn get_unix_groups(user: &LibcUser) -> anyhow::Result<Vec<LibcGroup>> {
let user_cstr =
CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?;
let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)?
.iter()
.filter_map(|gid| match Group::from_gid(*gid) {
.filter_map(|gid| match LibcGroup::from_gid(*gid) {
Ok(Some(group)) => Some(group),
Ok(None) => None,
Err(e) => {
@ -62,211 +38,32 @@ pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> {
None
}
})
.collect::<Vec<Group>>();
.collect::<Vec<LibcGroup>>();
Ok(groups)
}
/// This function creates a regex that matches items (users, databases)
/// that belong to the user or any of the user's groups.
pub fn create_user_group_matching_regex(user: &User) -> String {
let groups = get_unix_groups(user).unwrap_or_default();
impl UnixUser {
pub fn from_uid(uid: u32) -> anyhow::Result<Self> {
let libc_uid = nix::unistd::Uid::from_raw(uid);
let libc_user = LibcUser::from_uid(libc_uid)
.context("Failed to look up your UNIX username")?
.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))?;
if groups.is_empty() {
format!("{}(_.+)?", user.name)
} else {
format!(
"({}|{})(_.+)?",
user.name,
groups
.iter()
.map(|g| g.name.as_str())
.collect::<Vec<_>>()
.join("|")
)
}
}
let groups = get_unix_groups(&libc_user)?;
/// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
Ok(UnixUser {
username: libc_user.name,
groups: groups.iter().map(|g| g.name.clone()).collect(),
})
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
pub fn from_enviroment() -> anyhow::Result<Self> {
let libc_uid = nix::unistd::getuid();
UnixUser::from_uid(libc_uid.as_raw())
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum NameValidationResult {
Valid,
EmptyString,
InvalidCharacters,
TooLong,
}
pub fn validate_name(name: &str) -> NameValidationResult {
if name.is_empty() {
NameValidationResult::EmptyString
} else if name.len() > 64 {
NameValidationResult::TooLong
} else if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
NameValidationResult::InvalidCharacters
} else {
NameValidationResult::Valid
}
}
pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
match validate_name(name) {
NameValidationResult::Valid => Ok(()),
NameValidationResult::EmptyString => {
anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
}
NameValidationResult::TooLong => anyhow::bail!(
"{} is too long. Maximum length is 64 characters.",
db_or_user.capitalized()
),
NameValidationResult::InvalidCharacters => anyhow::bail!(
indoc! {r#"
Invalid characters in {} name: '{}'
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#},
db_or_user.lowercased(),
name
),
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum OwnerValidationResult {
// The name is valid and matches one of the given prefixes
Match,
// The name is valid, but none of the given prefixes matched the name
NoMatch,
// The name is empty, which is invalid
StringEmpty,
// The name is in the format "_<postfix>", which is invalid
MissingPrefix,
// The name is in the format "<prefix>_", which is invalid
MissingPostfix,
}
/// Core logic for validating the ownership of a database name.
/// This function checks if the given name matches any of the given prefixes.
/// These prefixes will in most cases be the user's unix username and any
/// unix groups the user is a member of.
pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
if name.is_empty() {
return OwnerValidationResult::StringEmpty;
}
if name.starts_with('_') {
return OwnerValidationResult::MissingPrefix;
}
let (prefix, _) = match name.split_once('_') {
Some(pair) => pair,
None => return OwnerValidationResult::MissingPostfix,
};
if prefixes.iter().any(|g| g == prefix) {
OwnerValidationResult::Match
} else {
OwnerValidationResult::NoMatch
}
}
/// Validate the ownership of a database name or database user name.
/// This function takes the name of a database or user and a unix user,
/// for which it fetches the user's groups. It then checks if the name
/// is prefixed with the user's username or any of the user's groups.
pub fn validate_ownership_or_error<'a>(
name: &'a str,
user: &User,
db_or_user: DbOrUser,
) -> anyhow::Result<&'a str> {
let user_groups = get_unix_groups(user)?;
let prefixes = std::iter::once(user.name.clone())
.chain(user_groups.iter().map(|g| g.name.clone()))
.collect::<Vec<String>>();
match validate_ownership_by_prefixes(name, &prefixes) {
OwnerValidationResult::Match => Ok(name),
OwnerValidationResult::NoMatch => {
anyhow::bail!(
indoc! {r#"
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
Allowed prefixes:
- {}
{}
"#},
db_or_user.lowercased(),
name,
db_or_user.lowercased(),
user.name,
user_groups
.iter()
.filter(|g| g.name != user.name)
.map(|g| format!(" - {}", g.name))
.sorted()
.join("\n"),
);
}
_ => anyhow::bail!(
"'{}' is not a valid {} name.",
name,
db_or_user.lowercased()
),
}
}
/// Gracefully close a MySQL connection.
pub async fn close_database_connection(connection: MySqlConnection) {
if let Err(e) = connection
.close()
.await
.context("Failed to close connection properly")
{
eprintln!("{}", e);
eprintln!("Ignoring...");
}
}
#[inline]
pub fn quote_literal(s: &str) -> String {
format!("'{}'", s.replace('\'', r"\'"))
}
#[inline]
pub fn quote_identifier(s: &str) -> String {
format!("`{}`", s.replace('`', r"\`"))
}
#[inline]
pub(crate) fn yn(b: bool) -> &'static str {
if b {
@ -303,94 +100,4 @@ mod test {
assert_eq!(rev_yn("n"), Some(false));
assert_eq!(rev_yn("X"), None);
}
#[test]
fn test_quote_literal() {
let payload = "' OR 1=1 --";
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
}
#[test]
fn test_quote_identifier() {
let payload = "` OR 1=1 --";
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
}
#[test]
fn test_validate_name() {
assert_eq!(validate_name(""), NameValidationResult::EmptyString);
assert_eq!(
validate_name("abcdefghijklmnopqrstuvwxyz"),
NameValidationResult::Valid
);
assert_eq!(
validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
NameValidationResult::Valid
);
assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
assert_eq!(
validate_name(&c.to_string()),
NameValidationResult::InvalidCharacters
);
}
assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
assert_eq!(
validate_name(&"a".repeat(65)),
NameValidationResult::TooLong
);
}
#[test]
fn test_validate_owner_by_prefixes() {
let prefixes = vec!["user".to_string(), "group".to_string()];
assert_eq!(
validate_ownership_by_prefixes("", &prefixes),
OwnerValidationResult::StringEmpty
);
assert_eq!(
validate_ownership_by_prefixes("user", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("something", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("user-testdb", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("_testdb", &prefixes),
OwnerValidationResult::MissingPrefix
);
assert_eq!(
validate_ownership_by_prefixes("user_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test_db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test-db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
OwnerValidationResult::NoMatch
);
}
}

View File

@ -1,120 +0,0 @@
use anyhow::Context;
use indoc::formatdoc;
use itertools::Itertools;
use nix::unistd::User;
use sqlx::{prelude::*, MySqlConnection};
use crate::core::{
common::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier,
validate_name_or_error, validate_ownership_or_error, DbOrUser,
},
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
};
pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
let user = get_current_unix_user()?;
validate_database_name(name, &user)?;
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name)))
.execute(connection)
.await
.map_err(|e| {
if e.to_string().contains("database exists") {
anyhow::anyhow!("Database '{}' already exists", name)
} else {
e.into()
}
})?;
Ok(())
}
pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
let user = get_current_unix_user()?;
validate_database_name(name, &user)?;
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name)))
.execute(connection)
.await
.map_err(|e| {
if e.to_string().contains("doesn't exist") {
anyhow::anyhow!("Database '{}' does not exist", name)
} else {
e.into()
}
})?;
Ok(())
}
pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result<Vec<String>> {
let unix_user = get_current_unix_user()?;
let databases: Vec<String> = sqlx::query(
r#"
SELECT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?
"#,
)
.bind(create_user_group_matching_regex(&unix_user))
.fetch_all(connection)
.await
.and_then(|row| {
row.into_iter()
.map(|row| row.try_get::<String, _>("database"))
.collect::<Result<_, _>>()
})
.context(format!(
"Failed to get databases for user '{}'",
unix_user.name
))?;
Ok(databases)
}
pub async fn get_databases_where_user_has_privileges(
username: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<String>> {
let result = sqlx::query(
formatdoc!(
r#"
SELECT `db` AS `database`
FROM `db`
WHERE `user` = ?
AND ({})
"#,
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{}` = 'Y'", field))
.join(" OR "),
)
.as_str(),
)
.bind(username)
.fetch_all(connection)
.await?
.into_iter()
.map(|databases| databases.try_get::<String, _>("database").unwrap())
.collect();
Ok(result)
}
/// NOTE: It is very critical that this function validates the database name
/// properly. MySQL does not seem to allow for prepared statements, binding
/// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection.
pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_or_error(name, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
Ok(())
}

View File

@ -1,52 +1,16 @@
//! Database privilege operations
//!
//! This module contains functions for querying, modifying,
//! displaying and comparing database privileges.
//!
//! A lot of the complexity comes from two core components:
//!
//! - The privilege editor that needs to be able to print
//! an editable table of privileges and reparse the content
//! after the user has made manual changes.
//!
//! - The comparison functionality that tells the user what
//! changes will be made when applying a set of changes
//! to the list of database privileges.
use std::collections::{BTreeSet, HashMap};
use anyhow::{anyhow, Context};
use indoc::indoc;
use itertools::Itertools;
use prettytable::Table;
use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
use crate::core::{
common::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn,
},
database_operations::validate_database_name,
use std::{
cmp::max,
collections::{BTreeSet, HashMap},
};
/// This is the list of fields that are used to fetch the db + user + privileges
/// from the `db` table in the database. If you need to add or remove privilege
/// fields, this is a good place to start.
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db",
"user",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
use super::common::{rev_yn, yn};
use crate::server::sql::database_privilege_operations::{
DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS,
};
pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name {
@ -67,162 +31,24 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String {
}
}
/// This struct represents the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow {
pub db: String,
pub user: String,
pub select_priv: bool,
pub insert_priv: bool,
pub update_priv: bool,
pub delete_priv: bool,
pub create_priv: bool,
pub drop_priv: bool,
pub alter_priv: bool,
pub index_priv: bool,
pub create_tmp_table_priv: bool,
pub lock_tables_priv: bool,
pub references_priv: bool,
}
pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
debug_assert!(row1.db == row2.db && row1.user == row2.user);
impl DatabasePrivilegeRow {
pub fn empty(db: &str, user: &str) -> Self {
Self {
db: db.to_owned(),
user: user.to_owned(),
select_priv: false,
insert_priv: false,
update_priv: false,
delete_priv: false,
create_priv: false,
drop_priv: false,
alter_priv: false,
index_priv: false,
create_tmp_table_priv: false,
lock_tables_priv: false,
references_priv: false,
}
DatabasePrivilegeRowDiff {
db: row1.db.clone(),
user: row1.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.filter_map(|field| {
DatabasePrivilegeChange::new(
row1.get_privilege_by_name(field),
row2.get_privilege_by_name(field),
field,
)
})
.collect(),
}
pub fn get_privilege_by_name(&self, name: &str) -> bool {
match name {
"select_priv" => self.select_priv,
"insert_priv" => self.insert_priv,
"update_priv" => self.update_priv,
"delete_priv" => self.delete_priv,
"create_priv" => self.create_priv,
"drop_priv" => self.drop_priv,
"alter_priv" => self.alter_priv,
"index_priv" => self.index_priv,
"create_tmp_table_priv" => self.create_tmp_table_priv,
"lock_tables_priv" => self.lock_tables_priv,
"references_priv" => self.references_priv,
_ => false,
}
}
pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
debug_assert!(self.db == other.db && self.user == other.user);
DatabasePrivilegeRowDiff {
db: self.db.clone(),
user: self.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.filter_map(|field| {
DatabasePrivilegeChange::new(
self.get_privilege_by_name(field),
other.get_privilege_by_name(field),
field,
)
})
.collect(),
}
}
}
#[inline]
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
let field = DATABASE_PRIVILEGE_FIELDS[position];
let value = row.try_get(position)?;
match rev_yn(value) {
Some(val) => Ok(val),
_ => {
log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
Ok(false)
}
}
}
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
db: row.try_get("db")?,
user: row.try_get("user")?,
select_priv: get_mysql_row_priv_field(row, 2)?,
insert_priv: get_mysql_row_priv_field(row, 3)?,
update_priv: get_mysql_row_priv_field(row, 4)?,
delete_priv: get_mysql_row_priv_field(row, 5)?,
create_priv: get_mysql_row_priv_field(row, 6)?,
drop_priv: get_mysql_row_priv_field(row, 7)?,
alter_priv: get_mysql_row_priv_field(row, 8)?,
index_priv: get_mysql_row_priv_field(row, 9)?,
create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
references_priv: get_mysql_row_priv_field(row, 12)?,
})
}
}
/// Get all users + privileges for a single database.
pub async fn get_database_privileges(
database_name: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = get_current_unix_user()?;
validate_database_name(database_name, &unix_user)?;
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(database_name)
.fetch_all(connection)
.await
.context("Failed to show database")?;
Ok(result)
}
/// Get all database + user + privileges pairs that are owned by the current user.
pub async fn get_all_database_privileges(
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = get_current_unix_user()?;
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?)
"#},
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(","),
))
.bind(create_user_group_matching_regex(&unix_user))
.fetch_all(connection)
.await
.context("Failed to show databases")?;
Ok(result)
}
/*************************/
@ -340,17 +166,23 @@ pub fn generate_editor_content_from_privilege_data(
// editor will be the example user and example db name.
// Hence, it's put as the fallback value, despite not really
// being a "fallback" in the normal sense.
let longest_username = privilege_data
.iter()
.map(|p| p.user.len())
.max()
.unwrap_or(example_user.len());
let longest_username = max(
privilege_data
.iter()
.map(|p| p.user.len())
.max()
.unwrap_or(example_user.len()),
"User".len(),
);
let longest_database_name = privilege_data
.iter()
.map(|p| p.db.len())
.max()
.unwrap_or(example_db.len());
let longest_database_name = max(
privilege_data
.iter()
.map(|p| p.db.len())
.max()
.unwrap_or(example_db.len()),
"Database".len(),
);
let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS
.into_iter()
@ -578,7 +410,7 @@ pub fn parse_privilege_data_from_editor_content(
/// instances of privilege sets for a single user on a single database.
///
/// The `User` and `Database` are the same for both instances.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRowDiff {
pub db: String,
pub user: String,
@ -586,7 +418,7 @@ pub struct DatabasePrivilegeRowDiff {
}
/// This enum represents a change for a single privilege.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub enum DatabasePrivilegeChange {
YesToNo(String),
NoToYes(String),
@ -603,13 +435,31 @@ impl DatabasePrivilegeChange {
}
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub enum DatabasePrivilegesDiff {
New(DatabasePrivilegeRow),
Modified(DatabasePrivilegeRowDiff),
Deleted(DatabasePrivilegeRow),
}
impl DatabasePrivilegesDiff {
pub fn get_database_name(&self) -> &str {
match self {
DatabasePrivilegesDiff::New(p) => &p.db,
DatabasePrivilegesDiff::Modified(p) => &p.db,
DatabasePrivilegesDiff::Deleted(p) => &p.db,
}
}
pub fn get_user_name(&self) -> &str {
match self {
DatabasePrivilegesDiff::New(p) => &p.user,
DatabasePrivilegesDiff::Modified(p) => &p.user,
DatabasePrivilegesDiff::Deleted(p) => &p.user,
}
}
}
/// This function calculates the differences between two sets of database privileges.
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
/// apply a set of privilege modifications to the database.
@ -633,7 +483,7 @@ pub fn diff_privileges(
for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p);
let diff = diff(old_p, p);
if !diff.diff.is_empty() {
result.insert(DatabasePrivilegesDiff::Modified(diff));
}
@ -651,72 +501,6 @@ pub fn diff_privileges(
result
}
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
pub async fn apply_privilege_diffs(
diffs: BTreeSet<DatabasePrivilegesDiff>,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
for diff in diffs {
match diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(",");
let question_marks = std::iter::repeat("?")
.take(DATABASE_PRIVILEGE_FIELDS.len())
.join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db)
.bind(p.user)
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(&mut *connection)
.await?;
}
DatabasePrivilegesDiff::Modified(p) => {
let tables = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name),
DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name),
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
)
.bind(p.db)
.bind(p.user)
.execute(&mut *connection)
.await?;
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db)
.bind(p.user)
.execute(&mut *connection)
.await?;
}
}
}
Ok(())
}
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
diff.diff
.iter()
@ -731,6 +515,20 @@ fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
.join("\n")
}
fn display_new_privileges_list(row: &DatabasePrivilegeRow) -> String {
DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.map(|field| {
if row.get_privilege_by_name(field) {
format!("{}: Y", db_priv_field_human_readable_name(field))
} else {
format!("{}: N", db_priv_field_human_readable_name(field))
}
})
.join("\n")
}
/// Displays the difference between two sets of database privileges.
pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> String {
let mut table = Table::new();
@ -741,24 +539,14 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri
table.add_row(row![
p.db,
p.user,
"(New user)\n".to_string()
+ &display_privilege_cell(
&DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p)
)
"(New user)\n".to_string() + &display_new_privileges_list(p)
]);
}
DatabasePrivilegesDiff::Modified(p) => {
table.add_row(row![p.db, p.user, display_privilege_cell(p),]);
}
DatabasePrivilegesDiff::Deleted(p) => {
table.add_row(row![
p.db,
p.user,
"(All privileges removed)\n".to_string()
+ &display_privilege_cell(
&p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user))
)
]);
table.add_row(row![p.db, p.user, "Removed".to_string()]);
}
}
}

5
src/core/protocol.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod request_response;
pub mod server_responses;
pub use request_response::*;
pub use server_responses::*;

View File

@ -0,0 +1,79 @@
use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
use tokio::net::UnixStream;
use tokio_serde::{formats::Bincode, Framed as SerdeFramed};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use crate::core::{database_privileges::DatabasePrivilegesDiff, protocol::*};
pub type ServerToClientMessageStream = SerdeFramed<
Framed<UnixStream, LengthDelimitedCodec>,
Request,
Response,
Bincode<Request, Response>,
>;
pub type ClientToServerMessageStream = SerdeFramed<
Framed<UnixStream, LengthDelimitedCodec>,
Response,
Request,
Bincode<Response, Request>,
>;
pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream {
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
tokio_serde::Framed::new(length_delimited, Bincode::default())
}
pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream {
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
tokio_serde::Framed::new(length_delimited, Bincode::default())
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Request {
CreateDatabases(Vec<String>),
DropDatabases(Vec<String>),
ListDatabases,
ListPrivileges(Option<Vec<String>>),
ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>),
CreateUsers(Vec<String>),
DropUsers(Vec<String>),
PasswdUser(String, String),
ListUsers(Option<Vec<String>>),
LockUsers(Vec<String>),
UnlockUsers(Vec<String>),
// Commit,
Exit,
}
// TODO: include a generic "message" that will display a message to the user?
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Response {
// Specific data for specific commands
CreateDatabases(CreateDatabasesOutput),
DropDatabases(DropDatabasesOutput),
ListAllDatabases(ListAllDatabasesOutput),
ListPrivileges(GetDatabasesPrivilegeData),
ListAllPrivileges(GetAllDatabasesPrivilegeData),
ModifyPrivileges(ModifyDatabasePrivilegesOutput),
CreateUsers(CreateUsersOutput),
DropUsers(DropUsersOutput),
PasswdUser(SetPasswordOutput),
ListUsers(ListUsersOutput),
ListAllUsers(ListAllUsersOutput),
LockUsers(LockUsersOutput),
UnlockUsers(UnlockUsersOutput),
// Generic responses
OperationAborted,
Error(String),
Exit,
}

View File

@ -0,0 +1,611 @@
use std::collections::BTreeMap;
use indoc::indoc;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::{
core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff},
server::sql::{
database_privilege_operations::DatabasePrivilegeRow, user_operations::DatabaseUser,
},
};
/// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum NameValidationError {
EmptyString,
InvalidCharacters,
TooLong,
}
impl NameValidationError {
pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String {
match self {
NameValidationError::EmptyString => {
format!("{} name cannot be empty.", db_or_user.capitalized()).to_owned()
}
NameValidationError::TooLong => format!(
"{} is too long. Maximum length is 64 characters.",
db_or_user.capitalized()
)
.to_owned(),
NameValidationError::InvalidCharacters => format!(
indoc! {r#"
Invalid characters in {} name: '{}'
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#},
db_or_user.lowercased(),
name
)
.to_owned(),
}
}
}
impl OwnerValidationError {
pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String {
let user = UnixUser::from_enviroment();
match self {
OwnerValidationError::NoMatch => format!(
indoc! {r#"
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
Allowed prefixes:
- {}
{}
"#},
db_or_user.lowercased(),
name,
db_or_user.lowercased(),
user.as_ref()
.map(|u| u.username.clone())
.unwrap_or("???".to_string()),
user.map(|u| u.groups)
.unwrap_or_default()
.iter()
.map(|g| format!(" - {}", g))
.sorted()
.join("\n"),
)
.to_owned(),
_ => format!(
"'{}' is not a valid {} name.",
name,
db_or_user.lowercased()
)
.to_string(),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum OwnerValidationError {
// The name is valid, but none of the given prefixes matched the name
NoMatch,
// The name is empty, which is invalid
StringEmpty,
// The name is in the format "_<postfix>", which is invalid
MissingPrefix,
// The name is in the format "<prefix>_", which is invalid
MissingPostfix,
}
pub type CreateDatabasesOutput = BTreeMap<String, Result<(), CreateDatabaseError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CreateDatabaseError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
DatabaseAlreadyExists,
MySqlError(String),
}
pub fn print_create_databases_output_status(output: &CreateDatabasesOutput) {
for (database_name, result) in output {
match result {
Ok(()) => {
println!("Database '{}' created successfully.", database_name);
}
Err(err) => {
println!("{}", err.to_error_message(database_name));
println!("Skipping...");
}
}
println!();
}
}
impl CreateDatabaseError {
pub fn to_error_message(&self, database_name: &str) -> String {
match self {
CreateDatabaseError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
CreateDatabaseError::OwnershipError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
CreateDatabaseError::DatabaseAlreadyExists => {
format!("Database {} already exists.", database_name)
}
CreateDatabaseError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type DropDatabasesOutput = BTreeMap<String, Result<(), DropDatabaseError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DropDatabaseError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
DatabaseDoesNotExist,
MySqlError(String),
}
pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) {
for (database_name, result) in output {
match result {
Ok(()) => {
println!("Database '{}' dropped successfully.", database_name);
}
Err(err) => {
println!("{}", err.to_error_message(database_name));
println!("Skipping...");
}
}
println!();
}
}
impl DropDatabaseError {
pub fn to_error_message(&self, database_name: &str) -> String {
match self {
DropDatabaseError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
DropDatabaseError::OwnershipError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
DropDatabaseError::DatabaseDoesNotExist => {
format!("Database {} does not exist.", database_name)
}
DropDatabaseError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type ListAllDatabasesOutput = Result<Vec<String>, ListDatabasesError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListDatabasesError {
MySqlError(String),
}
impl ListDatabasesError {
pub fn to_error_message(&self) -> String {
match self {
ListDatabasesError::MySqlError(err) => format!("MySQL error: {}", err),
}
}
}
// TODO: merge all rows into a single collection.
// they already contain which database they belong to.
// no need to index by database name.
pub type GetDatabasesPrivilegeData =
BTreeMap<String, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum GetDatabasesPrivilegeDataError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
DatabaseDoesNotExist,
MySqlError(String),
}
impl GetDatabasesPrivilegeDataError {
pub fn to_error_message(&self, database_name: &str) -> String {
match self {
GetDatabasesPrivilegeDataError::SanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
GetDatabasesPrivilegeDataError::OwnershipError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => {
format!("Database '{}' does not exist.", database_name)
}
GetDatabasesPrivilegeDataError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type GetAllDatabasesPrivilegeData =
Result<Vec<DatabasePrivilegeRow>, GetAllDatabasesPrivilegeDataError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum GetAllDatabasesPrivilegeDataError {
MySqlError(String),
}
impl GetAllDatabasesPrivilegeDataError {
pub fn to_error_message(&self) -> String {
match self {
GetAllDatabasesPrivilegeDataError::MySqlError(err) => format!("MySQL error: {}", err),
}
}
}
pub type ModifyDatabasePrivilegesOutput =
BTreeMap<(String, String), Result<(), ModifyDatabasePrivilegesError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModifyDatabasePrivilegesError {
DatabaseSanitizationError(NameValidationError),
DatabaseOwnershipError(OwnerValidationError),
UserSanitizationError(NameValidationError),
UserOwnershipError(OwnerValidationError),
DatabaseDoesNotExist,
DiffDoesNotApply(DiffDoesNotApplyError),
MySqlError(String),
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DiffDoesNotApplyError {
RowAlreadyExists(String, String),
RowDoesNotExist(String, String),
RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow),
}
pub fn print_modify_database_privileges_output_status(output: &ModifyDatabasePrivilegesOutput) {
for ((database_name, username), result) in output {
match result {
Ok(()) => {
println!(
"Privileges for user '{}' on database '{}' modified successfully.",
username, database_name
);
}
Err(err) => {
println!("{}", err.to_error_message(database_name, username));
println!("Skipping...");
}
}
println!();
}
}
impl ModifyDatabasePrivilegesError {
pub fn to_error_message(&self, database_name: &str, username: &str) -> String {
match self {
ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
ModifyDatabasePrivilegesError::DatabaseOwnershipError(err) => {
err.to_error_message(database_name, DbOrUser::Database)
}
ModifyDatabasePrivilegesError::UserSanitizationError(err) => {
err.to_error_message(username, DbOrUser::User)
}
ModifyDatabasePrivilegesError::UserOwnershipError(err) => {
err.to_error_message(username, DbOrUser::User)
}
ModifyDatabasePrivilegesError::DatabaseDoesNotExist => {
format!("Database '{}' does not exist.", database_name)
}
ModifyDatabasePrivilegesError::DiffDoesNotApply(diff) => {
format!(
"Could not apply privilege change:\n{}",
diff.to_error_message()
)
}
ModifyDatabasePrivilegesError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
impl DiffDoesNotApplyError {
pub fn to_error_message(&self) -> String {
match self {
DiffDoesNotApplyError::RowAlreadyExists(database_name, username) => {
format!(
"Privileges for user '{}' on database '{}' already exist.",
username, database_name
)
}
DiffDoesNotApplyError::RowDoesNotExist(database_name, username) => {
format!(
"Privileges for user '{}' on database '{}' do not exist.",
username, database_name
)
}
DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(diff, row) => {
format!(
"Could not apply privilege change {:?} to row {:?}",
diff, row
)
}
}
}
}
pub type CreateUsersOutput = BTreeMap<String, Result<(), CreateUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum CreateUserError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
UserAlreadyExists,
MySqlError(String),
}
pub fn print_create_users_output_status(output: &CreateUsersOutput) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' created successfully.", username);
}
Err(err) => {
println!("{}", err.to_error_message(username));
println!("Skipping...");
}
}
println!();
}
}
impl CreateUserError {
pub fn to_error_message(&self, username: &str) -> String {
match self {
CreateUserError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User)
}
CreateUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
CreateUserError::UserAlreadyExists => {
format!("User '{}' already exists.", username)
}
CreateUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type DropUsersOutput = BTreeMap<String, Result<(), DropUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum DropUserError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
UserDoesNotExist,
MySqlError(String),
}
pub fn print_drop_users_output_status(output: &DropUsersOutput) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' dropped successfully.", username);
}
Err(err) => {
println!("{}", err.to_error_message(username));
println!("Skipping...");
}
}
println!();
}
}
impl DropUserError {
pub fn to_error_message(&self, username: &str) -> String {
match self {
DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User),
DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
DropUserError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
}
DropUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type SetPasswordOutput = Result<(), SetPasswordError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SetPasswordError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
UserDoesNotExist,
MySqlError(String),
}
pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &str) {
match output {
Ok(()) => {
println!("Password for user '{}' set successfully.", username);
}
Err(err) => {
println!("{}", err.to_error_message(username));
println!("Skipping...");
}
}
}
impl SetPasswordError {
pub fn to_error_message(&self, username: &str) -> String {
match self {
SetPasswordError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User)
}
SetPasswordError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
SetPasswordError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
}
SetPasswordError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type LockUsersOutput = BTreeMap<String, Result<(), LockUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum LockUserError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
UserDoesNotExist,
UserIsAlreadyLocked,
MySqlError(String),
}
pub fn print_lock_users_output_status(output: &LockUsersOutput) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' locked successfully.", username);
}
Err(err) => {
println!("{}", err.to_error_message(username));
println!("Skipping...");
}
}
println!();
}
}
impl LockUserError {
pub fn to_error_message(&self, username: &str) -> String {
match self {
LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User),
LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
LockUserError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
}
LockUserError::UserIsAlreadyLocked => {
format!("User '{}' is already locked.", username)
}
LockUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type UnlockUsersOutput = BTreeMap<String, Result<(), UnlockUserError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum UnlockUserError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
UserDoesNotExist,
UserIsAlreadyUnlocked,
MySqlError(String),
}
pub fn print_unlock_users_output_status(output: &UnlockUsersOutput) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' unlocked successfully.", username);
}
Err(err) => {
println!("{}", err.to_error_message(username));
println!("Skipping...");
}
}
println!();
}
}
impl UnlockUserError {
pub fn to_error_message(&self, username: &str) -> String {
match self {
UnlockUserError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User)
}
UnlockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
UnlockUserError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
}
UnlockUserError::UserIsAlreadyUnlocked => {
format!("User '{}' is already unlocked.", username)
}
UnlockUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type ListUsersOutput = BTreeMap<String, Result<DatabaseUser, ListUsersError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListUsersError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
UserDoesNotExist,
MySqlError(String),
}
impl ListUsersError {
pub fn to_error_message(&self, username: &str) -> String {
match self {
ListUsersError::SanitizationError(err) => {
err.to_error_message(username, DbOrUser::User)
}
ListUsersError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
ListUsersError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
}
ListUsersError::MySqlError(err) => {
format!("MySQL error: {}", err)
}
}
}
}
pub type ListAllUsersOutput = Result<Vec<DatabaseUser>, ListAllUsersError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ListAllUsersError {
MySqlError(String),
}
impl ListAllUsersError {
pub fn to_error_message(&self) -> String {
match self {
ListAllUsersError::MySqlError(err) => format!("MySQL error: {}", err),
}
}
}

View File

@ -1,249 +0,0 @@
use anyhow::Context;
use nix::unistd::User;
use serde::{Deserialize, Serialize};
use sqlx::{prelude::*, MySqlConnection};
use crate::core::common::{
create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
validate_ownership_or_error, DbOrUser,
};
pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
let user_exists = sqlx::query(
r#"
SELECT EXISTS(
SELECT 1
FROM `mysql`.`user`
WHERE `User` = ?
)
"#,
)
.bind(db_user)
.fetch_one(connection)
.await?
.get::<bool, _>(0);
Ok(user_exists)
}
pub async fn create_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' already exists", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
pub async fn delete_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
pub async fn set_password_for_database_user(
db_user: &str,
password: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = crate::core::common::get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(
format!(
"ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user),
quote_literal(password).as_str()
)
.as_str(),
)
.execute(connection)
.await?;
Ok(())
}
async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
let is_locked = sqlx::query(
r#"
SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true'
FROM `mysql`.`global_priv`
WHERE `User` = ?
AND `Host` = '%'
"#,
)
.bind(db_user)
.fetch_one(connection)
.await?
.get::<bool, _>(0);
Ok(is_locked)
}
pub async fn lock_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
if user_is_locked(db_user, connection).await? {
anyhow::bail!("User '{}' is already locked", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
pub async fn unlock_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
if !user_is_locked(db_user, connection).await? {
anyhow::bail!("User '{}' is already unlocked", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
/// This struct contains information about a database user.
/// This can be extended if we need more information in the future.
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
pub struct DatabaseUser {
#[sqlx(rename = "User")]
pub user: String,
#[allow(dead_code)]
#[serde(skip)]
#[sqlx(rename = "Host")]
pub host: String,
#[sqlx(rename = "has_password")]
pub has_password: bool,
#[sqlx(rename = "is_locked")]
pub is_locked: bool,
}
const DB_USER_SELECT_STATEMENT: &str = r#"
SELECT
`mysql`.`user`.`User`,
`mysql`.`user`.`Host`,
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
COALESCE(
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
'false'
) != 'false' AS `is_locked`
FROM `mysql`.`user`
JOIN `mysql`.`global_priv` ON
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
"#;
/// This function fetches all database users that have a prefix matching the
/// unix username and group names of the given unix user.
pub async fn get_all_database_users_for_unix_user(
unix_user: &User,
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabaseUser>> {
let users = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
)
.bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection)
.await?;
Ok(users)
}
/// This function fetches a database user if it exists.
pub async fn get_database_user_for_user(
username: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<Option<DatabaseUser>> {
let user = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
)
.bind(username)
.fetch_optional(connection)
.await?;
Ok(user)
}
/// NOTE: It is very critical that this function validates the database name
/// properly. MySQL does not seem to allow for prepared statements, binding
/// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection.
pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_or_error(name, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?;
Ok(())
}

View File

@ -1,42 +1,69 @@
#[macro_use]
extern crate prettytable;
use core::common::CommandStatus;
#[cfg(feature = "mysql-admutils-compatibility")]
use clap::Parser;
use std::path::PathBuf;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream;
use crate::{
core::{
bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs},
protocol::create_client_to_server_message_stream,
},
server::command::ServerArgs,
};
#[cfg(feature = "mysql-admutils-compatibility")]
use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
use clap::Parser;
mod server;
mod authenticated_unix_socket;
mod cli;
mod core;
#[cfg(feature = "tui")]
mod tui;
#[derive(Parser)]
#[derive(Parser, Debug)]
struct Args {
#[command(subcommand)]
command: Command,
#[command(flatten)]
config_overrides: core::config::GlobalConfigArgs,
/// Path to the socket of the server, if it already exists.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true
)]
server_socket_path: Option<PathBuf>,
/// Config file to use for the server.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true
)]
config: Option<PathBuf>,
#[cfg(feature = "tui")]
#[arg(short, long, alias = "tui", global = true)]
interactive: bool,
}
/// Database administration tool for non-admin users to manage their own MySQL databases and users.
///
/// This tool allows you to manage users and databases in MySQL.
///
/// You are only allowed to manage databases and users that are prefixed with
/// either your username, or a group that you are a member of.
#[derive(Parser)]
// Database administration tool for non-admin users to manage their own MySQL databases and users.
//
// This tool allows you to manage users and databases in MySQL.
//
// You are only allowed to manage databases and users that are prefixed with
// either your username, or a group that you are a member of.
#[derive(Parser, Debug, Clone)]
#[command(version, about, disable_help_subcommand = true)]
enum Command {
#[command(flatten)]
@ -44,10 +71,18 @@ enum Command {
#[command(flatten)]
User(cli::user_command::UserCommand),
#[command(hide = true)]
Server(server::command::ServerArgs),
}
#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
// TODO: tag all functions that are run with elevated privileges with
// comments emphasizing the need for caution.
fn main() -> anyhow::Result<()> {
// TODO: find out if there are any security risks of running
// env_logger and clap with elevated privileges.
env_logger::init();
#[cfg(feature = "mysql-admutils-compatibility")]
@ -59,42 +94,60 @@ async fn main() -> anyhow::Result<()> {
});
match argv0.as_deref() {
Some("mysql-dbadm") => return mysql_dbadm::main().await,
Some("mysql-useradm") => return mysql_useradm::main().await,
Some("mysql-dbadm") => return mysql_dbadm::main(),
Some("mysql-useradm") => return mysql_useradm::main(),
_ => { /* fall through */ }
}
}
let args: Args = Args::parse();
let config = core::config::get_config(args.config_overrides)?;
let connection = core::config::create_mysql_connection_from_config(config.mysql).await?;
let result = match args.command {
Command::Db(command) => cli::database_command::handle_command(command, connection).await,
Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await,
};
match result {
Ok(CommandStatus::SuccessfullyModified) => {
println!("Modifications committed successfully");
Ok(())
match args.command {
Command::Server(ref command) => {
drop_privs()?;
tokio_start_server(args.server_socket_path, args.config, command.clone())?;
return Ok(());
}
Ok(CommandStatus::PartiallySuccessfullyModified) => {
println!("Some modifications committed successfully");
Ok(())
}
Ok(CommandStatus::NoModificationsNeeded) => {
println!("No modifications made");
Ok(())
}
Ok(CommandStatus::NoModificationsIntended) => {
/* Don't report anything */
Ok(())
}
Ok(CommandStatus::Cancelled) => {
println!("Command cancelled successfully");
Ok(())
}
Err(e) => Err(e),
_ => { /* fall through */ }
}
let server_connection =
bootstrap_server_connection_and_drop_privileges(args.server_socket_path, args.config)?;
tokio_run_command(args.command, server_connection)?;
Ok(())
}
fn tokio_start_server(
server_socket_path: Option<PathBuf>,
config_path: Option<PathBuf>,
args: ServerArgs,
) -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
server::command::handle_command(server_socket_path, config_path, args).await
})
}
fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyhow::Result<()> {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let tokio_socket = TokioUnixStream::from_std(server_connection)?;
let message_stream = create_client_to_server_message_stream(tokio_socket);
match command {
Command::User(user_args) => {
cli::user_command::handle_command(user_args, message_stream).await
}
Command::Db(db_args) => {
cli::database_command::handle_command(db_args, message_stream).await
}
Command::Server(_) => unreachable!(),
}
})
}

6
src/server.rs Normal file
View File

@ -0,0 +1,6 @@
pub mod command;
mod common;
pub mod config;
pub mod input_sanitization;
pub mod server_loop;
pub mod sql;

77
src/server/command.rs Normal file
View File

@ -0,0 +1,77 @@
use std::os::fd::FromRawFd;
use std::path::PathBuf;
use anyhow::Context;
use clap::Parser;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream;
use crate::core::bootstrap::authenticated_unix_socket;
use crate::core::common::UnixUser;
use crate::server::config::read_config_from_path_with_arg_overrides;
use crate::server::server_loop::listen_for_incoming_connections;
use crate::server::{
config::{ServerConfig, ServerConfigArgs},
server_loop::handle_requests_for_single_session,
};
#[derive(Parser, Debug, Clone)]
pub struct ServerArgs {
#[command(subcommand)]
subcmd: ServerCommand,
#[command(flatten)]
config_overrides: ServerConfigArgs,
}
#[derive(Parser, Debug, Clone)]
pub enum ServerCommand {
#[command()]
Listen,
#[command()]
SocketActivate,
}
pub async fn handle_command(
socket_path: Option<PathBuf>,
config_path: Option<PathBuf>,
args: ServerArgs,
) -> anyhow::Result<()> {
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
// if let Err(e) = &result {
// eprintln!("{}", e);
// }
match args.subcmd {
ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
ServerCommand::SocketActivate => socket_activate(config).await,
}
}
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
// TODO: allow getting socket path from other socket activation sources
let mut conn = get_socket_from_systemd().await?;
let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?;
let unix_user = UnixUser::from_uid(uid.into())?;
handle_requests_for_single_session(conn, &unix_user, &config).await?;
Ok(())
}
async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
let fd = std::env::var("LISTEN_FDS")
.context("LISTEN_FDS not set, not running under systemd?")?
.parse::<i32>()
.context("Failed to parse LISTEN_FDS")?;
if fd != 1 {
return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd));
}
let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
let socket = TokioUnixStream::from_std(std_unix_stream)?;
Ok(socket)
}

11
src/server/common.rs Normal file
View File

@ -0,0 +1,11 @@
use crate::core::common::UnixUser;
/// This function creates a regex that matches items (users, databases)
/// that belong to the user or any of the user's groups.
pub fn create_user_group_matching_regex(user: &UnixUser) -> String {
if user.groups.is_empty() {
format!("{}(_.+)?", user.username)
} else {
format!("({}|{})(_.+)?", user.username, user.groups.join("|"))
}
}

View File

@ -5,11 +5,16 @@ use clap::Parser;
use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection};
use crate::core::common::DEFAULT_CONFIG_PATH;
pub const DEFAULT_PORT: u16 = 3306;
pub const DEFAULT_TIMEOUT: u64 = 2;
// NOTE: this might look empty now, and the extra wrapping for the mysql
// config seems unnecessary, but it will be useful later when we
// add more configuration options.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Config {
pub struct ServerConfig {
pub mysql: MysqlConfig,
}
@ -23,58 +28,36 @@ pub struct MysqlConfig {
pub timeout: Option<u64>,
}
const DEFAULT_PORT: u16 = 3306;
const DEFAULT_TIMEOUT: u64 = 2;
#[derive(Parser)]
pub struct GlobalConfigArgs {
/// Path to the configuration file.
#[arg(
short,
long,
value_name = "PATH",
global = true,
hide_short_help = true,
default_value = "/etc/mysqladm/config.toml"
)]
config_file: String,
#[derive(Parser, Debug, Clone)]
pub struct ServerConfigArgs {
/// Hostname of the MySQL server.
#[arg(long, value_name = "HOST", global = true, hide_short_help = true)]
#[arg(long, value_name = "HOST", global = true)]
mysql_host: Option<String>,
/// Port of the MySQL server.
#[arg(long, value_name = "PORT", global = true, hide_short_help = true)]
#[arg(long, value_name = "PORT", global = true)]
mysql_port: Option<u16>,
/// Username to use for the MySQL connection.
#[arg(long, value_name = "USER", global = true, hide_short_help = true)]
#[arg(long, value_name = "USER", global = true)]
mysql_user: Option<String>,
/// Path to a file containing the MySQL password.
#[arg(long, value_name = "PATH", global = true, hide_short_help = true)]
#[arg(long, value_name = "PATH", global = true)]
mysql_password_file: Option<String>,
/// Seconds to wait for the MySQL connection to be established.
#[arg(long, value_name = "SECONDS", global = true, hide_short_help = true)]
#[arg(long, value_name = "SECONDS", global = true)]
mysql_connect_timeout: Option<u64>,
}
/// Use the arguments and whichever configuration file which might or might not
/// be found and default values to determine the configuration for the program.
pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> {
let config_path = PathBuf::from(args.config_file);
let config: Config = fs::read_to_string(&config_path)
.context(format!(
"Failed to read config file from {:?}",
&config_path
))
.and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
.context(format!(
"Failed to parse config file from {:?}",
&config_path
))?;
pub fn read_config_from_path_with_arg_overrides(
config_path: Option<PathBuf>,
args: ServerConfigArgs,
) -> anyhow::Result<ServerConfig> {
let config = read_config_form_path(config_path)?;
let mysql = &config.mysql;
@ -86,22 +69,35 @@ pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> {
mysql.password.to_owned()
};
let mysql_config = MysqlConfig {
host: args.mysql_host.unwrap_or(mysql.host.to_owned()),
port: args.mysql_port.or(mysql.port),
username: args.mysql_user.unwrap_or(mysql.username.to_owned()),
password,
timeout: args.mysql_connect_timeout.or(mysql.timeout),
};
Ok(Config {
mysql: mysql_config,
Ok(ServerConfig {
mysql: MysqlConfig {
host: args.mysql_host.unwrap_or(mysql.host.to_owned()),
port: args.mysql_port.or(mysql.port),
username: args.mysql_user.unwrap_or(mysql.username.to_owned()),
password,
timeout: args.mysql_connect_timeout.or(mysql.timeout),
},
})
}
pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
fs::read_to_string(&config_path)
.context(format!(
"Failed to read config file from {:?}",
&config_path
))
.and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
.context(format!(
"Failed to parse config file from {:?}",
&config_path
))
}
/// Use the provided configuration to establish a connection to a MySQL server.
pub async fn create_mysql_connection_from_config(
config: MysqlConfig,
config: &MysqlConfig,
) -> anyhow::Result<MySqlConnection> {
match tokio::time::timeout(
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),

View File

@ -0,0 +1,158 @@
use crate::core::{
common::UnixUser,
protocol::server_responses::{NameValidationError, OwnerValidationError},
};
const MAX_NAME_LENGTH: usize = 64;
pub fn validate_name(name: &str) -> Result<(), NameValidationError> {
if name.is_empty() {
Err(NameValidationError::EmptyString)
} else if name.len() > MAX_NAME_LENGTH {
Err(NameValidationError::TooLong)
} else if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
Err(NameValidationError::InvalidCharacters)
} else {
Ok(())
}
}
pub fn validate_ownership_by_unix_user(
name: &str,
user: &UnixUser,
) -> Result<(), OwnerValidationError> {
let prefixes = std::iter::once(user.username.clone())
.chain(user.groups.iter().cloned())
.collect::<Vec<String>>();
validate_ownership_by_prefixes(name, &prefixes)
}
/// Core logic for validating the ownership of a database name.
/// This function checks if the given name matches any of the given prefixes.
/// These prefixes will in most cases be the user's unix username and any
/// unix groups the user is a member of.
pub fn validate_ownership_by_prefixes(
name: &str,
prefixes: &[String],
) -> Result<(), OwnerValidationError> {
if name.is_empty() {
return Err(OwnerValidationError::StringEmpty);
}
if name.starts_with('_') {
return Err(OwnerValidationError::MissingPrefix);
}
let (prefix, _) = match name.split_once('_') {
Some(pair) => pair,
None => return Err(OwnerValidationError::MissingPostfix),
};
if !prefixes.iter().any(|g| g == prefix) {
return Err(OwnerValidationError::NoMatch);
}
Ok(())
}
#[inline]
pub fn quote_literal(s: &str) -> String {
format!("'{}'", s.replace('\'', r"\'"))
}
#[inline]
pub fn quote_identifier(s: &str) -> String {
format!("`{}`", s.replace('`', r"\`"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quote_literal() {
let payload = "' OR 1=1 --";
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
}
#[test]
fn test_quote_identifier() {
let payload = "` OR 1=1 --";
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
}
#[test]
fn test_validate_name() {
assert_eq!(validate_name(""), Err(NameValidationError::EmptyString));
assert_eq!(validate_name("abcdefghijklmnopqrstuvwxyz"), Ok(()));
assert_eq!(validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), Ok(()));
assert_eq!(validate_name("0123456789_-"), Ok(()));
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
assert_eq!(
validate_name(&c.to_string()),
Err(NameValidationError::InvalidCharacters)
);
}
assert_eq!(validate_name(&"a".repeat(MAX_NAME_LENGTH)), Ok(()));
assert_eq!(
validate_name(&"a".repeat(MAX_NAME_LENGTH + 1)),
Err(NameValidationError::TooLong)
);
}
#[test]
fn test_validate_owner_by_prefixes() {
let prefixes = vec!["user".to_string(), "group".to_string()];
assert_eq!(
validate_ownership_by_prefixes("", &prefixes),
Err(OwnerValidationError::StringEmpty)
);
assert_eq!(
validate_ownership_by_prefixes("user", &prefixes),
Err(OwnerValidationError::MissingPostfix)
);
assert_eq!(
validate_ownership_by_prefixes("something", &prefixes),
Err(OwnerValidationError::MissingPostfix)
);
assert_eq!(
validate_ownership_by_prefixes("user-testdb", &prefixes),
Err(OwnerValidationError::MissingPostfix)
);
assert_eq!(
validate_ownership_by_prefixes("_testdb", &prefixes),
Err(OwnerValidationError::MissingPrefix)
);
assert_eq!(
validate_ownership_by_prefixes("user_testdb", &prefixes),
Ok(())
);
assert_eq!(
validate_ownership_by_prefixes("group_testdb", &prefixes),
Ok(())
);
assert_eq!(
validate_ownership_by_prefixes("group_test_db", &prefixes),
Ok(())
);
assert_eq!(
validate_ownership_by_prefixes("group_test-db", &prefixes),
Ok(())
);
assert_eq!(
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
Err(OwnerValidationError::NoMatch)
);
}
}

229
src/server/server_loop.rs Normal file
View File

@ -0,0 +1,229 @@
use std::{collections::BTreeSet, fs, path::PathBuf};
use anyhow::Context;
use futures_util::{SinkExt, StreamExt};
use tokio::io::AsyncWriteExt;
use tokio::net::{UnixListener, UnixStream};
use sqlx::prelude::*;
use sqlx::MySqlConnection;
use crate::{
core::{
bootstrap::authenticated_unix_socket,
common::{UnixUser, DEFAULT_SOCKET_PATH},
protocol::request_response::{
create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream,
},
},
server::{
config::{create_mysql_connection_from_config, ServerConfig},
sql::{
database_operations::{create_databases, drop_databases, list_databases_for_user},
database_privilege_operations::{
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
},
user_operations::{
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
list_database_users, lock_database_users, set_password_for_database_user,
unlock_database_users,
},
},
},
};
// TODO: consider using a connection pool
// TODO: use tracing for login, so we can scope the log messages per incoming connection
pub async fn listen_for_incoming_connections(
socket_path: Option<PathBuf>,
config: ServerConfig,
// db_connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
let parent_directory = socket_path.parent().unwrap();
if !parent_directory.exists() {
println!("Creating directory {:?}", parent_directory);
fs::create_dir_all(parent_directory)?;
}
println!("Listening on {:?}", socket_path);
match fs::remove_file(socket_path.as_path()) {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => return Err(e.into()),
}
let listener = UnixListener::bind(socket_path)?;
while let Ok((mut conn, _addr)) = listener.accept().await {
let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await {
Ok(uid) => uid,
Err(e) => {
eprintln!("Failed to authenticate client: {}", e);
conn.shutdown().await?;
continue;
}
};
let unix_user = match UnixUser::from_uid(uid.into()) {
Ok(user) => user,
Err(e) => {
eprintln!("Failed to get UnixUser from uid: {}", e);
conn.shutdown().await?;
continue;
}
};
match handle_requests_for_single_session(conn, &unix_user, &config).await {
Ok(_) => {}
Err(e) => {
eprintln!("Failed to run server: {}", e);
}
}
}
Ok(())
}
pub async fn handle_requests_for_single_session(
socket: UnixStream,
unix_user: &UnixUser,
config: &ServerConfig,
) -> anyhow::Result<()> {
let message_stream = create_server_to_client_message_stream(socket);
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?;
let result = handle_requests_for_single_session_with_db_connection(
message_stream,
unix_user,
&mut db_connection,
)
.await;
if let Err(e) = db_connection
.close()
.await
.context("Failed to close connection properly")
{
eprintln!("{}", e);
eprintln!("Ignoring...");
}
result
}
// TODO: ensure proper db_connection hygiene for functions that invoke
// this function
pub async fn handle_requests_for_single_session_with_db_connection(
mut stream: ServerToClientMessageStream,
unix_user: &UnixUser,
db_connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
loop {
// TODO: better error handling
let request = match stream.next().await {
Some(Ok(request)) => request,
Some(Err(e)) => return Err(e.into()),
None => {
log::warn!("Client disconnected without sending an exit message");
break;
}
};
match request {
Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::CreateDatabases(result)).await?;
stream.flush().await?;
}
Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::DropDatabases(result)).await?;
stream.flush().await?;
}
Request::ListDatabases => {
let result = list_databases_for_user(unix_user, db_connection).await;
stream.send(Response::ListAllDatabases(result)).await?;
stream.flush().await?;
}
Request::ListPrivileges(database_names) => {
let response = match database_names {
Some(database_names) => {
let privilege_data =
get_databases_privilege_data(database_names, unix_user, db_connection)
.await;
Response::ListPrivileges(privilege_data)
}
None => {
let privilege_data =
get_all_database_privileges(unix_user, db_connection).await;
Response::ListAllPrivileges(privilege_data)
}
};
stream.send(response).await?;
stream.flush().await?;
}
Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs(
BTreeSet::from_iter(database_privilege_diffs),
unix_user,
db_connection,
)
.await;
stream.send(Response::ModifyPrivileges(result)).await?;
stream.flush().await?;
}
Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::CreateUsers(result)).await?;
stream.flush().await?;
}
Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::DropUsers(result)).await?;
stream.flush().await?;
}
Request::PasswdUser(db_user, password) => {
let result =
set_password_for_database_user(&db_user, &password, unix_user, db_connection)
.await;
stream.send(Response::PasswdUser(result)).await?;
stream.flush().await?;
}
Request::ListUsers(db_users) => {
let response = match db_users {
Some(db_users) => {
let result = list_database_users(db_users, unix_user, db_connection).await;
Response::ListUsers(result)
}
None => {
let result =
list_all_database_users_for_unix_user(unix_user, db_connection).await;
Response::ListAllUsers(result)
}
};
stream.send(response).await?;
stream.flush().await?;
}
Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::LockUsers(result)).await?;
stream.flush().await?;
}
Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::UnlockUsers(result)).await?;
stream.flush().await?;
}
Request::Exit => {
break;
}
}
}
Ok(())
}

3
src/server/sql.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod database_operations;
pub mod database_privilege_operations;
pub mod user_operations;

View File

@ -0,0 +1,165 @@
use crate::{
core::{
common::UnixUser,
protocol::{
CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput,
ListDatabasesError,
},
},
server::{
common::create_user_group_matching_regex,
input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
},
};
use sqlx::prelude::*;
use sqlx::MySqlConnection;
use std::collections::BTreeMap;
// NOTE: this function is unsafe because it does no input validation.
pub(super) async fn unsafe_database_exists(
database_name: &str,
connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> {
let result =
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
.bind(database_name)
.fetch_optional(connection)
.await?;
Ok(result.is_some())
}
pub async fn create_databases(
database_names: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> CreateDatabasesOutput {
let mut results = BTreeMap::new();
for database_name in database_names {
if let Err(err) = validate_name(&database_name) {
results.insert(
database_name.clone(),
Err(CreateDatabaseError::SanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
results.insert(
database_name.clone(),
Err(CreateDatabaseError::OwnershipError(err)),
);
continue;
}
match unsafe_database_exists(&database_name, &mut *connection).await {
Ok(true) => {
results.insert(
database_name.clone(),
Err(CreateDatabaseError::DatabaseAlreadyExists),
);
continue;
}
Err(err) => {
results.insert(
database_name.clone(),
Err(CreateDatabaseError::MySqlError(err.to_string())),
);
continue;
}
_ => {}
}
let result =
sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
results.insert(database_name, result);
}
results
}
pub async fn drop_databases(
database_names: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> DropDatabasesOutput {
let mut results = BTreeMap::new();
for database_name in database_names {
if let Err(err) = validate_name(&database_name) {
results.insert(
database_name.clone(),
Err(DropDatabaseError::SanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
results.insert(
database_name.clone(),
Err(DropDatabaseError::OwnershipError(err)),
);
continue;
}
match unsafe_database_exists(&database_name, &mut *connection).await {
Ok(false) => {
results.insert(
database_name.clone(),
Err(DropDatabaseError::DatabaseDoesNotExist),
);
continue;
}
Err(err) => {
results.insert(
database_name.clone(),
Err(DropDatabaseError::MySqlError(err.to_string())),
);
continue;
}
_ => {}
}
let result =
sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
results.insert(database_name, result);
}
results
}
pub async fn list_databases_for_user(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> Result<Vec<String>, ListDatabasesError> {
sqlx::query(
r#"
SELECT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?
"#,
)
.bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection)
.await
.and_then(|rows| {
rows.into_iter()
.map(|row| row.try_get::<String, _>("database"))
.collect::<Result<Vec<String>, sqlx::Error>>()
})
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
}

View File

@ -0,0 +1,452 @@
// TODO: fix comment
//! Database privilege operations
//!
//! This module contains functions for querying, modifying,
//! displaying and comparing database privileges.
//!
//! A lot of the complexity comes from two core components:
//!
//! - The privilege editor that needs to be able to print
//! an editable table of privileges and reparse the content
//! after the user has made manual changes.
//!
//! - The comparison functionality that tells the user what
//! changes will be made when applying a set of changes
//! to the list of database privileges.
use std::collections::{BTreeMap, BTreeSet};
use indoc::indoc;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
use crate::{
core::{
common::{rev_yn, yn, UnixUser},
database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff},
protocol::{
DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError,
GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError,
ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput,
},
},
server::{
common::create_user_group_matching_regex,
input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
sql::database_operations::unsafe_database_exists,
},
};
/// This is the list of fields that are used to fetch the db + user + privileges
/// from the `db` table in the database. If you need to add or remove privilege
/// fields, this is a good place to start.
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db",
"user",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
// NOTE: ord is needed for BTreeSet to accept the type, but it
// doesn't have any natural implementation semantics.
/// This struct represents the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow {
pub db: String,
pub user: String,
pub select_priv: bool,
pub insert_priv: bool,
pub update_priv: bool,
pub delete_priv: bool,
pub create_priv: bool,
pub drop_priv: bool,
pub alter_priv: bool,
pub index_priv: bool,
pub create_tmp_table_priv: bool,
pub lock_tables_priv: bool,
pub references_priv: bool,
}
impl DatabasePrivilegeRow {
pub fn get_privilege_by_name(&self, name: &str) -> bool {
match name {
"select_priv" => self.select_priv,
"insert_priv" => self.insert_priv,
"update_priv" => self.update_priv,
"delete_priv" => self.delete_priv,
"create_priv" => self.create_priv,
"drop_priv" => self.drop_priv,
"alter_priv" => self.alter_priv,
"index_priv" => self.index_priv,
"create_tmp_table_priv" => self.create_tmp_table_priv,
"lock_tables_priv" => self.lock_tables_priv,
"references_priv" => self.references_priv,
_ => false,
}
}
}
#[inline]
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
let field = DATABASE_PRIVILEGE_FIELDS[position];
let value = row.try_get(position)?;
match rev_yn(value) {
Some(val) => Ok(val),
_ => {
log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
Ok(false)
}
}
}
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
db: row.try_get("db")?,
user: row.try_get("user")?,
select_priv: get_mysql_row_priv_field(row, 2)?,
insert_priv: get_mysql_row_priv_field(row, 3)?,
update_priv: get_mysql_row_priv_field(row, 4)?,
delete_priv: get_mysql_row_priv_field(row, 5)?,
create_priv: get_mysql_row_priv_field(row, 6)?,
drop_priv: get_mysql_row_priv_field(row, 7)?,
alter_priv: get_mysql_row_priv_field(row, 8)?,
index_priv: get_mysql_row_priv_field(row, 9)?,
create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
references_priv: get_mysql_row_priv_field(row, 12)?,
})
}
}
// NOTE: this function is unsafe because it does no input validation.
/// Get all users + privileges for a single database.
async fn unsafe_get_database_privileges(
database_name: &str,
connection: &mut MySqlConnection,
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(database_name)
.fetch_all(connection)
.await
}
// NOTE: this function is unsafe because it does no input validation.
/// Get all users + privileges for a single database-user pair.
pub async fn unsafe_get_database_privileges_for_db_user_pair(
database_name: &str,
user_name: &str,
connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(database_name)
.bind(user_name)
.fetch_optional(connection)
.await
}
pub async fn get_databases_privilege_data(
database_names: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> GetDatabasesPrivilegeData {
let mut results = BTreeMap::new();
for database_name in database_names.iter() {
if let Err(err) = validate_name(database_name) {
results.insert(
database_name.clone(),
Err(GetDatabasesPrivilegeDataError::SanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) {
results.insert(
database_name.clone(),
Err(GetDatabasesPrivilegeDataError::OwnershipError(err)),
);
continue;
}
if !unsafe_database_exists(database_name, connection)
.await
.unwrap()
{
results.insert(
database_name.clone(),
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist),
);
continue;
}
let result = unsafe_get_database_privileges(database_name, connection)
.await
.map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string()));
results.insert(database_name.clone(), result);
}
debug_assert!(database_names.len() == results.len());
results
}
/// Get all database + user + privileges pairs that are owned by the current user.
pub async fn get_all_database_privileges(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> GetAllDatabasesPrivilegeData {
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?)
"#},
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection)
.await
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()))
}
async fn unsafe_apply_privilege_diff(
database_privilege_diff: &DatabasePrivilegesDiff,
connection: &mut MySqlConnection,
) -> Result<(), sqlx::Error> {
match database_privilege_diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(",");
let question_marks = std::iter::repeat("?")
.take(DATABASE_PRIVILEGE_FIELDS.len())
.join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db.to_string())
.bind(p.user.to_string())
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(connection)
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Modified(p) => {
let changes = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeChange::YesToNo(name) => {
format!("{} = 'N'", quote_identifier(name))
}
DatabasePrivilegeChange::NoToYes(name) => {
format!("{} = 'Y'", quote_identifier(name))
}
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(),
)
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
.await
.map(|_| ())
}
}
}
async fn validate_diff(
diff: &DatabasePrivilegesDiff,
connection: &mut MySqlConnection,
) -> Result<(), ModifyDatabasePrivilegesError> {
let privilege_row = unsafe_get_database_privileges_for_db_user_pair(
diff.get_database_name(),
diff.get_user_name(),
connection,
)
.await;
let privilege_row = match privilege_row {
Ok(privilege_row) => privilege_row,
Err(e) => return Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
};
let result = match diff {
DatabasePrivilegesDiff::New(_) => {
if privilege_row.is_some() {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowAlreadyExists(
diff.get_user_name().to_string(),
diff.get_database_name().to_string(),
),
))
} else {
Ok(())
}
}
DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowDoesNotExist(
diff.get_user_name().to_string(),
diff.get_database_name().to_string(),
),
))
}
DatabasePrivilegesDiff::Modified(row_diff) => {
let row = privilege_row.unwrap();
let error_exists = row_diff.diff.iter().any(|change| match change {
DatabasePrivilegeChange::YesToNo(name) => !row.get_privilege_by_name(name),
DatabasePrivilegeChange::NoToYes(name) => row.get_privilege_by_name(name),
});
if error_exists {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.clone(), row),
))
} else {
Ok(())
}
}
DatabasePrivilegesDiff::Deleted(_) => {
if privilege_row.is_none() {
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
DiffDoesNotApplyError::RowDoesNotExist(
diff.get_user_name().to_string(),
diff.get_database_name().to_string(),
),
))
} else {
Ok(())
}
}
};
result
}
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
pub async fn apply_privilege_diffs(
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> ModifyDatabasePrivilegesOutput {
let mut results: BTreeMap<(String, String), _> = BTreeMap::new();
for diff in database_privilege_diffs {
let key = (
diff.get_database_name().to_string(),
diff.get_user_name().to_string(),
);
if let Err(err) = validate_name(diff.get_database_name()) {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::DatabaseSanitizationError(
err,
)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::DatabaseOwnershipError(err)),
);
continue;
}
if let Err(err) = validate_name(diff.get_user_name()) {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::UserSanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::UserOwnershipError(err)),
);
continue;
}
if !unsafe_database_exists(diff.get_database_name(), connection)
.await
.unwrap()
{
results.insert(
key,
Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist),
);
continue;
}
if let Err(err) = validate_diff(&diff, connection).await {
results.insert(key, Err(err));
continue;
}
let result = unsafe_apply_privilege_diff(&diff, connection)
.await
.map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));
results.insert(key, result);
}
results
}

View File

@ -0,0 +1,375 @@
use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use sqlx::prelude::*;
use sqlx::MySqlConnection;
use crate::{
core::{
common::UnixUser,
protocol::{
CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError,
ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput,
SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput,
},
},
server::{
common::create_user_group_matching_regex,
input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user},
},
};
// NOTE: this function is unsafe because it does no input validation.
async fn unsafe_user_exists(
db_user: &str,
connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> {
sqlx::query(
r#"
SELECT EXISTS(
SELECT 1
FROM `mysql`.`user`
WHERE `User` = ?
)
"#,
)
.bind(db_user)
.fetch_one(connection)
.await
.map(|row| row.get::<bool, _>(0))
}
pub async fn create_database_users(
db_users: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> CreateUsersOutput {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_name(&db_user) {
results.insert(db_user, Err(CreateUserError::SanitizationError(err)));
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
results.insert(db_user, Err(CreateUserError::OwnershipError(err)));
continue;
}
match unsafe_user_exists(&db_user, &mut *connection).await {
Ok(true) => {
results.insert(db_user, Err(CreateUserError::UserAlreadyExists));
continue;
}
Err(err) => {
results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string())));
continue;
}
_ => {}
}
let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str())
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| CreateUserError::MySqlError(err.to_string()));
results.insert(db_user, result);
}
results
}
pub async fn drop_database_users(
db_users: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> DropUsersOutput {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_name(&db_user) {
results.insert(db_user, Err(DropUserError::SanitizationError(err)));
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
results.insert(db_user, Err(DropUserError::OwnershipError(err)));
continue;
}
match unsafe_user_exists(&db_user, &mut *connection).await {
Ok(false) => {
results.insert(db_user, Err(DropUserError::UserDoesNotExist));
continue;
}
Err(err) => {
results.insert(db_user, Err(DropUserError::MySqlError(err.to_string())));
continue;
}
_ => {}
}
let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str())
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| DropUserError::MySqlError(err.to_string()));
results.insert(db_user, result);
}
results
}
pub async fn set_password_for_database_user(
db_user: &str,
password: &str,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> SetPasswordOutput {
if let Err(err) = validate_name(db_user) {
return Err(SetPasswordError::SanitizationError(err));
}
if let Err(err) = validate_ownership_by_unix_user(db_user, unix_user) {
return Err(SetPasswordError::OwnershipError(err));
}
match unsafe_user_exists(db_user, &mut *connection).await {
Ok(false) => return Err(SetPasswordError::UserDoesNotExist),
Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())),
_ => {}
}
sqlx::query(
format!(
"ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user),
quote_literal(password).as_str()
)
.as_str(),
)
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
}
// NOTE: this function is unsafe because it does no input validation.
async fn database_user_is_locked_unsafe(
db_user: &str,
connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> {
sqlx::query(
r#"
SELECT COALESCE(
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
'false'
) != 'false'
FROM `mysql`.`global_priv`
WHERE `User` = ?
AND `Host` = '%'
"#,
)
.bind(db_user)
.fetch_one(connection)
.await
.map(|row| row.get::<bool, _>(0))
}
pub async fn lock_database_users(
db_users: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> LockUsersOutput {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_name(&db_user) {
results.insert(db_user, Err(LockUserError::SanitizationError(err)));
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
results.insert(db_user, Err(LockUserError::OwnershipError(err)));
continue;
}
match unsafe_user_exists(&db_user, &mut *connection).await {
Ok(true) => {}
Ok(false) => {
results.insert(db_user, Err(LockUserError::UserDoesNotExist));
continue;
}
Err(err) => {
results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
continue;
}
}
match database_user_is_locked_unsafe(&db_user, &mut *connection).await {
Ok(false) => {}
Ok(true) => {
results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked));
continue;
}
Err(err) => {
results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
continue;
}
}
let result = sqlx::query(
format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(),
)
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| LockUserError::MySqlError(err.to_string()));
results.insert(db_user, result);
}
results
}
pub async fn unlock_database_users(
db_users: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> UnlockUsersOutput {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_name(&db_user) {
results.insert(db_user, Err(UnlockUserError::SanitizationError(err)));
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
results.insert(db_user, Err(UnlockUserError::OwnershipError(err)));
continue;
}
match unsafe_user_exists(&db_user, &mut *connection).await {
Ok(false) => {
results.insert(db_user, Err(UnlockUserError::UserDoesNotExist));
continue;
}
Err(err) => {
results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
continue;
}
_ => {}
}
match database_user_is_locked_unsafe(&db_user, &mut *connection).await {
Ok(false) => {
results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked));
continue;
}
Err(err) => {
results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
continue;
}
_ => {}
}
let result = sqlx::query(
format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(),
)
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
results.insert(db_user, result);
}
results
}
/// This struct contains information about a database user.
/// This can be extended if we need more information in the future.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)]
pub struct DatabaseUser {
#[sqlx(rename = "User")]
pub user: String,
#[allow(dead_code)]
#[serde(skip)]
#[sqlx(rename = "Host")]
pub host: String,
#[sqlx(rename = "has_password")]
pub has_password: bool,
#[sqlx(rename = "is_locked")]
pub is_locked: bool,
}
const DB_USER_SELECT_STATEMENT: &str = r#"
SELECT
`mysql`.`user`.`User`,
`mysql`.`user`.`Host`,
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
COALESCE(
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
'false'
) != 'false' AS `is_locked`
FROM `mysql`.`user`
JOIN `mysql`.`global_priv` ON
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
"#;
pub async fn list_database_users(
db_users: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> ListUsersOutput {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_name(&db_user) {
results.insert(db_user, Err(ListUsersError::SanitizationError(err)));
continue;
}
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
results.insert(db_user, Err(ListUsersError::OwnershipError(err)));
continue;
}
let result = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
)
.bind(&db_user)
.fetch_optional(&mut *connection)
.await;
match result {
Ok(Some(user)) => results.insert(db_user, Ok(user)),
Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)),
Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))),
};
}
results
}
pub async fn list_all_database_users_for_unix_user(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> ListAllUsersOutput {
sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
)
.bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection)
.await
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()))
}