Implement denylists
All checks were successful
Build and test / check-license (push) Successful in 1m38s
Build and test / check (push) Successful in 1m51s
Build and test / build (push) Successful in 2m40s
Build and test / test (push) Successful in 4m25s
Build and test / docs (push) Successful in 6m1s

This commit is contained in:
2025-12-15 15:17:37 +09:00
parent 45cefb8af4
commit 8b4d549e18
29 changed files with 743 additions and 188 deletions

View File

@@ -132,6 +132,11 @@ assets = [
"etc/muscl/config.toml",
"644",
],
[
"assets/debian/group_denylist.txt",
"etc/muscl/group_denylist.txt",
"644",
],
[
"assets/completions/_*",
"usr/share/zsh/site-functions/completions/",

View File

@@ -22,3 +22,6 @@ password_file = "/run/credentials/muscl.service/muscl_mysql_password"
# Database connection timeout in seconds
timeout = 2
[authorization]
group_denylist_file = "/etc/muscl/group_denylist.txt"

View File

@@ -0,0 +1,58 @@
# These are the default system groups on debian.
# You can alos add groups by gid by prefixing the line with 'gid:'.
group:adm
group:audio
group:avahi
group:backup
group:bin
group:cdrom
group:crontab
group:daemon
group:dialout
group:dip
group:disk
group:fax
group:floppy
group:games
group:gnats
group:input
group:irc
group:kmem
group:kvm
group:list
group:lp
group:mail
group:man
group:mlocate
group:netdev
group:news
group:nogroup
group:openldap
group:operator
group:plocate
group:plugdev
group:polkitd
group:postgres
group:proxy
group:render
group:root
group:sasl
group:shadow
group:src
group:staff
group:sudo
group:sync
group:sys
group:systemd-journal
group:systemd-network
group:systemd-resolve
group:systemd-timesync
group:tape
group:tty
group:users
group:utmp
group:uucp
group:video
group:voice
group:www-data

View File

@@ -40,6 +40,14 @@ in
};
};
authorization = {
group_denylist = lib.mkOption {
type = with lib.types; nullOr (listOf str);
default = [ "wheel" ];
description = "List of groups that are denied access";
};
};
mysql = {
socket_path = lib.mkOption {
type = with lib.types; nullOr path;
@@ -81,6 +89,12 @@ in
environment.systemPackages = [ cfg.package ];
environment.etc."muscl/config.toml".source = lib.pipe cfg.settings [
# Handle group_denylist_file
(conf: lib.recursiveUpdate conf {
authorization.group_denylist_file = if (conf.authorization.group_denylist != [ ]) then "/etc/muscl/group-denylist" else null;
authorization.group_denylist = null;
})
# Remove nulls
(lib.filterAttrsRecursive (_: v: v != null))
@@ -95,6 +109,10 @@ in
(format.generate "muscl.conf")
];
environment.etc."muscl/group-denylist" = lib.mkIf (cfg.settings.authorization.group_denylist != [ ]) {
text = lib.concatMapStringsSep "\n" (group: "group:${group}") cfg.settings.authorization.group_denylist;
};
services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [
{
name = cfg.settings.mysql.username;

View File

@@ -17,16 +17,19 @@ pub use create_user::*;
pub use drop_db::*;
pub use drop_user::*;
pub use edit_privs::*;
use futures_util::SinkExt;
use itertools::Itertools;
pub use lock_user::*;
pub use passwd_user::*;
pub use show_db::*;
pub use show_privs::*;
pub use show_user::*;
use tokio_stream::StreamExt;
pub use unlock_user::*;
use clap::Subcommand;
use crate::core::protocol::{ClientToServerMessageStream, Response};
use crate::core::protocol::{ClientToServerMessageStream, Request, Response};
#[derive(Subcommand, Debug, Clone)]
#[command(subcommand_required = true)]
@@ -183,3 +186,23 @@ pub fn erroneous_server_response(
}
}
}
pub async fn print_authorization_owner_hint(
server_connection: &mut ClientToServerMessageStream,
) -> anyhow::Result<()> {
server_connection
.send(Request::ListValidNamePrefixes)
.await?;
let response = match server_connection.next().await {
Some(Ok(Response::ListValidNamePrefixes(prefixes))) => prefixes,
response => return erroneous_server_response(response),
};
println!(
"Note: You are allowed to manage databases and users with the following prefixes:\n{}",
response.into_iter().map(|p| format!(" - {}", p)).join("\n")
);
Ok(())
}

View File

@@ -3,11 +3,12 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
protocol::{
ClientToServerMessageStream, Request, Response, print_create_databases_output_status,
print_create_databases_output_status_json,
ClientToServerMessageStream, CreateDatabaseError, Request, Response,
print_create_databases_output_status, print_create_databases_output_status_json,
request_validation::ValidationError,
},
types::MySQLDatabase,
},
@@ -40,13 +41,24 @@ pub async fn create_databases(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_create_databases_output_status_json(&result);
} else {
print_create_databases_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(CreateDatabaseError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
Ok(())
}

View File

@@ -4,11 +4,15 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
client::commands::{
erroneous_server_response, print_authorization_owner_hint,
read_password_from_stdin_with_double_check,
},
core::{
protocol::{
ClientToServerMessageStream, Request, Response, print_create_users_output_status,
print_create_users_output_status_json, print_set_password_output_status,
ClientToServerMessageStream, CreateUserError, Request, Response,
print_create_users_output_status, print_create_users_output_status_json,
print_set_password_output_status, request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -55,6 +59,17 @@ pub async fn create_users(
} else {
print_create_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(CreateUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
let successfully_created_users = result
.iter()
.filter_map(|(username, result)| result.as_ref().ok().map(|_| username))

View File

@@ -5,12 +5,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_database_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_drop_databases_output_status,
print_drop_databases_output_status_json,
ClientToServerMessageStream, DropDatabaseError, Request, Response,
print_drop_databases_output_status, print_drop_databases_output_status_json,
request_validation::ValidationError,
},
types::MySQLDatabase,
},
@@ -66,13 +67,24 @@ pub async fn drop_databases(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_drop_databases_output_status_json(&result);
} else {
print_drop_databases_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(DropDatabaseError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
};
server_connection.send(Request::Exit).await?;
Ok(())
}

View File

@@ -5,12 +5,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_user_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_drop_users_output_status,
print_drop_users_output_status_json,
ClientToServerMessageStream, DropUserError, Request, Response,
print_drop_users_output_status, print_drop_users_output_status_json,
request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -70,13 +71,24 @@ pub async fn drop_users(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_drop_users_output_status_json(&result);
} else {
print_drop_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(DropUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
Ok(())
}

View File

@@ -9,7 +9,7 @@ use nix::unistd::{User, getuid};
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::{mysql_database_completer, mysql_user_completer},
database_privileges::{
@@ -19,8 +19,8 @@ use crate::{
parse_privilege_data_from_editor_content, reduce_privilege_diffs,
},
protocol::{
ClientToServerMessageStream, Request, Response,
print_modify_database_privileges_output_status,
ClientToServerMessageStream, ModifyDatabasePrivilegesError, Request, Response,
print_modify_database_privileges_output_status, request_validation::ValidationError,
},
types::{MySQLDatabase, MySQLUser},
},
@@ -219,6 +219,8 @@ pub async fn edit_database_privileges(
diff_privileges(&existing_privilege_rows, &privileges_to_change)
};
// TODO: validate authorization before existence
let user_existence_map = users_exist(&mut server_connection, &diffs).await?;
let database_existence_map = databases_exist(&mut server_connection, &diffs).await?;
@@ -274,6 +276,19 @@ pub async fn edit_database_privileges(
print_modify_database_privileges_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(ModifyDatabasePrivilegesError::UserValidationError(
ValidationError::AuthorizationError(_)
) | ModifyDatabasePrivilegesError::DatabaseValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
server_connection.send(Request::Exit).await?;
Ok(())

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_user_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_lock_users_output_status,
print_lock_users_output_status_json,
ClientToServerMessageStream, LockUserError, Request, Response,
print_lock_users_output_status, print_lock_users_output_status_json,
request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -47,13 +48,24 @@ pub async fn lock_users(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_lock_users_output_status_json(&result);
} else {
print_lock_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(LockUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
Ok(())
}

View File

@@ -8,12 +8,12 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_user_completer,
protocol::{
ClientToServerMessageStream, ListUsersError, Request, Response,
print_set_password_output_status,
ClientToServerMessageStream, ListUsersError, Request, Response, SetPasswordError,
print_set_password_output_status, request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -103,9 +103,18 @@ pub async fn passwd_user(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
print_set_password_output_status(&result, &args.username);
if matches!(
result,
Err(SetPasswordError::ValidationError(
ValidationError::AuthorizationError(_)
))
) {
print_authorization_owner_hint(&mut server_connection).await?
}
server_connection.send(Request::Exit).await?;
Ok(())
}

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_database_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_list_databases_output_status,
print_list_databases_output_status_json,
ClientToServerMessageStream, ListDatabasesError, Request, Response,
print_list_databases_output_status, print_list_databases_output_status_json,
request_validation::ValidationError,
},
types::MySQLDatabase,
},
@@ -60,14 +61,25 @@ pub async fn show_databases(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_list_databases_output_status_json(&databases);
} else {
print_list_databases_output_status(&databases);
if databases.iter().any(|(_, res)| {
matches!(
res,
Err(ListDatabasesError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
if args.fail && databases.values().any(|res| res.is_err()) {
std::process::exit(1);
}

View File

@@ -5,12 +5,13 @@ use itertools::Itertools;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_database_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_list_privileges_output_status,
print_list_privileges_output_status_json,
ClientToServerMessageStream, GetDatabasesPrivilegeDataError, Request, Response,
print_list_privileges_output_status, print_list_privileges_output_status_json,
request_validation::ValidationError,
},
types::MySQLDatabase,
},
@@ -68,14 +69,25 @@ pub async fn show_database_privileges(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_list_privileges_output_status_json(&privilege_data);
} else {
print_list_privileges_output_status(&privilege_data, args.long);
if privilege_data.iter().any(|(_, res)| {
matches!(
res,
Err(GetDatabasesPrivilegeDataError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
if args.fail && privilege_data.values().any(|res| res.is_err()) {
std::process::exit(1);
}

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_user_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_list_users_output_status,
print_list_users_output_status_json,
ClientToServerMessageStream, ListUsersError, Request, Response,
print_list_users_output_status, print_list_users_output_status_json,
request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -63,14 +64,25 @@ pub async fn show_users(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_list_users_output_status_json(&users);
} else {
print_list_users_output_status(&users);
if users.iter().any(|(_, res)| {
matches!(
res,
Err(ListUsersError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
if args.fail && users.values().any(|result| result.is_err()) {
std::process::exit(1);
}

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt;
use crate::{
client::commands::erroneous_server_response,
client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{
completion::mysql_user_completer,
protocol::{
ClientToServerMessageStream, Request, Response, print_unlock_users_output_status,
print_unlock_users_output_status_json,
ClientToServerMessageStream, Request, Response, UnlockUserError,
print_unlock_users_output_status, print_unlock_users_output_status_json,
request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -47,13 +48,24 @@ pub async fn unlock_users(
response => return erroneous_server_response(response),
};
server_connection.send(Request::Exit).await?;
if args.json {
print_unlock_users_output_status_json(&result);
} else {
print_unlock_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(UnlockUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}
server_connection.send(Request::Exit).await?;
Ok(())
}

View File

@@ -9,10 +9,12 @@ use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
use tracing_subscriber::prelude::*;
use crate::{
core::common::{
DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode,
core::{
common::{DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode},
protocol::request_validation::GroupDenylist,
},
server::{
authorization::read_and_parse_group_denylist,
config::{MysqlConfig, ServerConfig},
landlock::landlock_restrict_server,
session_handler,
@@ -270,6 +272,13 @@ fn run_forked_server(
let config = ServerConfig::read_config_from_path(&config_path)
.context("Failed to read server config in forked process")?;
let group_denylist = if let Some(denylist_path) = &config.authorization.group_denylist_file {
read_and_parse_group_denylist(denylist_path)
.context("Failed to read and parse group denylist")?
} else {
GroupDenylist::new()
};
let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
@@ -292,6 +301,7 @@ fn run_forked_server(
&unix_user,
db_pool,
db_is_mariadb,
&group_denylist,
)
.await?;
Ok(())

View File

@@ -99,10 +99,10 @@ impl UnixUser {
})
}
pub fn from_enviroment() -> anyhow::Result<Self> {
let libc_uid = nix::unistd::getuid();
UnixUser::from_uid(libc_uid.as_raw())
}
// pub fn from_enviroment() -> anyhow::Result<Self> {
// let libc_uid = nix::unistd::getuid();
// UnixUser::from_uid(libc_uid.as_raw())
// }
}
#[inline]

View File

@@ -1,5 +1,7 @@
use std::collections::HashSet;
use indoc::indoc;
use itertools::Itertools;
use nix::{libc::gid_t, unistd::Group};
use serde::{Deserialize, Serialize};
use thiserror::Error;
@@ -23,23 +25,19 @@ impl NameValidationError {
pub fn to_error_message(self, db_or_user: DbOrUser) -> String {
match self {
NameValidationError::EmptyString => {
format!("{} name cannot be empty.", db_or_user.capitalized_noun()).to_owned()
format!("{} name can not be empty.", db_or_user.capitalized_noun())
}
NameValidationError::TooLong => format!(
"{} is too long. Maximum length is 64 characters.",
"{} is too long, maximum length is 64 characters.",
db_or_user.capitalized_noun()
)
.to_owned(),
),
NameValidationError::InvalidCharacters => format!(
indoc! {r#"
Invalid characters in {} name: '{}'
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
Invalid characters in {} name: '{}', only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#},
db_or_user.lowercased_noun(),
db_or_user.name(),
)
.to_owned(),
),
}
}
@@ -54,64 +52,41 @@ impl NameValidationError {
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthorizationError {
#[error("No matching owner prefix found")]
NoMatch,
#[error("Illegal prefix, user is not authorized to manage this resource")]
IllegalPrefix,
// TODO: I don't think this should ever happen?
#[error("Name cannot be empty")]
StringEmpty,
#[error("Group was found in denylist")]
DenylistError,
}
impl AuthorizationError {
pub fn to_error_message(self, db_or_user: DbOrUser) -> String {
let user = UnixUser::from_enviroment();
let UnixUser {
username,
mut groups,
} = user.unwrap_or(UnixUser {
username: "???".to_string(),
groups: vec![],
});
groups.sort();
match self {
AuthorizationError::NoMatch => format!(
indoc! {r#"
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
The format should be: <prefix>_<{} name>
Allowed prefixes:
- {}
{}
"#},
AuthorizationError::IllegalPrefix => format!(
"Illegal {} name prefix: you are not allowed to manage databases or users prefixed with '{}'",
db_or_user.lowercased_noun(),
db_or_user.name(),
db_or_user.lowercased_noun(),
db_or_user.lowercased_noun(),
username,
groups
.into_iter()
.filter(|g| g != &username)
.map(|g| format!(" - {}", g))
.join("\n"),
db_or_user.prefix(),
)
.to_owned(),
AuthorizationError::StringEmpty => format!(
"'{}' is not a valid {} name.",
db_or_user.name(),
db_or_user.lowercased_noun()
)
.to_string(),
// TODO: This error message could be clearer
AuthorizationError::StringEmpty => {
format!("{} name can not be empty.", db_or_user.capitalized_noun())
}
AuthorizationError::DenylistError => {
format!("'{}' is denied by the group denylist", db_or_user.name())
}
}
}
pub fn error_type(&self) -> &'static str {
match self {
AuthorizationError::NoMatch => "no-match",
AuthorizationError::IllegalPrefix => "illegal-prefix",
AuthorizationError::StringEmpty => "string-empty",
AuthorizationError::DenylistError => "denylist-error",
}
}
}
@@ -155,6 +130,8 @@ impl ValidationError {
}
}
pub type GroupDenylist = HashSet<gid_t>;
const MAX_NAME_LENGTH: usize = 64;
pub fn validate_name(name: &str) -> Result<(), NameValidationError> {
@@ -201,21 +178,49 @@ pub fn validate_authorization_by_prefixes(
.collect::<Vec<_>>()
.is_empty()
{
return Err(AuthorizationError::NoMatch);
return Err(AuthorizationError::IllegalPrefix);
};
Ok(())
}
pub fn validate_authorization_by_group_denylist(
name: &str,
user: &UnixUser,
group_denylist: &GroupDenylist,
) -> Result<(), AuthorizationError> {
// NOTE: if the username matches, we allow it regardless of denylist
if user.username == name {
return Ok(());
}
let user_group = Group::from_name(name)
.ok()
.flatten()
.map(|g| g.gid.as_raw());
if let Some(gid) = user_group
&& group_denylist.contains(&gid)
{
Err(AuthorizationError::DenylistError)
} else {
Ok(())
}
}
pub fn validate_db_or_user_request(
db_or_user: &DbOrUser,
unix_user: &UnixUser,
group_denylist: &GroupDenylist,
) -> Result<(), ValidationError> {
validate_name(db_or_user.name()).map_err(ValidationError::NameValidationError)?;
validate_authorization_by_unix_user(db_or_user.name(), unix_user)
.map_err(ValidationError::AuthorizationError)?;
validate_authorization_by_group_denylist(db_or_user.name(), unix_user, group_denylist)
.map_err(ValidationError::AuthorizationError)?;
Ok(())
}
@@ -273,7 +278,7 @@ mod tests {
assert_eq!(
validate_authorization_by_prefixes("nonexistent_testdb", &prefixes),
Err(AuthorizationError::NoMatch)
Err(AuthorizationError::IllegalPrefix)
);
}
}

View File

@@ -132,4 +132,11 @@ impl DbOrUser {
DbOrUser::User(user) => user.as_str(),
}
}
pub fn prefix(&self) -> &str {
match self {
DbOrUser::Database(db) => db.split('_').next().unwrap_or("?"),
DbOrUser::User(user) => user.split('_').next().unwrap_or("?"),
}
}
}

View File

@@ -1,4 +1,4 @@
mod authorization;
pub mod authorization;
pub mod command;
mod common;
pub mod config;

View File

@@ -1,25 +1,127 @@
use std::{collections::HashSet, path::Path};
use anyhow::Context;
use nix::unistd::Group;
use crate::core::{
common::UnixUser,
protocol::{CheckAuthorizationError, request_validation::validate_db_or_user_request},
protocol::{
CheckAuthorizationError,
request_validation::{GroupDenylist, validate_db_or_user_request},
},
types::DbOrUser,
};
pub async fn check_authorization(
dbs_or_users: Vec<DbOrUser>,
unix_user: &UnixUser,
group_denylist: &GroupDenylist,
) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> {
let mut results = std::collections::BTreeMap::new();
for db_or_user in dbs_or_users {
if let Err(err) =
validate_db_or_user_request(&db_or_user, unix_user).map_err(CheckAuthorizationError)
if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
.map_err(CheckAuthorizationError)
{
results.insert(db_or_user.clone(), Err(err));
continue;
}
results.insert(db_or_user.clone(), Ok(()));
}
results
}
/// Reads and parses a group denylist file, returning a set of GUIDs
///
/// The format of the denylist file is expected to be one group name or GID per line.
/// Lines starting with '#' are treated as comments and ignored.
/// Empty lines are also ignored.
///
/// Each line looks like one of the following:
/// - `gid:1001`
/// - `group:admins`
pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<GroupDenylist> {
let content = std::fs::read_to_string(denylist_path).context(format!(
"Failed to read denylist file at {:?}",
denylist_path
))?;
let mut groups = HashSet::with_capacity(content.lines().count());
for (line_number, line) in content.lines().enumerate() {
let trimmed_line = line.trim();
if trimmed_line.is_empty() || trimmed_line.starts_with('#') {
continue;
}
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
if parts.len() != 2 {
anyhow::bail!(
"Invalid format in denylist file at {:?} on line {}: {}",
denylist_path,
line_number + 1,
line
);
}
match parts[0] {
"gid" => {
let gid: u32 = parts[1].parse().with_context(|| {
format!(
"Invalid GID in denylist file at {:?} on line {}: {}",
denylist_path,
line_number + 1,
parts[1]
)
})?;
let group = Group::from_gid(nix::unistd::Gid::from_raw(gid))
.context(format!(
"Failed to get group for GID {} in denylist file at {:?} on line {}",
gid,
denylist_path,
line_number + 1
))?
.ok_or_else(|| {
anyhow::anyhow!(
"No group found for GID {} in denylist file at {:?} on line {}",
gid,
denylist_path,
line_number + 1
)
})?;
groups.insert(group.gid.as_raw());
}
"group" => {
let group = Group::from_name(parts[1])
.context(format!(
"Failed to get group for name '{}' in denylist file at {:?} on line {}",
parts[1],
denylist_path,
line_number + 1
))?
.ok_or_else(|| {
anyhow::anyhow!(
"No group found for name '{}' in denylist file at {:?} on line {}",
parts[1],
denylist_path,
line_number + 1
)
})?;
groups.insert(group.gid.as_raw());
}
_ => {
anyhow::bail!(
"Invalid prefix '{}' in denylist file at {:?} on line {}: {}",
parts[0],
denylist_path,
line_number + 1,
line
);
}
}
}
Ok(groups)
}

View File

@@ -1,13 +1,37 @@
use crate::core::common::UnixUser;
use crate::core::{common::UnixUser, protocol::request_validation::GroupDenylist};
use nix::unistd::Group;
use sqlx::prelude::*;
/// This function retrieves the groups of a user, filtering out any groups
/// that are present in the provided denylist.
pub fn get_user_filtered_groups(user: &UnixUser, group_denylist: &GroupDenylist) -> Vec<String> {
user.groups
.iter()
.cloned()
.filter_map(|group_name| {
match Group::from_name(&group_name) {
Ok(Some(group)) => {
if group_denylist.contains(&group.gid.as_raw()) {
None
} else {
Some(group.name)
}
}
// NOTE: allow non-existing groups to pass through the filter
_ => Some(group_name),
}
})
.collect()
}
/// This function creates a regex that matches items (users, databases)
/// that belong to the user or any of the user's groups.
pub fn create_user_group_matching_regex(user: &UnixUser) -> String {
if user.groups.is_empty() {
pub fn create_user_group_matching_regex(user: &UnixUser, group_denylist: &GroupDenylist) -> String {
let filtered_groups = get_user_filtered_groups(user, group_denylist);
if filtered_groups.is_empty() {
format!("{}_.+", user.username)
} else {
format!("({}|{})_.+", user.username, user.groups.join("|"))
format!("({}|{})_.+", user.username, filtered_groups.join("|"))
}
}
@@ -37,7 +61,8 @@ mod tests {
groups: vec!["group1".to_owned(), "group2".to_owned()],
};
let regex = create_user_group_matching_regex(&user);
let regex = create_user_group_matching_regex(&user, &GroupDenylist::new());
println!("Generated regex: {}", regex);
let re = Regex::new(&regex).unwrap();
assert!(re.is_match("user_something"));

View File

@@ -78,9 +78,15 @@ impl MysqlConfig {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct AuthorizationConfig {
pub group_denylist_file: Option<PathBuf>,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct ServerConfig {
pub socket_path: Option<PathBuf>,
pub authorization: AuthorizationConfig,
pub mysql: MysqlConfig,
}

View File

@@ -11,11 +11,12 @@ use crate::{
common::UnixUser,
protocol::{
Request, Response, ServerToClientMessageStream, SetPasswordError,
create_server_to_client_message_stream,
create_server_to_client_message_stream, request_validation::GroupDenylist,
},
},
server::{
authorization::check_authorization,
common::get_user_filtered_groups,
sql::{
database_operations::{
complete_database_name, create_databases, drop_databases,
@@ -39,6 +40,7 @@ pub async fn session_handler(
socket: UnixStream,
db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> anyhow::Result<()> {
let uid = match socket.peer_cred() {
Ok(cred) => cred.uid(),
@@ -85,8 +87,14 @@ pub async fn session_handler(
(async move {
tracing::info!("Accepted connection from user: {}", unix_user);
let result =
session_handler_with_unix_user(socket, &unix_user, db_pool, db_is_mariadb).await;
let result = session_handler_with_unix_user(
socket,
&unix_user,
db_pool,
db_is_mariadb,
group_denylist,
)
.await;
tracing::info!(
"Finished handling requests for connection from user: {}",
@@ -104,6 +112,7 @@ pub async fn session_handler_with_unix_user(
unix_user: &UnixUser,
db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> anyhow::Result<()> {
let mut message_stream = create_server_to_client_message_stream(socket);
@@ -131,6 +140,7 @@ pub async fn session_handler_with_unix_user(
unix_user,
&mut db_connection,
db_is_mariadb,
group_denylist,
)
.await;
@@ -147,6 +157,7 @@ async fn session_handler_with_db_connection(
unix_user: &UnixUser,
db_connection: &mut MySqlConnection,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> anyhow::Result<()> {
stream.send(Response::Ready).await?;
loop {
@@ -178,18 +189,14 @@ async fn session_handler_with_db_connection(
let response = match request {
Request::CheckAuthorization(dbs_or_users) => {
let result = check_authorization(dbs_or_users, unix_user).await;
let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
Response::CheckAuthorization(result)
}
Request::ListValidNamePrefixes => {
let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
result.push(unix_user.username.to_owned());
for group in unix_user
.groups
.iter()
.filter(|x| *x != &unix_user.username)
{
for group in get_user_filtered_groups(unix_user, group_denylist) {
result.push(group.to_owned());
}
@@ -208,6 +215,7 @@ async fn session_handler_with_db_connection(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CompleteDatabaseName(result)
@@ -226,32 +234,54 @@ async fn session_handler_with_db_connection(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CompleteUserName(result)
}
}
Request::CreateDatabases(databases_names) => {
let result =
create_databases(databases_names, unix_user, db_connection, db_is_mariadb)
.await;
let result = create_databases(
databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CreateDatabases(result)
}
Request::DropDatabases(databases_names) => {
let result =
drop_databases(databases_names, unix_user, db_connection, db_is_mariadb).await;
let result = drop_databases(
databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropDatabases(result)
}
Request::ListDatabases(database_names) => match database_names {
Some(database_names) => {
let result =
list_databases(database_names, unix_user, db_connection, db_is_mariadb)
.await;
let result = list_databases(
database_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListDatabases(result)
}
None => {
let result =
list_all_databases_for_user(unix_user, db_connection, db_is_mariadb).await;
let result = list_all_databases_for_user(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllDatabases(result)
}
},
@@ -262,13 +292,19 @@ async fn session_handler_with_db_connection(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListPrivileges(privilege_data)
}
None => {
let privilege_data =
get_all_database_privileges(unix_user, db_connection, db_is_mariadb).await;
let privilege_data = get_all_database_privileges(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllPrivileges(privilege_data)
}
},
@@ -278,18 +314,31 @@ async fn session_handler_with_db_connection(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ModifyPrivileges(result)
}
Request::CreateUsers(db_users) => {
let result =
create_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
let result = create_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CreateUsers(result)
}
Request::DropUsers(db_users) => {
let result =
drop_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
let result = drop_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropUsers(result)
}
Request::PasswdUser((db_user, password)) => {
@@ -299,15 +348,21 @@ async fn session_handler_with_db_connection(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::SetUserPassword(result)
}
Request::ListUsers(db_users) => match db_users {
Some(db_users) => {
let result =
list_database_users(db_users, unix_user, db_connection, db_is_mariadb)
.await;
let result = list_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListUsers(result)
}
None => {
@@ -315,19 +370,32 @@ async fn session_handler_with_db_connection(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllUsers(result)
}
},
Request::LockUsers(db_users) => {
let result =
lock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
let result = lock_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::LockUsers(result)
}
Request::UnlockUsers(db_users) => {
let result =
unlock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
let result = unlock_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::UnlockUsers(result)
}
Request::Exit => {

View File

@@ -6,6 +6,7 @@ use sqlx::prelude::*;
use serde::{Deserialize, Serialize};
use crate::core::protocol::CompleteDatabaseNameResponse;
use crate::core::protocol::request_validation::GroupDenylist;
use crate::core::protocol::request_validation::validate_db_or_user_request;
use crate::core::types::DbOrUser;
use crate::core::types::MySQLDatabase;
@@ -49,6 +50,7 @@ pub async fn complete_database_name(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> CompleteDatabaseNameResponse {
let result = sqlx::query(
r#"
@@ -59,7 +61,7 @@ pub async fn complete_database_name(
AND `SCHEMA_NAME` LIKE ?
"#,
)
.bind(create_user_group_matching_regex(unix_user))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(format!("{}%", database_prefix))
.fetch_all(connection)
.await;
@@ -89,13 +91,17 @@ pub async fn create_databases(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> CreateDatabasesResponse {
let mut results = BTreeMap::new();
for database_name in database_names {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user)
.map_err(CreateDatabaseError::ValidationError)
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
group_denylist,
)
.map_err(CreateDatabaseError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
continue;
@@ -141,13 +147,17 @@ pub async fn drop_databases(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> DropDatabasesResponse {
let mut results = BTreeMap::new();
for database_name in database_names {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user)
.map_err(DropDatabaseError::ValidationError)
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
group_denylist,
)
.map_err(DropDatabaseError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
continue;
@@ -236,13 +246,17 @@ pub async fn list_databases(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListDatabasesResponse {
let mut results = BTreeMap::new();
for database_name in database_names {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user)
.map_err(ListDatabasesError::ValidationError)
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
group_denylist,
)
.map_err(ListDatabasesError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
continue;
@@ -296,6 +310,7 @@ pub async fn list_all_databases_for_user(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListAllDatabasesResponse {
let result = sqlx::query_as::<_, DatabaseRow>(
r#"
@@ -319,7 +334,7 @@ pub async fn list_all_databases_for_user(
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
"#,
)
.bind(create_user_group_matching_regex(unix_user))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection)
.await
.map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));

View File

@@ -31,7 +31,7 @@ use crate::{
DiffDoesNotApplyError, GetAllDatabasesPrivilegeDataError,
GetDatabasesPrivilegeDataError, ListAllPrivilegesResponse, ListPrivilegesResponse,
ModifyDatabasePrivilegesError, ModifyPrivilegesResponse,
request_validation::validate_db_or_user_request,
request_validation::{GroupDenylist, validate_db_or_user_request},
},
types::{DbOrUser, MySQLDatabase, MySQLUser},
},
@@ -143,13 +143,17 @@ pub async fn get_databases_privilege_data(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListPrivilegesResponse {
let mut results = BTreeMap::new();
for database_name in database_names.iter() {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user)
.map_err(GetDatabasesPrivilegeDataError::ValidationError)
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
group_denylist,
)
.map_err(GetDatabasesPrivilegeDataError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
continue;
@@ -200,9 +204,10 @@ pub async fn get_all_database_privileges(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListAllPrivilegesResponse {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query())
.bind(create_user_group_matching_regex(unix_user))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection)
.await
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()));
@@ -397,6 +402,7 @@ pub async fn apply_privilege_diffs(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ModifyPrivilegesResponse {
let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new();
@@ -408,6 +414,7 @@ pub async fn apply_privilege_diffs(
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(diff.get_database_name().to_owned()),
unix_user,
group_denylist,
)
.map_err(ModifyDatabasePrivilegesError::UserValidationError)
{
@@ -415,9 +422,12 @@ pub async fn apply_privilege_diffs(
continue;
}
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(diff.get_user_name().to_owned()), unix_user)
.map_err(ModifyDatabasePrivilegesError::UserValidationError)
if let Err(err) = validate_db_or_user_request(
&DbOrUser::User(diff.get_user_name().to_owned()),
unix_user,
group_denylist,
)
.map_err(ModifyDatabasePrivilegesError::UserValidationError)
{
results.insert(key, Err(err));
continue;

View File

@@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use sqlx::MySqlConnection;
use sqlx::prelude::*;
use crate::core::protocol::request_validation::GroupDenylist;
use crate::core::protocol::request_validation::validate_db_or_user_request;
use crate::core::types::DbOrUser;
use crate::{
@@ -58,6 +59,7 @@ pub async fn complete_user_name(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> Vec<MySQLUser> {
let result = sqlx::query(
r#"
@@ -67,7 +69,7 @@ pub async fn complete_user_name(
AND `User` LIKE ?
"#,
)
.bind(create_user_group_matching_regex(unix_user))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(format!("{}%", user_prefix))
.fetch_all(connection)
.await;
@@ -97,12 +99,14 @@ pub async fn create_database_users(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> CreateUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user)
.map_err(CreateUserError::ValidationError)
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(CreateUserError::ValidationError)
{
results.insert(db_user, Err(err));
continue;
@@ -141,12 +145,14 @@ pub async fn drop_database_users(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> DropUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user)
.map_err(DropUserError::ValidationError)
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(DropUserError::ValidationError)
{
results.insert(db_user, Err(err));
continue;
@@ -186,8 +192,9 @@ pub async fn set_password_for_database_user(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> SetUserPasswordResponse {
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user)
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(SetPasswordError::ValidationError)?;
match unsafe_user_exists(db_user, &mut *connection).await {
@@ -269,12 +276,14 @@ pub async fn lock_database_users(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> LockUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user)
.map_err(LockUserError::ValidationError)
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(LockUserError::ValidationError)
{
results.insert(db_user, Err(err));
continue;
@@ -327,12 +336,14 @@ pub async fn unlock_database_users(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> UnlockUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user)
.map_err(UnlockUserError::ValidationError)
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(UnlockUserError::ValidationError)
{
results.insert(db_user, Err(err));
continue;
@@ -433,12 +444,14 @@ pub async fn list_database_users(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user)
.map_err(ListUsersError::ValidationError)
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(ListUsersError::ValidationError)
{
results.insert(db_user, Err(err));
continue;
@@ -477,6 +490,7 @@ pub async fn list_all_database_users_for_unix_user(
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListAllUsersResponse {
let mut result = sqlx::query_as::<_, DatabaseUser>(
&(if db_is_mariadb {
@@ -485,7 +499,7 @@ pub async fn list_all_database_users_for_unix_user(
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ?"),
)
.bind(create_user_group_matching_regex(unix_user))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(&mut *connection)
.await
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()));

View File

@@ -17,9 +17,13 @@ use tokio::{
};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use crate::server::{
config::{MysqlConfig, ServerConfig},
session_handler::session_handler,
use crate::{
core::protocol::request_validation::GroupDenylist,
server::{
authorization::read_and_parse_group_denylist,
config::{MysqlConfig, ServerConfig},
session_handler::session_handler,
},
};
#[derive(Clone, Debug)]
@@ -36,6 +40,7 @@ pub struct ReloadEvent;
pub struct Supervisor {
config_path: PathBuf,
config: Arc<Mutex<ServerConfig>>,
group_deny_list: Arc<RwLock<GroupDenylist>>,
systemd_mode: bool,
shutdown_cancel_token: CancellationToken,
@@ -66,6 +71,23 @@ impl Supervisor {
let config = ServerConfig::read_config_from_path(&config_path)
.context("Failed to read server configuration")?;
let group_deny_list = match &config.authorization.group_denylist_file {
Some(denylist_path) => {
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
Arc::new(RwLock::new(denylist))
}
None => {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
Arc::new(RwLock::new(GroupDenylist::new()))
}
};
let mut watchdog_duration = None;
let mut watchdog_micro_seconds = 0;
#[cfg(target_os = "linux")]
@@ -148,12 +170,14 @@ impl Supervisor {
db_connection_pool.clone(),
rx,
db_is_mariadb.clone(),
group_deny_list.clone(),
))
};
Ok(Self {
config_path,
config: Arc::new(Mutex::new(config)),
group_deny_list,
systemd_mode,
reload_message_receiver: reload_rx,
shutdown_cancel_token,
@@ -196,6 +220,26 @@ impl Supervisor {
.context("Failed to read server configuration")?;
let mut config = self.config.clone().lock_owned().await;
*config = new_config;
let group_deny_list = match &config.authorization.group_denylist_file {
Some(denylist_path) => {
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
denylist
}
None => {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
GroupDenylist::new()
}
};
let mut group_deny_list_lock = self.group_deny_list.write().await;
*group_deny_list_lock = group_deny_list;
Ok(())
}
@@ -502,6 +546,7 @@ async fn listener_task(
db_pool: Arc<RwLock<MySqlPool>>,
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
db_is_mariadb: Arc<RwLock<bool>>,
group_denylist: Arc<RwLock<GroupDenylist>>,
) -> anyhow::Result<()> {
#[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
@@ -539,8 +584,14 @@ async fn listener_task(
let db_pool_clone = db_pool.clone();
let db_is_mariadb_clone = *db_is_mariadb.read().await;
let group_denylist_arc_clone = group_denylist.clone();
task_tracker.spawn(async move {
match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await {
match session_handler(
conn,
db_pool_clone,
db_is_mariadb_clone,
&*group_denylist_arc_clone.read().await,
).await {
Ok(()) => {}
Err(e) => {
tracing::error!("Failed to run server: {}", e);