WIP
This commit is contained in:
parent
2c8c16e5bf
commit
72eeb148b8
|
@ -418,6 +418,27 @@ dependencies = [
|
||||||
"zeroize",
|
"zeroize",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_more"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05"
|
||||||
|
dependencies = [
|
||||||
|
"derive_more-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_more-impl"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
"unicode-xid",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dialoguer"
|
name = "dialoguer"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
@ -993,6 +1014,7 @@ dependencies = [
|
||||||
"async-bincode",
|
"async-bincode",
|
||||||
"bincode",
|
"bincode",
|
||||||
"clap",
|
"clap",
|
||||||
|
"derive_more",
|
||||||
"dialoguer",
|
"dialoguer",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"futures",
|
"futures",
|
||||||
|
@ -2130,6 +2152,12 @@ version = "0.1.11"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
|
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-xid"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode_categories"
|
name = "unicode_categories"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
|
|
|
@ -8,6 +8,7 @@ anyhow = "1.0.82"
|
||||||
async-bincode = "0.7.2"
|
async-bincode = "0.7.2"
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
clap = { version = "4.5.4", features = ["derive"] }
|
clap = { version = "4.5.4", features = ["derive"] }
|
||||||
|
derive_more = { version = "1.0.0", features = ["display", "error"] }
|
||||||
dialoguer = "0.11.0"
|
dialoguer = "0.11.0"
|
||||||
env_logger = "0.11.3"
|
env_logger = "0.11.3"
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
pub mod database_command;
|
pub mod database_command;
|
||||||
pub mod mysql_admutils_compatibility;
|
// pub mod mysql_admutils_compatibility;
|
||||||
|
mod common;
|
||||||
pub mod user_command;
|
pub mod user_command;
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
use crate::server::Response;
|
||||||
|
|
||||||
|
/// This enum is used to differentiate between database and user operations.
|
||||||
|
/// Their output are very similar, but there are slight differences in the words used.
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub enum DbOrUser {
|
||||||
|
Database,
|
||||||
|
User,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DbOrUser {
|
||||||
|
pub fn lowercased(&self) -> String {
|
||||||
|
match self {
|
||||||
|
DbOrUser::Database => "database".to_string(),
|
||||||
|
DbOrUser::User => "user".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn capitalized(&self) -> String {
|
||||||
|
match self {
|
||||||
|
DbOrUser::Database => "Database".to_string(),
|
||||||
|
DbOrUser::User => "User".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn erroneous_server_response(
|
||||||
|
response: Option<Result<Response, std::io::Error>>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
match response {
|
||||||
|
Some(Ok(Response::Error(e))) => {
|
||||||
|
anyhow::bail!("Error from server: {}", e);
|
||||||
|
}
|
||||||
|
Some(Err(e)) => {
|
||||||
|
anyhow::bail!(e);
|
||||||
|
}
|
||||||
|
Some(response) => {
|
||||||
|
anyhow::bail!("Unexpected response from server: {:?}", response);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
anyhow::bail!("No response from server");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,17 +1,18 @@
|
||||||
use anyhow::Context;
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use dialoguer::{Confirm, Editor};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use prettytable::{Cell, Row, Table};
|
use prettytable::{Cell, Row, Table};
|
||||||
use sqlx::{Connection, MySqlConnection};
|
|
||||||
|
|
||||||
use crate::core::{
|
use crate::{
|
||||||
common::{close_database_connection, get_current_unix_user, yn, CommandStatus},
|
core::{common::yn, database_privileges::db_priv_field_human_readable_name},
|
||||||
database_operations::*,
|
server::{
|
||||||
database_privilege_operations::*,
|
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
|
||||||
user_operations::user_exists,
|
protocol::ClientToServerMessageStream, Request, Response,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Parser)]
|
use super::common::erroneous_server_response;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
// #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))]
|
// #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))]
|
||||||
pub enum DatabaseCommand {
|
pub enum DatabaseCommand {
|
||||||
/// Create one or more databases
|
/// Create one or more databases
|
||||||
|
@ -86,28 +87,28 @@ pub enum DatabaseCommand {
|
||||||
EditDbPrivs(DatabaseEditPrivsArgs),
|
EditDbPrivs(DatabaseEditPrivsArgs),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct DatabaseCreateArgs {
|
pub struct DatabaseCreateArgs {
|
||||||
/// The name of the database(s) to create.
|
/// The name of the database(s) to create.
|
||||||
#[arg(num_args = 1..)]
|
#[arg(num_args = 1..)]
|
||||||
name: Vec<String>,
|
name: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct DatabaseDropArgs {
|
pub struct DatabaseDropArgs {
|
||||||
/// The name of the database(s) to drop.
|
/// The name of the database(s) to drop.
|
||||||
#[arg(num_args = 1..)]
|
#[arg(num_args = 1..)]
|
||||||
name: Vec<String>,
|
name: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct DatabaseListArgs {
|
pub struct DatabaseListArgs {
|
||||||
/// Whether to output the information in JSON format.
|
/// Whether to output the information in JSON format.
|
||||||
#[arg(short, long)]
|
#[arg(short, long)]
|
||||||
json: bool,
|
json: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct DatabaseShowPrivsArgs {
|
pub struct DatabaseShowPrivsArgs {
|
||||||
/// The name of the database(s) to show.
|
/// The name of the database(s) to show.
|
||||||
#[arg(num_args = 0..)]
|
#[arg(num_args = 0..)]
|
||||||
|
@ -118,7 +119,7 @@ pub struct DatabaseShowPrivsArgs {
|
||||||
json: bool,
|
json: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct DatabaseEditPrivsArgs {
|
pub struct DatabaseEditPrivsArgs {
|
||||||
/// The name of the database to edit privileges for.
|
/// The name of the database to edit privileges for.
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
@ -141,125 +142,156 @@ pub struct DatabaseEditPrivsArgs {
|
||||||
|
|
||||||
pub async fn handle_command(
|
pub async fn handle_command(
|
||||||
command: DatabaseCommand,
|
command: DatabaseCommand,
|
||||||
mut connection: MySqlConnection,
|
server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
let result = connection
|
|
||||||
.transaction(|txn| {
|
|
||||||
Box::pin(async move {
|
|
||||||
match command {
|
match command {
|
||||||
DatabaseCommand::CreateDb(args) => create_databases(args, txn).await,
|
DatabaseCommand::CreateDb(args) => create_databases(args, server_connection).await,
|
||||||
DatabaseCommand::DropDb(args) => drop_databases(args, txn).await,
|
DatabaseCommand::DropDb(args) => drop_databases(args, server_connection).await,
|
||||||
DatabaseCommand::ListDb(args) => list_databases(args, txn).await,
|
DatabaseCommand::ListDb(args) => list_databases(args, server_connection).await,
|
||||||
DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await,
|
DatabaseCommand::ShowDbPrivs(args) => {
|
||||||
DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await,
|
show_database_privileges(args, server_connection).await
|
||||||
|
}
|
||||||
|
DatabaseCommand::EditDbPrivs(_args) => todo!(),
|
||||||
}
|
}
|
||||||
})
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
close_database_connection(connection).await;
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_databases(
|
async fn create_databases(
|
||||||
args: DatabaseCreateArgs,
|
args: DatabaseCreateArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
if args.name.is_empty() {
|
if args.name.is_empty() {
|
||||||
anyhow::bail!("No database names provided");
|
anyhow::bail!("No database names provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = CommandStatus::SuccessfullyModified;
|
let message = Request::CreateDatabases(args.name.clone());
|
||||||
|
server_connection.send(message).await?;
|
||||||
|
|
||||||
for name in args.name {
|
let result = match server_connection.next().await {
|
||||||
// TODO: This can be optimized by fetching all the database privileges in one query.
|
Some(Ok(Response::CreateDatabases(result))) => result,
|
||||||
if let Err(e) = create_database(&name, connection).await {
|
response => return erroneous_server_response(response),
|
||||||
eprintln!("Failed to create database '{}': {}", name, e);
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
for (database_name, result) in result.iter() {
|
||||||
|
match result {
|
||||||
|
Ok(_) => println!("Database '{}' created.", database_name),
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
eprintln!("Skipping...");
|
eprintln!("Skipping...");
|
||||||
result = CommandStatus::PartiallySuccessfullyModified;
|
}
|
||||||
} else {
|
|
||||||
println!("Database '{}' created.", name);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn drop_databases(
|
async fn drop_databases(
|
||||||
args: DatabaseDropArgs,
|
args: DatabaseDropArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
if args.name.is_empty() {
|
if args.name.is_empty() {
|
||||||
anyhow::bail!("No database names provided");
|
anyhow::bail!("No database names provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = CommandStatus::SuccessfullyModified;
|
let message = Request::DropDatabases(args.name.clone());
|
||||||
|
server_connection.send(message).await?;
|
||||||
|
|
||||||
for name in args.name {
|
let result = match server_connection.next().await {
|
||||||
// TODO: This can be optimized by fetching all the database privileges in one query.
|
Some(Ok(Response::DropDatabases(result))) => result,
|
||||||
if let Err(e) = drop_database(&name, connection).await {
|
response => return erroneous_server_response(response),
|
||||||
eprintln!("Failed to drop database '{}': {}", name, e);
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
for (database_name, result) in result.iter() {
|
||||||
|
match result {
|
||||||
|
Ok(_) => println!("Database '{}' dropped.", database_name),
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
eprintln!("Skipping...");
|
eprintln!("Skipping...");
|
||||||
result = CommandStatus::PartiallySuccessfullyModified;
|
}
|
||||||
} else {
|
|
||||||
println!("Database '{}' dropped.", name);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list_databases(
|
async fn list_databases(
|
||||||
args: DatabaseListArgs,
|
args: DatabaseListArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
let databases = get_database_list(connection).await?;
|
let message = Request::ListDatabases;
|
||||||
|
server_connection.send(message).await?;
|
||||||
|
|
||||||
|
let result = match server_connection.next().await {
|
||||||
|
Some(Ok(Response::ListAllDatabases(result))) => result,
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
let database_list = match result {
|
||||||
|
Ok(list) => list,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if args.json {
|
if args.json {
|
||||||
println!("{}", serde_json::to_string_pretty(&databases)?);
|
println!("{}", serde_json::to_string_pretty(&database_list)?);
|
||||||
return Ok(CommandStatus::NoModificationsIntended);
|
} else if database_list.is_empty() {
|
||||||
}
|
|
||||||
|
|
||||||
if databases.is_empty() {
|
|
||||||
println!("No databases to show.");
|
println!("No databases to show.");
|
||||||
} else {
|
} else {
|
||||||
for db in databases {
|
for db in database_list {
|
||||||
println!("{}", db);
|
println!("{}", db);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(CommandStatus::NoModificationsIntended)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn show_database_privileges(
|
async fn show_database_privileges(
|
||||||
args: DatabaseShowPrivsArgs,
|
args: DatabaseShowPrivsArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
let database_users_to_show = if args.name.is_empty() {
|
let message = if args.name.is_empty() {
|
||||||
get_all_database_privileges(connection).await?
|
Request::ListPrivileges(None)
|
||||||
} else {
|
} else {
|
||||||
// TODO: This can be optimized by fetching all the database privileges in one query.
|
Request::ListPrivileges(Some(args.name.clone()))
|
||||||
let mut result = Vec::with_capacity(args.name.len());
|
};
|
||||||
for name in args.name {
|
server_connection.send(message).await?;
|
||||||
match get_database_privileges(&name, connection).await {
|
|
||||||
Ok(db) => result.extend(db),
|
let result = match server_connection.next().await {
|
||||||
Err(e) => {
|
Some(Ok(Response::ListPrivileges(databases))) => databases
|
||||||
eprintln!("Failed to show database '{}': {}", name, e);
|
.into_iter()
|
||||||
|
.filter_map(|(_db, result)| match result {
|
||||||
|
Ok(privileges) => Some(privileges),
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
eprintln!("Skipping...");
|
eprintln!("Skipping...");
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.flatten()
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
|
||||||
|
Ok(list) => list,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
result
|
response => return erroneous_server_response(response),
|
||||||
};
|
};
|
||||||
|
|
||||||
if args.json {
|
server_connection.send(Request::Exit).await?;
|
||||||
println!("{}", serde_json::to_string_pretty(&database_users_to_show)?);
|
|
||||||
return Ok(CommandStatus::NoModificationsIntended);
|
|
||||||
}
|
|
||||||
|
|
||||||
if database_users_to_show.is_empty() {
|
if args.json {
|
||||||
println!("No database users to show.");
|
println!("{}", serde_json::to_string_pretty(&result)?);
|
||||||
|
} else if result.is_empty() {
|
||||||
|
println!("No database privileges to show.");
|
||||||
} else {
|
} else {
|
||||||
let mut table = Table::new();
|
let mut table = Table::new();
|
||||||
table.add_row(Row::new(
|
table.add_row(Row::new(
|
||||||
|
@ -270,7 +302,7 @@ async fn show_database_privileges(
|
||||||
.collect(),
|
.collect(),
|
||||||
));
|
));
|
||||||
|
|
||||||
for row in database_users_to_show {
|
for row in result {
|
||||||
table.add_row(row![
|
table.add_row(row![
|
||||||
row.db,
|
row.db,
|
||||||
row.user,
|
row.user,
|
||||||
|
@ -290,101 +322,101 @@ async fn show_database_privileges(
|
||||||
table.printstd();
|
table.printstd();
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(CommandStatus::NoModificationsIntended)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn edit_privileges(
|
// pub async fn edit_privileges(
|
||||||
args: DatabaseEditPrivsArgs,
|
// args: DatabaseEditPrivsArgs,
|
||||||
connection: &mut MySqlConnection,
|
// connection: &mut MySqlConnection,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
// ) -> anyhow::Result<CommandStatus> {
|
||||||
let privilege_data = if let Some(name) = &args.name {
|
// let privilege_data = if let Some(name) = &args.name {
|
||||||
get_database_privileges(name, connection).await?
|
// get_database_privileges(name, connection).await?
|
||||||
} else {
|
// } else {
|
||||||
get_all_database_privileges(connection).await?
|
// get_all_database_privileges(connection).await?
|
||||||
};
|
// };
|
||||||
|
|
||||||
// TODO: The data from args should not be absolute.
|
// // TODO: The data from args should not be absolute.
|
||||||
// In the current implementation, the user would need to
|
// // In the current implementation, the user would need to
|
||||||
// provide all privileges for all users on all databases.
|
// // provide all privileges for all users on all databases.
|
||||||
// The intended effect is to modify the privileges which have
|
// // The intended effect is to modify the privileges which have
|
||||||
// matching users and databases, as well as add any
|
// // matching users and databases, as well as add any
|
||||||
// new db-user pairs. This makes it impossible to remove
|
// // new db-user pairs. This makes it impossible to remove
|
||||||
// privileges, but that is an issue for another day.
|
// // privileges, but that is an issue for another day.
|
||||||
let privileges_to_change = if !args.privs.is_empty() {
|
// let privileges_to_change = if !args.privs.is_empty() {
|
||||||
parse_privilege_tables_from_args(&args)?
|
// parse_privilege_tables_from_args(&args)?
|
||||||
} else {
|
// } else {
|
||||||
edit_privileges_with_editor(&privilege_data)?
|
// edit_privileges_with_editor(&privilege_data)?
|
||||||
};
|
// };
|
||||||
|
|
||||||
for row in privileges_to_change.iter() {
|
// for row in privileges_to_change.iter() {
|
||||||
if !user_exists(&row.user, connection).await? {
|
// if !user_exists(&row.user, connection).await? {
|
||||||
// TODO: allow user to return and correct their mistake
|
// // TODO: allow user to return and correct their mistake
|
||||||
anyhow::bail!("User {} does not exist", row.user);
|
// anyhow::bail!("User {} does not exist", row.user);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
let diffs = diff_privileges(&privilege_data, &privileges_to_change);
|
// let diffs = diff_privileges(&privilege_data, &privileges_to_change);
|
||||||
|
|
||||||
if diffs.is_empty() {
|
// if diffs.is_empty() {
|
||||||
println!("No changes to make.");
|
// println!("No changes to make.");
|
||||||
return Ok(CommandStatus::NoModificationsNeeded);
|
// return Ok(CommandStatus::NoModificationsNeeded);
|
||||||
}
|
// }
|
||||||
|
|
||||||
println!("The following changes will be made:\n");
|
// println!("The following changes will be made:\n");
|
||||||
println!("{}", display_privilege_diffs(&diffs));
|
// println!("{}", display_privilege_diffs(&diffs));
|
||||||
if !args.yes
|
// if !args.yes
|
||||||
&& !Confirm::new()
|
// && !Confirm::new()
|
||||||
.with_prompt("Do you want to apply these changes?")
|
// .with_prompt("Do you want to apply these changes?")
|
||||||
.default(false)
|
// .default(false)
|
||||||
.show_default(true)
|
// .show_default(true)
|
||||||
.interact()?
|
// .interact()?
|
||||||
{
|
// {
|
||||||
return Ok(CommandStatus::Cancelled);
|
// return Ok(CommandStatus::Cancelled);
|
||||||
}
|
// }
|
||||||
|
|
||||||
apply_privilege_diffs(diffs, connection).await?;
|
// apply_privilege_diffs(diffs, connection).await?;
|
||||||
|
|
||||||
Ok(CommandStatus::SuccessfullyModified)
|
// Ok(CommandStatus::SuccessfullyModified)
|
||||||
}
|
// }
|
||||||
|
|
||||||
pub fn parse_privilege_tables_from_args(
|
// pub fn parse_privilege_tables_from_args(
|
||||||
args: &DatabaseEditPrivsArgs,
|
// args: &DatabaseEditPrivsArgs,
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
// ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
||||||
debug_assert!(!args.privs.is_empty());
|
// debug_assert!(!args.privs.is_empty());
|
||||||
let result = if let Some(name) = &args.name {
|
// let result = if let Some(name) = &args.name {
|
||||||
args.privs
|
// args.privs
|
||||||
.iter()
|
// .iter()
|
||||||
.map(|p| {
|
// .map(|p| {
|
||||||
parse_privilege_table_cli_arg(&format!("{}:{}", name, &p))
|
// parse_privilege_table_cli_arg(&format!("{}:{}", name, &p))
|
||||||
.context(format!("Failed parsing database privileges: `{}`", &p))
|
// .context(format!("Failed parsing database privileges: `{}`", &p))
|
||||||
})
|
// })
|
||||||
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()?
|
// .collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()?
|
||||||
} else {
|
// } else {
|
||||||
args.privs
|
// args.privs
|
||||||
.iter()
|
// .iter()
|
||||||
.map(|p| {
|
// .map(|p| {
|
||||||
parse_privilege_table_cli_arg(p)
|
// parse_privilege_table_cli_arg(p)
|
||||||
.context(format!("Failed parsing database privileges: `{}`", &p))
|
// .context(format!("Failed parsing database privileges: `{}`", &p))
|
||||||
})
|
// })
|
||||||
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()?
|
// .collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()?
|
||||||
};
|
// };
|
||||||
Ok(result)
|
// Ok(result)
|
||||||
}
|
// }
|
||||||
|
|
||||||
pub fn edit_privileges_with_editor(
|
// pub fn edit_privileges_with_editor(
|
||||||
privilege_data: &[DatabasePrivilegeRow],
|
// privilege_data: &[DatabasePrivilegeRow],
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
// ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
||||||
let unix_user = get_current_unix_user()?;
|
// let unix_user = get_current_unix_user()?;
|
||||||
|
|
||||||
let editor_content =
|
// let editor_content =
|
||||||
generate_editor_content_from_privilege_data(privilege_data, &unix_user.name);
|
// generate_editor_content_from_privilege_data(privilege_data, &unix_user.name);
|
||||||
|
|
||||||
// TODO: handle errors better here
|
// // TODO: handle errors better here
|
||||||
let result = Editor::new()
|
// let result = Editor::new()
|
||||||
.extension("tsv")
|
// .extension("tsv")
|
||||||
.edit(&editor_content)?
|
// .edit(&editor_content)?
|
||||||
.unwrap();
|
// .unwrap();
|
||||||
|
|
||||||
parse_privilege_data_from_editor_content(result)
|
// parse_privilege_data_from_editor_content(result)
|
||||||
.context("Could not parse privilege data from editor")
|
// .context("Could not parse privilege data from editor")
|
||||||
}
|
// }
|
||||||
|
|
|
@ -1,27 +1,21 @@
|
||||||
use std::collections::BTreeMap;
|
|
||||||
use std::vec;
|
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use dialoguer::{Confirm, Password};
|
use dialoguer::{Confirm, Password};
|
||||||
use prettytable::Table;
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use serde_json::json;
|
|
||||||
use sqlx::{Connection, MySqlConnection};
|
|
||||||
|
|
||||||
use crate::core::{
|
use crate::server::protocol::ClientToServerMessageStream;
|
||||||
common::{close_database_connection, get_current_unix_user, CommandStatus},
|
use crate::server::{Request, Response};
|
||||||
database_operations::*,
|
|
||||||
user_operations::*,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
use super::common::erroneous_server_response;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserArgs {
|
pub struct UserArgs {
|
||||||
#[clap(subcommand)]
|
#[clap(subcommand)]
|
||||||
subcmd: UserCommand,
|
subcmd: UserCommand,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::enum_variant_names)]
|
#[allow(clippy::enum_variant_names)]
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub enum UserCommand {
|
pub enum UserCommand {
|
||||||
/// Create one or more users
|
/// Create one or more users
|
||||||
#[command()]
|
#[command()]
|
||||||
|
@ -50,7 +44,7 @@ pub enum UserCommand {
|
||||||
UnlockUser(UserUnlockArgs),
|
UnlockUser(UserUnlockArgs),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserCreateArgs {
|
pub struct UserCreateArgs {
|
||||||
#[arg(num_args = 1..)]
|
#[arg(num_args = 1..)]
|
||||||
username: Vec<String>,
|
username: Vec<String>,
|
||||||
|
@ -60,13 +54,13 @@ pub struct UserCreateArgs {
|
||||||
no_password: bool,
|
no_password: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserDeleteArgs {
|
pub struct UserDeleteArgs {
|
||||||
#[arg(num_args = 1..)]
|
#[arg(num_args = 1..)]
|
||||||
username: Vec<String>,
|
username: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserPasswdArgs {
|
pub struct UserPasswdArgs {
|
||||||
username: String,
|
username: String,
|
||||||
|
|
||||||
|
@ -74,7 +68,7 @@ pub struct UserPasswdArgs {
|
||||||
password_file: Option<String>,
|
password_file: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserShowArgs {
|
pub struct UserShowArgs {
|
||||||
#[arg(num_args = 0..)]
|
#[arg(num_args = 0..)]
|
||||||
username: Vec<String>,
|
username: Vec<String>,
|
||||||
|
@ -83,13 +77,13 @@ pub struct UserShowArgs {
|
||||||
json: bool,
|
json: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserLockArgs {
|
pub struct UserLockArgs {
|
||||||
#[arg(num_args = 1..)]
|
#[arg(num_args = 1..)]
|
||||||
username: Vec<String>,
|
username: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct UserUnlockArgs {
|
pub struct UserUnlockArgs {
|
||||||
#[arg(num_args = 1..)]
|
#[arg(num_args = 1..)]
|
||||||
username: Vec<String>,
|
username: Vec<String>,
|
||||||
|
@ -97,48 +91,47 @@ pub struct UserUnlockArgs {
|
||||||
|
|
||||||
pub async fn handle_command(
|
pub async fn handle_command(
|
||||||
command: UserCommand,
|
command: UserCommand,
|
||||||
mut connection: MySqlConnection,
|
server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
let result = connection
|
|
||||||
.transaction(|txn| {
|
|
||||||
Box::pin(async move {
|
|
||||||
match command {
|
match command {
|
||||||
UserCommand::CreateUser(args) => create_users(args, txn).await,
|
UserCommand::CreateUser(args) => create_users(args, server_connection).await,
|
||||||
UserCommand::DropUser(args) => drop_users(args, txn).await,
|
UserCommand::DropUser(args) => drop_users(args, server_connection).await,
|
||||||
UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await,
|
UserCommand::PasswdUser(args) => passwd_user(args, server_connection).await,
|
||||||
UserCommand::ShowUser(args) => show_users(args, txn).await,
|
UserCommand::ShowUser(args) => show_users(args, server_connection).await,
|
||||||
UserCommand::LockUser(args) => lock_users(args, txn).await,
|
UserCommand::LockUser(args) => lock_users(args, server_connection).await,
|
||||||
UserCommand::UnlockUser(args) => unlock_users(args, txn).await,
|
UserCommand::UnlockUser(args) => unlock_users(args, server_connection).await,
|
||||||
}
|
}
|
||||||
})
|
|
||||||
})
|
|
||||||
.await;
|
|
||||||
|
|
||||||
close_database_connection(connection).await;
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_users(
|
async fn create_users(
|
||||||
args: UserCreateArgs,
|
args: UserCreateArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
if args.username.is_empty() {
|
if args.username.is_empty() {
|
||||||
anyhow::bail!("No usernames provided");
|
anyhow::bail!("No usernames provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = CommandStatus::SuccessfullyModified;
|
let message = Request::CreateUsers(args.username.clone());
|
||||||
|
// TODO: better error handling
|
||||||
for username in args.username {
|
if let Err(err) = server_connection.send(message).await {
|
||||||
if let Err(e) = create_database_user(&username, connection).await {
|
server_connection.close().await.ok();
|
||||||
eprintln!("{}", e);
|
anyhow::bail!(err);
|
||||||
eprintln!("Skipping...\n");
|
|
||||||
result = CommandStatus::PartiallySuccessfullyModified;
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
println!("User '{}' created.", username);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let result = match server_connection.next().await {
|
||||||
|
Some(Ok(Response::CreateUsers(result))) => result,
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
let successfully_created_users = result
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(username, result)| match result {
|
||||||
|
Ok(_) => Some(username),
|
||||||
|
Err(_) => None,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for username in successfully_created_users {
|
||||||
if !args.no_password
|
if !args.no_password
|
||||||
&& Confirm::new()
|
&& Confirm::new()
|
||||||
.with_prompt(format!(
|
.with_prompt(format!(
|
||||||
|
@ -147,41 +140,65 @@ async fn create_users(
|
||||||
))
|
))
|
||||||
.interact()?
|
.interact()?
|
||||||
{
|
{
|
||||||
change_password_for_user(
|
let password = read_password_from_stdin_with_double_check(username)?;
|
||||||
UserPasswdArgs {
|
let message = Request::PasswdUser(username.clone(), password);
|
||||||
username,
|
|
||||||
password_file: None,
|
if let Err(err) = server_connection.send(message).await {
|
||||||
|
server_connection.close().await.ok();
|
||||||
|
anyhow::bail!(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
match server_connection.next().await {
|
||||||
|
Some(Ok(Response::PasswdUser(result))) => match result {
|
||||||
|
Ok(_) => println!("Password set for user '{}'", username),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("{:?}", e);
|
||||||
|
eprintln!("Skipping...\n");
|
||||||
|
}
|
||||||
},
|
},
|
||||||
connection,
|
response => return erroneous_server_response(response),
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
println!();
|
|
||||||
}
|
}
|
||||||
Ok(result)
|
}
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn drop_users(
|
async fn drop_users(
|
||||||
args: UserDeleteArgs,
|
args: UserDeleteArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
if args.username.is_empty() {
|
if args.username.is_empty() {
|
||||||
anyhow::bail!("No usernames provided");
|
anyhow::bail!("No usernames provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = CommandStatus::SuccessfullyModified;
|
let message = Request::DropUsers(args.username.clone());
|
||||||
|
|
||||||
for username in args.username {
|
if let Err(err) = server_connection.send(message).await {
|
||||||
if let Err(e) = delete_database_user(&username, connection).await {
|
server_connection.close().await.ok();
|
||||||
eprintln!("{}", e);
|
anyhow::bail!(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = match server_connection.next().await {
|
||||||
|
Some(Ok(Response::DropUsers(result))) => result,
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
for (username, result) in result.iter() {
|
||||||
|
match result {
|
||||||
|
Ok(_) => println!("User '{}' dropped.", username),
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
eprintln!("Skipping...");
|
eprintln!("Skipping...");
|
||||||
result = CommandStatus::PartiallySuccessfullyModified;
|
}
|
||||||
} else {
|
|
||||||
println!("User '{}' dropped.", username);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> {
|
pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> {
|
||||||
|
@ -195,15 +212,10 @@ pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Res
|
||||||
.map_err(Into::into)
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn change_password_for_user(
|
async fn passwd_user(
|
||||||
args: UserPasswdArgs,
|
args: UserPasswdArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
// NOTE: although this also is checked in `set_password_for_database_user`, we check it here
|
|
||||||
// to provide a more natural order of error messages.
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
validate_user_name(&args.username, &unix_user)?;
|
|
||||||
|
|
||||||
let password = if let Some(password_file) = args.password_file {
|
let password = if let Some(password_file) = args.password_file {
|
||||||
std::fs::read_to_string(password_file)
|
std::fs::read_to_string(password_file)
|
||||||
.context("Failed to read password file")?
|
.context("Failed to read password file")?
|
||||||
|
@ -213,129 +225,170 @@ async fn change_password_for_user(
|
||||||
read_password_from_stdin_with_double_check(&args.username)?
|
read_password_from_stdin_with_double_check(&args.username)?
|
||||||
};
|
};
|
||||||
|
|
||||||
set_password_for_database_user(&args.username, &password, connection).await?;
|
let message = Request::PasswdUser(args.username.clone(), password);
|
||||||
|
|
||||||
Ok(CommandStatus::SuccessfullyModified)
|
if let Err(err) = server_connection.send(message).await {
|
||||||
|
server_connection.close().await.ok();
|
||||||
|
anyhow::bail!(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = match server_connection.next().await {
|
||||||
|
Some(Ok(Response::PasswdUser(result))) => result,
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => println!("Password set for user '{}'", args.username),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("{:?}", e);
|
||||||
|
eprintln!("Skipping...\n");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn show_users(
|
async fn show_users(
|
||||||
args: UserShowArgs,
|
args: UserShowArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
let unix_user = get_current_unix_user()?;
|
let message = if args.username.is_empty() {
|
||||||
|
Request::ListUsers(None)
|
||||||
let users = if args.username.is_empty() {
|
|
||||||
get_all_database_users_for_unix_user(&unix_user, connection).await?
|
|
||||||
} else {
|
} else {
|
||||||
let mut result = vec![];
|
Request::ListUsers(Some(args.username.clone()))
|
||||||
for username in args.username {
|
|
||||||
if let Err(e) = validate_user_name(&username, &unix_user) {
|
|
||||||
eprintln!("{}", e);
|
|
||||||
eprintln!("Skipping...");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let user = get_database_user_for_user(&username, connection).await?;
|
|
||||||
if let Some(user) = user {
|
|
||||||
result.push(user);
|
|
||||||
} else {
|
|
||||||
eprintln!("User not found: {}", username);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut user_databases: BTreeMap<String, Vec<String>> = BTreeMap::new();
|
if let Err(err) = server_connection.send(message).await {
|
||||||
for user in users.iter() {
|
server_connection.close().await.ok();
|
||||||
user_databases.insert(
|
anyhow::bail!(err);
|
||||||
user.user.clone(),
|
|
||||||
get_databases_where_user_has_privileges(&user.user, connection).await?,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.json {
|
let users = match server_connection.next().await {
|
||||||
let users_json = users
|
Some(Ok(Response::ListUsers(users))) => users
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|user| {
|
.filter_map(|(_username, result)| match result {
|
||||||
json!({
|
Ok(user) => Some(user),
|
||||||
"user": user.user,
|
Err(err) => {
|
||||||
"has_password": user.has_password,
|
eprintln!("{:?}", err);
|
||||||
"is_locked": user.is_locked,
|
eprintln!("Skipping...");
|
||||||
"databases": user_databases.get(&user.user).unwrap_or(&vec![]),
|
None
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
.collect::<Vec<_>>(),
|
||||||
.collect::<serde_json::Value>();
|
Some(Ok(Response::ListAllUsers(users))) => match users {
|
||||||
|
Ok(users) => users,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
|
// TODO: close connection
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
// TODO: print erroneous users
|
||||||
|
// for user in users.iter()
|
||||||
|
|
||||||
|
// TODO: print databases where user has privileges
|
||||||
|
if args.json {
|
||||||
println!(
|
println!(
|
||||||
"{}",
|
"{}",
|
||||||
serde_json::to_string_pretty(&users_json)
|
serde_json::to_string_pretty(&users).context("Failed to serialize users to JSON")?
|
||||||
.context("Failed to serialize users to JSON")?
|
|
||||||
);
|
);
|
||||||
} else if users.is_empty() {
|
} else if users.is_empty() {
|
||||||
println!("No users found.");
|
println!("No users to show.");
|
||||||
} else {
|
} else {
|
||||||
let mut table = Table::new();
|
let mut table = prettytable::Table::new();
|
||||||
table.add_row(row![
|
table.add_row(row![
|
||||||
"User",
|
"User",
|
||||||
"Password is set",
|
"Password is set",
|
||||||
"Locked",
|
"Locked",
|
||||||
"Databases where user has privileges"
|
// "Databases where user has privileges"
|
||||||
]);
|
]);
|
||||||
for user in users {
|
for user in users {
|
||||||
table.add_row(row![
|
table.add_row(row![
|
||||||
user.user,
|
user.user,
|
||||||
user.has_password,
|
user.has_password,
|
||||||
user.is_locked,
|
user.is_locked,
|
||||||
user_databases.get(&user.user).unwrap_or(&vec![]).join("\n")
|
// user.databases.join("\n")
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
table.printstd();
|
table.printstd();
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(CommandStatus::NoModificationsIntended)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn lock_users(
|
async fn lock_users(
|
||||||
args: UserLockArgs,
|
args: UserLockArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
if args.username.is_empty() {
|
if args.username.is_empty() {
|
||||||
anyhow::bail!("No usernames provided");
|
anyhow::bail!("No usernames provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = CommandStatus::SuccessfullyModified;
|
let message = Request::LockUsers(args.username.clone());
|
||||||
|
|
||||||
for username in args.username {
|
if let Err(err) = server_connection.send(message).await {
|
||||||
if let Err(e) = lock_database_user(&username, connection).await {
|
server_connection.close().await.ok();
|
||||||
eprintln!("{}", e);
|
anyhow::bail!(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = match server_connection.next().await {
|
||||||
|
Some(Ok(Response::LockUsers(result))) => result,
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
for (username, result) in result.iter() {
|
||||||
|
match result {
|
||||||
|
Ok(_) => println!("User '{}' locked.", username),
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
eprintln!("Skipping...");
|
eprintln!("Skipping...");
|
||||||
result = CommandStatus::PartiallySuccessfullyModified;
|
}
|
||||||
} else {
|
|
||||||
println!("User '{}' locked.", username);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn unlock_users(
|
async fn unlock_users(
|
||||||
args: UserUnlockArgs,
|
args: UserUnlockArgs,
|
||||||
connection: &mut MySqlConnection,
|
mut server_connection: ClientToServerMessageStream,
|
||||||
) -> anyhow::Result<CommandStatus> {
|
) -> anyhow::Result<()> {
|
||||||
if args.username.is_empty() {
|
if args.username.is_empty() {
|
||||||
anyhow::bail!("No usernames provided");
|
anyhow::bail!("No usernames provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result = CommandStatus::SuccessfullyModified;
|
let message = Request::UnlockUsers(args.username.clone());
|
||||||
|
|
||||||
for username in args.username {
|
if let Err(err) = server_connection.send(message).await {
|
||||||
if let Err(e) = unlock_database_user(&username, connection).await {
|
server_connection.close().await.ok();
|
||||||
eprintln!("{}", e);
|
anyhow::bail!(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = match server_connection.next().await {
|
||||||
|
Some(Ok(Response::UnlockUsers(result))) => result,
|
||||||
|
response => return erroneous_server_response(response),
|
||||||
|
};
|
||||||
|
|
||||||
|
server_connection.send(Request::Exit).await?;
|
||||||
|
|
||||||
|
for (username, result) in result.iter() {
|
||||||
|
match result {
|
||||||
|
Ok(_) => println!("User '{}' unlocked.", username),
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("{:?}", err);
|
||||||
eprintln!("Skipping...");
|
eprintln!("Skipping...");
|
||||||
result = CommandStatus::PartiallySuccessfullyModified;
|
}
|
||||||
} else {
|
|
||||||
println!("User '{}' unlocked.", username);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,2 @@
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod config;
|
pub mod database_privileges;
|
||||||
pub mod database_operations;
|
|
||||||
pub mod database_privilege_operations;
|
|
||||||
pub mod user_operations;
|
|
||||||
|
|
|
@ -1,272 +1,3 @@
|
||||||
use anyhow::Context;
|
|
||||||
use indoc::indoc;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use nix::unistd::{getuid, Group, User};
|
|
||||||
use sqlx::{Connection, MySqlConnection};
|
|
||||||
|
|
||||||
#[cfg(not(target_os = "macos"))]
|
|
||||||
use std::ffi::CString;
|
|
||||||
|
|
||||||
/// Report the result status of a command.
|
|
||||||
/// This is used to display a status message to the user.
|
|
||||||
pub enum CommandStatus {
|
|
||||||
/// The command was successful,
|
|
||||||
/// and made modification to the database.
|
|
||||||
SuccessfullyModified,
|
|
||||||
|
|
||||||
/// The command was mostly successful,
|
|
||||||
/// and modifications have been made to the database.
|
|
||||||
/// However, some of the requested modifications failed.
|
|
||||||
PartiallySuccessfullyModified,
|
|
||||||
|
|
||||||
/// The command was successful,
|
|
||||||
/// but no modifications were needed.
|
|
||||||
NoModificationsNeeded,
|
|
||||||
|
|
||||||
/// The command was successful,
|
|
||||||
/// and made no modification to the database.
|
|
||||||
NoModificationsIntended,
|
|
||||||
|
|
||||||
/// The command was cancelled, either through a dialog or a signal.
|
|
||||||
/// No modifications have been made to the database.
|
|
||||||
Cancelled,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_current_unix_user() -> anyhow::Result<User> {
|
|
||||||
User::from_uid(getuid())
|
|
||||||
.context("Failed to look up your UNIX username")
|
|
||||||
.and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(target_os = "macos")]
|
|
||||||
pub fn get_unix_groups(_user: &User) -> anyhow::Result<Vec<Group>> {
|
|
||||||
// Return an empty list on macOS since there is no `getgrouplist` function
|
|
||||||
Ok(vec![])
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(not(target_os = "macos"))]
|
|
||||||
pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> {
|
|
||||||
let user_cstr =
|
|
||||||
CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?;
|
|
||||||
let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)?
|
|
||||||
.iter()
|
|
||||||
.filter_map(|gid| match Group::from_gid(*gid) {
|
|
||||||
Ok(Some(group)) => Some(group),
|
|
||||||
Ok(None) => None,
|
|
||||||
Err(e) => {
|
|
||||||
log::warn!(
|
|
||||||
"Failed to look up group with GID {}: {}\nIgnoring...",
|
|
||||||
gid,
|
|
||||||
e
|
|
||||||
);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<Group>>();
|
|
||||||
|
|
||||||
Ok(groups)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function creates a regex that matches items (users, databases)
|
|
||||||
/// that belong to the user or any of the user's groups.
|
|
||||||
pub fn create_user_group_matching_regex(user: &User) -> String {
|
|
||||||
let groups = get_unix_groups(user).unwrap_or_default();
|
|
||||||
|
|
||||||
if groups.is_empty() {
|
|
||||||
format!("{}(_.+)?", user.name)
|
|
||||||
} else {
|
|
||||||
format!(
|
|
||||||
"({}|{})(_.+)?",
|
|
||||||
user.name,
|
|
||||||
groups
|
|
||||||
.iter()
|
|
||||||
.map(|g| g.name.as_str())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("|")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This enum is used to differentiate between database and user operations.
|
|
||||||
/// Their output are very similar, but there are slight differences in the words used.
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
|
||||||
pub enum DbOrUser {
|
|
||||||
Database,
|
|
||||||
User,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DbOrUser {
|
|
||||||
pub fn lowercased(&self) -> String {
|
|
||||||
match self {
|
|
||||||
DbOrUser::Database => "database".to_string(),
|
|
||||||
DbOrUser::User => "user".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn capitalized(&self) -> String {
|
|
||||||
match self {
|
|
||||||
DbOrUser::Database => "Database".to_string(),
|
|
||||||
DbOrUser::User => "User".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
|
||||||
pub enum NameValidationResult {
|
|
||||||
Valid,
|
|
||||||
EmptyString,
|
|
||||||
InvalidCharacters,
|
|
||||||
TooLong,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate_name(name: &str) -> NameValidationResult {
|
|
||||||
if name.is_empty() {
|
|
||||||
NameValidationResult::EmptyString
|
|
||||||
} else if name.len() > 64 {
|
|
||||||
NameValidationResult::TooLong
|
|
||||||
} else if !name
|
|
||||||
.chars()
|
|
||||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
|
||||||
{
|
|
||||||
NameValidationResult::InvalidCharacters
|
|
||||||
} else {
|
|
||||||
NameValidationResult::Valid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
|
|
||||||
match validate_name(name) {
|
|
||||||
NameValidationResult::Valid => Ok(()),
|
|
||||||
NameValidationResult::EmptyString => {
|
|
||||||
anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
|
|
||||||
}
|
|
||||||
NameValidationResult::TooLong => anyhow::bail!(
|
|
||||||
"{} is too long. Maximum length is 64 characters.",
|
|
||||||
db_or_user.capitalized()
|
|
||||||
),
|
|
||||||
NameValidationResult::InvalidCharacters => anyhow::bail!(
|
|
||||||
indoc! {r#"
|
|
||||||
Invalid characters in {} name: '{}'
|
|
||||||
|
|
||||||
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
|
|
||||||
"#},
|
|
||||||
db_or_user.lowercased(),
|
|
||||||
name
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
|
||||||
pub enum OwnerValidationResult {
|
|
||||||
// The name is valid and matches one of the given prefixes
|
|
||||||
Match,
|
|
||||||
|
|
||||||
// The name is valid, but none of the given prefixes matched the name
|
|
||||||
NoMatch,
|
|
||||||
|
|
||||||
// The name is empty, which is invalid
|
|
||||||
StringEmpty,
|
|
||||||
|
|
||||||
// The name is in the format "_<postfix>", which is invalid
|
|
||||||
MissingPrefix,
|
|
||||||
|
|
||||||
// The name is in the format "<prefix>_", which is invalid
|
|
||||||
MissingPostfix,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Core logic for validating the ownership of a database name.
|
|
||||||
/// This function checks if the given name matches any of the given prefixes.
|
|
||||||
/// These prefixes will in most cases be the user's unix username and any
|
|
||||||
/// unix groups the user is a member of.
|
|
||||||
pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
|
|
||||||
if name.is_empty() {
|
|
||||||
return OwnerValidationResult::StringEmpty;
|
|
||||||
}
|
|
||||||
|
|
||||||
if name.starts_with('_') {
|
|
||||||
return OwnerValidationResult::MissingPrefix;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (prefix, _) = match name.split_once('_') {
|
|
||||||
Some(pair) => pair,
|
|
||||||
None => return OwnerValidationResult::MissingPostfix,
|
|
||||||
};
|
|
||||||
|
|
||||||
if prefixes.iter().any(|g| g == prefix) {
|
|
||||||
OwnerValidationResult::Match
|
|
||||||
} else {
|
|
||||||
OwnerValidationResult::NoMatch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate the ownership of a database name or database user name.
|
|
||||||
/// This function takes the name of a database or user and a unix user,
|
|
||||||
/// for which it fetches the user's groups. It then checks if the name
|
|
||||||
/// is prefixed with the user's username or any of the user's groups.
|
|
||||||
pub fn validate_ownership_or_error<'a>(
|
|
||||||
name: &'a str,
|
|
||||||
user: &User,
|
|
||||||
db_or_user: DbOrUser,
|
|
||||||
) -> anyhow::Result<&'a str> {
|
|
||||||
let user_groups = get_unix_groups(user)?;
|
|
||||||
let prefixes = std::iter::once(user.name.clone())
|
|
||||||
.chain(user_groups.iter().map(|g| g.name.clone()))
|
|
||||||
.collect::<Vec<String>>();
|
|
||||||
|
|
||||||
match validate_ownership_by_prefixes(name, &prefixes) {
|
|
||||||
OwnerValidationResult::Match => Ok(name),
|
|
||||||
OwnerValidationResult::NoMatch => {
|
|
||||||
anyhow::bail!(
|
|
||||||
indoc! {r#"
|
|
||||||
Invalid {} name prefix: '{}' does not match your username or any of your groups.
|
|
||||||
Are you sure you are allowed to create {} names with this prefix?
|
|
||||||
|
|
||||||
Allowed prefixes:
|
|
||||||
- {}
|
|
||||||
{}
|
|
||||||
"#},
|
|
||||||
db_or_user.lowercased(),
|
|
||||||
name,
|
|
||||||
db_or_user.lowercased(),
|
|
||||||
user.name,
|
|
||||||
user_groups
|
|
||||||
.iter()
|
|
||||||
.filter(|g| g.name != user.name)
|
|
||||||
.map(|g| format!(" - {}", g.name))
|
|
||||||
.sorted()
|
|
||||||
.join("\n"),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
_ => anyhow::bail!(
|
|
||||||
"'{}' is not a valid {} name.",
|
|
||||||
name,
|
|
||||||
db_or_user.lowercased()
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gracefully close a MySQL connection.
|
|
||||||
pub async fn close_database_connection(connection: MySqlConnection) {
|
|
||||||
if let Err(e) = connection
|
|
||||||
.close()
|
|
||||||
.await
|
|
||||||
.context("Failed to close connection properly")
|
|
||||||
{
|
|
||||||
eprintln!("{}", e);
|
|
||||||
eprintln!("Ignoring...");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn quote_literal(s: &str) -> String {
|
|
||||||
format!("'{}'", s.replace('\'', r"\'"))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn quote_identifier(s: &str) -> String {
|
|
||||||
format!("`{}`", s.replace('`', r"\`"))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
pub(crate) fn yn(b: bool) -> &'static str {
|
pub(crate) fn yn(b: bool) -> &'static str {
|
||||||
if b {
|
if b {
|
||||||
|
@ -303,94 +34,4 @@ mod test {
|
||||||
assert_eq!(rev_yn("n"), Some(false));
|
assert_eq!(rev_yn("n"), Some(false));
|
||||||
assert_eq!(rev_yn("X"), None);
|
assert_eq!(rev_yn("X"), None);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_quote_literal() {
|
|
||||||
let payload = "' OR 1=1 --";
|
|
||||||
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_quote_identifier() {
|
|
||||||
let payload = "` OR 1=1 --";
|
|
||||||
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_name() {
|
|
||||||
assert_eq!(validate_name(""), NameValidationResult::EmptyString);
|
|
||||||
assert_eq!(
|
|
||||||
validate_name("abcdefghijklmnopqrstuvwxyz"),
|
|
||||||
NameValidationResult::Valid
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
|
|
||||||
NameValidationResult::Valid
|
|
||||||
);
|
|
||||||
assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
|
|
||||||
|
|
||||||
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
|
|
||||||
assert_eq!(
|
|
||||||
validate_name(&c.to_string()),
|
|
||||||
NameValidationResult::InvalidCharacters
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
validate_name(&"a".repeat(65)),
|
|
||||||
NameValidationResult::TooLong
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_owner_by_prefixes() {
|
|
||||||
let prefixes = vec!["user".to_string(), "group".to_string()];
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("", &prefixes),
|
|
||||||
OwnerValidationResult::StringEmpty
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("user", &prefixes),
|
|
||||||
OwnerValidationResult::MissingPostfix
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("something", &prefixes),
|
|
||||||
OwnerValidationResult::MissingPostfix
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("user-testdb", &prefixes),
|
|
||||||
OwnerValidationResult::MissingPostfix
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("_testdb", &prefixes),
|
|
||||||
OwnerValidationResult::MissingPrefix
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("user_testdb", &prefixes),
|
|
||||||
OwnerValidationResult::Match
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("group_testdb", &prefixes),
|
|
||||||
OwnerValidationResult::Match
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("group_test_db", &prefixes),
|
|
||||||
OwnerValidationResult::Match
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("group_test-db", &prefixes),
|
|
||||||
OwnerValidationResult::Match
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
|
|
||||||
OwnerValidationResult::NoMatch
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,120 +0,0 @@
|
||||||
use anyhow::Context;
|
|
||||||
use indoc::formatdoc;
|
|
||||||
use itertools::Itertools;
|
|
||||||
use nix::unistd::User;
|
|
||||||
use sqlx::{prelude::*, MySqlConnection};
|
|
||||||
|
|
||||||
use crate::core::{
|
|
||||||
common::{
|
|
||||||
create_user_group_matching_regex, get_current_unix_user, quote_identifier,
|
|
||||||
validate_name_or_error, validate_ownership_or_error, DbOrUser,
|
|
||||||
},
|
|
||||||
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
|
|
||||||
let user = get_current_unix_user()?;
|
|
||||||
validate_database_name(name, &user)?;
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
|
|
||||||
sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name)))
|
|
||||||
.execute(connection)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
if e.to_string().contains("database exists") {
|
|
||||||
anyhow::anyhow!("Database '{}' already exists", name)
|
|
||||||
} else {
|
|
||||||
e.into()
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
|
|
||||||
let user = get_current_unix_user()?;
|
|
||||||
validate_database_name(name, &user)?;
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
|
|
||||||
sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name)))
|
|
||||||
.execute(connection)
|
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
if e.to_string().contains("doesn't exist") {
|
|
||||||
anyhow::anyhow!("Database '{}' does not exist", name)
|
|
||||||
} else {
|
|
||||||
e.into()
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result<Vec<String>> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
let databases: Vec<String> = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT `SCHEMA_NAME` AS `database`
|
|
||||||
FROM `information_schema`.`SCHEMATA`
|
|
||||||
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
|
||||||
AND `SCHEMA_NAME` REGEXP ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(create_user_group_matching_regex(&unix_user))
|
|
||||||
.fetch_all(connection)
|
|
||||||
.await
|
|
||||||
.and_then(|row| {
|
|
||||||
row.into_iter()
|
|
||||||
.map(|row| row.try_get::<String, _>("database"))
|
|
||||||
.collect::<Result<_, _>>()
|
|
||||||
})
|
|
||||||
.context(format!(
|
|
||||||
"Failed to get databases for user '{}'",
|
|
||||||
unix_user.name
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(databases)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_databases_where_user_has_privileges(
|
|
||||||
username: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<Vec<String>> {
|
|
||||||
let result = sqlx::query(
|
|
||||||
formatdoc!(
|
|
||||||
r#"
|
|
||||||
SELECT `db` AS `database`
|
|
||||||
FROM `db`
|
|
||||||
WHERE `user` = ?
|
|
||||||
AND ({})
|
|
||||||
"#,
|
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.iter()
|
|
||||||
.map(|field| format!("`{}` = 'Y'", field))
|
|
||||||
.join(" OR "),
|
|
||||||
)
|
|
||||||
.as_str(),
|
|
||||||
)
|
|
||||||
.bind(username)
|
|
||||||
.fetch_all(connection)
|
|
||||||
.await?
|
|
||||||
.into_iter()
|
|
||||||
.map(|databases| databases.try_get::<String, _>("database").unwrap())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// NOTE: It is very critical that this function validates the database name
|
|
||||||
/// properly. MySQL does not seem to allow for prepared statements, binding
|
|
||||||
/// the database name as a parameter to the query. This means that we have
|
|
||||||
/// to validate the database name ourselves to prevent SQL injection.
|
|
||||||
pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
|
|
||||||
validate_name_or_error(name, DbOrUser::Database)
|
|
||||||
.context(format!("Invalid database name: '{}'", name))?;
|
|
||||||
validate_ownership_or_error(name, user, DbOrUser::Database)
|
|
||||||
.context(format!("Invalid database name: '{}'", name))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
|
@ -1,53 +1,14 @@
|
||||||
//! Database privilege operations
|
|
||||||
//!
|
|
||||||
//! This module contains functions for querying, modifying,
|
|
||||||
//! displaying and comparing database privileges.
|
|
||||||
//!
|
|
||||||
//! A lot of the complexity comes from two core components:
|
|
||||||
//!
|
|
||||||
//! - The privilege editor that needs to be able to print
|
|
||||||
//! an editable table of privileges and reparse the content
|
|
||||||
//! after the user has made manual changes.
|
|
||||||
//!
|
|
||||||
//! - The comparison functionality that tells the user what
|
|
||||||
//! changes will be made when applying a set of changes
|
|
||||||
//! to the list of database privileges.
|
|
||||||
|
|
||||||
use std::collections::{BTreeSet, HashMap};
|
|
||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use indoc::indoc;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use prettytable::Table;
|
use prettytable::Table;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
|
use std::collections::{BTreeSet, HashMap};
|
||||||
|
|
||||||
use crate::core::{
|
use super::common::{rev_yn, yn};
|
||||||
common::{
|
use crate::server::database_privilege_operations::{
|
||||||
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn,
|
DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS,
|
||||||
},
|
|
||||||
database_operations::validate_database_name,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This is the list of fields that are used to fetch the db + user + privileges
|
|
||||||
/// from the `db` table in the database. If you need to add or remove privilege
|
|
||||||
/// fields, this is a good place to start.
|
|
||||||
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
|
|
||||||
"db",
|
|
||||||
"user",
|
|
||||||
"select_priv",
|
|
||||||
"insert_priv",
|
|
||||||
"update_priv",
|
|
||||||
"delete_priv",
|
|
||||||
"create_priv",
|
|
||||||
"drop_priv",
|
|
||||||
"alter_priv",
|
|
||||||
"index_priv",
|
|
||||||
"create_tmp_table_priv",
|
|
||||||
"lock_tables_priv",
|
|
||||||
"references_priv",
|
|
||||||
];
|
|
||||||
|
|
||||||
pub fn db_priv_field_human_readable_name(name: &str) -> String {
|
pub fn db_priv_field_human_readable_name(name: &str) -> String {
|
||||||
match name {
|
match name {
|
||||||
"db" => "Database".to_owned(),
|
"db" => "Database".to_owned(),
|
||||||
|
@ -67,163 +28,25 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This struct represents the set of privileges for a single user on a single database.
|
pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
debug_assert!(row1.db == row2.db && row1.user == row2.user);
|
||||||
pub struct DatabasePrivilegeRow {
|
|
||||||
pub db: String,
|
|
||||||
pub user: String,
|
|
||||||
pub select_priv: bool,
|
|
||||||
pub insert_priv: bool,
|
|
||||||
pub update_priv: bool,
|
|
||||||
pub delete_priv: bool,
|
|
||||||
pub create_priv: bool,
|
|
||||||
pub drop_priv: bool,
|
|
||||||
pub alter_priv: bool,
|
|
||||||
pub index_priv: bool,
|
|
||||||
pub create_tmp_table_priv: bool,
|
|
||||||
pub lock_tables_priv: bool,
|
|
||||||
pub references_priv: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DatabasePrivilegeRow {
|
|
||||||
pub fn empty(db: &str, user: &str) -> Self {
|
|
||||||
Self {
|
|
||||||
db: db.to_owned(),
|
|
||||||
user: user.to_owned(),
|
|
||||||
select_priv: false,
|
|
||||||
insert_priv: false,
|
|
||||||
update_priv: false,
|
|
||||||
delete_priv: false,
|
|
||||||
create_priv: false,
|
|
||||||
drop_priv: false,
|
|
||||||
alter_priv: false,
|
|
||||||
index_priv: false,
|
|
||||||
create_tmp_table_priv: false,
|
|
||||||
lock_tables_priv: false,
|
|
||||||
references_priv: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_privilege_by_name(&self, name: &str) -> bool {
|
|
||||||
match name {
|
|
||||||
"select_priv" => self.select_priv,
|
|
||||||
"insert_priv" => self.insert_priv,
|
|
||||||
"update_priv" => self.update_priv,
|
|
||||||
"delete_priv" => self.delete_priv,
|
|
||||||
"create_priv" => self.create_priv,
|
|
||||||
"drop_priv" => self.drop_priv,
|
|
||||||
"alter_priv" => self.alter_priv,
|
|
||||||
"index_priv" => self.index_priv,
|
|
||||||
"create_tmp_table_priv" => self.create_tmp_table_priv,
|
|
||||||
"lock_tables_priv" => self.lock_tables_priv,
|
|
||||||
"references_priv" => self.references_priv,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
|
|
||||||
debug_assert!(self.db == other.db && self.user == other.user);
|
|
||||||
|
|
||||||
DatabasePrivilegeRowDiff {
|
DatabasePrivilegeRowDiff {
|
||||||
db: self.db.clone(),
|
db: row1.db.clone(),
|
||||||
user: self.user.clone(),
|
user: row1.user.clone(),
|
||||||
diff: DATABASE_PRIVILEGE_FIELDS
|
diff: DATABASE_PRIVILEGE_FIELDS
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.skip(2)
|
.skip(2)
|
||||||
.filter_map(|field| {
|
.filter_map(|field| {
|
||||||
DatabasePrivilegeChange::new(
|
DatabasePrivilegeChange::new(
|
||||||
self.get_privilege_by_name(field),
|
row1.get_privilege_by_name(field),
|
||||||
other.get_privilege_by_name(field),
|
row2.get_privilege_by_name(field),
|
||||||
field,
|
field,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
|
|
||||||
let field = DATABASE_PRIVILEGE_FIELDS[position];
|
|
||||||
let value = row.try_get(position)?;
|
|
||||||
match rev_yn(value) {
|
|
||||||
Some(val) => Ok(val),
|
|
||||||
_ => {
|
|
||||||
log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
|
|
||||||
Ok(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
|
|
||||||
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
|
|
||||||
Ok(Self {
|
|
||||||
db: row.try_get("db")?,
|
|
||||||
user: row.try_get("user")?,
|
|
||||||
select_priv: get_mysql_row_priv_field(row, 2)?,
|
|
||||||
insert_priv: get_mysql_row_priv_field(row, 3)?,
|
|
||||||
update_priv: get_mysql_row_priv_field(row, 4)?,
|
|
||||||
delete_priv: get_mysql_row_priv_field(row, 5)?,
|
|
||||||
create_priv: get_mysql_row_priv_field(row, 6)?,
|
|
||||||
drop_priv: get_mysql_row_priv_field(row, 7)?,
|
|
||||||
alter_priv: get_mysql_row_priv_field(row, 8)?,
|
|
||||||
index_priv: get_mysql_row_priv_field(row, 9)?,
|
|
||||||
create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
|
|
||||||
lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
|
|
||||||
references_priv: get_mysql_row_priv_field(row, 12)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get all users + privileges for a single database.
|
|
||||||
pub async fn get_database_privileges(
|
|
||||||
database_name: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
validate_database_name(database_name, &unix_user)?;
|
|
||||||
|
|
||||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
|
||||||
"SELECT {} FROM `db` WHERE `db` = ?",
|
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.iter()
|
|
||||||
.map(|field| quote_identifier(field))
|
|
||||||
.join(","),
|
|
||||||
))
|
|
||||||
.bind(database_name)
|
|
||||||
.fetch_all(connection)
|
|
||||||
.await
|
|
||||||
.context("Failed to show database")?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get all database + user + privileges pairs that are owned by the current user.
|
|
||||||
pub async fn get_all_database_privileges(
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
|
||||||
indoc! {r#"
|
|
||||||
SELECT {} FROM `db` WHERE `db` IN
|
|
||||||
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
|
||||||
FROM `information_schema`.`SCHEMATA`
|
|
||||||
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
|
||||||
AND `SCHEMA_NAME` REGEXP ?)
|
|
||||||
"#},
|
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.iter()
|
|
||||||
.map(|field| format!("`{field}`"))
|
|
||||||
.join(","),
|
|
||||||
))
|
|
||||||
.bind(create_user_group_matching_regex(&unix_user))
|
|
||||||
.fetch_all(connection)
|
|
||||||
.await
|
|
||||||
.context("Failed to show databases")?;
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*************************/
|
/*************************/
|
||||||
/* CLI INTERFACE PARSING */
|
/* CLI INTERFACE PARSING */
|
||||||
|
@ -578,7 +401,7 @@ pub fn parse_privilege_data_from_editor_content(
|
||||||
/// instances of privilege sets for a single user on a single database.
|
/// instances of privilege sets for a single user on a single database.
|
||||||
///
|
///
|
||||||
/// The `User` and `Database` are the same for both instances.
|
/// The `User` and `Database` are the same for both instances.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||||
pub struct DatabasePrivilegeRowDiff {
|
pub struct DatabasePrivilegeRowDiff {
|
||||||
pub db: String,
|
pub db: String,
|
||||||
pub user: String,
|
pub user: String,
|
||||||
|
@ -586,7 +409,7 @@ pub struct DatabasePrivilegeRowDiff {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This enum represents a change for a single privilege.
|
/// This enum represents a change for a single privilege.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||||
pub enum DatabasePrivilegeChange {
|
pub enum DatabasePrivilegeChange {
|
||||||
YesToNo(String),
|
YesToNo(String),
|
||||||
NoToYes(String),
|
NoToYes(String),
|
||||||
|
@ -603,13 +426,31 @@ impl DatabasePrivilegeChange {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
|
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||||
pub enum DatabasePrivilegesDiff {
|
pub enum DatabasePrivilegesDiff {
|
||||||
New(DatabasePrivilegeRow),
|
New(DatabasePrivilegeRow),
|
||||||
Modified(DatabasePrivilegeRowDiff),
|
Modified(DatabasePrivilegeRowDiff),
|
||||||
Deleted(DatabasePrivilegeRow),
|
Deleted(DatabasePrivilegeRow),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl DatabasePrivilegesDiff {
|
||||||
|
pub fn get_database_name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
DatabasePrivilegesDiff::New(p) => &p.db,
|
||||||
|
DatabasePrivilegesDiff::Modified(p) => &p.db,
|
||||||
|
DatabasePrivilegesDiff::Deleted(p) => &p.db,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_user_name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
DatabasePrivilegesDiff::New(p) => &p.user,
|
||||||
|
DatabasePrivilegesDiff::Modified(p) => &p.user,
|
||||||
|
DatabasePrivilegesDiff::Deleted(p) => &p.user,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// This function calculates the differences between two sets of database privileges.
|
/// This function calculates the differences between two sets of database privileges.
|
||||||
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
|
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
|
||||||
/// apply a set of privilege modifications to the database.
|
/// apply a set of privilege modifications to the database.
|
||||||
|
@ -633,7 +474,7 @@ pub fn diff_privileges(
|
||||||
|
|
||||||
for p in to {
|
for p in to {
|
||||||
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
|
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
|
||||||
let diff = old_p.diff(p);
|
let diff = diff(old_p, p);
|
||||||
if !diff.diff.is_empty() {
|
if !diff.diff.is_empty() {
|
||||||
result.insert(DatabasePrivilegesDiff::Modified(diff));
|
result.insert(DatabasePrivilegesDiff::Modified(diff));
|
||||||
}
|
}
|
||||||
|
@ -651,72 +492,6 @@ pub fn diff_privileges(
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
|
||||||
pub async fn apply_privilege_diffs(
|
|
||||||
diffs: BTreeSet<DatabasePrivilegesDiff>,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
for diff in diffs {
|
|
||||||
match diff {
|
|
||||||
DatabasePrivilegesDiff::New(p) => {
|
|
||||||
let tables = DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.iter()
|
|
||||||
.map(|field| format!("`{field}`"))
|
|
||||||
.join(",");
|
|
||||||
|
|
||||||
let question_marks = std::iter::repeat("?")
|
|
||||||
.take(DATABASE_PRIVILEGE_FIELDS.len())
|
|
||||||
.join(",");
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
|
|
||||||
)
|
|
||||||
.bind(p.db)
|
|
||||||
.bind(p.user)
|
|
||||||
.bind(yn(p.select_priv))
|
|
||||||
.bind(yn(p.insert_priv))
|
|
||||||
.bind(yn(p.update_priv))
|
|
||||||
.bind(yn(p.delete_priv))
|
|
||||||
.bind(yn(p.create_priv))
|
|
||||||
.bind(yn(p.drop_priv))
|
|
||||||
.bind(yn(p.alter_priv))
|
|
||||||
.bind(yn(p.index_priv))
|
|
||||||
.bind(yn(p.create_tmp_table_priv))
|
|
||||||
.bind(yn(p.lock_tables_priv))
|
|
||||||
.bind(yn(p.references_priv))
|
|
||||||
.execute(&mut *connection)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
DatabasePrivilegesDiff::Modified(p) => {
|
|
||||||
let tables = p
|
|
||||||
.diff
|
|
||||||
.iter()
|
|
||||||
.map(|diff| match diff {
|
|
||||||
DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name),
|
|
||||||
DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name),
|
|
||||||
})
|
|
||||||
.join(",");
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
|
|
||||||
)
|
|
||||||
.bind(p.db)
|
|
||||||
.bind(p.user)
|
|
||||||
.execute(&mut *connection)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
DatabasePrivilegesDiff::Deleted(p) => {
|
|
||||||
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
|
|
||||||
.bind(p.db)
|
|
||||||
.bind(p.user)
|
|
||||||
.execute(&mut *connection)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
|
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
|
||||||
diff.diff
|
diff.diff
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -742,9 +517,10 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri
|
||||||
p.db,
|
p.db,
|
||||||
p.user,
|
p.user,
|
||||||
"(New user)\n".to_string()
|
"(New user)\n".to_string()
|
||||||
+ &display_privilege_cell(
|
+ &display_privilege_cell(&diff(
|
||||||
&DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p)
|
&DatabasePrivilegeRow::empty(&p.db, &p.user),
|
||||||
)
|
p
|
||||||
|
))
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
DatabasePrivilegesDiff::Modified(p) => {
|
DatabasePrivilegesDiff::Modified(p) => {
|
||||||
|
@ -755,9 +531,10 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri
|
||||||
p.db,
|
p.db,
|
||||||
p.user,
|
p.user,
|
||||||
"(All privileges removed)\n".to_string()
|
"(All privileges removed)\n".to_string()
|
||||||
+ &display_privilege_cell(
|
+ &display_privilege_cell(&diff(
|
||||||
&p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user))
|
p,
|
||||||
)
|
&DatabasePrivilegeRow::empty(&p.db, &p.user)
|
||||||
|
))
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,249 +0,0 @@
|
||||||
use anyhow::Context;
|
|
||||||
use nix::unistd::User;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use sqlx::{prelude::*, MySqlConnection};
|
|
||||||
|
|
||||||
use crate::core::common::{
|
|
||||||
create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
|
|
||||||
validate_ownership_or_error, DbOrUser,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
let user_exists = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT EXISTS(
|
|
||||||
SELECT 1
|
|
||||||
FROM `mysql`.`user`
|
|
||||||
WHERE `User` = ?
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(db_user)
|
|
||||||
.fetch_one(connection)
|
|
||||||
.await?
|
|
||||||
.get::<bool, _>(0);
|
|
||||||
|
|
||||||
Ok(user_exists)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn create_database_user(
|
|
||||||
db_user: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
if user_exists(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' already exists", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
|
||||||
sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str())
|
|
||||||
.execute(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn delete_database_user(
|
|
||||||
db_user: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
if !user_exists(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' does not exist", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
|
||||||
sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str())
|
|
||||||
.execute(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn set_password_for_database_user(
|
|
||||||
db_user: &str,
|
|
||||||
password: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let unix_user = crate::core::common::get_current_unix_user()?;
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
if !user_exists(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' does not exist", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
|
||||||
sqlx::query(
|
|
||||||
format!(
|
|
||||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
|
||||||
quote_literal(db_user),
|
|
||||||
quote_literal(password).as_str()
|
|
||||||
)
|
|
||||||
.as_str(),
|
|
||||||
)
|
|
||||||
.execute(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
if !user_exists(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' does not exist", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
let is_locked = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true'
|
|
||||||
FROM `mysql`.`global_priv`
|
|
||||||
WHERE `User` = ?
|
|
||||||
AND `Host` = '%'
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(db_user)
|
|
||||||
.fetch_one(connection)
|
|
||||||
.await?
|
|
||||||
.get::<bool, _>(0);
|
|
||||||
|
|
||||||
Ok(is_locked)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn lock_database_user(
|
|
||||||
db_user: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
if !user_exists(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' does not exist", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_is_locked(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' is already locked", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
|
||||||
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str())
|
|
||||||
.execute(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn unlock_database_user(
|
|
||||||
db_user: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let unix_user = get_current_unix_user()?;
|
|
||||||
|
|
||||||
validate_user_name(db_user, &unix_user)?;
|
|
||||||
|
|
||||||
if !user_exists(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' does not exist", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user_is_locked(db_user, connection).await? {
|
|
||||||
anyhow::bail!("User '{}' is already unlocked", db_user);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
|
||||||
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str())
|
|
||||||
.execute(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This struct contains information about a database user.
|
|
||||||
/// This can be extended if we need more information in the future.
|
|
||||||
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
|
|
||||||
pub struct DatabaseUser {
|
|
||||||
#[sqlx(rename = "User")]
|
|
||||||
pub user: String,
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[serde(skip)]
|
|
||||||
#[sqlx(rename = "Host")]
|
|
||||||
pub host: String,
|
|
||||||
|
|
||||||
#[sqlx(rename = "has_password")]
|
|
||||||
pub has_password: bool,
|
|
||||||
|
|
||||||
#[sqlx(rename = "is_locked")]
|
|
||||||
pub is_locked: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
const DB_USER_SELECT_STATEMENT: &str = r#"
|
|
||||||
SELECT
|
|
||||||
`mysql`.`user`.`User`,
|
|
||||||
`mysql`.`user`.`Host`,
|
|
||||||
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
|
|
||||||
COALESCE(
|
|
||||||
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
|
||||||
'false'
|
|
||||||
) != 'false' AS `is_locked`
|
|
||||||
FROM `mysql`.`user`
|
|
||||||
JOIN `mysql`.`global_priv` ON
|
|
||||||
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
|
|
||||||
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
|
|
||||||
"#;
|
|
||||||
|
|
||||||
/// This function fetches all database users that have a prefix matching the
|
|
||||||
/// unix username and group names of the given unix user.
|
|
||||||
pub async fn get_all_database_users_for_unix_user(
|
|
||||||
unix_user: &User,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<Vec<DatabaseUser>> {
|
|
||||||
let users = sqlx::query_as::<_, DatabaseUser>(
|
|
||||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
|
|
||||||
)
|
|
||||||
.bind(create_user_group_matching_regex(unix_user))
|
|
||||||
.fetch_all(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(users)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function fetches a database user if it exists.
|
|
||||||
pub async fn get_database_user_for_user(
|
|
||||||
username: &str,
|
|
||||||
connection: &mut MySqlConnection,
|
|
||||||
) -> anyhow::Result<Option<DatabaseUser>> {
|
|
||||||
let user = sqlx::query_as::<_, DatabaseUser>(
|
|
||||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
|
|
||||||
)
|
|
||||||
.bind(username)
|
|
||||||
.fetch_optional(connection)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(user)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// NOTE: It is very critical that this function validates the database name
|
|
||||||
/// properly. MySQL does not seem to allow for prepared statements, binding
|
|
||||||
/// the database name as a parameter to the query. This means that we have
|
|
||||||
/// to validate the database name ourselves to prevent SQL injection.
|
|
||||||
pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
|
|
||||||
validate_name_or_error(name, DbOrUser::User)
|
|
||||||
.context(format!("Invalid username: '{}'", name))?;
|
|
||||||
validate_ownership_or_error(name, user, DbOrUser::User)
|
|
||||||
.context(format!("Invalid username: '{}'", name))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
102
src/main.rs
102
src/main.rs
|
@ -1,16 +1,20 @@
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate prettytable;
|
extern crate prettytable;
|
||||||
|
|
||||||
use core::common::CommandStatus;
|
|
||||||
#[cfg(feature = "mysql-admutils-compatibility")]
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
#[cfg(feature = "mysql-admutils-compatibility")]
|
|
||||||
use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
|
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use server::bootstrap::bootstrap_server_connection_and_drop_privileges;
|
||||||
|
|
||||||
|
// use core::common::CommandStatus;
|
||||||
|
// #[cfg(feature = "mysql-admutils-compatibility")]
|
||||||
|
// use std::path::PathBuf;
|
||||||
|
|
||||||
|
// #[cfg(feature = "mysql-admutils-compatibility")]
|
||||||
|
// use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
|
||||||
|
|
||||||
|
mod server;
|
||||||
|
|
||||||
mod authenticated_unix_socket;
|
|
||||||
mod cli;
|
mod cli;
|
||||||
mod core;
|
mod core;
|
||||||
|
|
||||||
|
@ -22,21 +26,22 @@ struct Args {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
command: Command,
|
command: Command,
|
||||||
|
|
||||||
#[command(flatten)]
|
/// Path to the socket of the server, if it already exists.
|
||||||
config_overrides: core::config::GlobalConfigArgs,
|
#[arg(long, value_name = "HOST", global = true, hide_short_help = true)]
|
||||||
|
server_address: Option<PathBuf>,
|
||||||
|
|
||||||
#[cfg(feature = "tui")]
|
#[cfg(feature = "tui")]
|
||||||
#[arg(short, long, alias = "tui", global = true)]
|
#[arg(short, long, alias = "tui", global = true)]
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Database administration tool for non-admin users to manage their own MySQL databases and users.
|
// Database administration tool for non-admin users to manage their own MySQL databases and users.
|
||||||
///
|
//
|
||||||
/// This tool allows you to manage users and databases in MySQL.
|
// This tool allows you to manage users and databases in MySQL.
|
||||||
///
|
//
|
||||||
/// You are only allowed to manage databases and users that are prefixed with
|
// You are only allowed to manage databases and users that are prefixed with
|
||||||
/// either your username, or a group that you are a member of.
|
// either your username, or a group that you are a member of.
|
||||||
#[derive(Parser)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
#[command(version, about, disable_help_subcommand = true)]
|
#[command(version, about, disable_help_subcommand = true)]
|
||||||
enum Command {
|
enum Command {
|
||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
|
@ -44,57 +49,44 @@ enum Command {
|
||||||
|
|
||||||
#[command(flatten)]
|
#[command(flatten)]
|
||||||
User(cli::user_command::UserCommand),
|
User(cli::user_command::UserCommand),
|
||||||
|
|
||||||
|
#[command(hide = true)]
|
||||||
|
Server(server::command::ServerArgs),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main(flavor = "current_thread")]
|
#[tokio::main(flavor = "current_thread")]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
|
|
||||||
#[cfg(feature = "mysql-admutils-compatibility")]
|
// #[cfg(feature = "mysql-admutils-compatibility")]
|
||||||
{
|
// {
|
||||||
let argv0 = std::env::args().next().and_then(|s| {
|
// let argv0 = std::env::args().next().and_then(|s| {
|
||||||
PathBuf::from(s)
|
// PathBuf::from(s)
|
||||||
.file_name()
|
// .file_name()
|
||||||
.map(|s| s.to_string_lossy().to_string())
|
// .map(|s| s.to_string_lossy().to_string())
|
||||||
});
|
// });
|
||||||
|
|
||||||
match argv0.as_deref() {
|
// match argv0.as_deref() {
|
||||||
Some("mysql-dbadm") => return mysql_dbadm::main().await,
|
// Some("mysql-dbadm") => return mysql_dbadm::main().await,
|
||||||
Some("mysql-useradm") => return mysql_useradm::main().await,
|
// Some("mysql-useradm") => return mysql_useradm::main().await,
|
||||||
_ => { /* fall through */ }
|
// _ => { /* fall through */ }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
let args: Args = Args::parse();
|
let args: Args = Args::parse();
|
||||||
let config = core::config::get_config(args.config_overrides)?;
|
match args.command {
|
||||||
let connection = core::config::create_mysql_connection_from_config(config.mysql).await?;
|
Command::Server(ref command) => server::command::handle_command(command.to_owned()).await?,
|
||||||
|
_ => { /* fall through */ }
|
||||||
|
}
|
||||||
|
|
||||||
let result = match args.command {
|
let server_stream =
|
||||||
Command::Db(command) => cli::database_command::handle_command(command, connection).await,
|
bootstrap_server_connection_and_drop_privileges(args.server_address).await?;
|
||||||
Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await,
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
match args.command {
|
||||||
Ok(CommandStatus::SuccessfullyModified) => {
|
Command::User(user_args) => {
|
||||||
println!("Modifications committed successfully");
|
cli::user_command::handle_command(user_args, server_stream).await
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
Ok(CommandStatus::PartiallySuccessfullyModified) => {
|
Command::Db(db_args) => cli::database_command::handle_command(db_args, server_stream).await,
|
||||||
println!("Some modifications committed successfully");
|
Command::Server(_) => unreachable!(),
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Ok(CommandStatus::NoModificationsNeeded) => {
|
|
||||||
println!("No modifications made");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Ok(CommandStatus::NoModificationsIntended) => {
|
|
||||||
/* Don't report anything */
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Ok(CommandStatus::Cancelled) => {
|
|
||||||
println!("Command cancelled successfully");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Err(e) => Err(e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
|
pub mod bootstrap;
|
||||||
|
pub mod command;
|
||||||
mod common;
|
mod common;
|
||||||
|
pub mod config;
|
||||||
mod database_operations;
|
mod database_operations;
|
||||||
|
pub mod database_privilege_operations;
|
||||||
mod entrypoint;
|
mod entrypoint;
|
||||||
mod input_sanitization;
|
mod input_sanitization;
|
||||||
mod protocol;
|
pub mod protocol;
|
||||||
mod user_operations;
|
mod user_operations;
|
||||||
|
|
||||||
|
pub use protocol::{Request, Response};
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
use std::{fs, path::PathBuf};
|
||||||
|
|
||||||
|
use tokio::net::UnixStream;
|
||||||
|
|
||||||
|
use crate::server::protocol::create_client_to_server_message_stream;
|
||||||
|
|
||||||
|
use super::config::DEFAULT_SOCKET_PATH;
|
||||||
|
use super::protocol::ClientToServerMessageStream;
|
||||||
|
|
||||||
|
pub mod authenticated_unix_socket;
|
||||||
|
|
||||||
|
// TODO: allow overriding which way to connect to the server
|
||||||
|
|
||||||
|
pub async fn bootstrap_server_connection_and_drop_privileges<'a>(
|
||||||
|
socket_path: Option<PathBuf>,
|
||||||
|
) -> anyhow::Result<ClientToServerMessageStream> {
|
||||||
|
// If socket path explicitly provided, use or error
|
||||||
|
if let Some(socket_path) = socket_path {
|
||||||
|
let mut socket = UnixStream::connect(socket_path).await?;
|
||||||
|
authenticated_unix_socket::client_authenticate(&mut socket, None).await?;
|
||||||
|
let message_stream = create_client_to_server_message_stream(socket);
|
||||||
|
return Ok(message_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not, check if the default socket path exists
|
||||||
|
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
|
||||||
|
let mut socket = UnixStream::connect(DEFAULT_SOCKET_PATH).await?;
|
||||||
|
authenticated_unix_socket::client_authenticate(&mut socket, None).await?;
|
||||||
|
let message_stream = create_client_to_server_message_stream(socket);
|
||||||
|
return Ok(message_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not, check if we are suid, guid and can read the config file
|
||||||
|
// if {
|
||||||
|
// if so, create anonymous socket pair, spawn server, and drop privileges
|
||||||
|
// }
|
||||||
|
|
||||||
|
// If not, error
|
||||||
|
|
||||||
|
// Upon non suid-guid invocation, we will need to authenticate the socket.
|
||||||
|
|
||||||
|
todo!()
|
||||||
|
}
|
|
@ -34,6 +34,7 @@ use std::os::unix::io::AsRawFd;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
|
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
|
||||||
|
use derive_more::derive::{Display, Error};
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
use nix::{sys::stat, unistd::Uid};
|
use nix::{sys::stat, unistd::Uid};
|
||||||
use rand::distributions::Alphanumeric;
|
use rand::distributions::Alphanumeric;
|
||||||
|
@ -52,7 +53,7 @@ pub enum ClientRequest {
|
||||||
Cancel,
|
Cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)]
|
||||||
pub enum ServerResponse {
|
pub enum ServerResponse {
|
||||||
Authenticated,
|
Authenticated,
|
||||||
ChallengeDidNotMatch,
|
ChallengeDidNotMatch,
|
||||||
|
@ -61,7 +62,7 @@ pub enum ServerResponse {
|
||||||
|
|
||||||
// TODO: wrap more data into the errors
|
// TODO: wrap more data into the errors
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)]
|
||||||
pub enum ServerError {
|
pub enum ServerError {
|
||||||
InvalidRequest,
|
InvalidRequest,
|
||||||
UnableToReadPermissionsFromAuthSocket,
|
UnableToReadPermissionsFromAuthSocket,
|
||||||
|
@ -72,7 +73,7 @@ pub enum ServerError {
|
||||||
InvalidChallenge,
|
InvalidChallenge,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq, Display, Error)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
UnableToConnectToServer,
|
UnableToConnectToServer,
|
||||||
UnableToOpenAuthSocket,
|
UnableToOpenAuthSocket,
|
||||||
|
@ -86,7 +87,7 @@ pub enum ClientError {
|
||||||
ServerError(ServerError),
|
ServerError(ServerError),
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_auth_socket(socket_addr: &str) -> Result<UnixListener, ClientError> {
|
async fn create_auth_socket(socket_addr: &PathBuf) -> Result<UnixListener, ClientError> {
|
||||||
let auth_socket =
|
let auth_socket =
|
||||||
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
|
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
|
||||||
|
|
||||||
|
@ -109,11 +110,13 @@ type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDest
|
||||||
|
|
||||||
// TODO: add timeout
|
// TODO: add timeout
|
||||||
|
|
||||||
|
// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR
|
||||||
|
|
||||||
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
|
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
|
||||||
|
|
||||||
pub async fn client_authenticate(
|
pub async fn client_authenticate(
|
||||||
normal_socket: &mut UnixStream,
|
normal_socket: &mut UnixStream,
|
||||||
#[cfg(not(test))] auth_socket_dir: Option<PathBuf>,
|
auth_socket_dir: Option<PathBuf>,
|
||||||
#[cfg(test)] auth_socket_file: Option<PathBuf>,
|
|
||||||
) -> Result<(), ClientError> {
|
) -> Result<(), ClientError> {
|
||||||
let random_prefix: String = rand::thread_rng()
|
let random_prefix: String = rand::thread_rng()
|
||||||
.sample_iter(&Alphanumeric)
|
.sample_iter(&Alphanumeric)
|
||||||
|
@ -123,32 +126,16 @@ pub async fn client_authenticate(
|
||||||
|
|
||||||
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
|
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
|
||||||
|
|
||||||
#[cfg(not(test))]
|
let auth_socket_address = auth_socket_dir
|
||||||
let auth_socket_address = match auth_socket_dir {
|
.unwrap_or(std::env::temp_dir())
|
||||||
Some(dir) => dir.join(socket_name).to_str().unwrap().to_string(),
|
.join(socket_name);
|
||||||
None => std::env::temp_dir()
|
|
||||||
.join(socket_name)
|
|
||||||
.to_str()
|
|
||||||
.unwrap()
|
|
||||||
.to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
let auth_socket_address = match auth_socket_file {
|
|
||||||
Some(file) => file.to_str().unwrap().to_string(),
|
|
||||||
None => std::env::temp_dir()
|
|
||||||
.join(socket_name)
|
|
||||||
.to_str()
|
|
||||||
.unwrap()
|
|
||||||
.to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
|
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn client_authenticate_with_auth_socket_address(
|
async fn client_authenticate_with_auth_socket_address(
|
||||||
normal_socket: &mut UnixStream,
|
normal_socket: &mut UnixStream,
|
||||||
auth_socket_address: &str,
|
auth_socket_address: &PathBuf,
|
||||||
) -> Result<(), ClientError> {
|
) -> Result<(), ClientError> {
|
||||||
let auth_socket = create_auth_socket(auth_socket_address).await?;
|
let auth_socket = create_auth_socket(auth_socket_address).await?;
|
||||||
|
|
||||||
|
@ -164,7 +151,7 @@ async fn client_authenticate_with_auth_socket_address(
|
||||||
async fn client_authenticate_with_auth_socket(
|
async fn client_authenticate_with_auth_socket(
|
||||||
normal_socket: &mut UnixStream,
|
normal_socket: &mut UnixStream,
|
||||||
auth_socket: UnixListener,
|
auth_socket: UnixListener,
|
||||||
auth_socket_address: &str,
|
auth_socket_address: &PathBuf,
|
||||||
) -> Result<(), ClientError> {
|
) -> Result<(), ClientError> {
|
||||||
let challenge = rand::random::<u64>();
|
let challenge = rand::random::<u64>();
|
||||||
let uid = nix::unistd::getuid();
|
let uid = nix::unistd::getuid();
|
||||||
|
@ -199,7 +186,7 @@ async fn client_authenticate_with_auth_socket(
|
||||||
let client_hello = ClientRequest::Initialize {
|
let client_hello = ClientRequest::Initialize {
|
||||||
uid: uid.into(),
|
uid: uid.into(),
|
||||||
challenge,
|
challenge,
|
||||||
auth_socket: auth_socket_address.to_string(),
|
auth_socket: auth_socket_address.to_str().unwrap().to_owned(),
|
||||||
};
|
};
|
||||||
|
|
||||||
normal_socket
|
normal_socket
|
||||||
|
@ -239,9 +226,13 @@ macro_rules! report_server_error_and_return {
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn server_authenticate(
|
pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result<Uid, ServerError> {
|
||||||
|
_server_authenticate(normal_socket, None).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn _server_authenticate(
|
||||||
normal_socket: &mut UnixStream,
|
normal_socket: &mut UnixStream,
|
||||||
#[cfg(test)] unix_user_uid: Option<u32>,
|
unix_user_uid: Option<u32>,
|
||||||
) -> Result<Uid, ServerError> {
|
) -> Result<Uid, ServerError> {
|
||||||
let mut normal_socket: ServerToClientStream =
|
let mut normal_socket: ServerToClientStream =
|
||||||
AsyncBincodeStream::from(normal_socket).for_async();
|
AsyncBincodeStream::from(normal_socket).for_async();
|
||||||
|
@ -256,22 +247,15 @@ async fn server_authenticate(
|
||||||
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
|
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
let auth_socket_uid = match unix_user_uid {
|
let auth_socket_uid = match unix_user_uid {
|
||||||
Some(uid) => uid,
|
Some(uid) => uid,
|
||||||
None => report_server_error_and_return!(
|
None => match stat::stat(auth_socket.as_str()) {
|
||||||
normal_socket,
|
|
||||||
ServerError::UnableToReadPermissionsFromAuthSocket
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg(not(test))]
|
|
||||||
let auth_socket_uid = match stat::stat(auth_socket.as_str()) {
|
|
||||||
Ok(stat) => stat.st_uid,
|
Ok(stat) => stat.st_uid,
|
||||||
Err(_err) => report_server_error_and_return!(
|
Err(_err) => report_server_error_and_return!(
|
||||||
normal_socket,
|
normal_socket,
|
||||||
ServerError::UnableToReadPermissionsFromAuthSocket
|
ServerError::UnableToReadPermissionsFromAuthSocket
|
||||||
),
|
),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if uid != auth_socket_uid {
|
if uid != auth_socket_uid {
|
||||||
|
@ -324,10 +308,7 @@ mod test {
|
||||||
let client_handle =
|
let client_handle =
|
||||||
tokio::spawn(async move { client_authenticate(&mut client, None).await });
|
tokio::spawn(async move { client_authenticate(&mut client, None).await });
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
||||||
let uid = nix::unistd::getuid().into();
|
|
||||||
server_authenticate(&mut server, Some(uid)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
client_handle.await.unwrap().unwrap();
|
client_handle.await.unwrap().unwrap();
|
||||||
server_handle.await.unwrap().unwrap();
|
server_handle.await.unwrap().unwrap();
|
||||||
|
@ -340,15 +321,12 @@ mod test {
|
||||||
let client_handle = tokio::spawn(async move {
|
let client_handle = tokio::spawn(async move {
|
||||||
client_authenticate_with_auth_socket_address(
|
client_authenticate_with_auth_socket_address(
|
||||||
&mut client,
|
&mut client,
|
||||||
"/tmp/test_auth_socket_does_not_exist.sock",
|
&PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
});
|
});
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
||||||
let uid = nix::unistd::getuid().into();
|
|
||||||
server_authenticate(&mut server, Some(uid)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
client_handle.await.unwrap().unwrap();
|
client_handle.await.unwrap().unwrap();
|
||||||
server_handle.await.unwrap().unwrap();
|
server_handle.await.unwrap().unwrap();
|
||||||
|
@ -365,7 +343,7 @@ mod test {
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle = tokio::spawn(async move {
|
||||||
let uid: u32 = nix::unistd::getuid().into();
|
let uid: u32 = nix::unistd::getuid().into();
|
||||||
let err = server_authenticate(&mut server, Some(uid + 1)).await;
|
let err = _server_authenticate(&mut server, Some(uid + 1)).await;
|
||||||
assert_eq!(err, Err(ServerError::UidMismatch));
|
assert_eq!(err, Err(ServerError::UidMismatch));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -379,13 +357,19 @@ mod test {
|
||||||
|
|
||||||
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
|
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
|
||||||
let socket_path_clone = socket_path.clone();
|
let socket_path_clone = socket_path.clone();
|
||||||
let client_handle =
|
let client_handle = tokio::spawn(async move {
|
||||||
tokio::spawn(
|
client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await
|
||||||
async move { client_authenticate(&mut client, Some(socket_path_clone)).await },
|
});
|
||||||
);
|
|
||||||
|
|
||||||
while !socket_path.exists() {
|
for i in 0..100 {
|
||||||
sleep(std::time::Duration::from_millis(10)).await;
|
if socket_path.exists() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sleep(Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
if i == 99 {
|
||||||
|
panic!("Socket not created after 1 second, assuming test failure");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||||
|
@ -409,10 +393,7 @@ mod test {
|
||||||
|
|
||||||
sleep(Duration::from_millis(10)).await;
|
sleep(Duration::from_millis(10)).await;
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
||||||
let uid: u32 = nix::unistd::getuid().into();
|
|
||||||
server_authenticate(&mut server, Some(uid)).await
|
|
||||||
});
|
|
||||||
|
|
||||||
client_handle.await.unwrap().unwrap();
|
client_handle.await.unwrap().unwrap();
|
||||||
server_handle.await.unwrap().unwrap();
|
server_handle.await.unwrap().unwrap();
|
|
@ -0,0 +1,161 @@
|
||||||
|
use std::fs;
|
||||||
|
use std::os::fd::FromRawFd;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use clap::Parser;
|
||||||
|
use sqlx::prelude::*;
|
||||||
|
use sqlx::MySqlConnection;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio::net::UnixListener;
|
||||||
|
use tokio::net::UnixStream;
|
||||||
|
|
||||||
|
use crate::server::config::get_config;
|
||||||
|
use crate::server::config::DEFAULT_SOCKET_PATH;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
bootstrap::authenticated_unix_socket,
|
||||||
|
common::UnixUser,
|
||||||
|
config::{create_mysql_connection_from_config, ServerConfig, ServerConfigArgs},
|
||||||
|
entrypoint::handle_request,
|
||||||
|
protocol::create_server_to_client_message_stream,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
pub struct ServerArgs {
|
||||||
|
#[command(subcommand)]
|
||||||
|
subcmd: ServerCommand,
|
||||||
|
|
||||||
|
#[command(flatten)]
|
||||||
|
config_overrides: ServerConfigArgs,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
pub enum ServerCommand {
|
||||||
|
#[command()]
|
||||||
|
Listen,
|
||||||
|
|
||||||
|
#[command()]
|
||||||
|
SocketActivate,
|
||||||
|
|
||||||
|
#[command()]
|
||||||
|
StickyBitActivate,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_command(args: ServerArgs) -> anyhow::Result<()> {
|
||||||
|
let config = get_config(args.config_overrides)?;
|
||||||
|
|
||||||
|
let mut db_connection = create_mysql_connection_from_config(config.mysql.clone()).await?;
|
||||||
|
|
||||||
|
let result = match args.subcmd {
|
||||||
|
ServerCommand::Listen => listen(config, &mut db_connection).await,
|
||||||
|
ServerCommand::SocketActivate => socket_activate(config, &mut db_connection).await,
|
||||||
|
ServerCommand::StickyBitActivate => sticky_bit_activate(config, &mut db_connection).await,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = &result {
|
||||||
|
eprintln!("{}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
close_database_connection(db_connection).await;
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Gracefully close a MySQL connection.
|
||||||
|
async fn close_database_connection(connection: MySqlConnection) {
|
||||||
|
if let Err(e) = connection
|
||||||
|
.close()
|
||||||
|
.await
|
||||||
|
.context("Failed to close connection properly")
|
||||||
|
{
|
||||||
|
eprintln!("{}", e);
|
||||||
|
eprintln!("Ignoring...");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn listen(_config: ServerConfig, db_connection: &mut MySqlConnection) -> anyhow::Result<()> {
|
||||||
|
// let mut db_connection = create_mysql_connection_from_config(config.mysql).await?;
|
||||||
|
// TODO: fetch from arguments or config
|
||||||
|
|
||||||
|
let default_config_path = PathBuf::from(DEFAULT_SOCKET_PATH);
|
||||||
|
let parent_directory = default_config_path.parent().unwrap();
|
||||||
|
if !parent_directory.exists() {
|
||||||
|
println!("Creating directory {:?}", parent_directory);
|
||||||
|
fs::create_dir_all(parent_directory)?;
|
||||||
|
}
|
||||||
|
println!("Listening on {:?}", DEFAULT_SOCKET_PATH);
|
||||||
|
match fs::remove_file(DEFAULT_SOCKET_PATH) {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
}
|
||||||
|
|
||||||
|
let listener = UnixListener::bind(DEFAULT_SOCKET_PATH)?;
|
||||||
|
|
||||||
|
while let Ok((mut conn, _addr)) = listener.accept().await {
|
||||||
|
let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await {
|
||||||
|
Ok(uid) => uid,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Failed to authenticate client: {}", e);
|
||||||
|
conn.shutdown().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let unix_user = match UnixUser::from_uid(uid.into()) {
|
||||||
|
Ok(user) => user,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Failed to get UnixUser from uid: {}", e);
|
||||||
|
conn.shutdown().await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let stream = create_server_to_client_message_stream(conn);
|
||||||
|
match handle_request(stream, &unix_user, db_connection).await {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Failed to run server: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_socket_from_systemd() -> anyhow::Result<UnixStream> {
|
||||||
|
let fd = std::env::var("LISTEN_FDS")
|
||||||
|
.context("LISTEN_FDS not set, not running under systemd?")?
|
||||||
|
.parse::<i32>()
|
||||||
|
.context("Failed to parse LISTEN_FDS")?;
|
||||||
|
|
||||||
|
if fd != 1 {
|
||||||
|
return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd));
|
||||||
|
}
|
||||||
|
|
||||||
|
let std_unix_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(fd) };
|
||||||
|
let socket = tokio::net::UnixStream::from_std(std_unix_stream)?;
|
||||||
|
Ok(socket)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn socket_activate(
|
||||||
|
_config: ServerConfig,
|
||||||
|
db_connection: &mut MySqlConnection,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
// TODO: allow getting socket path from other socket activation sources
|
||||||
|
let mut conn = get_socket_from_systemd().await?;
|
||||||
|
let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?;
|
||||||
|
let unix_user = UnixUser::from_uid(uid.into())?;
|
||||||
|
let stream = create_server_to_client_message_stream(conn);
|
||||||
|
handle_request(stream, &unix_user, db_connection).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sticky_bit_activate(
|
||||||
|
_config: ServerConfig,
|
||||||
|
_db_connection: &mut MySqlConnection,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
// Is this a type of socket activation?
|
||||||
|
println!("Activating sticky bit");
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -1,6 +1,5 @@
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use nix::unistd::{Group as LibcGroup, User as LibcUser};
|
use nix::unistd::{Group as LibcGroup, User as LibcUser};
|
||||||
use sqlx::{Connection, MySqlConnection};
|
|
||||||
|
|
||||||
#[cfg(not(target_os = "macos"))]
|
#[cfg(not(target_os = "macos"))]
|
||||||
use std::ffi::CString;
|
use std::ffi::CString;
|
||||||
|
@ -30,12 +29,6 @@ pub enum CommandStatus {
|
||||||
Cancelled,
|
Cancelled,
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn get_current_unix_user() -> anyhow::Result<User> {
|
|
||||||
// User::from_uid(getuid())
|
|
||||||
// .context("Failed to look up your UNIX username")
|
|
||||||
// .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub struct UnixUser {
|
pub struct UnixUser {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub uid: u32,
|
pub uid: u32,
|
||||||
|
@ -99,15 +92,3 @@ pub fn create_user_group_matching_regex(user: &UnixUser) -> String {
|
||||||
format!("({}|{})(_.+)?", user.username, user.groups.join("|"))
|
format!("({}|{})(_.+)?", user.username, user.groups.join("|"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gracefully close a MySQL connection.
|
|
||||||
pub async fn close_database_connection(connection: MySqlConnection) {
|
|
||||||
if let Err(e) = connection
|
|
||||||
.close()
|
|
||||||
.await
|
|
||||||
.context("Failed to close connection properly")
|
|
||||||
{
|
|
||||||
eprintln!("{}", e);
|
|
||||||
eprintln!("Ignoring...");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -5,11 +5,16 @@ use clap::Parser;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection};
|
use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection};
|
||||||
|
|
||||||
|
pub const DEFAULT_CONFIG_PATH: &str = "/etc/mysqladm/config.toml";
|
||||||
|
pub const DEFAULT_SOCKET_PATH: &str = "/run/mysqladm/mysqladm.sock";
|
||||||
|
pub const DEFAULT_PORT: u16 = 3306;
|
||||||
|
pub const DEFAULT_TIMEOUT: u64 = 2;
|
||||||
|
|
||||||
// NOTE: this might look empty now, and the extra wrapping for the mysql
|
// NOTE: this might look empty now, and the extra wrapping for the mysql
|
||||||
// config seems unnecessary, but it will be useful later when we
|
// config seems unnecessary, but it will be useful later when we
|
||||||
// add more configuration options.
|
// add more configuration options.
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Config {
|
pub struct ServerConfig {
|
||||||
pub mysql: MysqlConfig,
|
pub mysql: MysqlConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,49 +28,45 @@ pub struct MysqlConfig {
|
||||||
pub timeout: Option<u64>,
|
pub timeout: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_PORT: u16 = 3306;
|
#[derive(Parser, Debug, Clone)]
|
||||||
const DEFAULT_TIMEOUT: u64 = 2;
|
pub struct ServerConfigArgs {
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
pub struct GlobalConfigArgs {
|
|
||||||
/// Path to the configuration file.
|
/// Path to the configuration file.
|
||||||
#[arg(
|
#[arg(
|
||||||
short,
|
short,
|
||||||
long,
|
long,
|
||||||
value_name = "PATH",
|
value_name = "PATH",
|
||||||
global = true,
|
global = true,
|
||||||
hide_short_help = true,
|
default_value = DEFAULT_CONFIG_PATH,
|
||||||
default_value = "/etc/mysqladm/config.toml"
|
|
||||||
)]
|
)]
|
||||||
config_file: String,
|
config_file: String,
|
||||||
|
|
||||||
/// Hostname of the MySQL server.
|
/// Hostname of the MySQL server.
|
||||||
#[arg(long, value_name = "HOST", global = true, hide_short_help = true)]
|
#[arg(long, value_name = "HOST", global = true)]
|
||||||
mysql_host: Option<String>,
|
mysql_host: Option<String>,
|
||||||
|
|
||||||
/// Port of the MySQL server.
|
/// Port of the MySQL server.
|
||||||
#[arg(long, value_name = "PORT", global = true, hide_short_help = true)]
|
#[arg(long, value_name = "PORT", global = true)]
|
||||||
mysql_port: Option<u16>,
|
mysql_port: Option<u16>,
|
||||||
|
|
||||||
/// Username to use for the MySQL connection.
|
/// Username to use for the MySQL connection.
|
||||||
#[arg(long, value_name = "USER", global = true, hide_short_help = true)]
|
#[arg(long, value_name = "USER", global = true)]
|
||||||
mysql_user: Option<String>,
|
mysql_user: Option<String>,
|
||||||
|
|
||||||
/// Path to a file containing the MySQL password.
|
/// Path to a file containing the MySQL password.
|
||||||
#[arg(long, value_name = "PATH", global = true, hide_short_help = true)]
|
#[arg(long, value_name = "PATH", global = true)]
|
||||||
mysql_password_file: Option<String>,
|
mysql_password_file: Option<String>,
|
||||||
|
|
||||||
/// Seconds to wait for the MySQL connection to be established.
|
/// Seconds to wait for the MySQL connection to be established.
|
||||||
#[arg(long, value_name = "SECONDS", global = true, hide_short_help = true)]
|
#[arg(long, value_name = "SECONDS", global = true)]
|
||||||
mysql_connect_timeout: Option<u64>,
|
mysql_connect_timeout: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Use the arguments and whichever configuration file which might or might not
|
/// Use the arguments and whichever configuration file which might or might not
|
||||||
/// be found and default values to determine the configuration for the program.
|
/// be found and default values to determine the configuration for the program.
|
||||||
pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> {
|
pub fn get_config(args: ServerConfigArgs) -> anyhow::Result<ServerConfig> {
|
||||||
let config_path = PathBuf::from(args.config_file);
|
let config_path = PathBuf::from(args.config_file);
|
||||||
|
|
||||||
let config: Config = fs::read_to_string(&config_path)
|
let config: ServerConfig = fs::read_to_string(&config_path)
|
||||||
.context(format!(
|
.context(format!(
|
||||||
"Failed to read config file from {:?}",
|
"Failed to read config file from {:?}",
|
||||||
&config_path
|
&config_path
|
||||||
|
@ -86,16 +87,14 @@ pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> {
|
||||||
mysql.password.to_owned()
|
mysql.password.to_owned()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mysql_config = MysqlConfig {
|
Ok(ServerConfig {
|
||||||
|
mysql: MysqlConfig {
|
||||||
host: args.mysql_host.unwrap_or(mysql.host.to_owned()),
|
host: args.mysql_host.unwrap_or(mysql.host.to_owned()),
|
||||||
port: args.mysql_port.or(mysql.port),
|
port: args.mysql_port.or(mysql.port),
|
||||||
username: args.mysql_user.unwrap_or(mysql.username.to_owned()),
|
username: args.mysql_user.unwrap_or(mysql.username.to_owned()),
|
||||||
password,
|
password,
|
||||||
timeout: args.mysql_connect_timeout.or(mysql.timeout),
|
timeout: args.mysql_connect_timeout.or(mysql.timeout),
|
||||||
};
|
},
|
||||||
|
|
||||||
Ok(Config {
|
|
||||||
mysql: mysql_config,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,13 +12,13 @@ use std::collections::BTreeMap;
|
||||||
use super::common::create_user_group_matching_regex;
|
use super::common::create_user_group_matching_regex;
|
||||||
|
|
||||||
// NOTE: this function is unsafe because it does no input validation.
|
// NOTE: this function is unsafe because it does no input validation.
|
||||||
async fn unsafe_database_exists(
|
pub(super) async fn unsafe_database_exists(
|
||||||
db_name: &str,
|
database_name: &str,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<bool, sqlx::Error> {
|
) -> Result<bool, sqlx::Error> {
|
||||||
let result =
|
let result =
|
||||||
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
|
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
|
||||||
.bind(db_name)
|
.bind(database_name)
|
||||||
.fetch_optional(connection)
|
.fetch_optional(connection)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ pub async fn drop_databases(
|
||||||
results
|
results
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type ListDatabasesOutput = Result<Vec<String>, ListDatabasesError>;
|
pub type ListAllDatabasesOutput = Result<Vec<String>, ListDatabasesError>;
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum ListDatabasesError {
|
pub enum ListDatabasesError {
|
||||||
MySqlError(String),
|
MySqlError(String),
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
// TODO: fix comment
|
||||||
//! Database privilege operations
|
//! Database privilege operations
|
||||||
//!
|
//!
|
||||||
//! This module contains functions for querying, modifying,
|
//! This module contains functions for querying, modifying,
|
||||||
|
@ -13,21 +14,21 @@
|
||||||
//! changes will be made when applying a set of changes
|
//! changes will be made when applying a set of changes
|
||||||
//! to the list of database privileges.
|
//! to the list of database privileges.
|
||||||
|
|
||||||
use std::collections::{BTreeSet, HashMap};
|
use std::collections::{BTreeMap, BTreeSet};
|
||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
|
||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use prettytable::Table;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
|
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
|
||||||
|
|
||||||
use crate::core::{
|
use super::common::{create_user_group_matching_regex, UnixUser};
|
||||||
common::{
|
use super::database_operations::unsafe_database_exists;
|
||||||
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn,
|
use super::input_sanitization::{
|
||||||
},
|
quote_identifier, validate_name, validate_ownership_by_unix_user, NameValidationError,
|
||||||
database_operations::validate_database_name,
|
OwnerValidationError,
|
||||||
};
|
};
|
||||||
|
use crate::core::common::{rev_yn, yn};
|
||||||
|
use crate::core::database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff};
|
||||||
|
|
||||||
/// This is the list of fields that are used to fetch the db + user + privileges
|
/// This is the list of fields that are used to fetch the db + user + privileges
|
||||||
/// from the `db` table in the database. If you need to add or remove privilege
|
/// from the `db` table in the database. If you need to add or remove privilege
|
||||||
|
@ -48,25 +49,6 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
|
||||||
"references_priv",
|
"references_priv",
|
||||||
];
|
];
|
||||||
|
|
||||||
pub fn db_priv_field_human_readable_name(name: &str) -> String {
|
|
||||||
match name {
|
|
||||||
"db" => "Database".to_owned(),
|
|
||||||
"user" => "User".to_owned(),
|
|
||||||
"select_priv" => "Select".to_owned(),
|
|
||||||
"insert_priv" => "Insert".to_owned(),
|
|
||||||
"update_priv" => "Update".to_owned(),
|
|
||||||
"delete_priv" => "Delete".to_owned(),
|
|
||||||
"create_priv" => "Create".to_owned(),
|
|
||||||
"drop_priv" => "Drop".to_owned(),
|
|
||||||
"alter_priv" => "Alter".to_owned(),
|
|
||||||
"index_priv" => "Index".to_owned(),
|
|
||||||
"create_tmp_table_priv" => "Temp".to_owned(),
|
|
||||||
"lock_tables_priv" => "Lock".to_owned(),
|
|
||||||
"references_priv" => "References".to_owned(),
|
|
||||||
_ => format!("Unknown({})", name),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: ord is needed for BTreeSet to accept the type, but it
|
// NOTE: ord is needed for BTreeSet to accept the type, but it
|
||||||
// doesn't have any natural implementation semantics.
|
// doesn't have any natural implementation semantics.
|
||||||
|
|
||||||
|
@ -123,26 +105,6 @@ impl DatabasePrivilegeRow {
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
|
|
||||||
debug_assert!(self.db == other.db && self.user == other.user);
|
|
||||||
|
|
||||||
DatabasePrivilegeRowDiff {
|
|
||||||
db: self.db.clone(),
|
|
||||||
user: self.user.clone(),
|
|
||||||
diff: DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.into_iter()
|
|
||||||
.skip(2)
|
|
||||||
.filter_map(|field| {
|
|
||||||
DatabasePrivilegeChange::new(
|
|
||||||
self.get_privilege_by_name(field),
|
|
||||||
other.get_privilege_by_name(field),
|
|
||||||
field,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
|
@ -178,15 +140,13 @@ impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: this function is unsafe because it does no input validation.
|
||||||
/// Get all users + privileges for a single database.
|
/// Get all users + privileges for a single database.
|
||||||
pub async fn get_database_privileges(
|
async fn unsafe_get_database_privileges(
|
||||||
database_name: &str,
|
database_name: &str,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
|
||||||
let unix_user = get_current_unix_user()?;
|
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||||
validate_database_name(database_name, &unix_user)?;
|
|
||||||
|
|
||||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
|
||||||
"SELECT {} FROM `db` WHERE `db` = ?",
|
"SELECT {} FROM `db` WHERE `db` = ?",
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -196,18 +156,82 @@ pub async fn get_database_privileges(
|
||||||
.bind(database_name)
|
.bind(database_name)
|
||||||
.fetch_all(connection)
|
.fetch_all(connection)
|
||||||
.await
|
.await
|
||||||
.context("Failed to show database")?;
|
}
|
||||||
|
|
||||||
Ok(result)
|
// TODO: merge all rows into a single collection.
|
||||||
|
// they already contain which database they belong to.
|
||||||
|
// no need to index by database name.
|
||||||
|
|
||||||
|
pub type GetDatabasesPrivilegeData =
|
||||||
|
BTreeMap<String, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>;
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum GetDatabasesPrivilegeDataError {
|
||||||
|
SanitizationError(NameValidationError),
|
||||||
|
OwnershipError(OwnerValidationError),
|
||||||
|
DatabaseDoesNotExist,
|
||||||
|
MySqlError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_databases_privilege_data(
|
||||||
|
database_names: Vec<String>,
|
||||||
|
unix_user: &UnixUser,
|
||||||
|
connection: &mut MySqlConnection,
|
||||||
|
) -> GetDatabasesPrivilegeData {
|
||||||
|
let mut results = BTreeMap::new();
|
||||||
|
|
||||||
|
for database_name in database_names.iter() {
|
||||||
|
if let Err(err) = validate_name(database_name) {
|
||||||
|
results.insert(
|
||||||
|
database_name.clone(),
|
||||||
|
Err(GetDatabasesPrivilegeDataError::SanitizationError(err)),
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) {
|
||||||
|
results.insert(
|
||||||
|
database_name.clone(),
|
||||||
|
Err(GetDatabasesPrivilegeDataError::OwnershipError(err)),
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !unsafe_database_exists(database_name, connection)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
{
|
||||||
|
results.insert(
|
||||||
|
database_name.clone(),
|
||||||
|
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist),
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = unsafe_get_database_privileges(database_name, connection)
|
||||||
|
.await
|
||||||
|
.map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string()));
|
||||||
|
|
||||||
|
results.insert(database_name.clone(), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_assert!(database_names.len() == results.len());
|
||||||
|
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type GetAllDatabasesPrivilegeData =
|
||||||
|
Result<Vec<DatabasePrivilegeRow>, GetAllDatabasesPrivilegeDataError>;
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum GetAllDatabasesPrivilegeDataError {
|
||||||
|
MySqlError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all database + user + privileges pairs that are owned by the current user.
|
/// Get all database + user + privileges pairs that are owned by the current user.
|
||||||
pub async fn get_all_database_privileges(
|
pub async fn get_all_database_privileges(
|
||||||
|
unix_user: &UnixUser,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
) -> GetAllDatabasesPrivilegeData {
|
||||||
let unix_user = get_current_unix_user()?;
|
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||||
|
|
||||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
|
||||||
indoc! {r#"
|
indoc! {r#"
|
||||||
SELECT {} FROM `db` WHERE `db` IN
|
SELECT {} FROM `db` WHERE `db` IN
|
||||||
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
||||||
|
@ -217,454 +241,24 @@ pub async fn get_all_database_privileges(
|
||||||
"#},
|
"#},
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| format!("`{field}`"))
|
.map(|field| quote_identifier(field))
|
||||||
.join(","),
|
.join(","),
|
||||||
))
|
))
|
||||||
.bind(create_user_group_matching_regex(&unix_user))
|
.bind(create_user_group_matching_regex(unix_user))
|
||||||
.fetch_all(connection)
|
.fetch_all(connection)
|
||||||
.await
|
.await
|
||||||
.context("Failed to show databases")?;
|
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()))
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*************************/
|
async fn unsafe_apply_privilege_diff(
|
||||||
/* CLI INTERFACE PARSING */
|
database_privilege_diff: &DatabasePrivilegesDiff,
|
||||||
/*************************/
|
|
||||||
|
|
||||||
/// See documentation for [`DatabaseCommand::EditDbPrivs`].
|
|
||||||
pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivilegeRow> {
|
|
||||||
let parts: Vec<&str> = arg.split(':').collect();
|
|
||||||
if parts.len() != 3 {
|
|
||||||
anyhow::bail!("Invalid argument format. See `edit-db-privs --help` for more information.");
|
|
||||||
}
|
|
||||||
|
|
||||||
let db = parts[0].to_string();
|
|
||||||
let user = parts[1].to_string();
|
|
||||||
let privs = parts[2].to_string();
|
|
||||||
|
|
||||||
let mut result = DatabasePrivilegeRow {
|
|
||||||
db,
|
|
||||||
user,
|
|
||||||
select_priv: false,
|
|
||||||
insert_priv: false,
|
|
||||||
update_priv: false,
|
|
||||||
delete_priv: false,
|
|
||||||
create_priv: false,
|
|
||||||
drop_priv: false,
|
|
||||||
alter_priv: false,
|
|
||||||
index_priv: false,
|
|
||||||
create_tmp_table_priv: false,
|
|
||||||
lock_tables_priv: false,
|
|
||||||
references_priv: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
for char in privs.chars() {
|
|
||||||
match char {
|
|
||||||
's' => result.select_priv = true,
|
|
||||||
'i' => result.insert_priv = true,
|
|
||||||
'u' => result.update_priv = true,
|
|
||||||
'd' => result.delete_priv = true,
|
|
||||||
'c' => result.create_priv = true,
|
|
||||||
'D' => result.drop_priv = true,
|
|
||||||
'a' => result.alter_priv = true,
|
|
||||||
'I' => result.index_priv = true,
|
|
||||||
't' => result.create_tmp_table_priv = true,
|
|
||||||
'l' => result.lock_tables_priv = true,
|
|
||||||
'r' => result.references_priv = true,
|
|
||||||
'A' => {
|
|
||||||
result.select_priv = true;
|
|
||||||
result.insert_priv = true;
|
|
||||||
result.update_priv = true;
|
|
||||||
result.delete_priv = true;
|
|
||||||
result.create_priv = true;
|
|
||||||
result.drop_priv = true;
|
|
||||||
result.alter_priv = true;
|
|
||||||
result.index_priv = true;
|
|
||||||
result.create_tmp_table_priv = true;
|
|
||||||
result.lock_tables_priv = true;
|
|
||||||
result.references_priv = true;
|
|
||||||
}
|
|
||||||
_ => anyhow::bail!("Invalid privilege character: {}", char),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**********************************/
|
|
||||||
/* EDITOR CONTENT DISPLAY/DISPLAY */
|
|
||||||
/**********************************/
|
|
||||||
|
|
||||||
/// Generates a single row of the privileges table for the editor.
|
|
||||||
pub fn format_privileges_line_for_editor(
|
|
||||||
privs: &DatabasePrivilegeRow,
|
|
||||||
username_len: usize,
|
|
||||||
database_name_len: usize,
|
|
||||||
) -> String {
|
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.into_iter()
|
|
||||||
.map(|field| match field {
|
|
||||||
"db" => format!("{:width$}", privs.db, width = database_name_len),
|
|
||||||
"user" => format!("{:width$}", privs.user, width = username_len),
|
|
||||||
privilege => format!(
|
|
||||||
"{:width$}",
|
|
||||||
yn(privs.get_privilege_by_name(privilege)),
|
|
||||||
width = db_priv_field_human_readable_name(privilege).len()
|
|
||||||
),
|
|
||||||
})
|
|
||||||
.join(" ")
|
|
||||||
.trim()
|
|
||||||
.to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
const EDITOR_COMMENT: &str = r#"
|
|
||||||
# Welcome to the privilege editor.
|
|
||||||
# Each line defines what privileges a single user has on a single database.
|
|
||||||
# The first two columns respectively represent the database name and the user, and the remaining columns are the privileges.
|
|
||||||
# If the user should have a certain privilege, write 'Y', otherwise write 'N'.
|
|
||||||
#
|
|
||||||
# Lines starting with '#' are comments and will be ignored.
|
|
||||||
"#;
|
|
||||||
|
|
||||||
/// Generates the content for the privilege editor.
|
|
||||||
///
|
|
||||||
/// The unix user is used in case there are no privileges to edit,
|
|
||||||
/// so that the user can see an example line based on their username.
|
|
||||||
pub fn generate_editor_content_from_privilege_data(
|
|
||||||
privilege_data: &[DatabasePrivilegeRow],
|
|
||||||
unix_user: &str,
|
|
||||||
) -> String {
|
|
||||||
let example_user = format!("{}_user", unix_user);
|
|
||||||
let example_db = format!("{}_db", unix_user);
|
|
||||||
|
|
||||||
// NOTE: `.max()`` fails when the iterator is empty.
|
|
||||||
// In this case, we know that the only fields in the
|
|
||||||
// editor will be the example user and example db name.
|
|
||||||
// Hence, it's put as the fallback value, despite not really
|
|
||||||
// being a "fallback" in the normal sense.
|
|
||||||
let longest_username = privilege_data
|
|
||||||
.iter()
|
|
||||||
.map(|p| p.user.len())
|
|
||||||
.max()
|
|
||||||
.unwrap_or(example_user.len());
|
|
||||||
|
|
||||||
let longest_database_name = privilege_data
|
|
||||||
.iter()
|
|
||||||
.map(|p| p.db.len())
|
|
||||||
.max()
|
|
||||||
.unwrap_or(example_db.len());
|
|
||||||
|
|
||||||
let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS
|
|
||||||
.into_iter()
|
|
||||||
.map(db_priv_field_human_readable_name)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Pad the first two columns with spaces to align the privileges.
|
|
||||||
header[0] = format!("{:width$}", header[0], width = longest_database_name);
|
|
||||||
header[1] = format!("{:width$}", header[1], width = longest_username);
|
|
||||||
|
|
||||||
let example_line = format_privileges_line_for_editor(
|
|
||||||
&DatabasePrivilegeRow {
|
|
||||||
db: example_db,
|
|
||||||
user: example_user,
|
|
||||||
select_priv: true,
|
|
||||||
insert_priv: true,
|
|
||||||
update_priv: true,
|
|
||||||
delete_priv: true,
|
|
||||||
create_priv: false,
|
|
||||||
drop_priv: false,
|
|
||||||
alter_priv: false,
|
|
||||||
index_priv: false,
|
|
||||||
create_tmp_table_priv: false,
|
|
||||||
lock_tables_priv: false,
|
|
||||||
references_priv: false,
|
|
||||||
},
|
|
||||||
longest_username,
|
|
||||||
longest_database_name,
|
|
||||||
);
|
|
||||||
|
|
||||||
format!(
|
|
||||||
"{}\n{}\n{}",
|
|
||||||
EDITOR_COMMENT,
|
|
||||||
header.join(" "),
|
|
||||||
if privilege_data.is_empty() {
|
|
||||||
format!("# {}", example_line)
|
|
||||||
} else {
|
|
||||||
privilege_data
|
|
||||||
.iter()
|
|
||||||
.map(|privs| {
|
|
||||||
format_privileges_line_for_editor(
|
|
||||||
privs,
|
|
||||||
longest_username,
|
|
||||||
longest_database_name,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.join("\n")
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum PrivilegeRowParseResult {
|
|
||||||
PrivilegeRow(DatabasePrivilegeRow),
|
|
||||||
ParserError(anyhow::Error),
|
|
||||||
TooFewFields(usize),
|
|
||||||
TooManyFields(usize),
|
|
||||||
Header,
|
|
||||||
Comment,
|
|
||||||
Empty,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn parse_privilege_cell_from_editor(yn: &str, name: &str) -> anyhow::Result<bool> {
|
|
||||||
rev_yn(yn)
|
|
||||||
.ok_or_else(|| anyhow!("Expected Y or N, found {}", yn))
|
|
||||||
.context(format!("Could not parse {} privilege", name))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn editor_row_is_header(row: &str) -> bool {
|
|
||||||
row.split_ascii_whitespace()
|
|
||||||
.zip(DATABASE_PRIVILEGE_FIELDS.iter())
|
|
||||||
.map(|(field, priv_name)| (field, db_priv_field_human_readable_name(priv_name)))
|
|
||||||
.all(|(field, header_field)| field == header_field)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a single row of the privileges table from the editor.
|
|
||||||
fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult {
|
|
||||||
if row.starts_with('#') || row.starts_with("//") {
|
|
||||||
return PrivilegeRowParseResult::Comment;
|
|
||||||
}
|
|
||||||
|
|
||||||
if row.trim().is_empty() {
|
|
||||||
return PrivilegeRowParseResult::Empty;
|
|
||||||
}
|
|
||||||
|
|
||||||
let parts: Vec<&str> = row.trim().split_ascii_whitespace().collect();
|
|
||||||
|
|
||||||
match parts.len() {
|
|
||||||
n if (n < DATABASE_PRIVILEGE_FIELDS.len()) => {
|
|
||||||
return PrivilegeRowParseResult::TooFewFields(n)
|
|
||||||
}
|
|
||||||
n if (n > DATABASE_PRIVILEGE_FIELDS.len()) => {
|
|
||||||
return PrivilegeRowParseResult::TooManyFields(n)
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
if editor_row_is_header(row) {
|
|
||||||
return PrivilegeRowParseResult::Header;
|
|
||||||
}
|
|
||||||
|
|
||||||
let row = DatabasePrivilegeRow {
|
|
||||||
db: (*parts.first().unwrap()).to_owned(),
|
|
||||||
user: (*parts.get(1).unwrap()).to_owned(),
|
|
||||||
select_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(2).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[2],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
insert_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(3).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[3],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
update_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(4).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[4],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
delete_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(5).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[5],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
create_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(6).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[6],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
drop_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(7).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[7],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
alter_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(8).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[8],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
index_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(9).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[9],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
create_tmp_table_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(10).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[10],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
lock_tables_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(11).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[11],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
references_priv: match parse_privilege_cell_from_editor(
|
|
||||||
parts.get(12).unwrap(),
|
|
||||||
DATABASE_PRIVILEGE_FIELDS[12],
|
|
||||||
) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => return PrivilegeRowParseResult::ParserError(e),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
PrivilegeRowParseResult::PrivilegeRow(row)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: return better errors
|
|
||||||
|
|
||||||
pub fn parse_privilege_data_from_editor_content(
|
|
||||||
content: String,
|
|
||||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
|
||||||
content
|
|
||||||
.trim()
|
|
||||||
.split('\n')
|
|
||||||
.map(|line| line.trim())
|
|
||||||
.map(parse_privilege_row_from_editor)
|
|
||||||
.map(|result| match result {
|
|
||||||
PrivilegeRowParseResult::PrivilegeRow(row) => Ok(Some(row)),
|
|
||||||
PrivilegeRowParseResult::ParserError(e) => Err(e),
|
|
||||||
PrivilegeRowParseResult::TooFewFields(n) => Err(anyhow!(
|
|
||||||
"Too few fields in line. Expected to find {} fields, found {}",
|
|
||||||
DATABASE_PRIVILEGE_FIELDS.len(),
|
|
||||||
n
|
|
||||||
)),
|
|
||||||
PrivilegeRowParseResult::TooManyFields(n) => Err(anyhow!(
|
|
||||||
"Too many fields in line. Expected to find {} fields, found {}",
|
|
||||||
DATABASE_PRIVILEGE_FIELDS.len(),
|
|
||||||
n
|
|
||||||
)),
|
|
||||||
PrivilegeRowParseResult::Header => Ok(None),
|
|
||||||
PrivilegeRowParseResult::Comment => Ok(None),
|
|
||||||
PrivilegeRowParseResult::Empty => Ok(None),
|
|
||||||
})
|
|
||||||
.filter_map(|result| result.transpose())
|
|
||||||
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()
|
|
||||||
}
|
|
||||||
|
|
||||||
/*****************************/
|
|
||||||
/* CALCULATE PRIVILEGE DIFFS */
|
|
||||||
/*****************************/
|
|
||||||
|
|
||||||
/// This struct represents encapsulates the differences between two
|
|
||||||
/// instances of privilege sets for a single user on a single database.
|
|
||||||
///
|
|
||||||
/// The `User` and `Database` are the same for both instances.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
|
||||||
pub struct DatabasePrivilegeRowDiff {
|
|
||||||
pub db: String,
|
|
||||||
pub user: String,
|
|
||||||
pub diff: BTreeSet<DatabasePrivilegeChange>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This enum represents a change for a single privilege.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
|
||||||
pub enum DatabasePrivilegeChange {
|
|
||||||
YesToNo(String),
|
|
||||||
NoToYes(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DatabasePrivilegeChange {
|
|
||||||
pub fn new(p1: bool, p2: bool, name: &str) -> Option<DatabasePrivilegeChange> {
|
|
||||||
match (p1, p2) {
|
|
||||||
(true, false) => Some(DatabasePrivilegeChange::YesToNo(name.to_owned())),
|
|
||||||
(false, true) => Some(DatabasePrivilegeChange::NoToYes(name.to_owned())),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
|
||||||
pub enum DatabasePrivilegesDiff {
|
|
||||||
New(DatabasePrivilegeRow),
|
|
||||||
Modified(DatabasePrivilegeRowDiff),
|
|
||||||
Deleted(DatabasePrivilegeRow),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function calculates the differences between two sets of database privileges.
|
|
||||||
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
|
|
||||||
/// apply a set of privilege modifications to the database.
|
|
||||||
pub fn diff_privileges(
|
|
||||||
from: &[DatabasePrivilegeRow],
|
|
||||||
to: &[DatabasePrivilegeRow],
|
|
||||||
) -> BTreeSet<DatabasePrivilegesDiff> {
|
|
||||||
let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter(
|
|
||||||
from.iter()
|
|
||||||
.cloned()
|
|
||||||
.map(|p| ((p.db.clone(), p.user.clone()), p)),
|
|
||||||
);
|
|
||||||
|
|
||||||
let to_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter(
|
|
||||||
to.iter()
|
|
||||||
.cloned()
|
|
||||||
.map(|p| ((p.db.clone(), p.user.clone()), p)),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut result = BTreeSet::new();
|
|
||||||
|
|
||||||
for p in to {
|
|
||||||
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
|
|
||||||
let diff = old_p.diff(p);
|
|
||||||
if !diff.diff.is_empty() {
|
|
||||||
result.insert(DatabasePrivilegesDiff::Modified(diff));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result.insert(DatabasePrivilegesDiff::New(p.clone()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for p in from {
|
|
||||||
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
|
|
||||||
result.insert(DatabasePrivilegesDiff::Deleted(p.clone()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
|
||||||
pub async fn apply_privilege_diffs(
|
|
||||||
diffs: BTreeSet<DatabasePrivilegesDiff>,
|
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> anyhow::Result<()> {
|
) -> Result<(), sqlx::Error> {
|
||||||
for diff in diffs {
|
match database_privilege_diff {
|
||||||
match diff {
|
|
||||||
DatabasePrivilegesDiff::New(p) => {
|
DatabasePrivilegesDiff::New(p) => {
|
||||||
let tables = DATABASE_PRIVILEGE_FIELDS
|
let tables = DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| format!("`{field}`"))
|
.map(|field| quote_identifier(field))
|
||||||
.join(",");
|
.join(",");
|
||||||
|
|
||||||
let question_marks = std::iter::repeat("?")
|
let question_marks = std::iter::repeat("?")
|
||||||
|
@ -674,8 +268,8 @@ pub async fn apply_privilege_diffs(
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
|
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
|
||||||
)
|
)
|
||||||
.bind(p.db)
|
.bind(p.db.to_string())
|
||||||
.bind(p.user)
|
.bind(p.user.to_string())
|
||||||
.bind(yn(p.select_priv))
|
.bind(yn(p.select_priv))
|
||||||
.bind(yn(p.insert_priv))
|
.bind(yn(p.insert_priv))
|
||||||
.bind(yn(p.update_priv))
|
.bind(yn(p.update_priv))
|
||||||
|
@ -687,202 +281,119 @@ pub async fn apply_privilege_diffs(
|
||||||
.bind(yn(p.create_tmp_table_priv))
|
.bind(yn(p.create_tmp_table_priv))
|
||||||
.bind(yn(p.lock_tables_priv))
|
.bind(yn(p.lock_tables_priv))
|
||||||
.bind(yn(p.references_priv))
|
.bind(yn(p.references_priv))
|
||||||
.execute(&mut *connection)
|
.execute(connection)
|
||||||
.await?;
|
.await
|
||||||
|
.map(|_| ())
|
||||||
}
|
}
|
||||||
DatabasePrivilegesDiff::Modified(p) => {
|
DatabasePrivilegesDiff::Modified(p) => {
|
||||||
let tables = p
|
let changes = p
|
||||||
.diff
|
.diff
|
||||||
.iter()
|
.iter()
|
||||||
.map(|diff| match diff {
|
.map(|diff| match diff {
|
||||||
DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name),
|
DatabasePrivilegeChange::YesToNo(name) => {
|
||||||
DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name),
|
format!("{} = 'N'", quote_identifier(name))
|
||||||
|
}
|
||||||
|
DatabasePrivilegeChange::NoToYes(name) => {
|
||||||
|
format!("{} = 'Y'", quote_identifier(name))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.join(",");
|
.join(",");
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
|
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(),
|
||||||
)
|
)
|
||||||
.bind(p.db)
|
.bind(p.db.to_string())
|
||||||
.bind(p.user)
|
.bind(p.user.to_string())
|
||||||
.execute(&mut *connection)
|
.execute(connection)
|
||||||
.await?;
|
.await
|
||||||
|
.map(|_| ())
|
||||||
}
|
}
|
||||||
DatabasePrivilegesDiff::Deleted(p) => {
|
DatabasePrivilegesDiff::Deleted(p) => {
|
||||||
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
|
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
|
||||||
.bind(p.db)
|
.bind(p.db.to_string())
|
||||||
.bind(p.user)
|
.bind(p.user.to_string())
|
||||||
.execute(&mut *connection)
|
.execute(connection)
|
||||||
.await?;
|
.await
|
||||||
}
|
.map(|_| ())
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
|
|
||||||
diff.diff
|
|
||||||
.iter()
|
|
||||||
.map(|change| match change {
|
|
||||||
DatabasePrivilegeChange::YesToNo(name) => {
|
|
||||||
format!("{}: Y -> N", db_priv_field_human_readable_name(name))
|
|
||||||
}
|
|
||||||
DatabasePrivilegeChange::NoToYes(name) => {
|
|
||||||
format!("{}: N -> Y", db_priv_field_human_readable_name(name))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.join("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Displays the difference between two sets of database privileges.
|
|
||||||
pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> String {
|
|
||||||
let mut table = Table::new();
|
|
||||||
table.set_titles(row!["Database", "User", "Privilege diff",]);
|
|
||||||
for row in diffs {
|
|
||||||
match row {
|
|
||||||
DatabasePrivilegesDiff::New(p) => {
|
|
||||||
table.add_row(row![
|
|
||||||
p.db,
|
|
||||||
p.user,
|
|
||||||
"(New user)\n".to_string()
|
|
||||||
+ &display_privilege_cell(
|
|
||||||
&DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p)
|
|
||||||
)
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
DatabasePrivilegesDiff::Modified(p) => {
|
|
||||||
table.add_row(row![p.db, p.user, display_privilege_cell(p),]);
|
|
||||||
}
|
|
||||||
DatabasePrivilegesDiff::Deleted(p) => {
|
|
||||||
table.add_row(row![
|
|
||||||
p.db,
|
|
||||||
p.user,
|
|
||||||
"(All privileges removed)\n".to_string()
|
|
||||||
+ &display_privilege_cell(
|
|
||||||
&p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user))
|
|
||||||
)
|
|
||||||
]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table.to_string()
|
pub type ModifyDatabasePrivilegesOutput =
|
||||||
|
BTreeMap<String, Result<(), ModifyDatabasePrivilegesError>>;
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum ModifyDatabasePrivilegesError {
|
||||||
|
DatabaseSanitizationError(NameValidationError),
|
||||||
|
DatabaseOwnershipError(OwnerValidationError),
|
||||||
|
UserSanitizationError(NameValidationError),
|
||||||
|
UserOwnershipError(OwnerValidationError),
|
||||||
|
DatabaseDoesNotExist,
|
||||||
|
DiffDoesNotApply(DatabasePrivilegeChange),
|
||||||
|
MySqlError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********/
|
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
||||||
/* TESTS */
|
pub async fn apply_privilege_diffs(
|
||||||
/*********/
|
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
|
||||||
|
unix_user: &UnixUser,
|
||||||
|
connection: &mut MySqlConnection,
|
||||||
|
) -> ModifyDatabasePrivilegesOutput {
|
||||||
|
let mut results: BTreeMap<String, _> = BTreeMap::new();
|
||||||
|
|
||||||
#[cfg(test)]
|
for diff in database_privilege_diffs {
|
||||||
mod tests {
|
if let Err(err) = validate_name(diff.get_database_name()) {
|
||||||
use super::*;
|
results.insert(
|
||||||
|
diff.get_database_name().to_string(),
|
||||||
#[test]
|
Err(ModifyDatabasePrivilegesError::DatabaseSanitizationError(
|
||||||
fn test_database_privilege_change_creation() {
|
err,
|
||||||
assert_eq!(
|
)),
|
||||||
DatabasePrivilegeChange::new(true, false, "test"),
|
|
||||||
Some(DatabasePrivilegeChange::YesToNo("test".to_owned()))
|
|
||||||
);
|
);
|
||||||
assert_eq!(
|
continue;
|
||||||
DatabasePrivilegeChange::new(false, true, "test"),
|
}
|
||||||
Some(DatabasePrivilegeChange::NoToYes("test".to_owned()))
|
|
||||||
|
if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) {
|
||||||
|
results.insert(
|
||||||
|
diff.get_database_name().to_string(),
|
||||||
|
Err(ModifyDatabasePrivilegesError::DatabaseOwnershipError(err)),
|
||||||
);
|
);
|
||||||
assert_eq!(DatabasePrivilegeChange::new(true, true, "test"), None);
|
continue;
|
||||||
assert_eq!(DatabasePrivilegeChange::new(false, false, "test"), None);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
if let Err(err) = validate_name(diff.get_user_name()) {
|
||||||
fn test_diff_privileges() {
|
results.insert(
|
||||||
let row_to_be_modified = DatabasePrivilegeRow {
|
diff.get_database_name().to_string(),
|
||||||
db: "db".to_owned(),
|
Err(ModifyDatabasePrivilegesError::UserSanitizationError(err)),
|
||||||
user: "user".to_owned(),
|
|
||||||
select_priv: true,
|
|
||||||
insert_priv: true,
|
|
||||||
update_priv: true,
|
|
||||||
delete_priv: true,
|
|
||||||
create_priv: true,
|
|
||||||
drop_priv: true,
|
|
||||||
alter_priv: true,
|
|
||||||
index_priv: false,
|
|
||||||
create_tmp_table_priv: true,
|
|
||||||
lock_tables_priv: true,
|
|
||||||
references_priv: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut row_to_be_deleted = row_to_be_modified.clone();
|
|
||||||
"user2".clone_into(&mut row_to_be_deleted.user);
|
|
||||||
|
|
||||||
let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()];
|
|
||||||
|
|
||||||
let mut modified_row = row_to_be_modified.clone();
|
|
||||||
modified_row.select_priv = false;
|
|
||||||
modified_row.insert_priv = false;
|
|
||||||
modified_row.index_priv = true;
|
|
||||||
|
|
||||||
let mut new_row = row_to_be_modified.clone();
|
|
||||||
"user3".clone_into(&mut new_row.user);
|
|
||||||
|
|
||||||
let to = vec![modified_row.clone(), new_row.clone()];
|
|
||||||
|
|
||||||
let diffs = diff_privileges(&from, &to);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
diffs,
|
|
||||||
BTreeSet::from_iter(vec![
|
|
||||||
DatabasePrivilegesDiff::Deleted(row_to_be_deleted),
|
|
||||||
DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
|
|
||||||
db: "db".to_owned(),
|
|
||||||
user: "user".to_owned(),
|
|
||||||
diff: BTreeSet::from_iter(vec![
|
|
||||||
DatabasePrivilegeChange::YesToNo("select_priv".to_owned()),
|
|
||||||
DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()),
|
|
||||||
DatabasePrivilegeChange::NoToYes("index_priv".to_owned()),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
DatabasePrivilegesDiff::New(new_row),
|
|
||||||
])
|
|
||||||
);
|
);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) {
|
||||||
fn ensure_generated_and_parsed_editor_content_is_equal() {
|
results.insert(
|
||||||
let permissions = vec![
|
diff.get_database_name().to_string(),
|
||||||
DatabasePrivilegeRow {
|
Err(ModifyDatabasePrivilegesError::UserOwnershipError(err)),
|
||||||
db: "db".to_owned(),
|
);
|
||||||
user: "user".to_owned(),
|
continue;
|
||||||
select_priv: true,
|
|
||||||
insert_priv: true,
|
|
||||||
update_priv: true,
|
|
||||||
delete_priv: true,
|
|
||||||
create_priv: true,
|
|
||||||
drop_priv: true,
|
|
||||||
alter_priv: true,
|
|
||||||
index_priv: true,
|
|
||||||
create_tmp_table_priv: true,
|
|
||||||
lock_tables_priv: true,
|
|
||||||
references_priv: true,
|
|
||||||
},
|
|
||||||
DatabasePrivilegeRow {
|
|
||||||
db: "db2".to_owned(),
|
|
||||||
user: "user2".to_owned(),
|
|
||||||
select_priv: false,
|
|
||||||
insert_priv: false,
|
|
||||||
update_priv: false,
|
|
||||||
delete_priv: false,
|
|
||||||
create_priv: false,
|
|
||||||
drop_priv: false,
|
|
||||||
alter_priv: false,
|
|
||||||
index_priv: false,
|
|
||||||
create_tmp_table_priv: false,
|
|
||||||
lock_tables_priv: false,
|
|
||||||
references_priv: false,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let content = generate_editor_content_from_privilege_data(&permissions, "user");
|
|
||||||
|
|
||||||
let parsed_permissions = parse_privilege_data_from_editor_content(content).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(permissions, parsed_permissions);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !unsafe_database_exists(diff.get_database_name(), connection)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
{
|
||||||
|
results.insert(
|
||||||
|
diff.get_database_name().to_string(),
|
||||||
|
Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist),
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: validate that the diff actually applies to the database
|
||||||
|
|
||||||
|
let result = unsafe_apply_privilege_diff(&diff, connection)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));
|
||||||
|
|
||||||
|
results.insert(diff.get_database_name().to_string(), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
results
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,62 +1,80 @@
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use sqlx::MySqlConnection;
|
use sqlx::MySqlConnection;
|
||||||
use tokio::net::UnixStream;
|
use std::collections::BTreeSet;
|
||||||
use tokio_serde::{formats::Bincode, Framed as SerdeFramed};
|
|
||||||
use tokio_util::codec::{Framed, LengthDelimitedCodec};
|
|
||||||
|
|
||||||
// use crate::server::
|
|
||||||
|
|
||||||
use crate::server::protocol::{Request, Response};
|
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
common::UnixUser, database_operations::{create_databases, drop_databases}, user_operations::{create_database_users, drop_database_users, set_password_for_database_user}
|
common::UnixUser,
|
||||||
|
database_operations::{create_databases, drop_databases, list_databases_for_user},
|
||||||
|
database_privilege_operations::{
|
||||||
|
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
|
||||||
|
},
|
||||||
|
protocol::{Request, Response, ServerToClientMessageStream},
|
||||||
|
user_operations::{
|
||||||
|
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
|
||||||
|
list_database_users, lock_database_users, set_password_for_database_user,
|
||||||
|
unlock_database_users,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub type ClientToServerMessageStream<'a> = SerdeFramed<
|
pub async fn handle_request(
|
||||||
Framed<&'a mut UnixStream, LengthDelimitedCodec>,
|
mut stream: ServerToClientMessageStream,
|
||||||
Request,
|
|
||||||
Response,
|
|
||||||
Bincode<Request, Response>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
pub async fn run_server(
|
|
||||||
socket: &mut UnixStream,
|
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
db_connection: &mut MySqlConnection,
|
db_connection: &mut MySqlConnection,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
|
loop {
|
||||||
let mut stream: ClientToServerMessageStream =
|
|
||||||
tokio_serde::Framed::new(length_delimited, Bincode::default());
|
|
||||||
|
|
||||||
// TODO: better error handling
|
// TODO: better error handling
|
||||||
let request = match stream.next().await {
|
let request = match stream.next().await {
|
||||||
Some(Ok(request)) => request,
|
Some(Ok(request)) => request,
|
||||||
Some(Err(e)) => return Err(e.into()),
|
Some(Err(e)) => return Err(e.into()),
|
||||||
None => return Err(anyhow::anyhow!("No request received")),
|
None => {
|
||||||
|
log::warn!("Client disconnected without sending an exit message");
|
||||||
|
break;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match request {
|
match request {
|
||||||
Request::CreateDatabases(databases) => {
|
Request::CreateDatabases(databases_names) => {
|
||||||
let result = create_databases(databases, unix_user, db_connection).await;
|
let result = create_databases(databases_names, unix_user, db_connection).await;
|
||||||
stream.send(Response::CreateDatabases(result)).await?;
|
stream.send(Response::CreateDatabases(result)).await?;
|
||||||
stream.flush().await?;
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
Request::DropDatabases(databases) => {
|
Request::DropDatabases(databases_names) => {
|
||||||
let result = drop_databases(databases, unix_user, db_connection).await;
|
let result = drop_databases(databases_names, unix_user, db_connection).await;
|
||||||
stream.send(Response::DropDatabases(result)).await?;
|
stream.send(Response::DropDatabases(result)).await?;
|
||||||
stream.flush().await?;
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
Request::ListDatabases => {
|
Request::ListDatabases => {
|
||||||
println!("Listing databases");
|
let result = list_databases_for_user(unix_user, db_connection).await;
|
||||||
// let result = list_databases(unix_user, db_connection).await;
|
stream.send(Response::ListAllDatabases(result)).await?;
|
||||||
// stream.send(Response::ListDatabases(result)).await?;
|
stream.flush().await?;
|
||||||
// stream.flush().await?;
|
|
||||||
}
|
}
|
||||||
Request::ListPrivileges(users) => {
|
Request::ListPrivileges(database_names) => {
|
||||||
println!("Listing privileges for users: {:?}", users);
|
let response = match database_names {
|
||||||
|
Some(database_names) => {
|
||||||
|
let privilege_data =
|
||||||
|
get_databases_privilege_data(database_names, unix_user, db_connection)
|
||||||
|
.await;
|
||||||
|
Response::ListPrivileges(privilege_data)
|
||||||
}
|
}
|
||||||
Request::ModifyPrivileges(()) => {
|
None => {
|
||||||
println!("Modifying privileges");
|
let privilege_data =
|
||||||
|
get_all_database_privileges(unix_user, db_connection).await;
|
||||||
|
Response::ListAllPrivileges(privilege_data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
stream.send(response).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
}
|
||||||
|
Request::ModifyPrivileges(database_privilege_diffs) => {
|
||||||
|
let result = apply_privilege_diffs(
|
||||||
|
BTreeSet::from_iter(database_privilege_diffs),
|
||||||
|
unix_user,
|
||||||
|
db_connection,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
stream.send(Response::ModifyPrivileges(result)).await?;
|
||||||
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
Request::CreateUsers(db_users) => {
|
Request::CreateUsers(db_users) => {
|
||||||
let result = create_database_users(db_users, unix_user, db_connection).await;
|
let result = create_database_users(db_users, unix_user, db_connection).await;
|
||||||
|
@ -70,18 +88,39 @@ pub async fn run_server(
|
||||||
}
|
}
|
||||||
Request::PasswdUser(db_user, password) => {
|
Request::PasswdUser(db_user, password) => {
|
||||||
let result =
|
let result =
|
||||||
set_password_for_database_user(&db_user, &password, unix_user, db_connection).await;
|
set_password_for_database_user(&db_user, &password, unix_user, db_connection)
|
||||||
|
.await;
|
||||||
stream.send(Response::PasswdUser(result)).await?;
|
stream.send(Response::PasswdUser(result)).await?;
|
||||||
stream.flush().await?;
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
Request::ListUsers(db_users) => {
|
Request::ListUsers(db_users) => {
|
||||||
println!("Listing users: {:?}", db_users);
|
let response = match db_users {
|
||||||
|
Some(db_users) => {
|
||||||
|
let result = list_database_users(db_users, unix_user, db_connection).await;
|
||||||
|
Response::ListUsers(result)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let result =
|
||||||
|
list_all_database_users_for_unix_user(unix_user, db_connection).await;
|
||||||
|
Response::ListAllUsers(result)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
stream.send(response).await?;
|
||||||
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
Request::LockUsers(db_users) => {
|
Request::LockUsers(db_users) => {
|
||||||
println!("Locking users: {:?}", db_users);
|
let result = lock_database_users(db_users, unix_user, db_connection).await;
|
||||||
|
stream.send(Response::LockUsers(result)).await?;
|
||||||
|
stream.flush().await?;
|
||||||
}
|
}
|
||||||
Request::UnlockUsers(db_users) => {
|
Request::UnlockUsers(db_users) => {
|
||||||
println!("Unlocking users: {:?}", db_users);
|
let result = unlock_database_users(db_users, unix_user, db_connection).await;
|
||||||
|
stream.send(Response::UnlockUsers(result)).await?;
|
||||||
|
stream.flush().await?;
|
||||||
|
}
|
||||||
|
Request::Exit => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,44 @@
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::net::UnixStream;
|
||||||
|
use tokio_serde::{formats::Bincode, Framed as SerdeFramed};
|
||||||
|
use tokio_util::codec::{Framed, LengthDelimitedCodec};
|
||||||
|
|
||||||
use super::{database_operations::{CreateDatabasesOutput, DropDatabasesOutput}, user_operations::{
|
use crate::core::database_privileges::DatabasePrivilegesDiff;
|
||||||
CreateUsersOutput, DropUsersOutput, LockUsersOutput, SetPasswordOutput, UnlockUsersOutput,
|
|
||||||
}};
|
use super::{
|
||||||
|
database_operations::{CreateDatabasesOutput, DropDatabasesOutput, ListAllDatabasesOutput},
|
||||||
|
database_privilege_operations::{
|
||||||
|
GetAllDatabasesPrivilegeData, GetDatabasesPrivilegeData, ModifyDatabasePrivilegesOutput,
|
||||||
|
},
|
||||||
|
user_operations::{
|
||||||
|
CreateUsersOutput, DropUsersOutput, ListAllUsersOutput, ListUsersOutput, LockUsersOutput,
|
||||||
|
SetPasswordOutput, UnlockUsersOutput,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub type ServerToClientMessageStream = SerdeFramed<
|
||||||
|
Framed<UnixStream, LengthDelimitedCodec>,
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
Bincode<Request, Response>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
pub type ClientToServerMessageStream = SerdeFramed<
|
||||||
|
Framed<UnixStream, LengthDelimitedCodec>,
|
||||||
|
Response,
|
||||||
|
Request,
|
||||||
|
Bincode<Response, Request>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream {
|
||||||
|
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
|
||||||
|
tokio_serde::Framed::new(length_delimited, Bincode::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream {
|
||||||
|
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
|
||||||
|
tokio_serde::Framed::new(length_delimited, Bincode::default())
|
||||||
|
}
|
||||||
|
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
@ -10,8 +46,8 @@ pub enum Request {
|
||||||
CreateDatabases(Vec<String>),
|
CreateDatabases(Vec<String>),
|
||||||
DropDatabases(Vec<String>),
|
DropDatabases(Vec<String>),
|
||||||
ListDatabases,
|
ListDatabases,
|
||||||
ListPrivileges(Vec<String>),
|
ListPrivileges(Option<Vec<String>>),
|
||||||
ModifyPrivileges(()), // what data should be sent with this command? Who should calculate the diff?
|
ModifyPrivileges(Vec<DatabasePrivilegesDiff>),
|
||||||
|
|
||||||
CreateUsers(Vec<String>),
|
CreateUsers(Vec<String>),
|
||||||
DropUsers(Vec<String>),
|
DropUsers(Vec<String>),
|
||||||
|
@ -19,6 +55,9 @@ pub enum Request {
|
||||||
ListUsers(Option<Vec<String>>),
|
ListUsers(Option<Vec<String>>),
|
||||||
LockUsers(Vec<String>),
|
LockUsers(Vec<String>),
|
||||||
UnlockUsers(Vec<String>),
|
UnlockUsers(Vec<String>),
|
||||||
|
|
||||||
|
// Commit,
|
||||||
|
Exit,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: include a generic "message" that will display a message to the user?
|
// TODO: include a generic "message" that will display a message to the user?
|
||||||
|
@ -29,12 +68,16 @@ pub enum Response {
|
||||||
// Specific data for specific commands
|
// Specific data for specific commands
|
||||||
CreateDatabases(CreateDatabasesOutput),
|
CreateDatabases(CreateDatabasesOutput),
|
||||||
DropDatabases(DropDatabasesOutput),
|
DropDatabases(DropDatabasesOutput),
|
||||||
// ListDatabases(ListDatabasesOutput),
|
ListAllDatabases(ListAllDatabasesOutput),
|
||||||
// ListPrivileges(ListPrivilegesOutput),
|
ListPrivileges(GetDatabasesPrivilegeData),
|
||||||
|
ListAllPrivileges(GetAllDatabasesPrivilegeData),
|
||||||
|
ModifyPrivileges(ModifyDatabasePrivilegesOutput),
|
||||||
|
|
||||||
CreateUsers(CreateUsersOutput),
|
CreateUsers(CreateUsersOutput),
|
||||||
DropUsers(DropUsersOutput),
|
DropUsers(DropUsersOutput),
|
||||||
PasswdUser(SetPasswordOutput),
|
PasswdUser(SetPasswordOutput),
|
||||||
ListUsers(()), // what data should be sent with this response?
|
ListUsers(ListUsersOutput),
|
||||||
|
ListAllUsers(ListAllUsersOutput),
|
||||||
LockUsers(LockUsersOutput),
|
LockUsers(LockUsersOutput),
|
||||||
UnlockUsers(UnlockUsersOutput),
|
UnlockUsers(UnlockUsersOutput),
|
||||||
|
|
||||||
|
|
|
@ -272,7 +272,7 @@ pub enum UnlockUserError {
|
||||||
MySqlError(String),
|
MySqlError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn unlock_database_user(
|
pub async fn unlock_database_users(
|
||||||
db_users: Vec<String>,
|
db_users: Vec<String>,
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
|
@ -362,34 +362,65 @@ JOIN `mysql`.`global_priv` ON
|
||||||
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
|
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
pub async fn get_all_database_users_for_unix_user(
|
pub type ListUsersOutput = BTreeMap<String, Result<DatabaseUser, ListUsersError>>;
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum ListUsersError {
|
||||||
|
SanitizationError(NameValidationError),
|
||||||
|
OwnershipError(OwnerValidationError),
|
||||||
|
UserDoesNotExist,
|
||||||
|
MySqlError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_database_users(
|
||||||
|
db_users: Vec<String>,
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<Vec<DatabaseUser>, sqlx::Error> {
|
) -> ListUsersOutput {
|
||||||
|
let mut results = BTreeMap::new();
|
||||||
|
|
||||||
|
for db_user in db_users {
|
||||||
|
if let Err(err) = validate_name(&db_user) {
|
||||||
|
results.insert(db_user, Err(ListUsersError::SanitizationError(err)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
|
||||||
|
results.insert(db_user, Err(ListUsersError::OwnershipError(err)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = sqlx::query_as::<_, DatabaseUser>(
|
||||||
|
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
|
||||||
|
)
|
||||||
|
.bind(&db_user)
|
||||||
|
.fetch_optional(&mut *connection)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Some(user)) => results.insert(db_user, Ok(user)),
|
||||||
|
Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)),
|
||||||
|
Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type ListAllUsersOutput = Result<Vec<DatabaseUser>, ListAllUsersError>;
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum ListAllUsersError {
|
||||||
|
MySqlError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_all_database_users_for_unix_user(
|
||||||
|
unix_user: &UnixUser,
|
||||||
|
connection: &mut MySqlConnection,
|
||||||
|
) -> ListAllUsersOutput {
|
||||||
sqlx::query_as::<_, DatabaseUser>(
|
sqlx::query_as::<_, DatabaseUser>(
|
||||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
|
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
|
||||||
)
|
)
|
||||||
.bind(create_user_group_matching_regex(unix_user))
|
.bind(create_user_group_matching_regex(unix_user))
|
||||||
.fetch_all(connection)
|
.fetch_all(connection)
|
||||||
.await
|
.await
|
||||||
|
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// /// This function fetches a database user if it exists.
|
|
||||||
// pub async fn get_database_user_for_user(
|
|
||||||
// username: &str,
|
|
||||||
// connection: &mut MySqlConnection,
|
|
||||||
// ) -> anyhow::Result<Option<DatabaseUser>> {
|
|
||||||
// let user = sqlx::query_as::<_, DatabaseUser>(
|
|
||||||
// &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
|
|
||||||
// )
|
|
||||||
// .bind(username)
|
|
||||||
// .fetch_optional(connection)
|
|
||||||
// .await?;
|
|
||||||
|
|
||||||
// Ok(user)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// /// NOTE: It is very critical that this function validates the database name
|
|
||||||
// /// properly. MySQL does not seem to allow for prepared statements, binding
|
|
||||||
// /// the database name as a parameter to the query. This means that we have
|
|
||||||
// /// to validate the database name ourselves to prevent SQL injection.
|
|
||||||
|
|
Loading…
Reference in New Issue