This commit is contained in:
Oystein Kristoffer Tveit 2024-08-10 15:15:56 +02:00
parent 2c8c16e5bf
commit 7b2aa0f1e1
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
18 changed files with 1060 additions and 2313 deletions

1
src/bootstrap.rs Normal file
View File

@ -0,0 +1 @@
pub mod authenticated_unix_socket;

View File

View File

View File

@ -1,3 +1,4 @@
pub mod database_command; // pub mod database_command;
pub mod mysql_admutils_compatibility; // pub mod mysql_admutils_compatibility;
pub mod user_command; mod common;
pub mod user_command;

63
src/cli/common.rs Normal file
View File

@ -0,0 +1,63 @@
/// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
}
}
#[inline]
pub(crate) fn yn(b: bool) -> &'static str {
if b {
"Y"
} else {
"N"
}
}
#[inline]
pub(crate) fn rev_yn(s: &str) -> Option<bool> {
match s.to_lowercase().as_str() {
"y" => Some(true),
"n" => Some(false),
_ => None,
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_yn() {
assert_eq!(yn(true), "Y");
assert_eq!(yn(false), "N");
}
#[test]
fn test_rev_yn() {
assert_eq!(rev_yn("Y"), Some(true));
assert_eq!(rev_yn("y"), Some(true));
assert_eq!(rev_yn("N"), Some(false));
assert_eq!(rev_yn("n"), Some(false));
assert_eq!(rev_yn("X"), None);
}
}

View File

@ -1,390 +1,390 @@
use anyhow::Context; // use anyhow::Context;
use clap::Parser; // use clap::Parser;
use dialoguer::{Confirm, Editor}; // use dialoguer::{Confirm, Editor};
use prettytable::{Cell, Row, Table}; // use prettytable::{Cell, Row, Table};
use sqlx::{Connection, MySqlConnection}; // use sqlx::{Connection, MySqlConnection};
use crate::core::{ // use crate::core::{
common::{close_database_connection, get_current_unix_user, yn, CommandStatus}, // common::{close_database_connection, get_current_unix_user, yn, CommandStatus},
database_operations::*, // database_operations::*,
database_privilege_operations::*, // database_privilege_operations::*,
user_operations::user_exists, // user_operations::user_exists,
}; // };
#[derive(Parser)] // #[derive(Parser)]
// #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))] // // #[command(next_help_heading = Some(DATABASE_COMMAND_HEADER))]
pub enum DatabaseCommand { // pub enum DatabaseCommand {
/// Create one or more databases // /// Create one or more databases
#[command()] // #[command()]
CreateDb(DatabaseCreateArgs), // CreateDb(DatabaseCreateArgs),
/// Delete one or more databases // /// Delete one or more databases
#[command()] // #[command()]
DropDb(DatabaseDropArgs), // DropDb(DatabaseDropArgs),
/// List all databases you have access to // /// List all databases you have access to
#[command()] // #[command()]
ListDb(DatabaseListArgs), // ListDb(DatabaseListArgs),
/// List user privileges for one or more databases // /// List user privileges for one or more databases
/// // ///
/// If no database names are provided, it will show privileges for all databases you have access to. // /// If no database names are provided, it will show privileges for all databases you have access to.
#[command()] // #[command()]
ShowDbPrivs(DatabaseShowPrivsArgs), // ShowDbPrivs(DatabaseShowPrivsArgs),
/// Change user privileges for one or more databases. See `edit-db-privs --help` for details. // /// Change user privileges for one or more databases. See `edit-db-privs --help` for details.
/// // ///
/// This command has two modes of operation: // /// This command has two modes of operation:
/// // ///
/// 1. Interactive mode: If nothing else is specified, the user will be prompted to edit the privileges using a text editor. // /// 1. Interactive mode: If nothing else is specified, the user will be prompted to edit the privileges using a text editor.
/// // ///
/// You can configure your preferred text editor by setting the `VISUAL` or `EDITOR` environment variables. // /// You can configure your preferred text editor by setting the `VISUAL` or `EDITOR` environment variables.
/// // ///
/// Follow the instructions inside the editor for more information. // /// Follow the instructions inside the editor for more information.
/// // ///
/// 2. Non-interactive mode: If the `-p` flag is specified, the user can write privileges using arguments. // /// 2. Non-interactive mode: If the `-p` flag is specified, the user can write privileges using arguments.
/// // ///
/// The privilege arguments should be formatted as `<db>:<user>:<privileges>` // /// The privilege arguments should be formatted as `<db>:<user>:<privileges>`
/// where the privileges are a string of characters, each representing a single privilege. // /// where the privileges are a string of characters, each representing a single privilege.
/// The character `A` is an exception - it represents all privileges. // /// The character `A` is an exception - it represents all privileges.
/// // ///
/// The character-to-privilege mapping is defined as follows: // /// The character-to-privilege mapping is defined as follows:
/// // ///
/// - `s` - SELECT // /// - `s` - SELECT
/// - `i` - INSERT // /// - `i` - INSERT
/// - `u` - UPDATE // /// - `u` - UPDATE
/// - `d` - DELETE // /// - `d` - DELETE
/// - `c` - CREATE // /// - `c` - CREATE
/// - `D` - DROP // /// - `D` - DROP
/// - `a` - ALTER // /// - `a` - ALTER
/// - `I` - INDEX // /// - `I` - INDEX
/// - `t` - CREATE TEMPORARY TABLES // /// - `t` - CREATE TEMPORARY TABLES
/// - `l` - LOCK TABLES // /// - `l` - LOCK TABLES
/// - `r` - REFERENCES // /// - `r` - REFERENCES
/// - `A` - ALL PRIVILEGES // /// - `A` - ALL PRIVILEGES
/// // ///
/// If you provide a database name, you can omit it from the privilege string, // /// If you provide a database name, you can omit it from the privilege string,
/// e.g. `edit-db-privs my_db -p my_user:siu` is equivalent to `edit-db-privs -p my_db:my_user:siu`. // /// e.g. `edit-db-privs my_db -p my_user:siu` is equivalent to `edit-db-privs -p my_db:my_user:siu`.
/// While it doesn't make much of a difference for a single edit, it can be useful for editing multiple users // /// While it doesn't make much of a difference for a single edit, it can be useful for editing multiple users
/// on the same database at once. // /// on the same database at once.
/// // ///
/// Example usage of non-interactive mode: // /// Example usage of non-interactive mode:
/// // ///
/// Enable privileges `SELECT`, `INSERT`, and `UPDATE` for user `my_user` on database `my_db`: // /// Enable privileges `SELECT`, `INSERT`, and `UPDATE` for user `my_user` on database `my_db`:
/// // ///
/// `mysqladm edit-db-privs -p my_db:my_user:siu` // /// `mysqladm edit-db-privs -p my_db:my_user:siu`
/// // ///
/// Enable all privileges for user `my_other_user` on database `my_other_db`: // /// Enable all privileges for user `my_other_user` on database `my_other_db`:
/// // ///
/// `mysqladm edit-db-privs -p my_other_db:my_other_user:A` // /// `mysqladm edit-db-privs -p my_other_db:my_other_user:A`
/// // ///
/// Set miscellaneous privileges for multiple users on database `my_db`: // /// Set miscellaneous privileges for multiple users on database `my_db`:
/// // ///
/// `mysqladm edit-db-privs my_db -p my_user:siu my_other_user:ct`` // /// `mysqladm edit-db-privs my_db -p my_user:siu my_other_user:ct``
/// // ///
#[command(verbatim_doc_comment)] // #[command(verbatim_doc_comment)]
EditDbPrivs(DatabaseEditPrivsArgs), // EditDbPrivs(DatabaseEditPrivsArgs),
} // }
#[derive(Parser)] // #[derive(Parser)]
pub struct DatabaseCreateArgs { // pub struct DatabaseCreateArgs {
/// The name of the database(s) to create. // /// The name of the database(s) to create.
#[arg(num_args = 1..)] // #[arg(num_args = 1..)]
name: Vec<String>, // name: Vec<String>,
} // }
#[derive(Parser)] // #[derive(Parser)]
pub struct DatabaseDropArgs { // pub struct DatabaseDropArgs {
/// The name of the database(s) to drop. // /// The name of the database(s) to drop.
#[arg(num_args = 1..)] // #[arg(num_args = 1..)]
name: Vec<String>, // name: Vec<String>,
} // }
#[derive(Parser)] // #[derive(Parser)]
pub struct DatabaseListArgs { // pub struct DatabaseListArgs {
/// Whether to output the information in JSON format. // /// Whether to output the information in JSON format.
#[arg(short, long)] // #[arg(short, long)]
json: bool, // json: bool,
} // }
#[derive(Parser)] // #[derive(Parser)]
pub struct DatabaseShowPrivsArgs { // pub struct DatabaseShowPrivsArgs {
/// The name of the database(s) to show. // /// The name of the database(s) to show.
#[arg(num_args = 0..)] // #[arg(num_args = 0..)]
name: Vec<String>, // name: Vec<String>,
/// Whether to output the information in JSON format. // /// Whether to output the information in JSON format.
#[arg(short, long)] // #[arg(short, long)]
json: bool, // json: bool,
} // }
#[derive(Parser)] // #[derive(Parser)]
pub struct DatabaseEditPrivsArgs { // pub struct DatabaseEditPrivsArgs {
/// The name of the database to edit privileges for. // /// The name of the database to edit privileges for.
pub name: Option<String>, // pub name: Option<String>,
#[arg(short, long, value_name = "[DATABASE:]USER:PRIVILEGES", num_args = 0..)] // #[arg(short, long, value_name = "[DATABASE:]USER:PRIVILEGES", num_args = 0..)]
pub privs: Vec<String>, // pub privs: Vec<String>,
/// Whether to output the information in JSON format. // /// Whether to output the information in JSON format.
#[arg(short, long)] // #[arg(short, long)]
pub json: bool, // pub json: bool,
/// Specify the text editor to use for editing privileges // /// Specify the text editor to use for editing privileges
#[arg(short, long)] // #[arg(short, long)]
pub editor: Option<String>, // pub editor: Option<String>,
/// Disable interactive confirmation before saving changes. // /// Disable interactive confirmation before saving changes.
#[arg(short, long)] // #[arg(short, long)]
pub yes: bool, // pub yes: bool,
} // }
pub async fn handle_command( // pub async fn handle_command(
command: DatabaseCommand, // command: DatabaseCommand,
mut connection: MySqlConnection, // mut connection: MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
let result = connection // let result = connection
.transaction(|txn| { // .transaction(|txn| {
Box::pin(async move { // Box::pin(async move {
match command { // match command {
DatabaseCommand::CreateDb(args) => create_databases(args, txn).await, // DatabaseCommand::CreateDb(args) => create_databases(args, txn).await,
DatabaseCommand::DropDb(args) => drop_databases(args, txn).await, // DatabaseCommand::DropDb(args) => drop_databases(args, txn).await,
DatabaseCommand::ListDb(args) => list_databases(args, txn).await, // DatabaseCommand::ListDb(args) => list_databases(args, txn).await,
DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await, // DatabaseCommand::ShowDbPrivs(args) => show_database_privileges(args, txn).await,
DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await, // DatabaseCommand::EditDbPrivs(args) => edit_privileges(args, txn).await,
} // }
}) // })
}) // })
.await; // .await;
close_database_connection(connection).await; // close_database_connection(connection).await;
result // result
} // }
async fn create_databases( // async fn create_databases(
args: DatabaseCreateArgs, // args: DatabaseCreateArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
if args.name.is_empty() { // if args.name.is_empty() {
anyhow::bail!("No database names provided"); // anyhow::bail!("No database names provided");
} // }
let mut result = CommandStatus::SuccessfullyModified; // let mut result = CommandStatus::SuccessfullyModified;
for name in args.name { // for name in args.name {
// TODO: This can be optimized by fetching all the database privileges in one query. // // TODO: This can be optimized by fetching all the database privileges in one query.
if let Err(e) = create_database(&name, connection).await { // if let Err(e) = create_database(&name, connection).await {
eprintln!("Failed to create database '{}': {}", name, e); // eprintln!("Failed to create database '{}': {}", name, e);
eprintln!("Skipping..."); // eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified; // result = CommandStatus::PartiallySuccessfullyModified;
} else { // } else {
println!("Database '{}' created.", name); // println!("Database '{}' created.", name);
} // }
} // }
Ok(result) // Ok(result)
} // }
async fn drop_databases( // async fn drop_databases(
args: DatabaseDropArgs, // args: DatabaseDropArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
if args.name.is_empty() { // if args.name.is_empty() {
anyhow::bail!("No database names provided"); // anyhow::bail!("No database names provided");
} // }
let mut result = CommandStatus::SuccessfullyModified; // let mut result = CommandStatus::SuccessfullyModified;
for name in args.name { // for name in args.name {
// TODO: This can be optimized by fetching all the database privileges in one query. // // TODO: This can be optimized by fetching all the database privileges in one query.
if let Err(e) = drop_database(&name, connection).await { // if let Err(e) = drop_database(&name, connection).await {
eprintln!("Failed to drop database '{}': {}", name, e); // eprintln!("Failed to drop database '{}': {}", name, e);
eprintln!("Skipping..."); // eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified; // result = CommandStatus::PartiallySuccessfullyModified;
} else { // } else {
println!("Database '{}' dropped.", name); // println!("Database '{}' dropped.", name);
} // }
} // }
Ok(result) // Ok(result)
} // }
async fn list_databases( // async fn list_databases(
args: DatabaseListArgs, // args: DatabaseListArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
let databases = get_database_list(connection).await?; // let databases = get_database_list(connection).await?;
if args.json { // if args.json {
println!("{}", serde_json::to_string_pretty(&databases)?); // println!("{}", serde_json::to_string_pretty(&databases)?);
return Ok(CommandStatus::NoModificationsIntended); // return Ok(CommandStatus::NoModificationsIntended);
} // }
if databases.is_empty() { // if databases.is_empty() {
println!("No databases to show."); // println!("No databases to show.");
} else { // } else {
for db in databases { // for db in databases {
println!("{}", db); // println!("{}", db);
} // }
} // }
Ok(CommandStatus::NoModificationsIntended) // Ok(CommandStatus::NoModificationsIntended)
} // }
async fn show_database_privileges( // async fn show_database_privileges(
args: DatabaseShowPrivsArgs, // args: DatabaseShowPrivsArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
let database_users_to_show = if args.name.is_empty() { // let database_users_to_show = if args.name.is_empty() {
get_all_database_privileges(connection).await? // get_all_database_privileges(connection).await?
} else { // } else {
// TODO: This can be optimized by fetching all the database privileges in one query. // // TODO: This can be optimized by fetching all the database privileges in one query.
let mut result = Vec::with_capacity(args.name.len()); // let mut result = Vec::with_capacity(args.name.len());
for name in args.name { // for name in args.name {
match get_database_privileges(&name, connection).await { // match get_database_privileges(&name, connection).await {
Ok(db) => result.extend(db), // Ok(db) => result.extend(db),
Err(e) => { // Err(e) => {
eprintln!("Failed to show database '{}': {}", name, e); // eprintln!("Failed to show database '{}': {}", name, e);
eprintln!("Skipping..."); // eprintln!("Skipping...");
} // }
} // }
} // }
result // result
}; // };
if args.json { // if args.json {
println!("{}", serde_json::to_string_pretty(&database_users_to_show)?); // println!("{}", serde_json::to_string_pretty(&database_users_to_show)?);
return Ok(CommandStatus::NoModificationsIntended); // return Ok(CommandStatus::NoModificationsIntended);
} // }
if database_users_to_show.is_empty() { // if database_users_to_show.is_empty() {
println!("No database users to show."); // println!("No database users to show.");
} else { // } else {
let mut table = Table::new(); // let mut table = Table::new();
table.add_row(Row::new( // table.add_row(Row::new(
DATABASE_PRIVILEGE_FIELDS // DATABASE_PRIVILEGE_FIELDS
.into_iter() // .into_iter()
.map(db_priv_field_human_readable_name) // .map(db_priv_field_human_readable_name)
.map(|name| Cell::new(&name)) // .map(|name| Cell::new(&name))
.collect(), // .collect(),
)); // ));
for row in database_users_to_show { // for row in database_users_to_show {
table.add_row(row![ // table.add_row(row![
row.db, // row.db,
row.user, // row.user,
c->yn(row.select_priv), // c->yn(row.select_priv),
c->yn(row.insert_priv), // c->yn(row.insert_priv),
c->yn(row.update_priv), // c->yn(row.update_priv),
c->yn(row.delete_priv), // c->yn(row.delete_priv),
c->yn(row.create_priv), // c->yn(row.create_priv),
c->yn(row.drop_priv), // c->yn(row.drop_priv),
c->yn(row.alter_priv), // c->yn(row.alter_priv),
c->yn(row.index_priv), // c->yn(row.index_priv),
c->yn(row.create_tmp_table_priv), // c->yn(row.create_tmp_table_priv),
c->yn(row.lock_tables_priv), // c->yn(row.lock_tables_priv),
c->yn(row.references_priv), // c->yn(row.references_priv),
]); // ]);
} // }
table.printstd(); // table.printstd();
} // }
Ok(CommandStatus::NoModificationsIntended) // Ok(CommandStatus::NoModificationsIntended)
} // }
pub async fn edit_privileges( // pub async fn edit_privileges(
args: DatabaseEditPrivsArgs, // args: DatabaseEditPrivsArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
let privilege_data = if let Some(name) = &args.name { // let privilege_data = if let Some(name) = &args.name {
get_database_privileges(name, connection).await? // get_database_privileges(name, connection).await?
} else { // } else {
get_all_database_privileges(connection).await? // get_all_database_privileges(connection).await?
}; // };
// TODO: The data from args should not be absolute. // // TODO: The data from args should not be absolute.
// In the current implementation, the user would need to // // In the current implementation, the user would need to
// provide all privileges for all users on all databases. // // provide all privileges for all users on all databases.
// The intended effect is to modify the privileges which have // // The intended effect is to modify the privileges which have
// matching users and databases, as well as add any // // matching users and databases, as well as add any
// new db-user pairs. This makes it impossible to remove // // new db-user pairs. This makes it impossible to remove
// privileges, but that is an issue for another day. // // privileges, but that is an issue for another day.
let privileges_to_change = if !args.privs.is_empty() { // let privileges_to_change = if !args.privs.is_empty() {
parse_privilege_tables_from_args(&args)? // parse_privilege_tables_from_args(&args)?
} else { // } else {
edit_privileges_with_editor(&privilege_data)? // edit_privileges_with_editor(&privilege_data)?
}; // };
for row in privileges_to_change.iter() { // for row in privileges_to_change.iter() {
if !user_exists(&row.user, connection).await? { // if !user_exists(&row.user, connection).await? {
// TODO: allow user to return and correct their mistake // // TODO: allow user to return and correct their mistake
anyhow::bail!("User {} does not exist", row.user); // anyhow::bail!("User {} does not exist", row.user);
} // }
} // }
let diffs = diff_privileges(&privilege_data, &privileges_to_change); // let diffs = diff_privileges(&privilege_data, &privileges_to_change);
if diffs.is_empty() { // if diffs.is_empty() {
println!("No changes to make."); // println!("No changes to make.");
return Ok(CommandStatus::NoModificationsNeeded); // return Ok(CommandStatus::NoModificationsNeeded);
} // }
println!("The following changes will be made:\n"); // println!("The following changes will be made:\n");
println!("{}", display_privilege_diffs(&diffs)); // println!("{}", display_privilege_diffs(&diffs));
if !args.yes // if !args.yes
&& !Confirm::new() // && !Confirm::new()
.with_prompt("Do you want to apply these changes?") // .with_prompt("Do you want to apply these changes?")
.default(false) // .default(false)
.show_default(true) // .show_default(true)
.interact()? // .interact()?
{ // {
return Ok(CommandStatus::Cancelled); // return Ok(CommandStatus::Cancelled);
} // }
apply_privilege_diffs(diffs, connection).await?; // apply_privilege_diffs(diffs, connection).await?;
Ok(CommandStatus::SuccessfullyModified) // Ok(CommandStatus::SuccessfullyModified)
} // }
pub fn parse_privilege_tables_from_args( // pub fn parse_privilege_tables_from_args(
args: &DatabaseEditPrivsArgs, // args: &DatabaseEditPrivsArgs,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { // ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
debug_assert!(!args.privs.is_empty()); // debug_assert!(!args.privs.is_empty());
let result = if let Some(name) = &args.name { // let result = if let Some(name) = &args.name {
args.privs // args.privs
.iter() // .iter()
.map(|p| { // .map(|p| {
parse_privilege_table_cli_arg(&format!("{}:{}", name, &p)) // parse_privilege_table_cli_arg(&format!("{}:{}", name, &p))
.context(format!("Failed parsing database privileges: `{}`", &p)) // .context(format!("Failed parsing database privileges: `{}`", &p))
}) // })
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()? // .collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()?
} else { // } else {
args.privs // args.privs
.iter() // .iter()
.map(|p| { // .map(|p| {
parse_privilege_table_cli_arg(p) // parse_privilege_table_cli_arg(p)
.context(format!("Failed parsing database privileges: `{}`", &p)) // .context(format!("Failed parsing database privileges: `{}`", &p))
}) // })
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()? // .collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()?
}; // };
Ok(result) // Ok(result)
} // }
pub fn edit_privileges_with_editor( // pub fn edit_privileges_with_editor(
privilege_data: &[DatabasePrivilegeRow], // privilege_data: &[DatabasePrivilegeRow],
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { // ) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = get_current_unix_user()?; // let unix_user = get_current_unix_user()?;
let editor_content = // let editor_content =
generate_editor_content_from_privilege_data(privilege_data, &unix_user.name); // generate_editor_content_from_privilege_data(privilege_data, &unix_user.name);
// TODO: handle errors better here // // TODO: handle errors better here
let result = Editor::new() // let result = Editor::new()
.extension("tsv") // .extension("tsv")
.edit(&editor_content)? // .edit(&editor_content)?
.unwrap(); // .unwrap();
parse_privilege_data_from_editor_content(result) // parse_privilege_data_from_editor_content(result)
.context("Could not parse privilege data from editor") // .context("Could not parse privilege data from editor")
} // }

View File

@ -4,15 +4,26 @@ use std::vec;
use anyhow::Context; use anyhow::Context;
use clap::Parser; use clap::Parser;
use dialoguer::{Confirm, Password}; use dialoguer::{Confirm, Password};
use prettytable::Table;
use serde_json::json;
use sqlx::{Connection, MySqlConnection}; use sqlx::{Connection, MySqlConnection};
use tokio::net::UnixStream;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tokio_serde::{formats::Bincode, Framed as SerdeFramed};
use futures_util::{SinkExt, StreamExt};
use crate::core::{ use crate::server::{Request, Response};
common::{close_database_connection, get_current_unix_user, CommandStatus},
database_operations::*, // use crate::core::{
user_operations::*, // common::{close_database_connection, get_current_unix_user, CommandStatus},
}; // database_operations::*,
// user_operations::*,
// };
pub type ServerToClientMessageStream<'a> = SerdeFramed<
Framed<&'a mut UnixStream, LengthDelimitedCodec>,
Response,
Request,
Bincode<Response, Request>,
>;
#[derive(Parser)] #[derive(Parser)]
pub struct UserArgs { pub struct UserArgs {
@ -26,28 +37,27 @@ pub enum UserCommand {
/// Create one or more users /// Create one or more users
#[command()] #[command()]
CreateUser(UserCreateArgs), CreateUser(UserCreateArgs),
// /// Delete one or more users
// #[command()]
// DropUser(UserDeleteArgs),
/// Delete one or more users // /// Change the MySQL password for a user
#[command()] // #[command()]
DropUser(UserDeleteArgs), // PasswdUser(UserPasswdArgs),
/// Change the MySQL password for a user // /// Give information about one or more users
#[command()] // ///
PasswdUser(UserPasswdArgs), // /// If no username is provided, all users you have access will be shown.
// #[command()]
// ShowUser(UserShowArgs),
/// Give information about one or more users // /// Lock account for one or more users
/// // #[command()]
/// If no username is provided, all users you have access will be shown. // LockUser(UserLockArgs),
#[command()]
ShowUser(UserShowArgs),
/// Lock account for one or more users // /// Unlock account for one or more users
#[command()] // #[command()]
LockUser(UserLockArgs), // UnlockUser(UserUnlockArgs),
/// Unlock account for one or more users
#[command()]
UnlockUser(UserUnlockArgs),
} }
#[derive(Parser)] #[derive(Parser)]
@ -95,50 +105,77 @@ pub struct UserUnlockArgs {
username: Vec<String>, username: Vec<String>,
} }
pub async fn handle_command( pub async fn handle_command<'a>(
command: UserCommand, command: UserCommand,
mut connection: MySqlConnection, server_connection: &mut ServerToClientMessageStream<'a>,
) -> anyhow::Result<CommandStatus> { // mut connection: MySqlConnection,
let result = connection ) -> anyhow::Result<()> {
.transaction(|txn| { match command {
Box::pin(async move { UserCommand::CreateUser(args) => create_users(args, server_connection).await,
match command { _ => todo!(),
UserCommand::CreateUser(args) => create_users(args, txn).await, // let result = connection
UserCommand::DropUser(args) => drop_users(args, txn).await, // .transaction(|txn| {
UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await, // Box::pin(async move {
UserCommand::ShowUser(args) => show_users(args, txn).await, // match command {
UserCommand::LockUser(args) => lock_users(args, txn).await, // UserCommand::CreateUser(args) => create_users(args, txn).await,
UserCommand::UnlockUser(args) => unlock_users(args, txn).await, // // UserCommand::DropUser(args) => drop_users(args, txn).await,
} // // UserCommand::PasswdUser(args) => change_password_for_user(args, txn).await,
}) // // UserCommand::ShowUser(args) => show_users(args, txn).await,
}) // // UserCommand::LockUser(args) => lock_users(args, txn).await,
.await; // // UserCommand::UnlockUser(args) => unlock_users(args, txn).await,
// }
// })
// })
// .await;
close_database_connection(connection).await; // close_database_connection(connection).await;
}
result
} }
async fn create_users( async fn create_users<'a>(
args: UserCreateArgs, args: UserCreateArgs,
connection: &mut MySqlConnection, server_connection: &mut ServerToClientMessageStream<'a>,
) -> anyhow::Result<CommandStatus> { ) -> anyhow::Result<()> {
if args.username.is_empty() { if args.username.is_empty() {
anyhow::bail!("No usernames provided"); anyhow::bail!("No usernames provided");
} }
let mut result = CommandStatus::SuccessfullyModified; let message = Request::CreateUsers(args.username.clone());
// TODO: better error handling
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
match server_connection.next().await {
Some(Ok(Response::CreateUsers(result))) => {
for (username, result) in result {
match result {
Ok(_) => println!("User '{}' created.", username),
// TODO: this should display custom error messages
// based on the error reported.
Err(err) => {
eprintln!("{:?}", err);
eprintln!("Skipping...\n");
}
}
}
}
Some(Ok(Response::Error(e))) => {
eprintln!("{}", e);
}
Some(Ok(_)) => {
eprintln!("Unexpected response from server");
}
Some(Err(e)) => {
eprintln!("{}", e);
}
None => {
eprintln!("No response from server");
}
}
for username in args.username { for username in args.username {
if let Err(e) = create_database_user(&username, connection).await {
eprintln!("{}", e);
eprintln!("Skipping...\n");
result = CommandStatus::PartiallySuccessfullyModified;
continue;
} else {
println!("User '{}' created.", username);
}
if !args.no_password if !args.no_password
&& Confirm::new() && Confirm::new()
.with_prompt(format!( .with_prompt(format!(
@ -147,43 +184,106 @@ async fn create_users(
)) ))
.interact()? .interact()?
{ {
change_password_for_user( let password = read_password_from_stdin_with_double_check(&username)?;
UserPasswdArgs { let message = Request::PasswdUser(username.clone(), password);
username,
password_file: None, if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
}
match server_connection.next().await {
Some(Ok(Response::PasswdUser(result))) => match result {
Ok(_) => println!("Password set for user '{}'", username),
Err(e) => {
eprintln!("{:?}", e);
eprintln!("Skipping...\n");
}
}, },
connection, Some(Ok(Response::Error(e))) => {
) eprintln!("{}", e);
.await?; }
} Some(Ok(_)) => {
println!(); eprintln!("Unexpected response from server");
} }
Ok(result) Some(Err(e)) => {
} eprintln!("{}", e);
}
async fn drop_users( None => {
args: UserDeleteArgs, eprintln!("No response from server");
connection: &mut MySqlConnection, }
) -> anyhow::Result<CommandStatus> { }
if args.username.is_empty() {
anyhow::bail!("No usernames provided");
}
let mut result = CommandStatus::SuccessfullyModified;
for username in args.username {
if let Err(e) = delete_database_user(&username, connection).await {
eprintln!("{}", e);
eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified;
} else {
println!("User '{}' dropped.", username);
} }
} }
Ok(result) Ok(())
} }
// async fn create_users(
// args: UserCreateArgs,
// connection: &mut MySqlConnection,
// ) -> anyhow::Result<CommandStatus> {
// if args.username.is_empty() {
// anyhow::bail!("No usernames provided");
// }
// let mut result = CommandStatus::SuccessfullyModified;
// for username in args.username {
// if let Err(e) = create_database_user(&username, connection).await {
// eprintln!("{}", e);
// eprintln!("Skipping...\n");
// result = CommandStatus::PartiallySuccessfullyModified;
// continue;
// } else {
// println!("User '{}' created.", username);
// }
// if !args.no_password
// && Confirm::new()
// .with_prompt(format!(
// "Do you want to set a password for user '{}'?",
// username
// ))
// .interact()?
// {
// change_password_for_user(
// UserPasswdArgs {
// username,
// password_file: None,
// },
// connection,
// )
// .await?;
// }
// println!();
// }
// Ok(result)
// }
// async fn drop_users(
// args: UserDeleteArgs,
// connection: &mut MySqlConnection,
// ) -> anyhow::Result<CommandStatus> {
// if args.username.is_empty() {
// anyhow::bail!("No usernames provided");
// }
// let mut result = CommandStatus::SuccessfullyModified;
// for username in args.username {
// if let Err(e) = delete_database_user(&username, connection).await {
// eprintln!("{}", e);
// eprintln!("Skipping...");
// result = CommandStatus::PartiallySuccessfullyModified;
// } else {
// println!("User '{}' dropped.", username);
// }
// }
// Ok(result)
// }
pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> { pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Result<String> {
Password::new() Password::new()
.with_prompt(format!("New MySQL password for user '{}'", username)) .with_prompt(format!("New MySQL password for user '{}'", username))
@ -195,147 +295,147 @@ pub fn read_password_from_stdin_with_double_check(username: &str) -> anyhow::Res
.map_err(Into::into) .map_err(Into::into)
} }
async fn change_password_for_user( // async fn change_password_for_user(
args: UserPasswdArgs, // args: UserPasswdArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
// NOTE: although this also is checked in `set_password_for_database_user`, we check it here // // NOTE: although this also is checked in `set_password_for_database_user`, we check it here
// to provide a more natural order of error messages. // // to provide a more natural order of error messages.
let unix_user = get_current_unix_user()?; // let unix_user = get_current_unix_user()?;
validate_user_name(&args.username, &unix_user)?; // validate_user_name(&args.username, &unix_user)?;
let password = if let Some(password_file) = args.password_file { // let password = if let Some(password_file) = args.password_file {
std::fs::read_to_string(password_file) // std::fs::read_to_string(password_file)
.context("Failed to read password file")? // .context("Failed to read password file")?
.trim() // .trim()
.to_string() // .to_string()
} else { // } else {
read_password_from_stdin_with_double_check(&args.username)? // read_password_from_stdin_with_double_check(&args.username)?
}; // };
set_password_for_database_user(&args.username, &password, connection).await?; // set_password_for_database_user(&args.username, &password, connection).await?;
Ok(CommandStatus::SuccessfullyModified) // Ok(CommandStatus::SuccessfullyModified)
} // }
async fn show_users( // async fn show_users(
args: UserShowArgs, // args: UserShowArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
let unix_user = get_current_unix_user()?; // let unix_user = get_current_unix_user()?;
let users = if args.username.is_empty() { // let users = if args.username.is_empty() {
get_all_database_users_for_unix_user(&unix_user, connection).await? // get_all_database_users_for_unix_user(&unix_user, connection).await?
} else { // } else {
let mut result = vec![]; // let mut result = vec![];
for username in args.username { // for username in args.username {
if let Err(e) = validate_user_name(&username, &unix_user) { // if let Err(e) = validate_user_name(&username, &unix_user) {
eprintln!("{}", e); // eprintln!("{}", e);
eprintln!("Skipping..."); // eprintln!("Skipping...");
continue; // continue;
} // }
let user = get_database_user_for_user(&username, connection).await?; // let user = get_database_user_for_user(&username, connection).await?;
if let Some(user) = user { // if let Some(user) = user {
result.push(user); // result.push(user);
} else { // } else {
eprintln!("User not found: {}", username); // eprintln!("User not found: {}", username);
} // }
} // }
result // result
}; // };
let mut user_databases: BTreeMap<String, Vec<String>> = BTreeMap::new(); // let mut user_databases: BTreeMap<String, Vec<String>> = BTreeMap::new();
for user in users.iter() { // for user in users.iter() {
user_databases.insert( // user_databases.insert(
user.user.clone(), // user.user.clone(),
get_databases_where_user_has_privileges(&user.user, connection).await?, // get_databases_where_user_has_privileges(&user.user, connection).await?,
); // );
} // }
if args.json { // if args.json {
let users_json = users // let users_json = users
.into_iter() // .into_iter()
.map(|user| { // .map(|user| {
json!({ // json!({
"user": user.user, // "user": user.user,
"has_password": user.has_password, // "has_password": user.has_password,
"is_locked": user.is_locked, // "is_locked": user.is_locked,
"databases": user_databases.get(&user.user).unwrap_or(&vec![]), // "databases": user_databases.get(&user.user).unwrap_or(&vec![]),
}) // })
}) // })
.collect::<serde_json::Value>(); // .collect::<serde_json::Value>();
println!( // println!(
"{}", // "{}",
serde_json::to_string_pretty(&users_json) // serde_json::to_string_pretty(&users_json)
.context("Failed to serialize users to JSON")? // .context("Failed to serialize users to JSON")?
); // );
} else if users.is_empty() { // } else if users.is_empty() {
println!("No users found."); // println!("No users found.");
} else { // } else {
let mut table = Table::new(); // let mut table = Table::new();
table.add_row(row![ // table.add_row(row![
"User", // "User",
"Password is set", // "Password is set",
"Locked", // "Locked",
"Databases where user has privileges" // "Databases where user has privileges"
]); // ]);
for user in users { // for user in users {
table.add_row(row![ // table.add_row(row![
user.user, // user.user,
user.has_password, // user.has_password,
user.is_locked, // user.is_locked,
user_databases.get(&user.user).unwrap_or(&vec![]).join("\n") // user_databases.get(&user.user).unwrap_or(&vec![]).join("\n")
]); // ]);
} // }
table.printstd(); // table.printstd();
} // }
Ok(CommandStatus::NoModificationsIntended) // Ok(CommandStatus::NoModificationsIntended)
} // }
async fn lock_users( // async fn lock_users(
args: UserLockArgs, // args: UserLockArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
if args.username.is_empty() { // if args.username.is_empty() {
anyhow::bail!("No usernames provided"); // anyhow::bail!("No usernames provided");
} // }
let mut result = CommandStatus::SuccessfullyModified; // let mut result = CommandStatus::SuccessfullyModified;
for username in args.username { // for username in args.username {
if let Err(e) = lock_database_user(&username, connection).await { // if let Err(e) = lock_database_user(&username, connection).await {
eprintln!("{}", e); // eprintln!("{}", e);
eprintln!("Skipping..."); // eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified; // result = CommandStatus::PartiallySuccessfullyModified;
} else { // } else {
println!("User '{}' locked.", username); // println!("User '{}' locked.", username);
} // }
} // }
Ok(result) // Ok(result)
} // }
async fn unlock_users( // async fn unlock_users(
args: UserUnlockArgs, // args: UserUnlockArgs,
connection: &mut MySqlConnection, // connection: &mut MySqlConnection,
) -> anyhow::Result<CommandStatus> { // ) -> anyhow::Result<CommandStatus> {
if args.username.is_empty() { // if args.username.is_empty() {
anyhow::bail!("No usernames provided"); // anyhow::bail!("No usernames provided");
} // }
let mut result = CommandStatus::SuccessfullyModified; // let mut result = CommandStatus::SuccessfullyModified;
for username in args.username { // for username in args.username {
if let Err(e) = unlock_database_user(&username, connection).await { // if let Err(e) = unlock_database_user(&username, connection).await {
eprintln!("{}", e); // eprintln!("{}", e);
eprintln!("Skipping..."); // eprintln!("Skipping...");
result = CommandStatus::PartiallySuccessfullyModified; // result = CommandStatus::PartiallySuccessfullyModified;
} else { // } else {
println!("User '{}' unlocked.", username); // println!("User '{}' unlocked.", username);
} // }
} // }
Ok(result) // Ok(result)
} // }

