diff --git a/Cargo.toml b/Cargo.toml index 85cc300..2a7ff9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,11 @@ assets = [ "etc/muscl/config.toml", "644", ], + [ + "assets/debian/group_denylist.txt", + "etc/muscl/group_denylist.txt", + "644", + ], [ "assets/completions/_*", "usr/share/zsh/site-functions/completions/", diff --git a/assets/debian/config.toml b/assets/debian/config.toml index e26c9b9..aafec6f 100644 --- a/assets/debian/config.toml +++ b/assets/debian/config.toml @@ -22,3 +22,6 @@ password_file = "/run/credentials/muscl.service/muscl_mysql_password" # Database connection timeout in seconds timeout = 2 + +[authorization] +group_denylist_file = "/etc/muscl/group_denylist.txt" diff --git a/assets/debian/group_denylist.txt b/assets/debian/group_denylist.txt new file mode 100644 index 0000000..9d37e4b --- /dev/null +++ b/assets/debian/group_denylist.txt @@ -0,0 +1,58 @@ +# These are the default system groups on debian. +# You can alos add groups by gid by prefixing the line with 'gid:'. + +group:adm +group:audio +group:avahi +group:backup +group:bin +group:cdrom +group:crontab +group:daemon +group:dialout +group:dip +group:disk +group:fax +group:floppy +group:games +group:gnats +group:input +group:irc +group:kmem +group:kvm +group:list +group:lp +group:mail +group:man +group:mlocate +group:netdev +group:news +group:nogroup +group:openldap +group:operator +group:plocate +group:plugdev +group:polkitd +group:postgres +group:proxy +group:render +group:root +group:sasl +group:shadow +group:src +group:staff +group:sudo +group:sync +group:sys +group:systemd-journal +group:systemd-network +group:systemd-resolve +group:systemd-timesync +group:tape +group:tty +group:users +group:utmp +group:uucp +group:video +group:voice +group:www-data diff --git a/nix/module.nix b/nix/module.nix index 9b13c6c..723e978 100644 --- a/nix/module.nix +++ b/nix/module.nix @@ -40,6 +40,14 @@ in }; }; + authorization = { + group_denylist = lib.mkOption { + type = with lib.types; nullOr (listOf str); + default = [ "wheel" ]; + description = "List of groups that are denied access"; + }; + }; + mysql = { socket_path = lib.mkOption { type = with lib.types; nullOr path; @@ -81,6 +89,12 @@ in environment.systemPackages = [ cfg.package ]; environment.etc."muscl/config.toml".source = lib.pipe cfg.settings [ + # Handle group_denylist_file + (conf: lib.recursiveUpdate conf { + authorization.group_denylist_file = if (conf.authorization.group_denylist != [ ]) then "/etc/muscl/group-denylist" else null; + authorization.group_denylist = null; + }) + # Remove nulls (lib.filterAttrsRecursive (_: v: v != null)) @@ -95,6 +109,10 @@ in (format.generate "muscl.conf") ]; + environment.etc."muscl/group-denylist" = lib.mkIf (cfg.settings.authorization.group_denylist != [ ]) { + text = lib.concatMapStringsSep "\n" (group: "group:${group}") cfg.settings.authorization.group_denylist; + }; + services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [ { name = cfg.settings.mysql.username; diff --git a/src/client/commands.rs b/src/client/commands.rs index a97f767..7e2041c 100644 --- a/src/client/commands.rs +++ b/src/client/commands.rs @@ -17,16 +17,19 @@ pub use create_user::*; pub use drop_db::*; pub use drop_user::*; pub use edit_privs::*; +use futures_util::SinkExt; +use itertools::Itertools; pub use lock_user::*; pub use passwd_user::*; pub use show_db::*; pub use show_privs::*; pub use show_user::*; +use tokio_stream::StreamExt; pub use unlock_user::*; use clap::Subcommand; -use crate::core::protocol::{ClientToServerMessageStream, Response}; +use crate::core::protocol::{ClientToServerMessageStream, Request, Response}; #[derive(Subcommand, Debug, Clone)] #[command(subcommand_required = true)] @@ -183,3 +186,23 @@ pub fn erroneous_server_response( } } } + +pub async fn print_authorization_owner_hint( + server_connection: &mut ClientToServerMessageStream, +) -> anyhow::Result<()> { + server_connection + .send(Request::ListValidNamePrefixes) + .await?; + + let response = match server_connection.next().await { + Some(Ok(Response::ListValidNamePrefixes(prefixes))) => prefixes, + response => return erroneous_server_response(response), + }; + + println!( + "Note: You are allowed to manage databases and users with the following prefixes:\n{}", + response.into_iter().map(|p| format!(" - {}", p)).join("\n") + ); + + Ok(()) +} diff --git a/src/client/commands/create_db.rs b/src/client/commands/create_db.rs index 91d7294..bba186e 100644 --- a/src/client/commands/create_db.rs +++ b/src/client/commands/create_db.rs @@ -3,11 +3,12 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ protocol::{ - ClientToServerMessageStream, Request, Response, print_create_databases_output_status, - print_create_databases_output_status_json, + ClientToServerMessageStream, CreateDatabaseError, Request, Response, + print_create_databases_output_status, print_create_databases_output_status_json, + request_validation::ValidationError, }, types::MySQLDatabase, }, @@ -40,13 +41,24 @@ pub async fn create_databases( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_create_databases_output_status_json(&result); } else { print_create_databases_output_status(&result); + + if result.iter().any(|(_, res)| { + matches!( + res, + Err(CreateDatabaseError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + Ok(()) } diff --git a/src/client/commands/create_user.rs b/src/client/commands/create_user.rs index 613b913..b4f7461 100644 --- a/src/client/commands/create_user.rs +++ b/src/client/commands/create_user.rs @@ -4,11 +4,15 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::{erroneous_server_response, read_password_from_stdin_with_double_check}, + client::commands::{ + erroneous_server_response, print_authorization_owner_hint, + read_password_from_stdin_with_double_check, + }, core::{ protocol::{ - ClientToServerMessageStream, Request, Response, print_create_users_output_status, - print_create_users_output_status_json, print_set_password_output_status, + ClientToServerMessageStream, CreateUserError, Request, Response, + print_create_users_output_status, print_create_users_output_status_json, + print_set_password_output_status, request_validation::ValidationError, }, types::MySQLUser, }, @@ -55,6 +59,17 @@ pub async fn create_users( } else { print_create_users_output_status(&result); + if result.iter().any(|(_, res)| { + matches!( + res, + Err(CreateUserError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } + let successfully_created_users = result .iter() .filter_map(|(username, result)| result.as_ref().ok().map(|_| username)) diff --git a/src/client/commands/drop_db.rs b/src/client/commands/drop_db.rs index a80abfe..f30619b 100644 --- a/src/client/commands/drop_db.rs +++ b/src/client/commands/drop_db.rs @@ -5,12 +5,13 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_database_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_drop_databases_output_status, - print_drop_databases_output_status_json, + ClientToServerMessageStream, DropDatabaseError, Request, Response, + print_drop_databases_output_status, print_drop_databases_output_status_json, + request_validation::ValidationError, }, types::MySQLDatabase, }, @@ -66,13 +67,24 @@ pub async fn drop_databases( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_drop_databases_output_status_json(&result); } else { print_drop_databases_output_status(&result); + + if result.iter().any(|(_, res)| { + matches!( + res, + Err(DropDatabaseError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } }; + server_connection.send(Request::Exit).await?; + Ok(()) } diff --git a/src/client/commands/drop_user.rs b/src/client/commands/drop_user.rs index fdb8d4d..db70610 100644 --- a/src/client/commands/drop_user.rs +++ b/src/client/commands/drop_user.rs @@ -5,12 +5,13 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_user_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_drop_users_output_status, - print_drop_users_output_status_json, + ClientToServerMessageStream, DropUserError, Request, Response, + print_drop_users_output_status, print_drop_users_output_status_json, + request_validation::ValidationError, }, types::MySQLUser, }, @@ -70,13 +71,24 @@ pub async fn drop_users( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_drop_users_output_status_json(&result); } else { print_drop_users_output_status(&result); + + if result.iter().any(|(_, res)| { + matches!( + res, + Err(DropUserError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + Ok(()) } diff --git a/src/client/commands/edit_privs.rs b/src/client/commands/edit_privs.rs index ac88d08..8924abf 100644 --- a/src/client/commands/edit_privs.rs +++ b/src/client/commands/edit_privs.rs @@ -9,7 +9,7 @@ use nix::unistd::{User, getuid}; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::{mysql_database_completer, mysql_user_completer}, database_privileges::{ @@ -19,8 +19,8 @@ use crate::{ parse_privilege_data_from_editor_content, reduce_privilege_diffs, }, protocol::{ - ClientToServerMessageStream, Request, Response, - print_modify_database_privileges_output_status, + ClientToServerMessageStream, ModifyDatabasePrivilegesError, Request, Response, + print_modify_database_privileges_output_status, request_validation::ValidationError, }, types::{MySQLDatabase, MySQLUser}, }, @@ -219,6 +219,8 @@ pub async fn edit_database_privileges( diff_privileges(&existing_privilege_rows, &privileges_to_change) }; + // TODO: validate authorization before existence + let user_existence_map = users_exist(&mut server_connection, &diffs).await?; let database_existence_map = databases_exist(&mut server_connection, &diffs).await?; @@ -274,6 +276,19 @@ pub async fn edit_database_privileges( print_modify_database_privileges_output_status(&result); + if result.iter().any(|(_, res)| { + matches!( + res, + Err(ModifyDatabasePrivilegesError::UserValidationError( + ValidationError::AuthorizationError(_) + ) | ModifyDatabasePrivilegesError::DatabaseValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } + server_connection.send(Request::Exit).await?; Ok(()) diff --git a/src/client/commands/lock_user.rs b/src/client/commands/lock_user.rs index 4d663f6..2e8af16 100644 --- a/src/client/commands/lock_user.rs +++ b/src/client/commands/lock_user.rs @@ -4,12 +4,13 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_user_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_lock_users_output_status, - print_lock_users_output_status_json, + ClientToServerMessageStream, LockUserError, Request, Response, + print_lock_users_output_status, print_lock_users_output_status_json, + request_validation::ValidationError, }, types::MySQLUser, }, @@ -47,13 +48,24 @@ pub async fn lock_users( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_lock_users_output_status_json(&result); } else { print_lock_users_output_status(&result); + + if result.iter().any(|(_, res)| { + matches!( + res, + Err(LockUserError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + Ok(()) } diff --git a/src/client/commands/passwd_user.rs b/src/client/commands/passwd_user.rs index 1144df9..48b2365 100644 --- a/src/client/commands/passwd_user.rs +++ b/src/client/commands/passwd_user.rs @@ -8,12 +8,12 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_user_completer, protocol::{ - ClientToServerMessageStream, ListUsersError, Request, Response, - print_set_password_output_status, + ClientToServerMessageStream, ListUsersError, Request, Response, SetPasswordError, + print_set_password_output_status, request_validation::ValidationError, }, types::MySQLUser, }, @@ -103,9 +103,18 @@ pub async fn passwd_user( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - print_set_password_output_status(&result, &args.username); + if matches!( + result, + Err(SetPasswordError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) { + print_authorization_owner_hint(&mut server_connection).await? + } + + server_connection.send(Request::Exit).await?; + Ok(()) } diff --git a/src/client/commands/show_db.rs b/src/client/commands/show_db.rs index 3405c40..5712bad 100644 --- a/src/client/commands/show_db.rs +++ b/src/client/commands/show_db.rs @@ -4,12 +4,13 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_database_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_list_databases_output_status, - print_list_databases_output_status_json, + ClientToServerMessageStream, ListDatabasesError, Request, Response, + print_list_databases_output_status, print_list_databases_output_status_json, + request_validation::ValidationError, }, types::MySQLDatabase, }, @@ -60,14 +61,25 @@ pub async fn show_databases( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_list_databases_output_status_json(&databases); } else { print_list_databases_output_status(&databases); + + if databases.iter().any(|(_, res)| { + matches!( + res, + Err(ListDatabasesError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + if args.fail && databases.values().any(|res| res.is_err()) { std::process::exit(1); } diff --git a/src/client/commands/show_privs.rs b/src/client/commands/show_privs.rs index 89dcd2c..db945a2 100644 --- a/src/client/commands/show_privs.rs +++ b/src/client/commands/show_privs.rs @@ -5,12 +5,13 @@ use itertools::Itertools; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_database_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_list_privileges_output_status, - print_list_privileges_output_status_json, + ClientToServerMessageStream, GetDatabasesPrivilegeDataError, Request, Response, + print_list_privileges_output_status, print_list_privileges_output_status_json, + request_validation::ValidationError, }, types::MySQLDatabase, }, @@ -68,14 +69,25 @@ pub async fn show_database_privileges( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_list_privileges_output_status_json(&privilege_data); } else { print_list_privileges_output_status(&privilege_data, args.long); + + if privilege_data.iter().any(|(_, res)| { + matches!( + res, + Err(GetDatabasesPrivilegeDataError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + if args.fail && privilege_data.values().any(|res| res.is_err()) { std::process::exit(1); } diff --git a/src/client/commands/show_user.rs b/src/client/commands/show_user.rs index 4a5656e..b723597 100644 --- a/src/client/commands/show_user.rs +++ b/src/client/commands/show_user.rs @@ -4,12 +4,13 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_user_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_list_users_output_status, - print_list_users_output_status_json, + ClientToServerMessageStream, ListUsersError, Request, Response, + print_list_users_output_status, print_list_users_output_status_json, + request_validation::ValidationError, }, types::MySQLUser, }, @@ -63,14 +64,25 @@ pub async fn show_users( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_list_users_output_status_json(&users); } else { print_list_users_output_status(&users); + + if users.iter().any(|(_, res)| { + matches!( + res, + Err(ListUsersError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + if args.fail && users.values().any(|result| result.is_err()) { std::process::exit(1); } diff --git a/src/client/commands/unlock_user.rs b/src/client/commands/unlock_user.rs index 0be189c..fe5ead0 100644 --- a/src/client/commands/unlock_user.rs +++ b/src/client/commands/unlock_user.rs @@ -4,12 +4,13 @@ use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ - client::commands::erroneous_server_response, + client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ completion::mysql_user_completer, protocol::{ - ClientToServerMessageStream, Request, Response, print_unlock_users_output_status, - print_unlock_users_output_status_json, + ClientToServerMessageStream, Request, Response, UnlockUserError, + print_unlock_users_output_status, print_unlock_users_output_status_json, + request_validation::ValidationError, }, types::MySQLUser, }, @@ -47,13 +48,24 @@ pub async fn unlock_users( response => return erroneous_server_response(response), }; - server_connection.send(Request::Exit).await?; - if args.json { print_unlock_users_output_status_json(&result); } else { print_unlock_users_output_status(&result); + + if result.iter().any(|(_, res)| { + matches!( + res, + Err(UnlockUserError::ValidationError( + ValidationError::AuthorizationError(_) + )) + ) + }) { + print_authorization_owner_hint(&mut server_connection).await? + } } + server_connection.send(Request::Exit).await?; + Ok(()) } diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index 9b9fee4..9189da8 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -9,10 +9,12 @@ use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock}; use tracing_subscriber::prelude::*; use crate::{ - core::common::{ - DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode, + core::{ + common::{DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode}, + protocol::request_validation::GroupDenylist, }, server::{ + authorization::read_and_parse_group_denylist, config::{MysqlConfig, ServerConfig}, landlock::landlock_restrict_server, session_handler, @@ -270,6 +272,13 @@ fn run_forked_server( let config = ServerConfig::read_config_from_path(&config_path) .context("Failed to read server config in forked process")?; + let group_denylist = if let Some(denylist_path) = &config.authorization.group_denylist_file { + read_and_parse_group_denylist(denylist_path) + .context("Failed to read and parse group denylist")? + } else { + GroupDenylist::new() + }; + let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -292,6 +301,7 @@ fn run_forked_server( &unix_user, db_pool, db_is_mariadb, + &group_denylist, ) .await?; Ok(()) diff --git a/src/core/common.rs b/src/core/common.rs index 90f217a..7e1a3f2 100644 --- a/src/core/common.rs +++ b/src/core/common.rs @@ -99,10 +99,10 @@ impl UnixUser { }) } - pub fn from_enviroment() -> anyhow::Result { - let libc_uid = nix::unistd::getuid(); - UnixUser::from_uid(libc_uid.as_raw()) - } + // pub fn from_enviroment() -> anyhow::Result { + // let libc_uid = nix::unistd::getuid(); + // UnixUser::from_uid(libc_uid.as_raw()) + // } } #[inline] diff --git a/src/core/protocol/request_validation.rs b/src/core/protocol/request_validation.rs index fae3dd7..acc37e5 100644 --- a/src/core/protocol/request_validation.rs +++ b/src/core/protocol/request_validation.rs @@ -1,5 +1,7 @@ +use std::collections::HashSet; + use indoc::indoc; -use itertools::Itertools; +use nix::{libc::gid_t, unistd::Group}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -23,23 +25,19 @@ impl NameValidationError { pub fn to_error_message(self, db_or_user: DbOrUser) -> String { match self { NameValidationError::EmptyString => { - format!("{} name cannot be empty.", db_or_user.capitalized_noun()).to_owned() + format!("{} name can not be empty.", db_or_user.capitalized_noun()) } NameValidationError::TooLong => format!( - "{} is too long. Maximum length is 64 characters.", + "{} is too long, maximum length is 64 characters.", db_or_user.capitalized_noun() - ) - .to_owned(), + ), NameValidationError::InvalidCharacters => format!( indoc! {r#" - Invalid characters in {} name: '{}' - - Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. + Invalid characters in {} name: '{}', only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. "#}, db_or_user.lowercased_noun(), db_or_user.name(), - ) - .to_owned(), + ), } } @@ -54,64 +52,41 @@ impl NameValidationError { #[derive(Error, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub enum AuthorizationError { - #[error("No matching owner prefix found")] - NoMatch, + #[error("Illegal prefix, user is not authorized to manage this resource")] + IllegalPrefix, // TODO: I don't think this should ever happen? #[error("Name cannot be empty")] StringEmpty, + + #[error("Group was found in denylist")] + DenylistError, } impl AuthorizationError { pub fn to_error_message(self, db_or_user: DbOrUser) -> String { - let user = UnixUser::from_enviroment(); - - let UnixUser { - username, - mut groups, - } = user.unwrap_or(UnixUser { - username: "???".to_string(), - groups: vec![], - }); - - groups.sort(); - match self { - AuthorizationError::NoMatch => format!( - indoc! {r#" - Invalid {} name prefix: '{}' does not match your username or any of your groups. - Are you sure you are allowed to create {} names with this prefix? - The format should be: _<{} name> - - Allowed prefixes: - - {} - {} - "#}, + AuthorizationError::IllegalPrefix => format!( + "Illegal {} name prefix: you are not allowed to manage databases or users prefixed with '{}'", db_or_user.lowercased_noun(), - db_or_user.name(), - db_or_user.lowercased_noun(), - db_or_user.lowercased_noun(), - username, - groups - .into_iter() - .filter(|g| g != &username) - .map(|g| format!(" - {}", g)) - .join("\n"), + db_or_user.prefix(), ) .to_owned(), - AuthorizationError::StringEmpty => format!( - "'{}' is not a valid {} name.", - db_or_user.name(), - db_or_user.lowercased_noun() - ) - .to_string(), + // TODO: This error message could be clearer + AuthorizationError::StringEmpty => { + format!("{} name can not be empty.", db_or_user.capitalized_noun()) + } + AuthorizationError::DenylistError => { + format!("'{}' is denied by the group denylist", db_or_user.name()) + } } } pub fn error_type(&self) -> &'static str { match self { - AuthorizationError::NoMatch => "no-match", + AuthorizationError::IllegalPrefix => "illegal-prefix", AuthorizationError::StringEmpty => "string-empty", + AuthorizationError::DenylistError => "denylist-error", } } } @@ -155,6 +130,8 @@ impl ValidationError { } } +pub type GroupDenylist = HashSet; + const MAX_NAME_LENGTH: usize = 64; pub fn validate_name(name: &str) -> Result<(), NameValidationError> { @@ -201,21 +178,49 @@ pub fn validate_authorization_by_prefixes( .collect::>() .is_empty() { - return Err(AuthorizationError::NoMatch); + return Err(AuthorizationError::IllegalPrefix); }; Ok(()) } +pub fn validate_authorization_by_group_denylist( + name: &str, + user: &UnixUser, + group_denylist: &GroupDenylist, +) -> Result<(), AuthorizationError> { + // NOTE: if the username matches, we allow it regardless of denylist + if user.username == name { + return Ok(()); + } + + let user_group = Group::from_name(name) + .ok() + .flatten() + .map(|g| g.gid.as_raw()); + + if let Some(gid) = user_group + && group_denylist.contains(&gid) + { + Err(AuthorizationError::DenylistError) + } else { + Ok(()) + } +} + pub fn validate_db_or_user_request( db_or_user: &DbOrUser, unix_user: &UnixUser, + group_denylist: &GroupDenylist, ) -> Result<(), ValidationError> { validate_name(db_or_user.name()).map_err(ValidationError::NameValidationError)?; validate_authorization_by_unix_user(db_or_user.name(), unix_user) .map_err(ValidationError::AuthorizationError)?; + validate_authorization_by_group_denylist(db_or_user.name(), unix_user, group_denylist) + .map_err(ValidationError::AuthorizationError)?; + Ok(()) } @@ -273,7 +278,7 @@ mod tests { assert_eq!( validate_authorization_by_prefixes("nonexistent_testdb", &prefixes), - Err(AuthorizationError::NoMatch) + Err(AuthorizationError::IllegalPrefix) ); } } diff --git a/src/core/types.rs b/src/core/types.rs index a466881..fd8a2ca 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -132,4 +132,11 @@ impl DbOrUser { DbOrUser::User(user) => user.as_str(), } } + + pub fn prefix(&self) -> &str { + match self { + DbOrUser::Database(db) => db.split('_').next().unwrap_or("?"), + DbOrUser::User(user) => user.split('_').next().unwrap_or("?"), + } + } } diff --git a/src/server.rs b/src/server.rs index 248c22c..aa7d712 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,4 @@ -mod authorization; +pub mod authorization; pub mod command; mod common; pub mod config; diff --git a/src/server/authorization.rs b/src/server/authorization.rs index 58864aa..0d4fe14 100644 --- a/src/server/authorization.rs +++ b/src/server/authorization.rs @@ -1,25 +1,127 @@ +use std::{collections::HashSet, path::Path}; + +use anyhow::Context; +use nix::unistd::Group; + use crate::core::{ common::UnixUser, - protocol::{CheckAuthorizationError, request_validation::validate_db_or_user_request}, + protocol::{ + CheckAuthorizationError, + request_validation::{GroupDenylist, validate_db_or_user_request}, + }, types::DbOrUser, }; pub async fn check_authorization( dbs_or_users: Vec, unix_user: &UnixUser, + group_denylist: &GroupDenylist, ) -> std::collections::BTreeMap> { let mut results = std::collections::BTreeMap::new(); for db_or_user in dbs_or_users { - if let Err(err) = - validate_db_or_user_request(&db_or_user, unix_user).map_err(CheckAuthorizationError) + if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist) + .map_err(CheckAuthorizationError) { results.insert(db_or_user.clone(), Err(err)); continue; } - results.insert(db_or_user.clone(), Ok(())); } results } + +/// Reads and parses a group denylist file, returning a set of GUIDs +/// +/// The format of the denylist file is expected to be one group name or GID per line. +/// Lines starting with '#' are treated as comments and ignored. +/// Empty lines are also ignored. +/// +/// Each line looks like one of the following: +/// - `gid:1001` +/// - `group:admins` +pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result { + let content = std::fs::read_to_string(denylist_path).context(format!( + "Failed to read denylist file at {:?}", + denylist_path + ))?; + + let mut groups = HashSet::with_capacity(content.lines().count()); + + for (line_number, line) in content.lines().enumerate() { + let trimmed_line = line.trim(); + + if trimmed_line.is_empty() || trimmed_line.starts_with('#') { + continue; + } + + let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect(); + if parts.len() != 2 { + anyhow::bail!( + "Invalid format in denylist file at {:?} on line {}: {}", + denylist_path, + line_number + 1, + line + ); + } + + match parts[0] { + "gid" => { + let gid: u32 = parts[1].parse().with_context(|| { + format!( + "Invalid GID in denylist file at {:?} on line {}: {}", + denylist_path, + line_number + 1, + parts[1] + ) + })?; + let group = Group::from_gid(nix::unistd::Gid::from_raw(gid)) + .context(format!( + "Failed to get group for GID {} in denylist file at {:?} on line {}", + gid, + denylist_path, + line_number + 1 + ))? + .ok_or_else(|| { + anyhow::anyhow!( + "No group found for GID {} in denylist file at {:?} on line {}", + gid, + denylist_path, + line_number + 1 + ) + })?; + groups.insert(group.gid.as_raw()); + } + "group" => { + let group = Group::from_name(parts[1]) + .context(format!( + "Failed to get group for name '{}' in denylist file at {:?} on line {}", + parts[1], + denylist_path, + line_number + 1 + ))? + .ok_or_else(|| { + anyhow::anyhow!( + "No group found for name '{}' in denylist file at {:?} on line {}", + parts[1], + denylist_path, + line_number + 1 + ) + })?; + groups.insert(group.gid.as_raw()); + } + _ => { + anyhow::bail!( + "Invalid prefix '{}' in denylist file at {:?} on line {}: {}", + parts[0], + denylist_path, + line_number + 1, + line + ); + } + } + } + + Ok(groups) +} diff --git a/src/server/common.rs b/src/server/common.rs index 0ddf1e5..946e1b5 100644 --- a/src/server/common.rs +++ b/src/server/common.rs @@ -1,13 +1,37 @@ -use crate::core::common::UnixUser; +use crate::core::{common::UnixUser, protocol::request_validation::GroupDenylist}; +use nix::unistd::Group; use sqlx::prelude::*; +/// This function retrieves the groups of a user, filtering out any groups +/// that are present in the provided denylist. +pub fn get_user_filtered_groups(user: &UnixUser, group_denylist: &GroupDenylist) -> Vec { + user.groups + .iter() + .cloned() + .filter_map(|group_name| { + match Group::from_name(&group_name) { + Ok(Some(group)) => { + if group_denylist.contains(&group.gid.as_raw()) { + None + } else { + Some(group.name) + } + } + // NOTE: allow non-existing groups to pass through the filter + _ => Some(group_name), + } + }) + .collect() +} + /// This function creates a regex that matches items (users, databases) /// that belong to the user or any of the user's groups. -pub fn create_user_group_matching_regex(user: &UnixUser) -> String { - if user.groups.is_empty() { +pub fn create_user_group_matching_regex(user: &UnixUser, group_denylist: &GroupDenylist) -> String { + let filtered_groups = get_user_filtered_groups(user, group_denylist); + if filtered_groups.is_empty() { format!("{}_.+", user.username) } else { - format!("({}|{})_.+", user.username, user.groups.join("|")) + format!("({}|{})_.+", user.username, filtered_groups.join("|")) } } @@ -37,7 +61,8 @@ mod tests { groups: vec!["group1".to_owned(), "group2".to_owned()], }; - let regex = create_user_group_matching_regex(&user); + let regex = create_user_group_matching_regex(&user, &GroupDenylist::new()); + println!("Generated regex: {}", regex); let re = Regex::new(®ex).unwrap(); assert!(re.is_match("user_something")); diff --git a/src/server/config.rs b/src/server/config.rs index 6ed0f20..22cc351 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -78,9 +78,15 @@ impl MysqlConfig { } } +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +pub struct AuthorizationConfig { + pub group_denylist_file: Option, +} + #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct ServerConfig { pub socket_path: Option, + pub authorization: AuthorizationConfig, pub mysql: MysqlConfig, } diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index 8302d03..ea15a85 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -11,11 +11,12 @@ use crate::{ common::UnixUser, protocol::{ Request, Response, ServerToClientMessageStream, SetPasswordError, - create_server_to_client_message_stream, + create_server_to_client_message_stream, request_validation::GroupDenylist, }, }, server::{ authorization::check_authorization, + common::get_user_filtered_groups, sql::{ database_operations::{ complete_database_name, create_databases, drop_databases, @@ -39,6 +40,7 @@ pub async fn session_handler( socket: UnixStream, db_pool: Arc>, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> anyhow::Result<()> { let uid = match socket.peer_cred() { Ok(cred) => cred.uid(), @@ -85,8 +87,14 @@ pub async fn session_handler( (async move { tracing::info!("Accepted connection from user: {}", unix_user); - let result = - session_handler_with_unix_user(socket, &unix_user, db_pool, db_is_mariadb).await; + let result = session_handler_with_unix_user( + socket, + &unix_user, + db_pool, + db_is_mariadb, + group_denylist, + ) + .await; tracing::info!( "Finished handling requests for connection from user: {}", @@ -104,6 +112,7 @@ pub async fn session_handler_with_unix_user( unix_user: &UnixUser, db_pool: Arc>, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> anyhow::Result<()> { let mut message_stream = create_server_to_client_message_stream(socket); @@ -131,6 +140,7 @@ pub async fn session_handler_with_unix_user( unix_user, &mut db_connection, db_is_mariadb, + group_denylist, ) .await; @@ -147,6 +157,7 @@ async fn session_handler_with_db_connection( unix_user: &UnixUser, db_connection: &mut MySqlConnection, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> anyhow::Result<()> { stream.send(Response::Ready).await?; loop { @@ -178,18 +189,14 @@ async fn session_handler_with_db_connection( let response = match request { Request::CheckAuthorization(dbs_or_users) => { - let result = check_authorization(dbs_or_users, unix_user).await; + let result = check_authorization(dbs_or_users, unix_user, group_denylist).await; Response::CheckAuthorization(result) } Request::ListValidNamePrefixes => { let mut result = Vec::with_capacity(unix_user.groups.len() + 1); result.push(unix_user.username.to_owned()); - for group in unix_user - .groups - .iter() - .filter(|x| *x != &unix_user.username) - { + for group in get_user_filtered_groups(unix_user, group_denylist) { result.push(group.to_owned()); } @@ -208,6 +215,7 @@ async fn session_handler_with_db_connection( unix_user, db_connection, db_is_mariadb, + group_denylist, ) .await; Response::CompleteDatabaseName(result) @@ -226,32 +234,54 @@ async fn session_handler_with_db_connection( unix_user, db_connection, db_is_mariadb, + group_denylist, ) .await; Response::CompleteUserName(result) } } Request::CreateDatabases(databases_names) => { - let result = - create_databases(databases_names, unix_user, db_connection, db_is_mariadb) - .await; + let result = create_databases( + databases_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::CreateDatabases(result) } Request::DropDatabases(databases_names) => { - let result = - drop_databases(databases_names, unix_user, db_connection, db_is_mariadb).await; + let result = drop_databases( + databases_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::DropDatabases(result) } Request::ListDatabases(database_names) => match database_names { Some(database_names) => { - let result = - list_databases(database_names, unix_user, db_connection, db_is_mariadb) - .await; + let result = list_databases( + database_names, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::ListDatabases(result) } None => { - let result = - list_all_databases_for_user(unix_user, db_connection, db_is_mariadb).await; + let result = list_all_databases_for_user( + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::ListAllDatabases(result) } }, @@ -262,13 +292,19 @@ async fn session_handler_with_db_connection( unix_user, db_connection, db_is_mariadb, + group_denylist, ) .await; Response::ListPrivileges(privilege_data) } None => { - let privilege_data = - get_all_database_privileges(unix_user, db_connection, db_is_mariadb).await; + let privilege_data = get_all_database_privileges( + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::ListAllPrivileges(privilege_data) } }, @@ -278,18 +314,31 @@ async fn session_handler_with_db_connection( unix_user, db_connection, db_is_mariadb, + group_denylist, ) .await; Response::ModifyPrivileges(result) } Request::CreateUsers(db_users) => { - let result = - create_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; + let result = create_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::CreateUsers(result) } Request::DropUsers(db_users) => { - let result = - drop_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; + let result = drop_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::DropUsers(result) } Request::PasswdUser((db_user, password)) => { @@ -299,15 +348,21 @@ async fn session_handler_with_db_connection( unix_user, db_connection, db_is_mariadb, + group_denylist, ) .await; Response::SetUserPassword(result) } Request::ListUsers(db_users) => match db_users { Some(db_users) => { - let result = - list_database_users(db_users, unix_user, db_connection, db_is_mariadb) - .await; + let result = list_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::ListUsers(result) } None => { @@ -315,19 +370,32 @@ async fn session_handler_with_db_connection( unix_user, db_connection, db_is_mariadb, + group_denylist, ) .await; Response::ListAllUsers(result) } }, Request::LockUsers(db_users) => { - let result = - lock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; + let result = lock_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::LockUsers(result) } Request::UnlockUsers(db_users) => { - let result = - unlock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; + let result = unlock_database_users( + db_users, + unix_user, + db_connection, + db_is_mariadb, + group_denylist, + ) + .await; Response::UnlockUsers(result) } Request::Exit => { diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index b38574e..3e3ea49 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -6,6 +6,7 @@ use sqlx::prelude::*; use serde::{Deserialize, Serialize}; use crate::core::protocol::CompleteDatabaseNameResponse; +use crate::core::protocol::request_validation::GroupDenylist; use crate::core::protocol::request_validation::validate_db_or_user_request; use crate::core::types::DbOrUser; use crate::core::types::MySQLDatabase; @@ -49,6 +50,7 @@ pub async fn complete_database_name( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> CompleteDatabaseNameResponse { let result = sqlx::query( r#" @@ -59,7 +61,7 @@ pub async fn complete_database_name( AND `SCHEMA_NAME` LIKE ? "#, ) - .bind(create_user_group_matching_regex(unix_user)) + .bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(format!("{}%", database_prefix)) .fetch_all(connection) .await; @@ -89,13 +91,17 @@ pub async fn create_databases( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> CreateDatabasesResponse { let mut results = BTreeMap::new(); for database_name in database_names { - if let Err(err) = - validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) - .map_err(CreateDatabaseError::ValidationError) + if let Err(err) = validate_db_or_user_request( + &DbOrUser::Database(database_name.clone()), + unix_user, + group_denylist, + ) + .map_err(CreateDatabaseError::ValidationError) { results.insert(database_name.to_owned(), Err(err)); continue; @@ -141,13 +147,17 @@ pub async fn drop_databases( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> DropDatabasesResponse { let mut results = BTreeMap::new(); for database_name in database_names { - if let Err(err) = - validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) - .map_err(DropDatabaseError::ValidationError) + if let Err(err) = validate_db_or_user_request( + &DbOrUser::Database(database_name.clone()), + unix_user, + group_denylist, + ) + .map_err(DropDatabaseError::ValidationError) { results.insert(database_name.to_owned(), Err(err)); continue; @@ -236,13 +246,17 @@ pub async fn list_databases( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ListDatabasesResponse { let mut results = BTreeMap::new(); for database_name in database_names { - if let Err(err) = - validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) - .map_err(ListDatabasesError::ValidationError) + if let Err(err) = validate_db_or_user_request( + &DbOrUser::Database(database_name.clone()), + unix_user, + group_denylist, + ) + .map_err(ListDatabasesError::ValidationError) { results.insert(database_name.to_owned(), Err(err)); continue; @@ -296,6 +310,7 @@ pub async fn list_all_databases_for_user( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ListAllDatabasesResponse { let result = sqlx::query_as::<_, DatabaseRow>( r#" @@ -319,7 +334,7 @@ pub async fn list_all_databases_for_user( GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME` "#, ) - .bind(create_user_group_matching_regex(unix_user)) + .bind(create_user_group_matching_regex(unix_user, group_denylist)) .fetch_all(connection) .await .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string())); diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index d1fb49b..ed2e030 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -31,7 +31,7 @@ use crate::{ DiffDoesNotApplyError, GetAllDatabasesPrivilegeDataError, GetDatabasesPrivilegeDataError, ListAllPrivilegesResponse, ListPrivilegesResponse, ModifyDatabasePrivilegesError, ModifyPrivilegesResponse, - request_validation::validate_db_or_user_request, + request_validation::{GroupDenylist, validate_db_or_user_request}, }, types::{DbOrUser, MySQLDatabase, MySQLUser}, }, @@ -143,13 +143,17 @@ pub async fn get_databases_privilege_data( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ListPrivilegesResponse { let mut results = BTreeMap::new(); for database_name in database_names.iter() { - if let Err(err) = - validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) - .map_err(GetDatabasesPrivilegeDataError::ValidationError) + if let Err(err) = validate_db_or_user_request( + &DbOrUser::Database(database_name.clone()), + unix_user, + group_denylist, + ) + .map_err(GetDatabasesPrivilegeDataError::ValidationError) { results.insert(database_name.to_owned(), Err(err)); continue; @@ -200,9 +204,10 @@ pub async fn get_all_database_privileges( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ListAllPrivilegesResponse { let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query()) - .bind(create_user_group_matching_regex(unix_user)) + .bind(create_user_group_matching_regex(unix_user, group_denylist)) .fetch_all(connection) .await .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())); @@ -397,6 +402,7 @@ pub async fn apply_privilege_diffs( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ModifyPrivilegesResponse { let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new(); @@ -408,6 +414,7 @@ pub async fn apply_privilege_diffs( if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(diff.get_database_name().to_owned()), unix_user, + group_denylist, ) .map_err(ModifyDatabasePrivilegesError::UserValidationError) { @@ -415,9 +422,12 @@ pub async fn apply_privilege_diffs( continue; } - if let Err(err) = - validate_db_or_user_request(&DbOrUser::User(diff.get_user_name().to_owned()), unix_user) - .map_err(ModifyDatabasePrivilegesError::UserValidationError) + if let Err(err) = validate_db_or_user_request( + &DbOrUser::User(diff.get_user_name().to_owned()), + unix_user, + group_denylist, + ) + .map_err(ModifyDatabasePrivilegesError::UserValidationError) { results.insert(key, Err(err)); continue; diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 0ef2cf8..b7c298c 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use sqlx::MySqlConnection; use sqlx::prelude::*; +use crate::core::protocol::request_validation::GroupDenylist; use crate::core::protocol::request_validation::validate_db_or_user_request; use crate::core::types::DbOrUser; use crate::{ @@ -58,6 +59,7 @@ pub async fn complete_user_name( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> Vec { let result = sqlx::query( r#" @@ -67,7 +69,7 @@ pub async fn complete_user_name( AND `User` LIKE ? "#, ) - .bind(create_user_group_matching_regex(unix_user)) + .bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(format!("{}%", user_prefix)) .fetch_all(connection) .await; @@ -97,12 +99,14 @@ pub async fn create_database_users( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> CreateUsersResponse { let mut results = BTreeMap::new(); for db_user in db_users { - if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) - .map_err(CreateUserError::ValidationError) + if let Err(err) = + validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) + .map_err(CreateUserError::ValidationError) { results.insert(db_user, Err(err)); continue; @@ -141,12 +145,14 @@ pub async fn drop_database_users( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> DropUsersResponse { let mut results = BTreeMap::new(); for db_user in db_users { - if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) - .map_err(DropUserError::ValidationError) + if let Err(err) = + validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) + .map_err(DropUserError::ValidationError) { results.insert(db_user, Err(err)); continue; @@ -186,8 +192,9 @@ pub async fn set_password_for_database_user( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> SetUserPasswordResponse { - validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) + validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) .map_err(SetPasswordError::ValidationError)?; match unsafe_user_exists(db_user, &mut *connection).await { @@ -269,12 +276,14 @@ pub async fn lock_database_users( unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> LockUsersResponse { let mut results = BTreeMap::new(); for db_user in db_users { - if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) - .map_err(LockUserError::ValidationError) + if let Err(err) = + validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) + .map_err(LockUserError::ValidationError) { results.insert(db_user, Err(err)); continue; @@ -327,12 +336,14 @@ pub async fn unlock_database_users( unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> UnlockUsersResponse { let mut results = BTreeMap::new(); for db_user in db_users { - if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) - .map_err(UnlockUserError::ValidationError) + if let Err(err) = + validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) + .map_err(UnlockUserError::ValidationError) { results.insert(db_user, Err(err)); continue; @@ -433,12 +444,14 @@ pub async fn list_database_users( unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ListUsersResponse { let mut results = BTreeMap::new(); for db_user in db_users { - if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) - .map_err(ListUsersError::ValidationError) + if let Err(err) = + validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) + .map_err(ListUsersError::ValidationError) { results.insert(db_user, Err(err)); continue; @@ -477,6 +490,7 @@ pub async fn list_all_database_users_for_unix_user( unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, + group_denylist: &GroupDenylist, ) -> ListAllUsersResponse { let mut result = sqlx::query_as::<_, DatabaseUser>( &(if db_is_mariadb { @@ -485,7 +499,7 @@ pub async fn list_all_database_users_for_unix_user( DB_USER_SELECT_STATEMENT_MYSQL.to_string() } + "WHERE `user`.`User` REGEXP ?"), ) - .bind(create_user_group_matching_regex(unix_user)) + .bind(create_user_group_matching_regex(unix_user, group_denylist)) .fetch_all(&mut *connection) .await .map_err(|err| ListAllUsersError::MySqlError(err.to_string())); diff --git a/src/server/supervisor.rs b/src/server/supervisor.rs index e117159..fdc39b2 100644 --- a/src/server/supervisor.rs +++ b/src/server/supervisor.rs @@ -17,9 +17,13 @@ use tokio::{ }; use tokio_util::{sync::CancellationToken, task::TaskTracker}; -use crate::server::{ - config::{MysqlConfig, ServerConfig}, - session_handler::session_handler, +use crate::{ + core::protocol::request_validation::GroupDenylist, + server::{ + authorization::read_and_parse_group_denylist, + config::{MysqlConfig, ServerConfig}, + session_handler::session_handler, + }, }; #[derive(Clone, Debug)] @@ -36,6 +40,7 @@ pub struct ReloadEvent; pub struct Supervisor { config_path: PathBuf, config: Arc>, + group_deny_list: Arc>, systemd_mode: bool, shutdown_cancel_token: CancellationToken, @@ -66,6 +71,23 @@ impl Supervisor { let config = ServerConfig::read_config_from_path(&config_path) .context("Failed to read server configuration")?; + let group_deny_list = match &config.authorization.group_denylist_file { + Some(denylist_path) => { + let denylist = read_and_parse_group_denylist(denylist_path) + .context("Failed to read group denylist file")?; + tracing::debug!( + "Loaded group denylist with {} entries from {:?}", + denylist.len(), + denylist_path + ); + Arc::new(RwLock::new(denylist)) + } + None => { + tracing::debug!("No group denylist file specified, proceeding without a denylist"); + Arc::new(RwLock::new(GroupDenylist::new())) + } + }; + let mut watchdog_duration = None; let mut watchdog_micro_seconds = 0; #[cfg(target_os = "linux")] @@ -148,12 +170,14 @@ impl Supervisor { db_connection_pool.clone(), rx, db_is_mariadb.clone(), + group_deny_list.clone(), )) }; Ok(Self { config_path, config: Arc::new(Mutex::new(config)), + group_deny_list, systemd_mode, reload_message_receiver: reload_rx, shutdown_cancel_token, @@ -196,6 +220,26 @@ impl Supervisor { .context("Failed to read server configuration")?; let mut config = self.config.clone().lock_owned().await; *config = new_config; + + let group_deny_list = match &config.authorization.group_denylist_file { + Some(denylist_path) => { + let denylist = read_and_parse_group_denylist(denylist_path) + .context("Failed to read group denylist file")?; + + tracing::debug!( + "Loaded group denylist with {} entries from {:?}", + denylist.len(), + denylist_path + ); + denylist + } + None => { + tracing::debug!("No group denylist file specified, proceeding without a denylist"); + GroupDenylist::new() + } + }; + let mut group_deny_list_lock = self.group_deny_list.write().await; + *group_deny_list_lock = group_deny_list; Ok(()) } @@ -502,6 +546,7 @@ async fn listener_task( db_pool: Arc>, mut supervisor_message_receiver: broadcast::Receiver, db_is_mariadb: Arc>, + group_denylist: Arc>, ) -> anyhow::Result<()> { #[cfg(target_os = "linux")] sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; @@ -539,8 +584,14 @@ async fn listener_task( let db_pool_clone = db_pool.clone(); let db_is_mariadb_clone = *db_is_mariadb.read().await; + let group_denylist_arc_clone = group_denylist.clone(); task_tracker.spawn(async move { - match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await { + match session_handler( + conn, + db_pool_clone, + db_is_mariadb_clone, + &*group_denylist_arc_clone.read().await, + ).await { Ok(()) => {} Err(e) => { tracing::error!("Failed to run server: {}", e);