diff --git a/Cargo.lock b/Cargo.lock index 02ca9cf..230629b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -470,6 +470,18 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "educe" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4bd92664bf78c4d3dba9b7cdafce6fa15b13ed3ed16175218196942e99168a8" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "either" version = "1.11.0" @@ -491,6 +503,26 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "env_filter" version = "0.1.0" @@ -964,6 +996,7 @@ dependencies = [ "dialoguer", "env_logger", "futures", + "futures-util", "indoc", "itertools", "log", @@ -976,6 +1009,8 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tokio-serde", + "tokio-stream", "tokio-util", "toml", "uuid", @@ -1109,6 +1144,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -1931,6 +1986,21 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "tokio-serde" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf600e7036b17782571dd44fa0a5cea3c82f60db5137f774a325a76a0d6852b" +dependencies = [ + "bincode", + "bytes", + "educe", + "futures-core", + "futures-sink", + "pin-project", + "serde", +] + [[package]] name = "tokio-stream" version = "0.1.15" diff --git a/Cargo.toml b/Cargo.toml index 003066a..77d154f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ clap = { version = "4.5.4", features = ["derive"] } dialoguer = "0.11.0" env_logger = "0.11.3" futures = "0.3.30" +futures-util = "0.3.30" indoc = "2.0.5" itertools = "0.12.1" log = "0.4.21" @@ -23,7 +24,9 @@ serde_json = { version = "1.0.116", features = ["preserve_order"] } sqlx = { version = "0.7.4", features = ["runtime-tokio", "mysql", "tls-rustls"] } thiserror = "1.0.63" tokio = { version = "1.37.0", features = ["rt", "macros"] } -tokio-util = "0.7.11" +tokio-serde = { version = "0.9.0", features = ["bincode"] } +tokio-stream = "0.1.15" +tokio-util = { version = "0.7.11", features = ["codec"] } toml = "0.8.12" uuid = { version = "1.10.0", features = ["v4"] } diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..2943954 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,6 @@ +mod common; +mod database_operations; +mod entrypoint; +mod input_sanitization; +mod protocol; +mod user_operations; diff --git a/src/server/common.rs b/src/server/common.rs new file mode 100644 index 0000000..bb74cd4 --- /dev/null +++ b/src/server/common.rs @@ -0,0 +1,113 @@ +use anyhow::Context; +use nix::unistd::{Group as LibcGroup, User as LibcUser}; +use sqlx::{Connection, MySqlConnection}; + +#[cfg(not(target_os = "macos"))] +use std::ffi::CString; + +/// Report the result status of a command. +/// This is used to display a status message to the user. +pub enum CommandStatus { + /// The command was successful, + /// and made modification to the database. + SuccessfullyModified, + + /// The command was mostly successful, + /// and modifications have been made to the database. + /// However, some of the requested modifications failed. + PartiallySuccessfullyModified, + + /// The command was successful, + /// but no modifications were needed. + NoModificationsNeeded, + + /// The command was successful, + /// and made no modification to the database. + NoModificationsIntended, + + /// The command was cancelled, either through a dialog or a signal. + /// No modifications have been made to the database. + Cancelled, +} + +// pub fn get_current_unix_user() -> anyhow::Result { +// User::from_uid(getuid()) +// .context("Failed to look up your UNIX username") +// .and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))) +// } + +pub struct UnixUser { + pub username: String, + pub uid: u32, + pub gid: u32, + pub groups: Vec, +} + +#[cfg(target_os = "macos")] +fn get_unix_groups(_user: &User) -> anyhow::Result> { + // Return an empty list on macOS since there is no `getgrouplist` function + Ok(vec![]) +} + +#[cfg(not(target_os = "macos"))] +fn get_unix_groups(user: &LibcUser) -> anyhow::Result> { + let user_cstr = + CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?; + let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)? + .iter() + .filter_map(|gid| match LibcGroup::from_gid(*gid) { + Ok(Some(group)) => Some(group), + Ok(None) => None, + Err(e) => { + log::warn!( + "Failed to look up group with GID {}: {}\nIgnoring...", + gid, + e + ); + None + } + }) + .collect::>(); + + Ok(groups) +} + +impl UnixUser { + pub fn from_uid(uid: u32) -> anyhow::Result { + let libc_uid = nix::unistd::Uid::from_raw(uid); + let libc_user = LibcUser::from_uid(libc_uid) + .context("Failed to look up your UNIX username")? + .ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))?; + + let groups = get_unix_groups(&libc_user)?; + + Ok(UnixUser { + username: libc_user.name, + uid, + gid: libc_user.gid.into(), + groups: groups.iter().map(|g| g.name.clone()).collect(), + }) + } +} + +/// This function creates a regex that matches items (users, databases) +/// that belong to the user or any of the user's groups. +pub fn create_user_group_matching_regex(user: &UnixUser) -> String { + if user.groups.is_empty() { + format!("{}(_.+)?", user.username) + } else { + format!("({}|{})(_.+)?", user.username, user.groups.join("|")) + } +} + +/// Gracefully close a MySQL connection. +pub async fn close_database_connection(connection: MySqlConnection) { + if let Err(e) = connection + .close() + .await + .context("Failed to close connection properly") + { + eprintln!("{}", e); + eprintln!("Ignoring..."); + } +} diff --git a/src/server/database_operations.rs b/src/server/database_operations.rs new file mode 100644 index 0000000..ada4189 --- /dev/null +++ b/src/server/database_operations.rs @@ -0,0 +1,183 @@ +use crate::server::common::UnixUser; +use crate::server::input_sanitization::quote_identifier; +use crate::server::input_sanitization::{validate_name, validate_ownership_by_unix_user}; +use crate::server::input_sanitization::{NameValidationError, OwnerValidationError}; + +use serde::{Deserialize, Serialize}; +use sqlx::prelude::*; + +use sqlx::MySqlConnection; +use std::collections::BTreeMap; + +use super::common::create_user_group_matching_regex; + +// NOTE: this function is unsafe because it does no input validation. +async fn unsafe_database_exists( + db_name: &str, + connection: &mut MySqlConnection, +) -> Result { + let result = + sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") + .bind(db_name) + .fetch_optional(connection) + .await?; + + Ok(result.is_some()) +} + +pub type CreateDatabasesOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum CreateDatabaseError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseAlreadyExists, + MySqlError(String), +} + +pub async fn create_databases( + database_names: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> CreateDatabasesOutput { + let mut results = BTreeMap::new(); + + for database_name in database_names { + if let Err(err) = validate_name(&database_name) { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::OwnershipError(err)), + ); + continue; + } + + match unsafe_database_exists(&database_name, &mut *connection).await { + Ok(true) => { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::DatabaseAlreadyExists), + ); + continue; + } + Err(err) => { + results.insert( + database_name.clone(), + Err(CreateDatabaseError::MySqlError(err.to_string())), + ); + continue; + } + _ => {} + } + + let result = + sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); + + results.insert(database_name, result); + } + + results +} + +pub type DropDatabasesOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DropDatabaseError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + DatabaseDoesNotExist, + MySqlError(String), +} + +pub async fn drop_databases( + database_names: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> DropDatabasesOutput { + let mut results = BTreeMap::new(); + + for database_name in database_names { + if let Err(err) = validate_name(&database_name) { + results.insert( + database_name.clone(), + Err(DropDatabaseError::SanitizationError(err)), + ); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) { + results.insert( + database_name.clone(), + Err(DropDatabaseError::OwnershipError(err)), + ); + continue; + } + + match unsafe_database_exists(&database_name, &mut *connection).await { + Ok(false) => { + results.insert( + database_name.clone(), + Err(DropDatabaseError::DatabaseDoesNotExist), + ); + continue; + } + Err(err) => { + results.insert( + database_name.clone(), + Err(DropDatabaseError::MySqlError(err.to_string())), + ); + continue; + } + _ => {} + } + + let result = + sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); + + results.insert(database_name, result); + } + + results +} + +pub type ListDatabasesOutput = Result, ListDatabasesError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ListDatabasesError { + MySqlError(String), +} + +pub async fn list_databases_for_user( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> Result, ListDatabasesError> { + sqlx::query( + r#" + SELECT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') + AND `SCHEMA_NAME` REGEXP ? + "#, + ) + .bind(create_user_group_matching_regex(unix_user)) + .fetch_all(connection) + .await + .and_then(|rows| { + rows.into_iter() + .map(|row| row.try_get::("database")) + .collect::, sqlx::Error>>() + }) + .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) +} \ No newline at end of file diff --git a/src/server/database_privilege_operations.rs b/src/server/database_privilege_operations.rs new file mode 100644 index 0000000..be21078 --- /dev/null +++ b/src/server/database_privilege_operations.rs @@ -0,0 +1,888 @@ +//! Database privilege operations +//! +//! This module contains functions for querying, modifying, +//! displaying and comparing database privileges. +//! +//! A lot of the complexity comes from two core components: +//! +//! - The privilege editor that needs to be able to print +//! an editable table of privileges and reparse the content +//! after the user has made manual changes. +//! +//! - The comparison functionality that tells the user what +//! changes will be made when applying a set of changes +//! to the list of database privileges. + +use std::collections::{BTreeSet, HashMap}; + +use anyhow::{anyhow, Context}; +use indoc::indoc; +use itertools::Itertools; +use prettytable::Table; +use serde::{Deserialize, Serialize}; +use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection}; + +use crate::core::{ + common::{ + create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn, + }, + database_operations::validate_database_name, +}; + +/// This is the list of fields that are used to fetch the db + user + privileges +/// from the `db` table in the database. If you need to add or remove privilege +/// fields, this is a good place to start. +pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [ + "db", + "user", + "select_priv", + "insert_priv", + "update_priv", + "delete_priv", + "create_priv", + "drop_priv", + "alter_priv", + "index_priv", + "create_tmp_table_priv", + "lock_tables_priv", + "references_priv", +]; + +pub fn db_priv_field_human_readable_name(name: &str) -> String { + match name { + "db" => "Database".to_owned(), + "user" => "User".to_owned(), + "select_priv" => "Select".to_owned(), + "insert_priv" => "Insert".to_owned(), + "update_priv" => "Update".to_owned(), + "delete_priv" => "Delete".to_owned(), + "create_priv" => "Create".to_owned(), + "drop_priv" => "Drop".to_owned(), + "alter_priv" => "Alter".to_owned(), + "index_priv" => "Index".to_owned(), + "create_tmp_table_priv" => "Temp".to_owned(), + "lock_tables_priv" => "Lock".to_owned(), + "references_priv" => "References".to_owned(), + _ => format!("Unknown({})", name), + } +} + +// NOTE: ord is needed for BTreeSet to accept the type, but it +// doesn't have any natural implementation semantics. + +/// This struct represents the set of privileges for a single user on a single database. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] +pub struct DatabasePrivilegeRow { + pub db: String, + pub user: String, + pub select_priv: bool, + pub insert_priv: bool, + pub update_priv: bool, + pub delete_priv: bool, + pub create_priv: bool, + pub drop_priv: bool, + pub alter_priv: bool, + pub index_priv: bool, + pub create_tmp_table_priv: bool, + pub lock_tables_priv: bool, + pub references_priv: bool, +} + +impl DatabasePrivilegeRow { + pub fn empty(db: &str, user: &str) -> Self { + Self { + db: db.to_owned(), + user: user.to_owned(), + select_priv: false, + insert_priv: false, + update_priv: false, + delete_priv: false, + create_priv: false, + drop_priv: false, + alter_priv: false, + index_priv: false, + create_tmp_table_priv: false, + lock_tables_priv: false, + references_priv: false, + } + } + + pub fn get_privilege_by_name(&self, name: &str) -> bool { + match name { + "select_priv" => self.select_priv, + "insert_priv" => self.insert_priv, + "update_priv" => self.update_priv, + "delete_priv" => self.delete_priv, + "create_priv" => self.create_priv, + "drop_priv" => self.drop_priv, + "alter_priv" => self.alter_priv, + "index_priv" => self.index_priv, + "create_tmp_table_priv" => self.create_tmp_table_priv, + "lock_tables_priv" => self.lock_tables_priv, + "references_priv" => self.references_priv, + _ => false, + } + } + + pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff { + debug_assert!(self.db == other.db && self.user == other.user); + + DatabasePrivilegeRowDiff { + db: self.db.clone(), + user: self.user.clone(), + diff: DATABASE_PRIVILEGE_FIELDS + .into_iter() + .skip(2) + .filter_map(|field| { + DatabasePrivilegeChange::new( + self.get_privilege_by_name(field), + other.get_privilege_by_name(field), + field, + ) + }) + .collect(), + } + } +} + +#[inline] +fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result { + let field = DATABASE_PRIVILEGE_FIELDS[position]; + let value = row.try_get(position)?; + match rev_yn(value) { + Some(val) => Ok(val), + _ => { + log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value); + Ok(false) + } + } +} + +impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow { + fn from_row(row: &MySqlRow) -> Result { + Ok(Self { + db: row.try_get("db")?, + user: row.try_get("user")?, + select_priv: get_mysql_row_priv_field(row, 2)?, + insert_priv: get_mysql_row_priv_field(row, 3)?, + update_priv: get_mysql_row_priv_field(row, 4)?, + delete_priv: get_mysql_row_priv_field(row, 5)?, + create_priv: get_mysql_row_priv_field(row, 6)?, + drop_priv: get_mysql_row_priv_field(row, 7)?, + alter_priv: get_mysql_row_priv_field(row, 8)?, + index_priv: get_mysql_row_priv_field(row, 9)?, + create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?, + lock_tables_priv: get_mysql_row_priv_field(row, 11)?, + references_priv: get_mysql_row_priv_field(row, 12)?, + }) + } +} + +/// Get all users + privileges for a single database. +pub async fn get_database_privileges( + database_name: &str, + connection: &mut MySqlConnection, +) -> anyhow::Result> { + let unix_user = get_current_unix_user()?; + validate_database_name(database_name, &unix_user)?; + + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + "SELECT {} FROM `db` WHERE `db` = ?", + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| quote_identifier(field)) + .join(","), + )) + .bind(database_name) + .fetch_all(connection) + .await + .context("Failed to show database")?; + + Ok(result) +} + +/// Get all database + user + privileges pairs that are owned by the current user. +pub async fn get_all_database_privileges( + connection: &mut MySqlConnection, +) -> anyhow::Result> { + let unix_user = get_current_unix_user()?; + + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + indoc! {r#" + SELECT {} FROM `db` WHERE `db` IN + (SELECT DISTINCT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') + AND `SCHEMA_NAME` REGEXP ?) + "#}, + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| format!("`{field}`")) + .join(","), + )) + .bind(create_user_group_matching_regex(&unix_user)) + .fetch_all(connection) + .await + .context("Failed to show databases")?; + + Ok(result) +} + +/*************************/ +/* CLI INTERFACE PARSING */ +/*************************/ + +/// See documentation for [`DatabaseCommand::EditDbPrivs`]. +pub fn parse_privilege_table_cli_arg(arg: &str) -> anyhow::Result { + let parts: Vec<&str> = arg.split(':').collect(); + if parts.len() != 3 { + anyhow::bail!("Invalid argument format. See `edit-db-privs --help` for more information."); + } + + let db = parts[0].to_string(); + let user = parts[1].to_string(); + let privs = parts[2].to_string(); + + let mut result = DatabasePrivilegeRow { + db, + user, + select_priv: false, + insert_priv: false, + update_priv: false, + delete_priv: false, + create_priv: false, + drop_priv: false, + alter_priv: false, + index_priv: false, + create_tmp_table_priv: false, + lock_tables_priv: false, + references_priv: false, + }; + + for char in privs.chars() { + match char { + 's' => result.select_priv = true, + 'i' => result.insert_priv = true, + 'u' => result.update_priv = true, + 'd' => result.delete_priv = true, + 'c' => result.create_priv = true, + 'D' => result.drop_priv = true, + 'a' => result.alter_priv = true, + 'I' => result.index_priv = true, + 't' => result.create_tmp_table_priv = true, + 'l' => result.lock_tables_priv = true, + 'r' => result.references_priv = true, + 'A' => { + result.select_priv = true; + result.insert_priv = true; + result.update_priv = true; + result.delete_priv = true; + result.create_priv = true; + result.drop_priv = true; + result.alter_priv = true; + result.index_priv = true; + result.create_tmp_table_priv = true; + result.lock_tables_priv = true; + result.references_priv = true; + } + _ => anyhow::bail!("Invalid privilege character: {}", char), + } + } + + Ok(result) +} + +/**********************************/ +/* EDITOR CONTENT DISPLAY/DISPLAY */ +/**********************************/ + +/// Generates a single row of the privileges table for the editor. +pub fn format_privileges_line_for_editor( + privs: &DatabasePrivilegeRow, + username_len: usize, + database_name_len: usize, +) -> String { + DATABASE_PRIVILEGE_FIELDS + .into_iter() + .map(|field| match field { + "db" => format!("{:width$}", privs.db, width = database_name_len), + "user" => format!("{:width$}", privs.user, width = username_len), + privilege => format!( + "{:width$}", + yn(privs.get_privilege_by_name(privilege)), + width = db_priv_field_human_readable_name(privilege).len() + ), + }) + .join(" ") + .trim() + .to_string() +} + +const EDITOR_COMMENT: &str = r#" +# Welcome to the privilege editor. +# Each line defines what privileges a single user has on a single database. +# The first two columns respectively represent the database name and the user, and the remaining columns are the privileges. +# If the user should have a certain privilege, write 'Y', otherwise write 'N'. +# +# Lines starting with '#' are comments and will be ignored. +"#; + +/// Generates the content for the privilege editor. +/// +/// The unix user is used in case there are no privileges to edit, +/// so that the user can see an example line based on their username. +pub fn generate_editor_content_from_privilege_data( + privilege_data: &[DatabasePrivilegeRow], + unix_user: &str, +) -> String { + let example_user = format!("{}_user", unix_user); + let example_db = format!("{}_db", unix_user); + + // NOTE: `.max()`` fails when the iterator is empty. + // In this case, we know that the only fields in the + // editor will be the example user and example db name. + // Hence, it's put as the fallback value, despite not really + // being a "fallback" in the normal sense. + let longest_username = privilege_data + .iter() + .map(|p| p.user.len()) + .max() + .unwrap_or(example_user.len()); + + let longest_database_name = privilege_data + .iter() + .map(|p| p.db.len()) + .max() + .unwrap_or(example_db.len()); + + let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS + .into_iter() + .map(db_priv_field_human_readable_name) + .collect(); + + // Pad the first two columns with spaces to align the privileges. + header[0] = format!("{:width$}", header[0], width = longest_database_name); + header[1] = format!("{:width$}", header[1], width = longest_username); + + let example_line = format_privileges_line_for_editor( + &DatabasePrivilegeRow { + db: example_db, + user: example_user, + select_priv: true, + insert_priv: true, + update_priv: true, + delete_priv: true, + create_priv: false, + drop_priv: false, + alter_priv: false, + index_priv: false, + create_tmp_table_priv: false, + lock_tables_priv: false, + references_priv: false, + }, + longest_username, + longest_database_name, + ); + + format!( + "{}\n{}\n{}", + EDITOR_COMMENT, + header.join(" "), + if privilege_data.is_empty() { + format!("# {}", example_line) + } else { + privilege_data + .iter() + .map(|privs| { + format_privileges_line_for_editor( + privs, + longest_username, + longest_database_name, + ) + }) + .join("\n") + } + ) +} + +#[derive(Debug)] +enum PrivilegeRowParseResult { + PrivilegeRow(DatabasePrivilegeRow), + ParserError(anyhow::Error), + TooFewFields(usize), + TooManyFields(usize), + Header, + Comment, + Empty, +} + +#[inline] +fn parse_privilege_cell_from_editor(yn: &str, name: &str) -> anyhow::Result { + rev_yn(yn) + .ok_or_else(|| anyhow!("Expected Y or N, found {}", yn)) + .context(format!("Could not parse {} privilege", name)) +} + +#[inline] +fn editor_row_is_header(row: &str) -> bool { + row.split_ascii_whitespace() + .zip(DATABASE_PRIVILEGE_FIELDS.iter()) + .map(|(field, priv_name)| (field, db_priv_field_human_readable_name(priv_name))) + .all(|(field, header_field)| field == header_field) +} + +/// Parse a single row of the privileges table from the editor. +fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult { + if row.starts_with('#') || row.starts_with("//") { + return PrivilegeRowParseResult::Comment; + } + + if row.trim().is_empty() { + return PrivilegeRowParseResult::Empty; + } + + let parts: Vec<&str> = row.trim().split_ascii_whitespace().collect(); + + match parts.len() { + n if (n < DATABASE_PRIVILEGE_FIELDS.len()) => { + return PrivilegeRowParseResult::TooFewFields(n) + } + n if (n > DATABASE_PRIVILEGE_FIELDS.len()) => { + return PrivilegeRowParseResult::TooManyFields(n) + } + _ => {} + } + + if editor_row_is_header(row) { + return PrivilegeRowParseResult::Header; + } + + let row = DatabasePrivilegeRow { + db: (*parts.first().unwrap()).to_owned(), + user: (*parts.get(1).unwrap()).to_owned(), + select_priv: match parse_privilege_cell_from_editor( + parts.get(2).unwrap(), + DATABASE_PRIVILEGE_FIELDS[2], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + insert_priv: match parse_privilege_cell_from_editor( + parts.get(3).unwrap(), + DATABASE_PRIVILEGE_FIELDS[3], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + update_priv: match parse_privilege_cell_from_editor( + parts.get(4).unwrap(), + DATABASE_PRIVILEGE_FIELDS[4], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + delete_priv: match parse_privilege_cell_from_editor( + parts.get(5).unwrap(), + DATABASE_PRIVILEGE_FIELDS[5], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + create_priv: match parse_privilege_cell_from_editor( + parts.get(6).unwrap(), + DATABASE_PRIVILEGE_FIELDS[6], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + drop_priv: match parse_privilege_cell_from_editor( + parts.get(7).unwrap(), + DATABASE_PRIVILEGE_FIELDS[7], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + alter_priv: match parse_privilege_cell_from_editor( + parts.get(8).unwrap(), + DATABASE_PRIVILEGE_FIELDS[8], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + index_priv: match parse_privilege_cell_from_editor( + parts.get(9).unwrap(), + DATABASE_PRIVILEGE_FIELDS[9], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + create_tmp_table_priv: match parse_privilege_cell_from_editor( + parts.get(10).unwrap(), + DATABASE_PRIVILEGE_FIELDS[10], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + lock_tables_priv: match parse_privilege_cell_from_editor( + parts.get(11).unwrap(), + DATABASE_PRIVILEGE_FIELDS[11], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + references_priv: match parse_privilege_cell_from_editor( + parts.get(12).unwrap(), + DATABASE_PRIVILEGE_FIELDS[12], + ) { + Ok(p) => p, + Err(e) => return PrivilegeRowParseResult::ParserError(e), + }, + }; + + PrivilegeRowParseResult::PrivilegeRow(row) +} + +// TODO: return better errors + +pub fn parse_privilege_data_from_editor_content( + content: String, +) -> anyhow::Result> { + content + .trim() + .split('\n') + .map(|line| line.trim()) + .map(parse_privilege_row_from_editor) + .map(|result| match result { + PrivilegeRowParseResult::PrivilegeRow(row) => Ok(Some(row)), + PrivilegeRowParseResult::ParserError(e) => Err(e), + PrivilegeRowParseResult::TooFewFields(n) => Err(anyhow!( + "Too few fields in line. Expected to find {} fields, found {}", + DATABASE_PRIVILEGE_FIELDS.len(), + n + )), + PrivilegeRowParseResult::TooManyFields(n) => Err(anyhow!( + "Too many fields in line. Expected to find {} fields, found {}", + DATABASE_PRIVILEGE_FIELDS.len(), + n + )), + PrivilegeRowParseResult::Header => Ok(None), + PrivilegeRowParseResult::Comment => Ok(None), + PrivilegeRowParseResult::Empty => Ok(None), + }) + .filter_map(|result| result.transpose()) + .collect::>>() +} + +/*****************************/ +/* CALCULATE PRIVILEGE DIFFS */ +/*****************************/ + +/// This struct represents encapsulates the differences between two +/// instances of privilege sets for a single user on a single database. +/// +/// The `User` and `Database` are the same for both instances. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +pub struct DatabasePrivilegeRowDiff { + pub db: String, + pub user: String, + pub diff: BTreeSet, +} + +/// This enum represents a change for a single privilege. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +pub enum DatabasePrivilegeChange { + YesToNo(String), + NoToYes(String), +} + +impl DatabasePrivilegeChange { + pub fn new(p1: bool, p2: bool, name: &str) -> Option { + match (p1, p2) { + (true, false) => Some(DatabasePrivilegeChange::YesToNo(name.to_owned())), + (false, true) => Some(DatabasePrivilegeChange::NoToYes(name.to_owned())), + _ => None, + } + } +} + +/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)] +pub enum DatabasePrivilegesDiff { + New(DatabasePrivilegeRow), + Modified(DatabasePrivilegeRowDiff), + Deleted(DatabasePrivilegeRow), +} + +/// This function calculates the differences between two sets of database privileges. +/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or +/// apply a set of privilege modifications to the database. +pub fn diff_privileges( + from: &[DatabasePrivilegeRow], + to: &[DatabasePrivilegeRow], +) -> BTreeSet { + let from_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( + from.iter() + .cloned() + .map(|p| ((p.db.clone(), p.user.clone()), p)), + ); + + let to_lookup_table: HashMap<(String, String), DatabasePrivilegeRow> = HashMap::from_iter( + to.iter() + .cloned() + .map(|p| ((p.db.clone(), p.user.clone()), p)), + ); + + let mut result = BTreeSet::new(); + + for p in to { + if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) { + let diff = old_p.diff(p); + if !diff.diff.is_empty() { + result.insert(DatabasePrivilegesDiff::Modified(diff)); + } + } else { + result.insert(DatabasePrivilegesDiff::New(p.clone())); + } + } + + for p in from { + if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) { + result.insert(DatabasePrivilegesDiff::Deleted(p.clone())); + } + } + + result +} + +/// Uses the result of [`diff_privileges`] to modify privileges in the database. +pub async fn apply_privilege_diffs( + diffs: BTreeSet, + connection: &mut MySqlConnection, +) -> anyhow::Result<()> { + for diff in diffs { + match diff { + DatabasePrivilegesDiff::New(p) => { + let tables = DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| format!("`{field}`")) + .join(","); + + let question_marks = std::iter::repeat("?") + .take(DATABASE_PRIVILEGE_FIELDS.len()) + .join(","); + + sqlx::query( + format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(), + ) + .bind(p.db) + .bind(p.user) + .bind(yn(p.select_priv)) + .bind(yn(p.insert_priv)) + .bind(yn(p.update_priv)) + .bind(yn(p.delete_priv)) + .bind(yn(p.create_priv)) + .bind(yn(p.drop_priv)) + .bind(yn(p.alter_priv)) + .bind(yn(p.index_priv)) + .bind(yn(p.create_tmp_table_priv)) + .bind(yn(p.lock_tables_priv)) + .bind(yn(p.references_priv)) + .execute(&mut *connection) + .await?; + } + DatabasePrivilegesDiff::Modified(p) => { + let tables = p + .diff + .iter() + .map(|diff| match diff { + DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name), + DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name), + }) + .join(","); + + sqlx::query( + format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(), + ) + .bind(p.db) + .bind(p.user) + .execute(&mut *connection) + .await?; + } + DatabasePrivilegesDiff::Deleted(p) => { + sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?") + .bind(p.db) + .bind(p.user) + .execute(&mut *connection) + .await?; + } + } + } + Ok(()) +} + +fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String { + diff.diff + .iter() + .map(|change| match change { + DatabasePrivilegeChange::YesToNo(name) => { + format!("{}: Y -> N", db_priv_field_human_readable_name(name)) + } + DatabasePrivilegeChange::NoToYes(name) => { + format!("{}: N -> Y", db_priv_field_human_readable_name(name)) + } + }) + .join("\n") +} + +/// Displays the difference between two sets of database privileges. +pub fn display_privilege_diffs(diffs: &BTreeSet) -> String { + let mut table = Table::new(); + table.set_titles(row!["Database", "User", "Privilege diff",]); + for row in diffs { + match row { + DatabasePrivilegesDiff::New(p) => { + table.add_row(row![ + p.db, + p.user, + "(New user)\n".to_string() + + &display_privilege_cell( + &DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p) + ) + ]); + } + DatabasePrivilegesDiff::Modified(p) => { + table.add_row(row![p.db, p.user, display_privilege_cell(p),]); + } + DatabasePrivilegesDiff::Deleted(p) => { + table.add_row(row![ + p.db, + p.user, + "(All privileges removed)\n".to_string() + + &display_privilege_cell( + &p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user)) + ) + ]); + } + } + } + + table.to_string() +} + +/*********/ +/* TESTS */ +/*********/ + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_database_privilege_change_creation() { + assert_eq!( + DatabasePrivilegeChange::new(true, false, "test"), + Some(DatabasePrivilegeChange::YesToNo("test".to_owned())) + ); + assert_eq!( + DatabasePrivilegeChange::new(false, true, "test"), + Some(DatabasePrivilegeChange::NoToYes("test".to_owned())) + ); + assert_eq!(DatabasePrivilegeChange::new(true, true, "test"), None); + assert_eq!(DatabasePrivilegeChange::new(false, false, "test"), None); + } + + #[test] + fn test_diff_privileges() { + let row_to_be_modified = DatabasePrivilegeRow { + db: "db".to_owned(), + user: "user".to_owned(), + select_priv: true, + insert_priv: true, + update_priv: true, + delete_priv: true, + create_priv: true, + drop_priv: true, + alter_priv: true, + index_priv: false, + create_tmp_table_priv: true, + lock_tables_priv: true, + references_priv: false, + }; + + let mut row_to_be_deleted = row_to_be_modified.clone(); + "user2".clone_into(&mut row_to_be_deleted.user); + + let from = vec![row_to_be_modified.clone(), row_to_be_deleted.clone()]; + + let mut modified_row = row_to_be_modified.clone(); + modified_row.select_priv = false; + modified_row.insert_priv = false; + modified_row.index_priv = true; + + let mut new_row = row_to_be_modified.clone(); + "user3".clone_into(&mut new_row.user); + + let to = vec![modified_row.clone(), new_row.clone()]; + + let diffs = diff_privileges(&from, &to); + + assert_eq!( + diffs, + BTreeSet::from_iter(vec![ + DatabasePrivilegesDiff::Deleted(row_to_be_deleted), + DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff { + db: "db".to_owned(), + user: "user".to_owned(), + diff: BTreeSet::from_iter(vec![ + DatabasePrivilegeChange::YesToNo("select_priv".to_owned()), + DatabasePrivilegeChange::YesToNo("insert_priv".to_owned()), + DatabasePrivilegeChange::NoToYes("index_priv".to_owned()), + ]), + }), + DatabasePrivilegesDiff::New(new_row), + ]) + ); + } + + #[test] + fn ensure_generated_and_parsed_editor_content_is_equal() { + let permissions = vec![ + DatabasePrivilegeRow { + db: "db".to_owned(), + user: "user".to_owned(), + select_priv: true, + insert_priv: true, + update_priv: true, + delete_priv: true, + create_priv: true, + drop_priv: true, + alter_priv: true, + index_priv: true, + create_tmp_table_priv: true, + lock_tables_priv: true, + references_priv: true, + }, + DatabasePrivilegeRow { + db: "db2".to_owned(), + user: "user2".to_owned(), + select_priv: false, + insert_priv: false, + update_priv: false, + delete_priv: false, + create_priv: false, + drop_priv: false, + alter_priv: false, + index_priv: false, + create_tmp_table_priv: false, + lock_tables_priv: false, + references_priv: false, + }, + ]; + + let content = generate_editor_content_from_privilege_data(&permissions, "user"); + + let parsed_permissions = parse_privilege_data_from_editor_content(content).unwrap(); + + assert_eq!(permissions, parsed_permissions); + } +} diff --git a/src/server/entrypoint.rs b/src/server/entrypoint.rs new file mode 100644 index 0000000..309e768 --- /dev/null +++ b/src/server/entrypoint.rs @@ -0,0 +1,89 @@ +use futures_util::{SinkExt, StreamExt}; +use sqlx::MySqlConnection; +use tokio::net::UnixStream; +use tokio_serde::{formats::Bincode, Framed as SerdeFramed}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; + +// use crate::server:: + +use crate::server::protocol::{Request, Response}; + +use super::{ + common::UnixUser, database_operations::{create_databases, drop_databases}, user_operations::{create_database_users, drop_database_users, set_password_for_database_user} +}; + +pub type ClientToServerMessageStream<'a> = SerdeFramed< + Framed<&'a mut UnixStream, LengthDelimitedCodec>, + Request, + Response, + Bincode, +>; + +pub async fn run_server( + socket: &mut UnixStream, + unix_user: &UnixUser, + db_connection: &mut MySqlConnection, +) -> anyhow::Result<()> { + let length_delimited = Framed::new(socket, LengthDelimitedCodec::new()); + let mut stream: ClientToServerMessageStream = + tokio_serde::Framed::new(length_delimited, Bincode::default()); + + // TODO: better error handling + let request = match stream.next().await { + Some(Ok(request)) => request, + Some(Err(e)) => return Err(e.into()), + None => return Err(anyhow::anyhow!("No request received")), + }; + + match request { + Request::CreateDatabases(databases) => { + let result = create_databases(databases, unix_user, db_connection).await; + stream.send(Response::CreateDatabases(result)).await?; + stream.flush().await?; + } + Request::DropDatabases(databases) => { + let result = drop_databases(databases, unix_user, db_connection).await; + stream.send(Response::DropDatabases(result)).await?; + stream.flush().await?; + } + Request::ListDatabases => { + println!("Listing databases"); + // let result = list_databases(unix_user, db_connection).await; + // stream.send(Response::ListDatabases(result)).await?; + // stream.flush().await?; + } + Request::ListPrivileges(users) => { + println!("Listing privileges for users: {:?}", users); + } + Request::ModifyPrivileges(()) => { + println!("Modifying privileges"); + } + Request::CreateUsers(db_users) => { + let result = create_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::CreateUsers(result)).await?; + stream.flush().await?; + } + Request::DropUsers(db_users) => { + let result = drop_database_users(db_users, unix_user, db_connection).await; + stream.send(Response::DropUsers(result)).await?; + stream.flush().await?; + } + Request::PasswdUser(db_user, password) => { + let result = + set_password_for_database_user(&db_user, &password, unix_user, db_connection).await; + stream.send(Response::PasswdUser(result)).await?; + stream.flush().await?; + } + Request::ListUsers(db_users) => { + println!("Listing users: {:?}", db_users); + } + Request::LockUsers(db_users) => { + println!("Locking users: {:?}", db_users); + } + Request::UnlockUsers(db_users) => { + println!("Unlocking users: {:?}", db_users); + } + } + + Ok(()) +} diff --git a/src/server/input_sanitization.rs b/src/server/input_sanitization.rs new file mode 100644 index 0000000..d06eff5 --- /dev/null +++ b/src/server/input_sanitization.rs @@ -0,0 +1,251 @@ +use super::common::UnixUser; + +use serde::{Deserialize, Serialize}; + +const MAX_NAME_LENGTH: usize = 64; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum NameValidationError { + EmptyString, + InvalidCharacters, + TooLong, +} + +pub fn validate_name(name: &str) -> Result<(), NameValidationError> { + if name.is_empty() { + Err(NameValidationError::EmptyString) + } else if name.len() > MAX_NAME_LENGTH { + Err(NameValidationError::TooLong) + } else if !name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + Err(NameValidationError::InvalidCharacters) + } else { + Ok(()) + } +} + +// TODO: move to cli + +// pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> { +// match validate_name(name) { +// NameValidationError::Valid => Ok(()), +// NameValidationError::EmptyString => { +// anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized()) +// } +// NameValidationError::TooLong => anyhow::bail!( +// "{} is too long. Maximum length is 64 characters.", +// db_or_user.capitalized() +// ), +// NameValidationError::InvalidCharacters => anyhow::bail!( +// indoc! {r#" +// Invalid characters in {} name: '{}' + +// Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. +// "#}, +// db_or_user.lowercased(), +// name +// ), +// } +// } + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] +pub enum OwnerValidationError { + // The name is valid, but none of the given prefixes matched the name + NoMatch, + + // The name is empty, which is invalid + StringEmpty, + + // The name is in the format "_", which is invalid + MissingPrefix, + + // The name is in the format "_", which is invalid + MissingPostfix, +} + +pub fn validate_ownership_by_unix_user( + name: &str, + user: &UnixUser, +) -> Result<(), OwnerValidationError> { + let prefixes = std::iter::once(user.username.clone()) + .chain(user.groups.iter().cloned()) + .collect::>(); + + validate_ownership_by_prefixes(name, &prefixes) +} + +/// Core logic for validating the ownership of a database name. +/// This function checks if the given name matches any of the given prefixes. +/// These prefixes will in most cases be the user's unix username and any +/// unix groups the user is a member of. +pub fn validate_ownership_by_prefixes( + name: &str, + prefixes: &[String], +) -> Result<(), OwnerValidationError> { + if name.is_empty() { + return Err(OwnerValidationError::StringEmpty); + } + + if name.starts_with('_') { + return Err(OwnerValidationError::MissingPrefix); + } + + let (prefix, _) = match name.split_once('_') { + Some(pair) => pair, + None => return Err(OwnerValidationError::MissingPostfix), + }; + + if !prefixes.iter().any(|g| g == prefix) { + return Err(OwnerValidationError::NoMatch); + } + + Ok(()) +} + +// TODO: move to cli + +/// Validate the ownership of a database name or database user name. +/// This function takes the name of a database or user and a unix user, +/// for which it fetches the user's groups. It then checks if the name +/// is prefixed with the user's username or any of the user's groups. +// pub fn validate_ownership_or_error<'a>( +// name: &'a str, +// user: &User, +// db_or_user: DbOrUser, +// ) -> anyhow::Result<&'a str> { +// let user_groups = get_unix_groups(user)?; +// let prefixes = std::iter::once(user.name.clone()) +// .chain(user_groups.iter().map(|g| g.name.clone())) +// .collect::>(); + +// match validate_ownership_by_prefixes(name, &prefixes) { +// OwnerValidationResult::Match => Ok(name), +// OwnerValidationResult::NoMatch => { +// anyhow::bail!( +// indoc! {r#" +// Invalid {} name prefix: '{}' does not match your username or any of your groups. +// Are you sure you are allowed to create {} names with this prefix? + +// Allowed prefixes: +// - {} +// {} +// "#}, +// db_or_user.lowercased(), +// name, +// db_or_user.lowercased(), +// user.name, +// user_groups +// .iter() +// .filter(|g| g.name != user.name) +// .map(|g| format!(" - {}", g.name)) +// .sorted() +// .join("\n"), +// ); +// } +// _ => anyhow::bail!( +// "'{}' is not a valid {} name.", +// name, +// db_or_user.lowercased() +// ), +// } +// } + +#[inline] +pub fn quote_literal(s: &str) -> String { + format!("'{}'", s.replace('\'', r"\'")) +} + +#[inline] +pub fn quote_identifier(s: &str) -> String { + format!("`{}`", s.replace('`', r"\`")) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_quote_literal() { + let payload = "' OR 1=1 --"; + assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#); + } + + #[test] + fn test_quote_identifier() { + let payload = "` OR 1=1 --"; + assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#); + } + + #[test] + fn test_validate_name() { + assert_eq!(validate_name(""), Err(NameValidationError::EmptyString)); + assert_eq!(validate_name("abcdefghijklmnopqrstuvwxyz"), Ok(())); + assert_eq!(validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), Ok(())); + assert_eq!(validate_name("0123456789_-"), Ok(())); + + for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() { + assert_eq!( + validate_name(&c.to_string()), + Err(NameValidationError::InvalidCharacters) + ); + } + + assert_eq!(validate_name(&"a".repeat(MAX_NAME_LENGTH)), Ok(())); + + assert_eq!( + validate_name(&"a".repeat(MAX_NAME_LENGTH + 1)), + Err(NameValidationError::TooLong) + ); + } + + #[test] + fn test_validate_owner_by_prefixes() { + let prefixes = vec!["user".to_string(), "group".to_string()]; + + assert_eq!( + validate_ownership_by_prefixes("", &prefixes), + Err(OwnerValidationError::StringEmpty) + ); + + assert_eq!( + validate_ownership_by_prefixes("user", &prefixes), + Err(OwnerValidationError::MissingPostfix) + ); + assert_eq!( + validate_ownership_by_prefixes("something", &prefixes), + Err(OwnerValidationError::MissingPostfix) + ); + assert_eq!( + validate_ownership_by_prefixes("user-testdb", &prefixes), + Err(OwnerValidationError::MissingPostfix) + ); + + assert_eq!( + validate_ownership_by_prefixes("_testdb", &prefixes), + Err(OwnerValidationError::MissingPrefix) + ); + + assert_eq!( + validate_ownership_by_prefixes("user_testdb", &prefixes), + Ok(()) + ); + assert_eq!( + validate_ownership_by_prefixes("group_testdb", &prefixes), + Ok(()) + ); + assert_eq!( + validate_ownership_by_prefixes("group_test_db", &prefixes), + Ok(()) + ); + assert_eq!( + validate_ownership_by_prefixes("group_test-db", &prefixes), + Ok(()) + ); + + assert_eq!( + validate_ownership_by_prefixes("nonexistent_testdb", &prefixes), + Err(OwnerValidationError::NoMatch) + ); + } +} diff --git a/src/server/protocol.rs b/src/server/protocol.rs new file mode 100644 index 0000000..f0faeb0 --- /dev/null +++ b/src/server/protocol.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; + +use super::{database_operations::{CreateDatabasesOutput, DropDatabasesOutput}, user_operations::{ + CreateUsersOutput, DropUsersOutput, LockUsersOutput, SetPasswordOutput, UnlockUsersOutput, +}}; + +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Request { + CreateDatabases(Vec), + DropDatabases(Vec), + ListDatabases, + ListPrivileges(Vec), + ModifyPrivileges(()), // what data should be sent with this command? Who should calculate the diff? + + CreateUsers(Vec), + DropUsers(Vec), + PasswdUser(String, String), + ListUsers(Option>), + LockUsers(Vec), + UnlockUsers(Vec), +} + +// TODO: include a generic "message" that will display a message to the user? + +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum Response { + // Specific data for specific commands + CreateDatabases(CreateDatabasesOutput), + DropDatabases(DropDatabasesOutput), + // ListDatabases(ListDatabasesOutput), + // ListPrivileges(ListPrivilegesOutput), + CreateUsers(CreateUsersOutput), + DropUsers(DropUsersOutput), + PasswdUser(SetPasswordOutput), + ListUsers(()), // what data should be sent with this response? + LockUsers(LockUsersOutput), + UnlockUsers(UnlockUsersOutput), + + // Generic responses + OperationAborted, + Error(String), + Exit, +} diff --git a/src/server/user_operations.rs b/src/server/user_operations.rs new file mode 100644 index 0000000..51abf6c --- /dev/null +++ b/src/server/user_operations.rs @@ -0,0 +1,395 @@ +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; + +use sqlx::prelude::*; +use sqlx::MySqlConnection; + +use super::common::create_user_group_matching_regex; +use super::common::UnixUser; +use super::input_sanitization::{ + quote_literal, validate_name, validate_ownership_by_unix_user, NameValidationError, + OwnerValidationError, +}; + +// NOTE: this function is unsafe because it does no input validation. +async fn unsafe_user_exists( + db_user: &str, + connection: &mut MySqlConnection, +) -> Result { + sqlx::query( + r#" + SELECT EXISTS( + SELECT 1 + FROM `mysql`.`user` + WHERE `User` = ? + ) + "#, + ) + .bind(db_user) + .fetch_one(connection) + .await + .map(|row| row.get::(0)) +} + +pub type CreateUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum CreateUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserAlreadyExists, + MySqlError(String), +} + +pub async fn create_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> CreateUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(CreateUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(CreateUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(true) => { + results.insert(db_user, Err(CreateUserError::UserAlreadyExists)); + continue; + } + Err(err) => { + results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| CreateUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +pub type DropUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum DropUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + MySqlError(String), +} + +pub async fn drop_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> DropUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(DropUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(DropUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(DropUserError::UserDoesNotExist)); + continue; + } + Err(err) => { + results.insert(db_user, Err(DropUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str()) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| DropUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +pub type SetPasswordOutput = Result<(), SetPasswordError>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SetPasswordError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + MySqlError(String), +} + +pub async fn set_password_for_database_user( + db_user: &str, + password: &str, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> SetPasswordOutput { + if let Err(err) = validate_name(db_user) { + return Err(SetPasswordError::SanitizationError(err)); + } + + if let Err(err) = validate_ownership_by_unix_user(db_user, unix_user) { + return Err(SetPasswordError::OwnershipError(err)); + } + + match unsafe_user_exists(db_user, &mut *connection).await { + Ok(false) => return Err(SetPasswordError::UserDoesNotExist), + Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())), + _ => {} + } + + sqlx::query( + format!( + "ALTER USER {}@'%' IDENTIFIED BY {}", + quote_literal(db_user), + quote_literal(password).as_str() + ) + .as_str(), + ) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| SetPasswordError::MySqlError(err.to_string())) +} + +// NOTE: this function is unsafe because it does no input validation. +async fn database_user_is_locked_unsafe( + db_user: &str, + connection: &mut MySqlConnection, +) -> Result { + sqlx::query( + r#" + SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true' + FROM `mysql`.`global_priv` + WHERE `User` = ? + AND `Host` = '%' + "#, + ) + .bind(db_user) + .fetch_one(connection) + .await + .map(|row| row.get::(0)) +} + +pub type LockUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum LockUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + UserIsAlreadyLocked, + MySqlError(String), +} + +pub async fn lock_database_users( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> LockUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(LockUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(LockUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(LockUserError::UserDoesNotExist)); + continue; + } + Err(err) => { + results.insert(db_user, Err(LockUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + match database_user_is_locked_unsafe(&db_user, &mut *connection).await { + Ok(true) => { + results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked)); + continue; + } + Err(err) => { + results.insert(db_user, Err(LockUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query( + format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(), + ) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| LockUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +pub type UnlockUsersOutput = BTreeMap>; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum UnlockUserError { + SanitizationError(NameValidationError), + OwnershipError(OwnerValidationError), + UserDoesNotExist, + UserIsAlreadyUnlocked, + MySqlError(String), +} + +pub async fn unlock_database_user( + db_users: Vec, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> UnlockUsersOutput { + let mut results = BTreeMap::new(); + + for db_user in db_users { + if let Err(err) = validate_name(&db_user) { + results.insert(db_user, Err(UnlockUserError::SanitizationError(err))); + continue; + } + + if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) { + results.insert(db_user, Err(UnlockUserError::OwnershipError(err))); + continue; + } + + match unsafe_user_exists(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(UnlockUserError::UserDoesNotExist)); + continue; + } + Err(err) => { + results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + match database_user_is_locked_unsafe(&db_user, &mut *connection).await { + Ok(false) => { + results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked)); + continue; + } + Err(err) => { + results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string()))); + continue; + } + _ => {} + } + + let result = sqlx::query( + format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(), + ) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| UnlockUserError::MySqlError(err.to_string())); + + results.insert(db_user, result); + } + + results +} + +/// This struct contains information about a database user. +/// This can be extended if we need more information in the future. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)] +pub struct DatabaseUser { + #[sqlx(rename = "User")] + pub user: String, + + #[allow(dead_code)] + #[serde(skip)] + #[sqlx(rename = "Host")] + pub host: String, + + #[sqlx(rename = "has_password")] + pub has_password: bool, + + #[sqlx(rename = "is_locked")] + pub is_locked: bool, +} + +const DB_USER_SELECT_STATEMENT: &str = r#" +SELECT + `mysql`.`user`.`User`, + `mysql`.`user`.`Host`, + `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`, + COALESCE( + JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), + 'false' + ) != 'false' AS `is_locked` +FROM `mysql`.`user` +JOIN `mysql`.`global_priv` ON + `mysql`.`user`.`User` = `mysql`.`global_priv`.`User` + AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host` +"#; + +pub async fn get_all_database_users_for_unix_user( + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> Result, sqlx::Error> { + sqlx::query_as::<_, DatabaseUser>( + &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"), + ) + .bind(create_user_group_matching_regex(unix_user)) + .fetch_all(connection) + .await +} + +// /// This function fetches a database user if it exists. +// pub async fn get_database_user_for_user( +// username: &str, +// connection: &mut MySqlConnection, +// ) -> anyhow::Result> { +// let user = sqlx::query_as::<_, DatabaseUser>( +// &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), +// ) +// .bind(username) +// .fetch_optional(connection) +// .await?; + +// Ok(user) +// } + +// /// NOTE: It is very critical that this function validates the database name +// /// properly. MySQL does not seem to allow for prepared statements, binding +// /// the database name as a parameter to the query. This means that we have +// /// to validate the database name ourselves to prevent SQL injection.