View File

@ -1,5 +1,3 @@
pub mod common; pub mod common;
pub mod config; pub mod config;
pub mod database_operations; pub mod database_privileges;
pub mod database_privilege_operations;
pub mod user_operations;

View File

@ -1,92 +1,3 @@
use anyhow::Context;
use indoc::indoc;
use itertools::Itertools;
use nix::unistd::{getuid, Group, User};
use sqlx::{Connection, MySqlConnection};
#[cfg(not(target_os = "macos"))]
use std::ffi::CString;
/// Report the result status of a command.
/// This is used to display a status message to the user.
pub enum CommandStatus {
/// The command was successful,
/// and made modification to the database.
SuccessfullyModified,
/// The command was mostly successful,
/// and modifications have been made to the database.
/// However, some of the requested modifications failed.
PartiallySuccessfullyModified,
/// The command was successful,
/// but no modifications were needed.
NoModificationsNeeded,
/// The command was successful,
/// and made no modification to the database.
NoModificationsIntended,
/// The command was cancelled, either through a dialog or a signal.
/// No modifications have been made to the database.
Cancelled,
}
pub fn get_current_unix_user() -> anyhow::Result<User> {
User::from_uid(getuid())
.context("Failed to look up your UNIX username")
.and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))
}
#[cfg(target_os = "macos")]
pub fn get_unix_groups(_user: &User) -> anyhow::Result<Vec<Group>> {
// Return an empty list on macOS since there is no `getgrouplist` function
Ok(vec![])
}
#[cfg(not(target_os = "macos"))]
pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> {
let user_cstr =
CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?;
let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)?
.iter()
.filter_map(|gid| match Group::from_gid(*gid) {
Ok(Some(group)) => Some(group),
Ok(None) => None,
Err(e) => {
log::warn!(
"Failed to look up group with GID {}: {}\nIgnoring...",
gid,
e
);
None
}
})
.collect::<Vec<Group>>();
Ok(groups)
}
/// This function creates a regex that matches items (users, databases)
/// that belong to the user or any of the user's groups.
pub fn create_user_group_matching_regex(user: &User) -> String {
let groups = get_unix_groups(user).unwrap_or_default();
if groups.is_empty() {
format!("{}(_.+)?", user.name)
} else {
format!(
"({}|{})(_.+)?",
user.name,
groups
.iter()
.map(|g| g.name.as_str())
.collect::<Vec<_>>()
.join("|")
)
}
}
/// This enum is used to differentiate between database and user operations. /// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used. /// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)] #[derive(Debug, PartialEq, Eq, Clone, Copy)]
@ -111,161 +22,6 @@ impl DbOrUser {
} }
} }
#[derive(Debug, PartialEq, Eq)]
pub enum NameValidationResult {
Valid,
EmptyString,
InvalidCharacters,
TooLong,
}
pub fn validate_name(name: &str) -> NameValidationResult {
if name.is_empty() {
NameValidationResult::EmptyString
} else if name.len() > 64 {
NameValidationResult::TooLong
} else if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
NameValidationResult::InvalidCharacters
} else {
NameValidationResult::Valid
}
}
pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
match validate_name(name) {
NameValidationResult::Valid => Ok(()),
NameValidationResult::EmptyString => {
anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
}
NameValidationResult::TooLong => anyhow::bail!(
"{} is too long. Maximum length is 64 characters.",
db_or_user.capitalized()
),
NameValidationResult::InvalidCharacters => anyhow::bail!(
indoc! {r#"
Invalid characters in {} name: '{}'
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#},
db_or_user.lowercased(),
name
),
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum OwnerValidationResult {
// The name is valid and matches one of the given prefixes
Match,
// The name is valid, but none of the given prefixes matched the name
NoMatch,
// The name is empty, which is invalid
StringEmpty,
// The name is in the format "_<postfix>", which is invalid
MissingPrefix,
// The name is in the format "<prefix>_", which is invalid
MissingPostfix,
}
/// Core logic for validating the ownership of a database name.
/// This function checks if the given name matches any of the given prefixes.
/// These prefixes will in most cases be the user's unix username and any
/// unix groups the user is a member of.
pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
if name.is_empty() {
return OwnerValidationResult::StringEmpty;
}
if name.starts_with('_') {
return OwnerValidationResult::MissingPrefix;
}
let (prefix, _) = match name.split_once('_') {
Some(pair) => pair,
None => return OwnerValidationResult::MissingPostfix,
};
if prefixes.iter().any(|g| g == prefix) {
OwnerValidationResult::Match
} else {
OwnerValidationResult::NoMatch
}
}
/// Validate the ownership of a database name or database user name.
/// This function takes the name of a database or user and a unix user,
/// for which it fetches the user's groups. It then checks if the name
/// is prefixed with the user's username or any of the user's groups.
pub fn validate_ownership_or_error<'a>(
name: &'a str,
user: &User,
db_or_user: DbOrUser,
) -> anyhow::Result<&'a str> {
let user_groups = get_unix_groups(user)?;
let prefixes = std::iter::once(user.name.clone())
.chain(user_groups.iter().map(|g| g.name.clone()))
.collect::<Vec<String>>();
match validate_ownership_by_prefixes(name, &prefixes) {
OwnerValidationResult::Match => Ok(name),
OwnerValidationResult::NoMatch => {
anyhow::bail!(
indoc! {r#"
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
Allowed prefixes:
- {}
{}
"#},
db_or_user.lowercased(),
name,
db_or_user.lowercased(),
user.name,
user_groups
.iter()
.filter(|g| g.name != user.name)
.map(|g| format!(" - {}", g.name))
.sorted()
.join("\n"),
);
}
_ => anyhow::bail!(
"'{}' is not a valid {} name.",
name,
db_or_user.lowercased()
),
}
}
/// Gracefully close a MySQL connection.
pub async fn close_database_connection(connection: MySqlConnection) {
if let Err(e) = connection
.close()
.await
.context("Failed to close connection properly")
{
eprintln!("{}", e);
eprintln!("Ignoring...");
}
}
#[inline]
pub fn quote_literal(s: &str) -> String {
format!("'{}'", s.replace('\'', r"\'"))
}
#[inline]
pub fn quote_identifier(s: &str) -> String {
format!("`{}`", s.replace('`', r"\`"))
}
#[inline] #[inline]
pub(crate) fn yn(b: bool) -> &'static str { pub(crate) fn yn(b: bool) -> &'static str {
@ -303,94 +59,4 @@ mod test {
assert_eq!(rev_yn("n"), Some(false)); assert_eq!(rev_yn("n"), Some(false));
assert_eq!(rev_yn("X"), None); assert_eq!(rev_yn("X"), None);
} }
}
#[test]
fn test_quote_literal() {
let payload = "' OR 1=1 --";
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
}
#[test]
fn test_quote_identifier() {
let payload = "` OR 1=1 --";
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
}
#[test]
fn test_validate_name() {
assert_eq!(validate_name(""), NameValidationResult::EmptyString);
assert_eq!(
validate_name("abcdefghijklmnopqrstuvwxyz"),
NameValidationResult::Valid
);
assert_eq!(
validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
NameValidationResult::Valid
);
assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
assert_eq!(
validate_name(&c.to_string()),
NameValidationResult::InvalidCharacters
);
}
assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
assert_eq!(
validate_name(&"a".repeat(65)),
NameValidationResult::TooLong
);
}
#[test]
fn test_validate_owner_by_prefixes() {
let prefixes = vec!["user".to_string(), "group".to_string()];
assert_eq!(
validate_ownership_by_prefixes("", &prefixes),
OwnerValidationResult::StringEmpty
);
assert_eq!(
validate_ownership_by_prefixes("user", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("something", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("user-testdb", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("_testdb", &prefixes),
OwnerValidationResult::MissingPrefix
);
assert_eq!(
validate_ownership_by_prefixes("user_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test_db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test-db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
OwnerValidationResult::NoMatch
);
}
}

View File

@ -1,120 +0,0 @@
use anyhow::Context;
use indoc::formatdoc;
use itertools::Itertools;
use nix::unistd::User;
use sqlx::{prelude::*, MySqlConnection};
use crate::core::{
common::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier,
validate_name_or_error, validate_ownership_or_error, DbOrUser,
},
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
};
pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
let user = get_current_unix_user()?;
validate_database_name(name, &user)?;
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name)))
.execute(connection)
.await
.map_err(|e| {
if e.to_string().contains("database exists") {
anyhow::anyhow!("Database '{}' already exists", name)
} else {
e.into()
}
})?;
Ok(())
}
pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
let user = get_current_unix_user()?;
validate_database_name(name, &user)?;
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name)))
.execute(connection)
.await
.map_err(|e| {
if e.to_string().contains("doesn't exist") {
anyhow::anyhow!("Database '{}' does not exist", name)
} else {
e.into()
}
})?;
Ok(())
}
pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result<Vec<String>> {
let unix_user = get_current_unix_user()?;
let databases: Vec<String> = sqlx::query(
r#"
SELECT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?
"#,
)
.bind(create_user_group_matching_regex(&unix_user))
.fetch_all(connection)
.await
.and_then(|row| {
row.into_iter()
.map(|row| row.try_get::<String, _>("database"))
.collect::<Result<_, _>>()
})
.context(format!(
"Failed to get databases for user '{}'",
unix_user.name
))?;
Ok(databases)
}
pub async fn get_databases_where_user_has_privileges(
username: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<String>> {
let result = sqlx::query(
formatdoc!(
r#"
SELECT `db` AS `database`
FROM `db`
WHERE `user` = ?
AND ({})
"#,
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{}` = 'Y'", field))
.join(" OR "),
)
.as_str(),
)
.bind(username)
.fetch_all(connection)
.await?
.into_iter()
.map(|databases| databases.try_get::<String, _>("database").unwrap())
.collect();
Ok(result)
}
/// NOTE: It is very critical that this function validates the database name
/// properly. MySQL does not seem to allow for prepared statements, binding
/// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection.
pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_or_error(name, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
Ok(())
}

View File

@ -1,52 +1,11 @@
//! Database privilege operations
//!
//! This module contains functions for querying, modifying,
//! displaying and comparing database privileges.
//!
//! A lot of the complexity comes from two core components:
//!
//! - The privilege editor that needs to be able to print
//! an editable table of privileges and reparse the content
//! after the user has made manual changes.
//!
//! - The comparison functionality that tells the user what
//! changes will be made when applying a set of changes
//! to the list of database privileges.
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use anyhow::{anyhow, Context};
use indoc::indoc;
use itertools::Itertools;
use prettytable::Table; use prettytable::Table;
use serde::{Deserialize, Serialize}; use itertools::Itertools;
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; use serde::{Serialize, Deserialize};
use anyhow::{anyhow, Context};
use crate::core::{ use crate::server::database_privilege_operations::{DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS};
common::{ use super::common::{rev_yn, yn};
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn,
},
database_operations::validate_database_name,
};
/// This is the list of fields that are used to fetch the db + user + privileges
/// from the `db` table in the database. If you need to add or remove privilege
/// fields, this is a good place to start.
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"db",
"user",
"select_priv",
"insert_priv",
"update_priv",
"delete_priv",
"create_priv",
"drop_priv",
"alter_priv",
"index_priv",
"create_tmp_table_priv",
"lock_tables_priv",
"references_priv",
];
pub fn db_priv_field_human_readable_name(name: &str) -> String { pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name { match name {
@ -67,162 +26,24 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String {
} }
} }
/// This struct represents the set of privileges for a single user on a single database. pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] debug_assert!(row1.db == row2.db && row1.user == row2.user);
pub struct DatabasePrivilegeRow {
pub db: String,
pub user: String,
pub select_priv: bool,
pub insert_priv: bool,
pub update_priv: bool,
pub delete_priv: bool,
pub create_priv: bool,
pub drop_priv: bool,
pub alter_priv: bool,
pub index_priv: bool,
pub create_tmp_table_priv: bool,
pub lock_tables_priv: bool,
pub references_priv: bool,
}
impl DatabasePrivilegeRow { DatabasePrivilegeRowDiff {
pub fn empty(db: &str, user: &str) -> Self { db: row1.db.clone(),
Self { user: row1.user.clone(),
db: db.to_owned(), diff: DATABASE_PRIVILEGE_FIELDS
user: user.to_owned(), .into_iter()
select_priv: false, .skip(2)
insert_priv: false, .filter_map(|field| {
update_priv: false, DatabasePrivilegeChange::new(
delete_priv: false, row1.get_privilege_by_name(field),
create_priv: false, row2.get_privilege_by_name(field),
drop_priv: false, field,
alter_priv: false, )
index_priv: false, })
create_tmp_table_priv: false, .collect(),
lock_tables_priv: false,
references_priv: false,
}
} }
pub fn get_privilege_by_name(&self, name: &str) -> bool {
match name {
"select_priv" => self.select_priv,
"insert_priv" => self.insert_priv,
"update_priv" => self.update_priv,
"delete_priv" => self.delete_priv,
"create_priv" => self.create_priv,
"drop_priv" => self.drop_priv,
"alter_priv" => self.alter_priv,
"index_priv" => self.index_priv,
"create_tmp_table_priv" => self.create_tmp_table_priv,
"lock_tables_priv" => self.lock_tables_priv,
"references_priv" => self.references_priv,
_ => false,
}
}
pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
debug_assert!(self.db == other.db && self.user == other.user);
DatabasePrivilegeRowDiff {
db: self.db.clone(),
user: self.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.filter_map(|field| {
DatabasePrivilegeChange::new(
self.get_privilege_by_name(field),
other.get_privilege_by_name(field),
field,
)
})
.collect(),
}
}
}
#[inline]
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
let field = DATABASE_PRIVILEGE_FIELDS[position];
let value = row.try_get(position)?;
match rev_yn(value) {
Some(val) => Ok(val),
_ => {
log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
Ok(false)
}
}
}
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
db: row.try_get("db")?,
user: row.try_get("user")?,
select_priv: get_mysql_row_priv_field(row, 2)?,
insert_priv: get_mysql_row_priv_field(row, 3)?,
update_priv: get_mysql_row_priv_field(row, 4)?,
delete_priv: get_mysql_row_priv_field(row, 5)?,
create_priv: get_mysql_row_priv_field(row, 6)?,
drop_priv: get_mysql_row_priv_field(row, 7)?,
alter_priv: get_mysql_row_priv_field(row, 8)?,
index_priv: get_mysql_row_priv_field(row, 9)?,
create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
references_priv: get_mysql_row_priv_field(row, 12)?,
})
}
}
/// Get all users + privileges for a single database.
pub async fn get_database_privileges(
database_name: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = get_current_unix_user()?;
validate_database_name(database_name, &unix_user)?;
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
.join(","),
))
.bind(database_name)
.fetch_all(connection)
.await
.context("Failed to show database")?;
Ok(result)
}
/// Get all database + user + privileges pairs that are owned by the current user.
pub async fn get_all_database_privileges(
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
let unix_user = get_current_unix_user()?;
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?)
"#},
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(","),
))
.bind(create_user_group_matching_regex(&unix_user))
.fetch_all(connection)
.await
.context("Failed to show databases")?;
Ok(result)
} }
/*************************/ /*************************/
@ -586,7 +407,7 @@ pub struct DatabasePrivilegeRowDiff {
} }
/// This enum represents a change for a single privilege. /// This enum represents a change for a single privilege.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub enum DatabasePrivilegeChange { pub enum DatabasePrivilegeChange {
YesToNo(String), YesToNo(String),
NoToYes(String), NoToYes(String),
@ -610,6 +431,24 @@ pub enum DatabasePrivilegesDiff {
Deleted(DatabasePrivilegeRow), Deleted(DatabasePrivilegeRow),
} }
impl DatabasePrivilegesDiff {
pub fn get_database_name(&self) -> &str {
match self {
DatabasePrivilegesDiff::New(p) => &p.db,
DatabasePrivilegesDiff::Modified(p) => &p.db,
DatabasePrivilegesDiff::Deleted(p) => &p.db,
}
}
pub fn get_user_name(&self) -> &str {
match self {
DatabasePrivilegesDiff::New(p) => &p.user,
DatabasePrivilegesDiff::Modified(p) => &p.user,
DatabasePrivilegesDiff::Deleted(p) => &p.user,
}
}
}
/// This function calculates the differences between two sets of database privileges. /// This function calculates the differences between two sets of database privileges.
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or /// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
/// apply a set of privilege modifications to the database. /// apply a set of privilege modifications to the database.
@ -633,7 +472,7 @@ pub fn diff_privileges(
for p in to { for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p); let diff = diff(old_p, p);
if !diff.diff.is_empty() { if !diff.diff.is_empty() {
result.insert(DatabasePrivilegesDiff::Modified(diff)); result.insert(DatabasePrivilegesDiff::Modified(diff));
} }
@ -651,72 +490,6 @@ pub fn diff_privileges(
result result
} }
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
pub async fn apply_privilege_diffs(
diffs: BTreeSet<DatabasePrivilegesDiff>,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
for diff in diffs {
match diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(",");
let question_marks = std::iter::repeat("?")
.take(DATABASE_PRIVILEGE_FIELDS.len())
.join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db)
.bind(p.user)
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(&mut *connection)
.await?;
}
DatabasePrivilegesDiff::Modified(p) => {
let tables = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name),
DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name),
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
)
.bind(p.db)
.bind(p.user)
.execute(&mut *connection)
.await?;
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db)
.bind(p.user)
.execute(&mut *connection)
.await?;
}
}
}
Ok(())
}
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
diff.diff diff.diff
.iter() .iter()
@ -743,7 +516,7 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri
p.user, p.user,
"(New user)\n".to_string() "(New user)\n".to_string()
+ &display_privilege_cell( + &display_privilege_cell(
&DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p) &diff(&DatabasePrivilegeRow::empty(&p.db, &p.user), &p)
) )
]); ]);
} }
@ -756,7 +529,7 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri
p.user, p.user,
"(All privileges removed)\n".to_string() "(All privileges removed)\n".to_string()
+ &display_privilege_cell( + &display_privilege_cell(
&p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user)) &diff(&p, &DatabasePrivilegeRow::empty(&p.db, &p.user))
) )
]); ]);
} }
@ -882,4 +655,4 @@ mod tests {
assert_eq!(permissions, parsed_permissions); assert_eq!(permissions, parsed_permissions);
} }
} }

View File

@ -1,249 +0,0 @@
use anyhow::Context;
use nix::unistd::User;
use serde::{Deserialize, Serialize};
use sqlx::{prelude::*, MySqlConnection};
use crate::core::common::{
create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
validate_ownership_or_error, DbOrUser,
};
pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
let user_exists = sqlx::query(
r#"
SELECT EXISTS(
SELECT 1
FROM `mysql`.`user`
WHERE `User` = ?
)
"#,
)
.bind(db_user)
.fetch_one(connection)
.await?
.get::<bool, _>(0);
Ok(user_exists)
}
pub async fn create_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' already exists", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
pub async fn delete_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
pub async fn set_password_for_database_user(
db_user: &str,
password: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = crate::core::common::get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(
format!(
"ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user),
quote_literal(password).as_str()
)
.as_str(),
)
.execute(connection)
.await?;
Ok(())
}
async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
let is_locked = sqlx::query(
r#"
SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true'
FROM `mysql`.`global_priv`
WHERE `User` = ?
AND `Host` = '%'
"#,
)
.bind(db_user)
.fetch_one(connection)
.await?
.get::<bool, _>(0);
Ok(is_locked)
}
pub async fn lock_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
if user_is_locked(db_user, connection).await? {
anyhow::bail!("User '{}' is already locked", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
pub async fn unlock_database_user(
db_user: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
let unix_user = get_current_unix_user()?;
validate_user_name(db_user, &unix_user)?;
if !user_exists(db_user, connection).await? {
anyhow::bail!("User '{}' does not exist", db_user);
}
if !user_is_locked(db_user, connection).await? {
anyhow::bail!("User '{}' is already unlocked", db_user);
}
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str())
.execute(connection)
.await?;
Ok(())
}
/// This struct contains information about a database user.
/// This can be extended if we need more information in the future.
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
pub struct DatabaseUser {
#[sqlx(rename = "User")]
pub user: String,
#[allow(dead_code)]
#[serde(skip)]
#[sqlx(rename = "Host")]
pub host: String,
#[sqlx(rename = "has_password")]
pub has_password: bool,
#[sqlx(rename = "is_locked")]
pub is_locked: bool,
}
const DB_USER_SELECT_STATEMENT: &str = r#"
SELECT
`mysql`.`user`.`User`,
`mysql`.`user`.`Host`,
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
COALESCE(
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
'false'
) != 'false' AS `is_locked`
FROM `mysql`.`user`
JOIN `mysql`.`global_priv` ON
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
"#;
/// This function fetches all database users that have a prefix matching the
/// unix username and group names of the given unix user.
pub async fn get_all_database_users_for_unix_user(
unix_user: &User,
connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabaseUser>> {
let users = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
)
.bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection)
.await?;
Ok(users)
}
/// This function fetches a database user if it exists.
pub async fn get_database_user_for_user(
username: &str,
connection: &mut MySqlConnection,
) -> anyhow::Result<Option<DatabaseUser>> {
let user = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
)
.bind(username)
.fetch_optional(connection)
.await?;
Ok(user)
}
/// NOTE: It is very critical that this function validates the database name
/// properly. MySQL does not seem to allow for prepared statements, binding
/// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection.
pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_or_error(name, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?;
Ok(())
}

View File

@ -1,34 +1,34 @@
#[macro_use] #[macro_use]
extern crate prettytable; extern crate prettytable;
use core::common::CommandStatus; // use core::common::CommandStatus;
#[cfg(feature = "mysql-admutils-compatibility")] // #[cfg(feature = "mysql-admutils-compatibility")]
use std::path::PathBuf; // use std::path::PathBuf;
#[cfg(feature = "mysql-admutils-compatibility")] // #[cfg(feature = "mysql-admutils-compatibility")]
use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm}; // use crate::cli::mysql_admutils_compatibility::{mysql_dbadm, mysql_useradm};
use clap::Parser; mod server;
mod authenticated_unix_socket; mod bootstrap;
mod cli; mod cli;
mod core; mod core;
#[cfg(feature = "tui")] #[cfg(feature = "tui")]
mod tui; mod tui;
#[derive(Parser)] // #[derive(Parser)]
struct Args { // struct Args {
#[command(subcommand)] // #[command(subcommand)]
command: Command, // command: Command,
#[command(flatten)] // #[command(flatten)]
config_overrides: core::config::GlobalConfigArgs, // config_overrides: core::config::GlobalConfigArgs,
#[cfg(feature = "tui")] // #[cfg(feature = "tui")]
#[arg(short, long, alias = "tui", global = true)] // #[arg(short, long, alias = "tui", global = true)]
interactive: bool, // interactive: bool,
} // }
/// Database administration tool for non-admin users to manage their own MySQL databases and users. /// Database administration tool for non-admin users to manage their own MySQL databases and users.
/// ///
@ -36,65 +36,67 @@ struct Args {
/// ///
/// You are only allowed to manage databases and users that are prefixed with /// You are only allowed to manage databases and users that are prefixed with
/// either your username, or a group that you are a member of. /// either your username, or a group that you are a member of.
#[derive(Parser)] // #[derive(Parser)]
#[command(version, about, disable_help_subcommand = true)] // #[command(version, about, disable_help_subcommand = true)]
enum Command { // enum Command {
#[command(flatten)] // #[command(flatten)]
Db(cli::database_command::DatabaseCommand), // Db(cli::database_command::DatabaseCommand),
#[command(flatten)] // #[command(flatten)]
User(cli::user_command::UserCommand), // User(cli::user_command::UserCommand),
} // }
#[tokio::main(flavor = "current_thread")] #[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
// boostrap_server_connection_and_drop_privileges().await?;
#[cfg(feature = "mysql-admutils-compatibility")] // #[cfg(feature = "mysql-admutils-compatibility")]
{ // {
let argv0 = std::env::args().next().and_then(|s| { // let argv0 = std::env::args().next().and_then(|s| {
PathBuf::from(s) // PathBuf::from(s)
.file_name() // .file_name()
.map(|s| s.to_string_lossy().to_string()) // .map(|s| s.to_string_lossy().to_string())
}); // });
match argv0.as_deref() { // match argv0.as_deref() {
Some("mysql-dbadm") => return mysql_dbadm::main().await, // Some("mysql-dbadm") => return mysql_dbadm::main().await,
Some("mysql-useradm") => return mysql_useradm::main().await, // Some("mysql-useradm") => return mysql_useradm::main().await,
_ => { /* fall through */ } // _ => { /* fall through */ }
} // }
} // }
let args: Args = Args::parse(); // let args: Args = Args::parse();
let config = core::config::get_config(args.config_overrides)?; // let config = core::config::get_config(args.config_overrides)?;
let connection = core::config::create_mysql_connection_from_config(config.mysql).await?; // let connection = core::config::create_mysql_connection_from_config(config.mysql).await?;
let result = match args.command { // let result = match args.command {
Command::Db(command) => cli::database_command::handle_command(command, connection).await, // Command::Db(command) => cli::database_command::handle_command(command, connection).await,
Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await, // Command::User(user_args) => cli::user_command::handle_command(user_args, connection).await,
}; // };
match result { // match result {
Ok(CommandStatus::SuccessfullyModified) => { // Ok(CommandStatus::SuccessfullyModified) => {
println!("Modifications committed successfully"); // println!("Modifications committed successfully");
Ok(()) // Ok(())
} // }
Ok(CommandStatus::PartiallySuccessfullyModified) => { // Ok(CommandStatus::PartiallySuccessfullyModified) => {
println!("Some modifications committed successfully"); // println!("Some modifications committed successfully");
Ok(()) // Ok(())
} // }
Ok(CommandStatus::NoModificationsNeeded) => { // Ok(CommandStatus::NoModificationsNeeded) => {
println!("No modifications made"); // println!("No modifications made");
Ok(()) // Ok(())
} // }
Ok(CommandStatus::NoModificationsIntended) => { // Ok(CommandStatus::NoModificationsIntended) => {
/* Don't report anything */ // /* Don't report anything */
Ok(()) // Ok(())
} // }
Ok(CommandStatus::Cancelled) => { // Ok(CommandStatus::Cancelled) => {
println!("Command cancelled successfully"); // println!("Command cancelled successfully");
Ok(()) // Ok(())
} // }
Err(e) => Err(e), // Err(e) => Err(e),
} // }
Ok(())
} }

View File

@ -1,6 +1,9 @@
mod common; mod common;
mod database_operations; mod database_operations;
pub mod database_privilege_operations;
mod entrypoint; mod entrypoint;
mod input_sanitization; mod input_sanitization;
mod protocol; mod protocol;
mod user_operations; mod user_operations;
pub use protocol::{Request, Response};

View File

@ -12,13 +12,13 @@ use std::collections::BTreeMap;
use super::common::create_user_group_matching_regex; use super::common::create_user_group_matching_regex;
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
async fn unsafe_database_exists( pub(super) async fn unsafe_database_exists(
db_name: &str, database_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> { ) -> Result<bool, sqlx::Error> {
let result = let result =
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
.bind(db_name) .bind(database_name)
.fetch_optional(connection) .fetch_optional(connection)
.await?; .await?;

View File

@ -1,3 +1,4 @@
// TODO: fix comment
//! Database privilege operations //! Database privilege operations
//! //!
//! This module contains functions for querying, modifying, //! This module contains functions for querying, modifying,
@ -13,21 +14,18 @@
//! changes will be made when applying a set of changes //! changes will be made when applying a set of changes
//! to the list of database privileges. //! to the list of database privileges.
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeMap, BTreeSet};
use anyhow::{anyhow, Context};
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use prettytable::Table;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
use crate::core::{ use super::common::{create_user_group_matching_regex, UnixUser};
common::{ use super::database_operations::unsafe_database_exists;
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, use super::input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user, NameValidationError, OwnerValidationError};
}, use crate::core::common::{rev_yn, yn};
database_operations::validate_database_name, use crate::core::database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff};
};
/// This is the list of fields that are used to fetch the db + user + privileges /// This is the list of fields that are used to fetch the db + user + privileges
/// from the `db` table in the database. If you need to add or remove privilege /// from the `db` table in the database. If you need to add or remove privilege
@ -48,25 +46,6 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
"references_priv", "references_priv",
]; ];
pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name {
"db" => "Database".to_owned(),
"user" => "User".to_owned(),
"select_priv" => "Select".to_owned(),
"insert_priv" => "Insert".to_owned(),
"update_priv" => "Update".to_owned(),
"delete_priv" => "Delete".to_owned(),
"create_priv" => "Create".to_owned(),
"drop_priv" => "Drop".to_owned(),
"alter_priv" => "Alter".to_owned(),
"index_priv" => "Index".to_owned(),
"create_tmp_table_priv" => "Temp".to_owned(),
"lock_tables_priv" => "Lock".to_owned(),
"references_priv" => "References".to_owned(),
_ => format!("Unknown({})", name),
}
}
// NOTE: ord is needed for BTreeSet to accept the type, but it // NOTE: ord is needed for BTreeSet to accept the type, but it
// doesn't have any natural implementation semantics. // doesn't have any natural implementation semantics.
@ -123,26 +102,6 @@ impl DatabasePrivilegeRow {
_ => false, _ => false,
} }
} }
pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
debug_assert!(self.db == other.db && self.user == other.user);
DatabasePrivilegeRowDiff {
db: self.db.clone(),
user: self.user.clone(),
diff: DATABASE_PRIVILEGE_FIELDS
.into_iter()
.skip(2)
.filter_map(|field| {
DatabasePrivilegeChange::new(
self.get_privilege_by_name(field),
other.get_privilege_by_name(field),
field,
)
})
.collect(),
}
}
} }
#[inline] #[inline]
@ -178,15 +137,13 @@ impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
} }
} }
// NOTE: this function is unsafe because it does no input validation.
/// Get all users + privileges for a single database. /// Get all users + privileges for a single database.
pub async fn get_database_privileges( async fn unsafe_get_database_privileges(
database_name: &str, database_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { ) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
let unix_user = get_current_unix_user()?; sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
validate_database_name(database_name, &unix_user)?;
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?", "SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
@ -196,18 +153,82 @@ pub async fn get_database_privileges(
.bind(database_name) .bind(database_name)
.fetch_all(connection) .fetch_all(connection)
.await .await
.context("Failed to show database")?; }
Ok(result) // TODO: merge all rows into a single collection.
// they already contain which database they belong to.
// no need to index by database name.
pub type GetDatabasesPrivilegeData =
BTreeMap<String, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum GetDatabasesPrivilegeDataError {
SanitizationError(NameValidationError),
OwnershipError(OwnerValidationError),
DatabaseDoesNotExist,
MySqlError(String),
}
pub async fn get_databases_privilege_data(
database_names: Vec<String>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
) -> GetDatabasesPrivilegeData {
let mut results = BTreeMap::new();
for database_name in database_names.iter() {
if let Err(err) = validate_name(database_name) {
results.insert(
database_name.clone(),
Err(GetDatabasesPrivilegeDataError::SanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) {
results.insert(
database_name.clone(),
Err(GetDatabasesPrivilegeDataError::OwnershipError(err)),
);
continue;
}
if !unsafe_database_exists(database_name, connection)
.await
.unwrap()
{
results.insert(
database_name.clone(),
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist),
);
continue;
}
let result = unsafe_get_database_privileges(database_name, connection)
.await
.map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string()));
results.insert(database_name.clone(), result);
}
debug_assert!(database_names.len() == results.len());
results
}
pub type GetAllDatabasesPrivilegeData =
Result<Vec<DatabasePrivilegeRow>, GetAllDatabasesPrivilegeDataError>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum GetAllDatabasesPrivilegeDataError {
MySqlError(String),
} }
/// Get all database + user + privileges pairs that are owned by the current user. /// Get all database + user + privileges pairs that are owned by the current user.
pub async fn get_all_database_privileges( pub async fn get_all_database_privileges(
unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> { ) -> GetAllDatabasesPrivilegeData {
let unix_user = get_current_unix_user()?; sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
indoc! {r#" indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database` (SELECT DISTINCT `SCHEMA_NAME` AS `database`
@ -217,672 +238,157 @@ pub async fn get_all_database_privileges(
"#}, "#},
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| format!("`{field}`")) .map(|field| quote_identifier(field))
.join(","), .join(","),
)) ))
.bind(create_user_group_matching_regex(&unix_user)) .bind(create_user_group_matching_regex(&unix_user))
.fetch_all(connection) .fetch_all(connection)
.await .await
.context("Failed to show databases")?; .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()))
Ok(result)
} }
/*************************/ async fn unsafe_apply_privilege_diff(
/* CLI INTERFACE PARSING */ database_privilege_diff: &DatabasePrivilegesDiff,
/*************************/ connection: &mut MySqlConnection,
) -> Result<(), sqlx::Error> {
/// See documentation for [`DatabaseCommand::EditDbPrivs`]. match database_privilege_diff {
pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result<DatabasePrivilegeRow> { DatabasePrivilegesDiff::New(p) => {
let parts: Vec<&str> = arg.split(':').collect(); let tables = DATABASE_PRIVILEGE_FIELDS
if parts.len() != 3 {
anyhow::bail!("Invalid argument format. See `edit-db-privs --help` for more information.");
}
let db = parts[0].to_string();
let user = parts[1].to_string();
let privs = parts[2].to_string();
let mut result = DatabasePrivilegeRow {
db,
user,
select_priv: false,
insert_priv: false,
update_priv: false,
delete_priv: false,
create_priv: false,
drop_priv: false,
alter_priv: false,
index_priv: false,
create_tmp_table_priv: false,
lock_tables_priv: false,
references_priv: false,
};
for char in privs.chars() {
match char {
's' => result.select_priv = true,
'i' => result.insert_priv = true,
'u' => result.update_priv = true,
'd' => result.delete_priv = true,
'c' => result.create_priv = true,
'D' => result.drop_priv = true,
'a' => result.alter_priv = true,
'I' => result.index_priv = true,
't' => result.create_tmp_table_priv = true,
'l' => result.lock_tables_priv = true,
'r' => result.references_priv = true,
'A' => {
result.select_priv = true;
result.insert_priv = true;
result.update_priv = true;
result.delete_priv = true;
result.create_priv = true;
result.drop_priv = true;
result.alter_priv = true;
result.index_priv = true;
result.create_tmp_table_priv = true;
result.lock_tables_priv = true;
result.references_priv = true;
}
_ => anyhow::bail!("Invalid privilege character: {}", char),
}
}
Ok(result)
}
/**********************************/
/* EDITOR CONTENT DISPLAY/DISPLAY */
/**********************************/
/// Generates a single row of the privileges table for the editor.
pub fn format_privileges_line_for_editor(
privs: &DatabasePrivilegeRow,
username_len: usize,
database_name_len: usize,
) -> String {
DATABASE_PRIVILEGE_FIELDS
.into_iter()
.map(|field| match field {
"db" => format!("{:width$}", privs.db, width = database_name_len),
"user" => format!("{:width$}", privs.user, width = username_len),
privilege => format!(
"{:width$}",
yn(privs.get_privilege_by_name(privilege)),
width = db_priv_field_human_readable_name(privilege).len()
),
})
.join(" ")
.trim()
.to_string()
}
const EDITOR_COMMENT: &str = r#"
# Welcome to the privilege editor.
# Each line defines what privileges a single user has on a single database.
# The first two columns respectively represent the database name and the user, and the remaining columns are the privileges.
# If the user should have a certain privilege, write 'Y', otherwise write 'N'.
#
# Lines starting with '#' are comments and will be ignored.
"#;
/// Generates the content for the privilege editor.
///
/// The unix user is used in case there are no privileges to edit,
/// so that the user can see an example line based on their username.
pub fn generate_editor_content_from_privilege_data(
privilege_data: &[DatabasePrivilegeRow],
unix_user: &str,
) -> String {
let example_user = format!("{}_user", unix_user);
let example_db = format!("{}_db", unix_user);
// NOTE: `.max()`` fails when the iterator is empty.
// In this case, we know that the only fields in the
// editor will be the example user and example db name.
// Hence, it's put as the fallback value, despite not really
// being a "fallback" in the normal sense.
let longest_username = privilege_data
.iter()
.map(|p| p.user.len())
.max()
.unwrap_or(example_user.len());
let longest_database_name = privilege_data
.iter()
.map(|p| p.db.len())
.max()
.unwrap_or(example_db.len());
let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS
.into_iter()
.map(db_priv_field_human_readable_name)
.collect();
// Pad the first two columns with spaces to align the privileges.
header[0] = format!("{:width$}", header[0], width = longest_database_name);
header[1] = format!("{:width$}", header[1], width = longest_username);
let example_line = format_privileges_line_for_editor(
&DatabasePrivilegeRow {
db: example_db,
user: example_user,
select_priv: true,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: false,
drop_priv: false,
alter_priv: false,
index_priv: false,
create_tmp_table_priv: false,
lock_tables_priv: false,
references_priv: false,
},
longest_username,
longest_database_name,
);
format!(
"{}\n{}\n{}",
EDITOR_COMMENT,
header.join(" "),
if privilege_data.is_empty() {
format!("# {}", example_line)
} else {
privilege_data
.iter() .iter()
.map(|privs| { .map(|field| quote_identifier(field))
format_privileges_line_for_editor( .join(",");
privs,
longest_username, let question_marks = std::iter::repeat("?")
longest_database_name, .take(DATABASE_PRIVILEGE_FIELDS.len())
) .join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db.to_string())
.bind(p.user.to_string())
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(connection)
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Modified(p) => {
let changes = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeChange::YesToNo(name) => {
format!("{} = 'N'", quote_identifier(name))
}
DatabasePrivilegeChange::NoToYes(name) => {
format!("{} = 'Y'", quote_identifier(name))
}
}) })
.join("\n") .join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(),
)
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
.await
.map(|_| ())
} }
) DatabasePrivilegesDiff::Deleted(p) => {
} sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db.to_string())
#[derive(Debug)] .bind(p.user.to_string())
enum PrivilegeRowParseResult { .execute(connection)
PrivilegeRow(DatabasePrivilegeRow), .await
ParserError(anyhow::Error), .map(|_| ())
TooFewFields(usize),
TooManyFields(usize),
Header,
Comment,
Empty,
}
#[inline]
fn parse_privilege_cell_from_editor(yn: &str, name: &str) -> anyhow::Result<bool> {
rev_yn(yn)
.ok_or_else(|| anyhow!("Expected Y or N, found {}", yn))
.context(format!("Could not parse {} privilege", name))
}
#[inline]
fn editor_row_is_header(row: &str) -> bool {
row.split_ascii_whitespace()
.zip(DATABASE_PRIVILEGE_FIELDS.iter())
.map(|(field, priv_name)| (field, db_priv_field_human_readable_name(priv_name)))
.all(|(field, header_field)| field == header_field)
}
/// Parse a single row of the privileges table from the editor.
fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult {
if row.starts_with('#') || row.starts_with("//") {
return PrivilegeRowParseResult::Comment;
}
if row.trim().is_empty() {
return PrivilegeRowParseResult::Empty;
}
let parts: Vec<&str> = row.trim().split_ascii_whitespace().collect();
match parts.len() {
n if (n < DATABASE_PRIVILEGE_FIELDS.len()) => {
return PrivilegeRowParseResult::TooFewFields(n)
}
n if (n > DATABASE_PRIVILEGE_FIELDS.len()) => {
return PrivilegeRowParseResult::TooManyFields(n)
}
_ => {}
}
if editor_row_is_header(row) {
return PrivilegeRowParseResult::Header;
}
let row = DatabasePrivilegeRow {
db: (*parts.first().unwrap()).to_owned(),
user: (*parts.get(1).unwrap()).to_owned(),
select_priv: match parse_privilege_cell_from_editor(
parts.get(2).unwrap(),
DATABASE_PRIVILEGE_FIELDS[2],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
insert_priv: match parse_privilege_cell_from_editor(
parts.get(3).unwrap(),
DATABASE_PRIVILEGE_FIELDS[3],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
update_priv: match parse_privilege_cell_from_editor(
parts.get(4).unwrap(),
DATABASE_PRIVILEGE_FIELDS[4],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
delete_priv: match parse_privilege_cell_from_editor(
parts.get(5).unwrap(),
DATABASE_PRIVILEGE_FIELDS[5],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
create_priv: match parse_privilege_cell_from_editor(
parts.get(6).unwrap(),
DATABASE_PRIVILEGE_FIELDS[6],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
drop_priv: match parse_privilege_cell_from_editor(
parts.get(7).unwrap(),
DATABASE_PRIVILEGE_FIELDS[7],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
alter_priv: match parse_privilege_cell_from_editor(
parts.get(8).unwrap(),
DATABASE_PRIVILEGE_FIELDS[8],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
index_priv: match parse_privilege_cell_from_editor(
parts.get(9).unwrap(),
DATABASE_PRIVILEGE_FIELDS[9],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
create_tmp_table_priv: match parse_privilege_cell_from_editor(
parts.get(10).unwrap(),
DATABASE_PRIVILEGE_FIELDS[10],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
lock_tables_priv: match parse_privilege_cell_from_editor(
parts.get(11).unwrap(),
DATABASE_PRIVILEGE_FIELDS[11],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
references_priv: match parse_privilege_cell_from_editor(
parts.get(12).unwrap(),
DATABASE_PRIVILEGE_FIELDS[12],
) {
Ok(p) => p,
Err(e) => return PrivilegeRowParseResult::ParserError(e),
},
};
PrivilegeRowParseResult::PrivilegeRow(row)
}
// TODO: return better errors
pub fn parse_privilege_data_from_editor_content(
content: String,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
content
.trim()
.split('\n')
.map(|line| line.trim())
.map(parse_privilege_row_from_editor)
.map(|result| match result {
PrivilegeRowParseResult::PrivilegeRow(row) => Ok(Some(row)),
PrivilegeRowParseResult::ParserError(e) => Err(e),
PrivilegeRowParseResult::TooFewFields(n) => Err(anyhow!(
"Too few fields in line. Expected to find {} fields, found {}",
DATABASE_PRIVILEGE_FIELDS.len(),
n
)),
PrivilegeRowParseResult::TooManyFields(n) => Err(anyhow!(
"Too many fields in line. Expected to find {} fields, found {}",
DATABASE_PRIVILEGE_FIELDS.len(),
n
)),
PrivilegeRowParseResult::Header => Ok(None),
PrivilegeRowParseResult::Comment => Ok(None),
PrivilegeRowParseResult::Empty => Ok(None),
})
.filter_map(|result| result.transpose())
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()
}
/*****************************/
/* CALCULATE PRIVILEGE DIFFS */
/*****************************/
/// This struct represents encapsulates the differences between two
/// instances of privilege sets for a single user on a single database.
///
/// The `User` and `Database` are the same for both instances.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRowDiff {
pub db: String,
pub user: String,
pub diff: BTreeSet<DatabasePrivilegeChange>,
}
/// This enum represents a change for a single privilege.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
pub enum DatabasePrivilegeChange {
YesToNo(String),
NoToYes(String),
}
impl DatabasePrivilegeChange {
pub fn new(p1: bool, p2: bool, name: &str) -> Option<DatabasePrivilegeChange> {
match (p1, p2) {
(true, false) => Some(DatabasePrivilegeChange::YesToNo(name.to_owned())),
(false, true) => Some(DatabasePrivilegeChange::NoToYes(name.to_owned())),
_ => None,
} }
} }
} }
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted. pub type ApplyDatabasePrivilegeChangesOutput =
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] BTreeMap<String, Result<(), ApplyDatabasePrivilegeChangesError>>;
pub enum DatabasePrivilegesDiff { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
New(DatabasePrivilegeRow), pub enum ApplyDatabasePrivilegeChangesError {
Modified(DatabasePrivilegeRowDiff), DatabaseSanitizationError(NameValidationError),
Deleted(DatabasePrivilegeRow), DatabaseOwnershipError(OwnerValidationError),
} UserSanitizationError(NameValidationError),
UserOwnershipError(OwnerValidationError),
/// This function calculates the differences between two sets of database privileges. DatabaseDoesNotExist,
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or DiffDoesNotApply(DatabasePrivilegeChange),
/// apply a set of privilege modifications to the database. MySqlError(String),
pub fn diff_privileges(
from: &[DatabasePrivilegeRow],
to: &[DatabasePrivilegeRow],
) -> BTreeSet<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let to_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter(
to.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p)),
);
let mut result = BTreeSet::new();
for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = old_p.diff(p);
if !diff.diff.is_empty() {
result.insert(DatabasePrivilegesDiff::Modified(diff));
}
} else {
result.insert(DatabasePrivilegesDiff::New(p.clone()));
}
}
for p in from {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
result.insert(DatabasePrivilegesDiff::Deleted(p.clone()));
}
}
result
} }
/// Uses the result of [`diff_privileges`] to modify privileges in the database. /// Uses the result of [`diff_privileges`] to modify privileges in the database.
pub async fn apply_privilege_diffs( pub async fn apply_privilege_diffs(
diffs: BTreeSet<DatabasePrivilegesDiff>, database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> ApplyDatabasePrivilegeChangesOutput {
for diff in diffs { let mut results: BTreeMap<String, _> = BTreeMap::new();
match diff {
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{field}`"))
.join(",");
let question_marks = std::iter::repeat("?") for diff in database_privilege_diffs {
.take(DATABASE_PRIVILEGE_FIELDS.len()) if let Err(err) = validate_name(diff.get_database_name()) {
.join(","); results.insert(
diff.get_database_name().to_string(),
sqlx::query( Err(ApplyDatabasePrivilegeChangesError::DatabaseSanitizationError(err)),
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), );
) continue;
.bind(p.db)
.bind(p.user)
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(&mut *connection)
.await?;
}
DatabasePrivilegesDiff::Modified(p) => {
let tables = p
.diff
.iter()
.map(|diff| match diff {
DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name),
DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name),
})
.join(",");
sqlx::query(
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
)
.bind(p.db)
.bind(p.user)
.execute(&mut *connection)
.await?;
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
.bind(p.db)
.bind(p.user)
.execute(&mut *connection)
.await?;
}
} }
}
Ok(())
}
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) {
diff.diff results.insert(
.iter() diff.get_database_name().to_string(),
.map(|change| match change { Err(ApplyDatabasePrivilegeChangesError::DatabaseOwnershipError(err)),
DatabasePrivilegeChange::YesToNo(name) => { );
format!("{}: Y -> N", db_priv_field_human_readable_name(name)) continue;
}
DatabasePrivilegeChange::NoToYes(name) => {
format!("{}: N -> Y", db_priv_field_human_readable_name(name))
}
})
.join("\n")
}
/// Displays the difference between two sets of database privileges.
pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> String {
let mut table = Table::new();
table.set_titles(row!["Database", "User", "Privilege diff",]);
for row in diffs {
match row {
DatabasePrivilegesDiff::New(p) => {
table.add_row(row![
p.db,
p.user,
"(New user)\n".to_string()
+ &display_privilege_cell(
&DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p)
)
]);
}
DatabasePrivilegesDiff::Modified(p) => {
table.add_row(row![p.db, p.user, display_privilege_cell(p),]);
}
DatabasePrivilegesDiff::Deleted(p) => {
table.add_row(row![
p.db,
p.user,
"(All privileges removed)\n".to_string()
+ &display_privilege_cell(
&p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user))
)
]);
}
} }
if let Err(err) = validate_name(diff.get_user_name()) {
results.insert(
diff.get_database_name().to_string(),
Err(ApplyDatabasePrivilegeChangesError::UserSanitizationError(err)),
);
continue;
}
if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) {
results.insert(
diff.get_database_name().to_string(),
Err(ApplyDatabasePrivilegeChangesError::UserOwnershipError(err)),
);
continue;
}
if !unsafe_database_exists(diff.get_database_name(), connection)
.await
.unwrap()
{
results.insert(
diff.get_database_name().to_string(),
Err(ApplyDatabasePrivilegeChangesError::DatabaseDoesNotExist),
);
continue;
}
// TODO: validate that the diff actually applies to the database
let result = unsafe_apply_privilege_diff(&diff, connection)
.await
.map_err(|e| ApplyDatabasePrivilegeChangesError::MySqlError(e.to_string()));
results.insert(diff.get_database_name().to_string(), result);
} }
table.to_string() results
}
/*********/
/* TESTS */
/*********/
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_privilege_change_creation() {
assert_eq!(
DatabasePrivilegeChange::new(true, false, "test"),
Some(DatabasePrivilegeChange::YesToNo("test".to_owned()))
);
assert_eq!(
DatabasePrivilegeChange::new(false, true, "test"),
Some(DatabasePrivilegeChange::NoToYes("test".to_owned()))
);
assert_eq!(DatabasePrivilegeChange::new(true, true, "test"), None);
assert_eq!(DatabasePrivilegeChange::new(false, false, "test"), None);
}
#[test]
fn test_diff_privileges() {
let row_to_be_modified = DatabasePrivilegeRow {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: true,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: false,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: false,
};
let mut row_to_be_deleted = row_to_be_modified.clone();
"user2".clone_into(&mut row_to_be_deleted.user);
let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()];
let mut modified_row = row_to_be_modified.clone();
modified_row.select_priv = false;
modified_row.insert_priv = false;
modified_row.index_priv = true;
let mut new_row = row_to_be_modified.clone();
"user3".clone_into(&mut new_row.user);
let to = vec![modified_row.clone(), new_row.clone()];
let diffs = diff_privileges(&from, &to);
assert_eq!(
diffs,
BTreeSet::from_iter(vec![
DatabasePrivilegesDiff::Deleted(row_to_be_deleted),
DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "db".to_owned(),
user: "user".to_owned(),
diff: BTreeSet::from_iter(vec![
DatabasePrivilegeChange::YesToNo("select_priv".to_owned()),
DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()),
DatabasePrivilegeChange::NoToYes("index_priv".to_owned()),
]),
}),
DatabasePrivilegesDiff::New(new_row),
])
);
}
#[test]
fn ensure_generated_and_parsed_editor_content_is_equal() {
let permissions = vec![
DatabasePrivilegeRow {
db: "db".to_owned(),
user: "user".to_owned(),
select_priv: true,
insert_priv: true,
update_priv: true,
delete_priv: true,
create_priv: true,
drop_priv: true,
alter_priv: true,
index_priv: true,
create_tmp_table_priv: true,
lock_tables_priv: true,
references_priv: true,
},
DatabasePrivilegeRow {
db: "db2".to_owned(),
user: "user2".to_owned(),
select_priv: false,
insert_priv: false,
update_priv: false,
delete_priv: false,
create_priv: false,
drop_priv: false,
alter_priv: false,
index_priv: false,
create_tmp_table_priv: false,
lock_tables_priv: false,
references_priv: false,
},
];
let content = generate_editor_content_from_privilege_data(&permissions, "user");
let parsed_permissions = parse_privilege_data_from_editor_content(content).unwrap();
assert_eq!(permissions, parsed_permissions);
}
} }

View File

@ -65,7 +65,10 @@ pub async fn create_database_users(
continue; continue;
} }
Err(err) => { Err(err) => {
results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string()))); results.insert(
db_user,
Err(CreateUserError::MySqlError(err.to_string())),
);
continue; continue;
} }
_ => {} _ => {}