Implement denylists
All checks were successful
Build and test / check-license (push) Successful in 1m38s
Build and test / check (push) Successful in 1m51s
Build and test / build (push) Successful in 2m40s
Build and test / test (push) Successful in 4m25s
Build and test / docs (push) Successful in 6m1s

This commit is contained in:
2025-12-15 15:17:37 +09:00
parent 45cefb8af4
commit 8b4d549e18
29 changed files with 743 additions and 188 deletions

View File

@@ -132,6 +132,11 @@ assets = [
"etc/muscl/config.toml", "etc/muscl/config.toml",
"644", "644",
], ],
[
"assets/debian/group_denylist.txt",
"etc/muscl/group_denylist.txt",
"644",
],
[ [
"assets/completions/_*", "assets/completions/_*",
"usr/share/zsh/site-functions/completions/", "usr/share/zsh/site-functions/completions/",

View File

@@ -22,3 +22,6 @@ password_file = "/run/credentials/muscl.service/muscl_mysql_password"
# Database connection timeout in seconds # Database connection timeout in seconds
timeout = 2 timeout = 2
[authorization]
group_denylist_file = "/etc/muscl/group_denylist.txt"

View File

@@ -0,0 +1,58 @@
# These are the default system groups on debian.
# You can alos add groups by gid by prefixing the line with 'gid:'.
group:adm
group:audio
group:avahi
group:backup
group:bin
group:cdrom
group:crontab
group:daemon
group:dialout
group:dip
group:disk
group:fax
group:floppy
group:games
group:gnats
group:input
group:irc
group:kmem
group:kvm
group:list
group:lp
group:mail
group:man
group:mlocate
group:netdev
group:news
group:nogroup
group:openldap
group:operator
group:plocate
group:plugdev
group:polkitd
group:postgres
group:proxy
group:render
group:root
group:sasl
group:shadow
group:src
group:staff
group:sudo
group:sync
group:sys
group:systemd-journal
group:systemd-network
group:systemd-resolve
group:systemd-timesync
group:tape
group:tty
group:users
group:utmp
group:uucp
group:video
group:voice
group:www-data

View File

@@ -40,6 +40,14 @@ in
}; };
}; };
authorization = {
group_denylist = lib.mkOption {
type = with lib.types; nullOr (listOf str);
default = [ "wheel" ];
description = "List of groups that are denied access";
};
};
mysql = { mysql = {
socket_path = lib.mkOption { socket_path = lib.mkOption {
type = with lib.types; nullOr path; type = with lib.types; nullOr path;
@@ -81,6 +89,12 @@ in
environment.systemPackages = [ cfg.package ]; environment.systemPackages = [ cfg.package ];
environment.etc."muscl/config.toml".source = lib.pipe cfg.settings [ environment.etc."muscl/config.toml".source = lib.pipe cfg.settings [
# Handle group_denylist_file
(conf: lib.recursiveUpdate conf {
authorization.group_denylist_file = if (conf.authorization.group_denylist != [ ]) then "/etc/muscl/group-denylist" else null;
authorization.group_denylist = null;
})
# Remove nulls # Remove nulls
(lib.filterAttrsRecursive (_: v: v != null)) (lib.filterAttrsRecursive (_: v: v != null))
@@ -95,6 +109,10 @@ in
(format.generate "muscl.conf") (format.generate "muscl.conf")
]; ];
environment.etc."muscl/group-denylist" = lib.mkIf (cfg.settings.authorization.group_denylist != [ ]) {
text = lib.concatMapStringsSep "\n" (group: "group:${group}") cfg.settings.authorization.group_denylist;
};
services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [ services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [
{ {
name = cfg.settings.mysql.username; name = cfg.settings.mysql.username;

View File

@@ -17,16 +17,19 @@ pub use create_user::*;
pub use drop_db::*; pub use drop_db::*;
pub use drop_user::*; pub use drop_user::*;
pub use edit_privs::*; pub use edit_privs::*;
use futures_util::SinkExt;
use itertools::Itertools;
pub use lock_user::*; pub use lock_user::*;
pub use passwd_user::*; pub use passwd_user::*;
pub use show_db::*; pub use show_db::*;
pub use show_privs::*; pub use show_privs::*;
pub use show_user::*; pub use show_user::*;
use tokio_stream::StreamExt;
pub use unlock_user::*; pub use unlock_user::*;
use clap::Subcommand; use clap::Subcommand;
use crate::core::protocol::{ClientToServerMessageStream, Response}; use crate::core::protocol::{ClientToServerMessageStream, Request, Response};
#[derive(Subcommand, Debug, Clone)] #[derive(Subcommand, Debug, Clone)]
#[command(subcommand_required = true)] #[command(subcommand_required = true)]
@@ -183,3 +186,23 @@ pub fn erroneous_server_response(
} }
} }
} }
pub async fn print_authorization_owner_hint(
server_connection: &mut ClientToServerMessageStream,
) -> anyhow::Result<()> {
server_connection
.send(Request::ListValidNamePrefixes)
.await?;
let response = match server_connection.next().await {
Some(Ok(Response::ListValidNamePrefixes(prefixes))) => prefixes,
response => return erroneous_server_response(response),
};
println!(
"Note: You are allowed to manage databases and users with the following prefixes:\n{}",
response.into_iter().map(|p| format!(" - {}", p)).join("\n")
);
Ok(())
}

View File

@@ -3,11 +3,12 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_create_databases_output_status, ClientToServerMessageStream, CreateDatabaseError, Request, Response,
print_create_databases_output_status_json, print_create_databases_output_status, print_create_databases_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLDatabase, types::MySQLDatabase,
}, },
@@ -40,13 +41,24 @@ pub async fn create_databases(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_create_databases_output_status_json(&result); print_create_databases_output_status_json(&result);
} else { } else {
print_create_databases_output_status(&result); print_create_databases_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(CreateDatabaseError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
Ok(()) Ok(())
} }

View File

@@ -4,11 +4,15 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::{erroneous_server_response, read_password_from_stdin_with_double_check}, client::commands::{
erroneous_server_response, print_authorization_owner_hint,
read_password_from_stdin_with_double_check,
},
core::{ core::{
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_create_users_output_status, ClientToServerMessageStream, CreateUserError, Request, Response,
print_create_users_output_status_json, print_set_password_output_status, print_create_users_output_status, print_create_users_output_status_json,
print_set_password_output_status, request_validation::ValidationError,
}, },
types::MySQLUser, types::MySQLUser,
}, },
@@ -55,6 +59,17 @@ pub async fn create_users(
} else { } else {
print_create_users_output_status(&result); print_create_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(CreateUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
let successfully_created_users = result let successfully_created_users = result
.iter() .iter()
.filter_map(|(username, result)| result.as_ref().ok().map(|_| username)) .filter_map(|(username, result)| result.as_ref().ok().map(|_| username))

View File

@@ -5,12 +5,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_database_completer, completion::mysql_database_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_drop_databases_output_status, ClientToServerMessageStream, DropDatabaseError, Request, Response,
print_drop_databases_output_status_json, print_drop_databases_output_status, print_drop_databases_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLDatabase, types::MySQLDatabase,
}, },
@@ -66,13 +67,24 @@ pub async fn drop_databases(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_drop_databases_output_status_json(&result); print_drop_databases_output_status_json(&result);
} else { } else {
print_drop_databases_output_status(&result); print_drop_databases_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(DropDatabaseError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
}; };
server_connection.send(Request::Exit).await?;
Ok(()) Ok(())
} }

View File

@@ -5,12 +5,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_user_completer, completion::mysql_user_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_drop_users_output_status, ClientToServerMessageStream, DropUserError, Request, Response,
print_drop_users_output_status_json, print_drop_users_output_status, print_drop_users_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLUser, types::MySQLUser,
}, },
@@ -70,13 +71,24 @@ pub async fn drop_users(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_drop_users_output_status_json(&result); print_drop_users_output_status_json(&result);
} else { } else {
print_drop_users_output_status(&result); print_drop_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(DropUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
Ok(()) Ok(())
} }

View File

@@ -9,7 +9,7 @@ use nix::unistd::{User, getuid};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::{mysql_database_completer, mysql_user_completer}, completion::{mysql_database_completer, mysql_user_completer},
database_privileges::{ database_privileges::{
@@ -19,8 +19,8 @@ use crate::{
parse_privilege_data_from_editor_content, reduce_privilege_diffs, parse_privilege_data_from_editor_content, reduce_privilege_diffs,
}, },
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, ClientToServerMessageStream, ModifyDatabasePrivilegesError, Request, Response,
print_modify_database_privileges_output_status, print_modify_database_privileges_output_status, request_validation::ValidationError,
}, },
types::{MySQLDatabase, MySQLUser}, types::{MySQLDatabase, MySQLUser},
}, },
@@ -219,6 +219,8 @@ pub async fn edit_database_privileges(
diff_privileges(&existing_privilege_rows, &privileges_to_change) diff_privileges(&existing_privilege_rows, &privileges_to_change)
}; };
// TODO: validate authorization before existence
let user_existence_map = users_exist(&mut server_connection, &diffs).await?; let user_existence_map = users_exist(&mut server_connection, &diffs).await?;
let database_existence_map = databases_exist(&mut server_connection, &diffs).await?; let database_existence_map = databases_exist(&mut server_connection, &diffs).await?;
@@ -274,6 +276,19 @@ pub async fn edit_database_privileges(
print_modify_database_privileges_output_status(&result); print_modify_database_privileges_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(ModifyDatabasePrivilegesError::UserValidationError(
ValidationError::AuthorizationError(_)
) | ModifyDatabasePrivilegesError::DatabaseValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
Ok(()) Ok(())

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_user_completer, completion::mysql_user_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_lock_users_output_status, ClientToServerMessageStream, LockUserError, Request, Response,
print_lock_users_output_status_json, print_lock_users_output_status, print_lock_users_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLUser, types::MySQLUser,
}, },
@@ -47,13 +48,24 @@ pub async fn lock_users(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_lock_users_output_status_json(&result); print_lock_users_output_status_json(&result);
} else { } else {
print_lock_users_output_status(&result); print_lock_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(LockUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
Ok(()) Ok(())
} }

View File

@@ -8,12 +8,12 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_user_completer, completion::mysql_user_completer,
protocol::{ protocol::{
ClientToServerMessageStream, ListUsersError, Request, Response, ClientToServerMessageStream, ListUsersError, Request, Response, SetPasswordError,
print_set_password_output_status, print_set_password_output_status, request_validation::ValidationError,
}, },
types::MySQLUser, types::MySQLUser,
}, },
@@ -103,9 +103,18 @@ pub async fn passwd_user(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
print_set_password_output_status(&result, &args.username); print_set_password_output_status(&result, &args.username);
if matches!(
result,
Err(SetPasswordError::ValidationError(
ValidationError::AuthorizationError(_)
))
) {
print_authorization_owner_hint(&mut server_connection).await?
}
server_connection.send(Request::Exit).await?;
Ok(()) Ok(())
} }

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_database_completer, completion::mysql_database_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_list_databases_output_status, ClientToServerMessageStream, ListDatabasesError, Request, Response,
print_list_databases_output_status_json, print_list_databases_output_status, print_list_databases_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLDatabase, types::MySQLDatabase,
}, },
@@ -60,14 +61,25 @@ pub async fn show_databases(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_list_databases_output_status_json(&databases); print_list_databases_output_status_json(&databases);
} else { } else {
print_list_databases_output_status(&databases); print_list_databases_output_status(&databases);
if databases.iter().any(|(_, res)| {
matches!(
res,
Err(ListDatabasesError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
if args.fail && databases.values().any(|res| res.is_err()) { if args.fail && databases.values().any(|res| res.is_err()) {
std::process::exit(1); std::process::exit(1);
} }

View File

@@ -5,12 +5,13 @@ use itertools::Itertools;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_database_completer, completion::mysql_database_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_list_privileges_output_status, ClientToServerMessageStream, GetDatabasesPrivilegeDataError, Request, Response,
print_list_privileges_output_status_json, print_list_privileges_output_status, print_list_privileges_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLDatabase, types::MySQLDatabase,
}, },
@@ -68,14 +69,25 @@ pub async fn show_database_privileges(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_list_privileges_output_status_json(&privilege_data); print_list_privileges_output_status_json(&privilege_data);
} else { } else {
print_list_privileges_output_status(&privilege_data, args.long); print_list_privileges_output_status(&privilege_data, args.long);
if privilege_data.iter().any(|(_, res)| {
matches!(
res,
Err(GetDatabasesPrivilegeDataError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
if args.fail && privilege_data.values().any(|res| res.is_err()) { if args.fail && privilege_data.values().any(|res| res.is_err()) {
std::process::exit(1); std::process::exit(1);
} }

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_user_completer, completion::mysql_user_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_list_users_output_status, ClientToServerMessageStream, ListUsersError, Request, Response,
print_list_users_output_status_json, print_list_users_output_status, print_list_users_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLUser, types::MySQLUser,
}, },
@@ -63,14 +64,25 @@ pub async fn show_users(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_list_users_output_status_json(&users); print_list_users_output_status_json(&users);
} else { } else {
print_list_users_output_status(&users); print_list_users_output_status(&users);
if users.iter().any(|(_, res)| {
matches!(
res,
Err(ListUsersError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
if args.fail && users.values().any(|result| result.is_err()) { if args.fail && users.values().any(|result| result.is_err()) {
std::process::exit(1); std::process::exit(1);
} }

View File

@@ -4,12 +4,13 @@ use futures_util::SinkExt;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use crate::{ use crate::{
client::commands::erroneous_server_response, client::commands::{erroneous_server_response, print_authorization_owner_hint},
core::{ core::{
completion::mysql_user_completer, completion::mysql_user_completer,
protocol::{ protocol::{
ClientToServerMessageStream, Request, Response, print_unlock_users_output_status, ClientToServerMessageStream, Request, Response, UnlockUserError,
print_unlock_users_output_status_json, print_unlock_users_output_status, print_unlock_users_output_status_json,
request_validation::ValidationError,
}, },
types::MySQLUser, types::MySQLUser,
}, },
@@ -47,13 +48,24 @@ pub async fn unlock_users(
response => return erroneous_server_response(response), response => return erroneous_server_response(response),
}; };
server_connection.send(Request::Exit).await?;
if args.json { if args.json {
print_unlock_users_output_status_json(&result); print_unlock_users_output_status_json(&result);
} else { } else {
print_unlock_users_output_status(&result); print_unlock_users_output_status(&result);
if result.iter().any(|(_, res)| {
matches!(
res,
Err(UnlockUserError::ValidationError(
ValidationError::AuthorizationError(_)
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
}
} }
server_connection.send(Request::Exit).await?;
Ok(()) Ok(())
} }

View File

@@ -9,10 +9,12 @@ use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
use tracing_subscriber::prelude::*; use tracing_subscriber::prelude::*;
use crate::{ use crate::{
core::common::{ core::{
DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode, common::{DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode},
protocol::request_validation::GroupDenylist,
}, },
server::{ server::{
authorization::read_and_parse_group_denylist,
config::{MysqlConfig, ServerConfig}, config::{MysqlConfig, ServerConfig},
landlock::landlock_restrict_server, landlock::landlock_restrict_server,
session_handler, session_handler,
@@ -270,6 +272,13 @@ fn run_forked_server(
let config = ServerConfig::read_config_from_path(&config_path) let config = ServerConfig::read_config_from_path(&config_path)
.context("Failed to read server config in forked process")?; .context("Failed to read server config in forked process")?;
let group_denylist = if let Some(denylist_path) = &config.authorization.group_denylist_file {
read_and_parse_group_denylist(denylist_path)
.context("Failed to read and parse group denylist")?
} else {
GroupDenylist::new()
};
let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread() let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
.enable_all() .enable_all()
.build() .build()
@@ -292,6 +301,7 @@ fn run_forked_server(
&unix_user, &unix_user,
db_pool, db_pool,
db_is_mariadb, db_is_mariadb,
&group_denylist,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@@ -99,10 +99,10 @@ impl UnixUser {
}) })
} }
pub fn from_enviroment() -> anyhow::Result<Self> { // pub fn from_enviroment() -> anyhow::Result<Self> {
let libc_uid = nix::unistd::getuid(); // let libc_uid = nix::unistd::getuid();
UnixUser::from_uid(libc_uid.as_raw()) // UnixUser::from_uid(libc_uid.as_raw())
} // }
} }
#[inline] #[inline]

View File

@@ -1,5 +1,7 @@
use std::collections::HashSet;
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use nix::{libc::gid_t, unistd::Group};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error; use thiserror::Error;
@@ -23,23 +25,19 @@ impl NameValidationError {
pub fn to_error_message(self, db_or_user: DbOrUser) -> String { pub fn to_error_message(self, db_or_user: DbOrUser) -> String {
match self { match self {
NameValidationError::EmptyString => { NameValidationError::EmptyString => {
format!("{} name cannot be empty.", db_or_user.capitalized_noun()).to_owned() format!("{} name can not be empty.", db_or_user.capitalized_noun())
} }
NameValidationError::TooLong => format!( NameValidationError::TooLong => format!(
"{} is too long. Maximum length is 64 characters.", "{} is too long, maximum length is 64 characters.",
db_or_user.capitalized_noun() db_or_user.capitalized_noun()
) ),
.to_owned(),
NameValidationError::InvalidCharacters => format!( NameValidationError::InvalidCharacters => format!(
indoc! {r#" indoc! {r#"
Invalid characters in {} name: '{}' Invalid characters in {} name: '{}', only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#}, "#},
db_or_user.lowercased_noun(), db_or_user.lowercased_noun(),
db_or_user.name(), db_or_user.name(),
) ),
.to_owned(),
} }
} }
@@ -54,64 +52,41 @@ impl NameValidationError {
#[derive(Error, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] #[derive(Error, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthorizationError { pub enum AuthorizationError {
#[error("No matching owner prefix found")] #[error("Illegal prefix, user is not authorized to manage this resource")]
NoMatch, IllegalPrefix,
// TODO: I don't think this should ever happen? // TODO: I don't think this should ever happen?
#[error("Name cannot be empty")] #[error("Name cannot be empty")]
StringEmpty, StringEmpty,
#[error("Group was found in denylist")]
DenylistError,
} }
impl AuthorizationError { impl AuthorizationError {
pub fn to_error_message(self, db_or_user: DbOrUser) -> String { pub fn to_error_message(self, db_or_user: DbOrUser) -> String {
let user = UnixUser::from_enviroment();
let UnixUser {
username,
mut groups,
} = user.unwrap_or(UnixUser {
username: "???".to_string(),
groups: vec![],
});
groups.sort();
match self { match self {
AuthorizationError::NoMatch => format!( AuthorizationError::IllegalPrefix => format!(
indoc! {r#" "Illegal {} name prefix: you are not allowed to manage databases or users prefixed with '{}'",
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
The format should be: <prefix>_<{} name>
Allowed prefixes:
- {}
{}
"#},
db_or_user.lowercased_noun(), db_or_user.lowercased_noun(),
db_or_user.name(), db_or_user.prefix(),
db_or_user.lowercased_noun(),
db_or_user.lowercased_noun(),
username,
groups
.into_iter()
.filter(|g| g != &username)
.map(|g| format!(" - {}", g))
.join("\n"),
) )
.to_owned(), .to_owned(),
AuthorizationError::StringEmpty => format!( // TODO: This error message could be clearer
"'{}' is not a valid {} name.", AuthorizationError::StringEmpty => {
db_or_user.name(), format!("{} name can not be empty.", db_or_user.capitalized_noun())
db_or_user.lowercased_noun() }
) AuthorizationError::DenylistError => {
.to_string(), format!("'{}' is denied by the group denylist", db_or_user.name())
}
} }
} }
pub fn error_type(&self) -> &'static str { pub fn error_type(&self) -> &'static str {
match self { match self {
AuthorizationError::NoMatch => "no-match", AuthorizationError::IllegalPrefix => "illegal-prefix",
AuthorizationError::StringEmpty => "string-empty", AuthorizationError::StringEmpty => "string-empty",
AuthorizationError::DenylistError => "denylist-error",
} }
} }
} }
@@ -155,6 +130,8 @@ impl ValidationError {
} }
} }
pub type GroupDenylist = HashSet<gid_t>;
const MAX_NAME_LENGTH: usize = 64; const MAX_NAME_LENGTH: usize = 64;
pub fn validate_name(name: &str) -> Result<(), NameValidationError> { pub fn validate_name(name: &str) -> Result<(), NameValidationError> {
@@ -201,21 +178,49 @@ pub fn validate_authorization_by_prefixes(
.collect::<Vec<_>>() .collect::<Vec<_>>()
.is_empty() .is_empty()
{ {
return Err(AuthorizationError::NoMatch); return Err(AuthorizationError::IllegalPrefix);
}; };
Ok(()) Ok(())
} }
pub fn validate_authorization_by_group_denylist(
name: &str,
user: &UnixUser,
group_denylist: &GroupDenylist,
) -> Result<(), AuthorizationError> {
// NOTE: if the username matches, we allow it regardless of denylist
if user.username == name {
return Ok(());
}
let user_group = Group::from_name(name)
.ok()
.flatten()
.map(|g| g.gid.as_raw());
if let Some(gid) = user_group
&& group_denylist.contains(&gid)
{
Err(AuthorizationError::DenylistError)
} else {
Ok(())
}
}
pub fn validate_db_or_user_request( pub fn validate_db_or_user_request(
db_or_user: &DbOrUser, db_or_user: &DbOrUser,
unix_user: &UnixUser, unix_user: &UnixUser,
group_denylist: &GroupDenylist,
) -> Result<(), ValidationError> { ) -> Result<(), ValidationError> {
validate_name(db_or_user.name()).map_err(ValidationError::NameValidationError)?; validate_name(db_or_user.name()).map_err(ValidationError::NameValidationError)?;
validate_authorization_by_unix_user(db_or_user.name(), unix_user) validate_authorization_by_unix_user(db_or_user.name(), unix_user)
.map_err(ValidationError::AuthorizationError)?; .map_err(ValidationError::AuthorizationError)?;
validate_authorization_by_group_denylist(db_or_user.name(), unix_user, group_denylist)
.map_err(ValidationError::AuthorizationError)?;
Ok(()) Ok(())
} }
@@ -273,7 +278,7 @@ mod tests {
assert_eq!( assert_eq!(
validate_authorization_by_prefixes("nonexistent_testdb", &prefixes), validate_authorization_by_prefixes("nonexistent_testdb", &prefixes),
Err(AuthorizationError::NoMatch) Err(AuthorizationError::IllegalPrefix)
); );
} }
} }

View File

@@ -132,4 +132,11 @@ impl DbOrUser {
DbOrUser::User(user) => user.as_str(), DbOrUser::User(user) => user.as_str(),
} }
} }
pub fn prefix(&self) -> &str {
match self {
DbOrUser::Database(db) => db.split('_').next().unwrap_or("?"),
DbOrUser::User(user) => user.split('_').next().unwrap_or("?"),
}
}
} }

View File

@@ -1,4 +1,4 @@
mod authorization; pub mod authorization;
pub mod command; pub mod command;
mod common; mod common;
pub mod config; pub mod config;

View File

@@ -1,25 +1,127 @@
use std::{collections::HashSet, path::Path};
use anyhow::Context;
use nix::unistd::Group;
use crate::core::{ use crate::core::{
common::UnixUser, common::UnixUser,
protocol::{CheckAuthorizationError, request_validation::validate_db_or_user_request}, protocol::{
CheckAuthorizationError,
request_validation::{GroupDenylist, validate_db_or_user_request},
},
types::DbOrUser, types::DbOrUser,
}; };
pub async fn check_authorization( pub async fn check_authorization(
dbs_or_users: Vec<DbOrUser>, dbs_or_users: Vec<DbOrUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
group_denylist: &GroupDenylist,
) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> { ) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> {
let mut results = std::collections::BTreeMap::new(); let mut results = std::collections::BTreeMap::new();
for db_or_user in dbs_or_users { for db_or_user in dbs_or_users {
if let Err(err) = if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
validate_db_or_user_request(&db_or_user, unix_user).map_err(CheckAuthorizationError) .map_err(CheckAuthorizationError)
{ {
results.insert(db_or_user.clone(), Err(err)); results.insert(db_or_user.clone(), Err(err));
continue; continue;
} }
results.insert(db_or_user.clone(), Ok(())); results.insert(db_or_user.clone(), Ok(()));
} }
results results
} }
/// Reads and parses a group denylist file, returning a set of GUIDs
///
/// The format of the denylist file is expected to be one group name or GID per line.
/// Lines starting with '#' are treated as comments and ignored.
/// Empty lines are also ignored.
///
/// Each line looks like one of the following:
/// - `gid:1001`
/// - `group:admins`
pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<GroupDenylist> {
let content = std::fs::read_to_string(denylist_path).context(format!(
"Failed to read denylist file at {:?}",
denylist_path
))?;
let mut groups = HashSet::with_capacity(content.lines().count());
for (line_number, line) in content.lines().enumerate() {
let trimmed_line = line.trim();
if trimmed_line.is_empty() || trimmed_line.starts_with('#') {
continue;
}
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
if parts.len() != 2 {
anyhow::bail!(
"Invalid format in denylist file at {:?} on line {}: {}",
denylist_path,
line_number + 1,
line
);
}
match parts[0] {
"gid" => {
let gid: u32 = parts[1].parse().with_context(|| {
format!(
"Invalid GID in denylist file at {:?} on line {}: {}",
denylist_path,
line_number + 1,
parts[1]
)
})?;
let group = Group::from_gid(nix::unistd::Gid::from_raw(gid))
.context(format!(
"Failed to get group for GID {} in denylist file at {:?} on line {}",
gid,
denylist_path,
line_number + 1
))?
.ok_or_else(|| {
anyhow::anyhow!(
"No group found for GID {} in denylist file at {:?} on line {}",
gid,
denylist_path,
line_number + 1
)
})?;
groups.insert(group.gid.as_raw());
}
"group" => {
let group = Group::from_name(parts[1])
.context(format!(
"Failed to get group for name '{}' in denylist file at {:?} on line {}",
parts[1],
denylist_path,
line_number + 1
))?
.ok_or_else(|| {
anyhow::anyhow!(
"No group found for name '{}' in denylist file at {:?} on line {}",
parts[1],
denylist_path,
line_number + 1
)
})?;
groups.insert(group.gid.as_raw());
}
_ => {
anyhow::bail!(
"Invalid prefix '{}' in denylist file at {:?} on line {}: {}",
parts[0],
denylist_path,
line_number + 1,
line
);
}
}
}
Ok(groups)
}

View File

@@ -1,13 +1,37 @@
use crate::core::common::UnixUser; use crate::core::{common::UnixUser, protocol::request_validation::GroupDenylist};
use nix::unistd::Group;
use sqlx::prelude::*; use sqlx::prelude::*;
/// This function retrieves the groups of a user, filtering out any groups
/// that are present in the provided denylist.
pub fn get_user_filtered_groups(user: &UnixUser, group_denylist: &GroupDenylist) -> Vec<String> {
user.groups
.iter()
.cloned()
.filter_map(|group_name| {
match Group::from_name(&group_name) {
Ok(Some(group)) => {
if group_denylist.contains(&group.gid.as_raw()) {
None
} else {
Some(group.name)
}
}
// NOTE: allow non-existing groups to pass through the filter
_ => Some(group_name),
}
})
.collect()
}
/// This function creates a regex that matches items (users, databases) /// This function creates a regex that matches items (users, databases)
/// that belong to the user or any of the user's groups. /// that belong to the user or any of the user's groups.
pub fn create_user_group_matching_regex(user: &UnixUser) -> String { pub fn create_user_group_matching_regex(user: &UnixUser, group_denylist: &GroupDenylist) -> String {
if user.groups.is_empty() { let filtered_groups = get_user_filtered_groups(user, group_denylist);
if filtered_groups.is_empty() {
format!("{}_.+", user.username) format!("{}_.+", user.username)
} else { } else {
format!("({}|{})_.+", user.username, user.groups.join("|")) format!("({}|{})_.+", user.username, filtered_groups.join("|"))
} }
} }
@@ -37,7 +61,8 @@ mod tests {
groups: vec!["group1".to_owned(), "group2".to_owned()], groups: vec!["group1".to_owned(), "group2".to_owned()],
}; };
let regex = create_user_group_matching_regex(&user); let regex = create_user_group_matching_regex(&user, &GroupDenylist::new());
println!("Generated regex: {}", regex);
let re = Regex::new(&regex).unwrap(); let re = Regex::new(&regex).unwrap();
assert!(re.is_match("user_something")); assert!(re.is_match("user_something"));

View File

@@ -78,9 +78,15 @@ impl MysqlConfig {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct AuthorizationConfig {
pub group_denylist_file: Option<PathBuf>,
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct ServerConfig { pub struct ServerConfig {
pub socket_path: Option<PathBuf>, pub socket_path: Option<PathBuf>,
pub authorization: AuthorizationConfig,
pub mysql: MysqlConfig, pub mysql: MysqlConfig,
} }

View File

@@ -11,11 +11,12 @@ use crate::{
common::UnixUser, common::UnixUser,
protocol::{ protocol::{
Request, Response, ServerToClientMessageStream, SetPasswordError, Request, Response, ServerToClientMessageStream, SetPasswordError,
create_server_to_client_message_stream, create_server_to_client_message_stream, request_validation::GroupDenylist,
}, },
}, },
server::{ server::{
authorization::check_authorization, authorization::check_authorization,
common::get_user_filtered_groups,
sql::{ sql::{
database_operations::{ database_operations::{
complete_database_name, create_databases, drop_databases, complete_database_name, create_databases, drop_databases,
@@ -39,6 +40,7 @@ pub async fn session_handler(
socket: UnixStream, socket: UnixStream,
db_pool: Arc<RwLock<MySqlPool>>, db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let uid = match socket.peer_cred() { let uid = match socket.peer_cred() {
Ok(cred) => cred.uid(), Ok(cred) => cred.uid(),
@@ -85,8 +87,14 @@ pub async fn session_handler(
(async move { (async move {
tracing::info!("Accepted connection from user: {}", unix_user); tracing::info!("Accepted connection from user: {}", unix_user);
let result = let result = session_handler_with_unix_user(
session_handler_with_unix_user(socket, &unix_user, db_pool, db_is_mariadb).await; socket,
&unix_user,
db_pool,
db_is_mariadb,
group_denylist,
)
.await;
tracing::info!( tracing::info!(
"Finished handling requests for connection from user: {}", "Finished handling requests for connection from user: {}",
@@ -104,6 +112,7 @@ pub async fn session_handler_with_unix_user(
unix_user: &UnixUser, unix_user: &UnixUser,
db_pool: Arc<RwLock<MySqlPool>>, db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut message_stream = create_server_to_client_message_stream(socket); let mut message_stream = create_server_to_client_message_stream(socket);
@@ -131,6 +140,7 @@ pub async fn session_handler_with_unix_user(
unix_user, unix_user,
&mut db_connection, &mut db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
@@ -147,6 +157,7 @@ async fn session_handler_with_db_connection(
unix_user: &UnixUser, unix_user: &UnixUser,
db_connection: &mut MySqlConnection, db_connection: &mut MySqlConnection,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
stream.send(Response::Ready).await?; stream.send(Response::Ready).await?;
loop { loop {
@@ -178,18 +189,14 @@ async fn session_handler_with_db_connection(
let response = match request { let response = match request {
Request::CheckAuthorization(dbs_or_users) => { Request::CheckAuthorization(dbs_or_users) => {
let result = check_authorization(dbs_or_users, unix_user).await; let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
Response::CheckAuthorization(result) Response::CheckAuthorization(result)
} }
Request::ListValidNamePrefixes => { Request::ListValidNamePrefixes => {
let mut result = Vec::with_capacity(unix_user.groups.len() + 1); let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
result.push(unix_user.username.to_owned()); result.push(unix_user.username.to_owned());
for group in unix_user for group in get_user_filtered_groups(unix_user, group_denylist) {
.groups
.iter()
.filter(|x| *x != &unix_user.username)
{
result.push(group.to_owned()); result.push(group.to_owned());
} }
@@ -208,6 +215,7 @@ async fn session_handler_with_db_connection(
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
Response::CompleteDatabaseName(result) Response::CompleteDatabaseName(result)
@@ -226,32 +234,54 @@ async fn session_handler_with_db_connection(
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
Response::CompleteUserName(result) Response::CompleteUserName(result)
} }
} }
Request::CreateDatabases(databases_names) => { Request::CreateDatabases(databases_names) => {
let result = let result = create_databases(
create_databases(databases_names, unix_user, db_connection, db_is_mariadb) databases_names,
.await; unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CreateDatabases(result) Response::CreateDatabases(result)
} }
Request::DropDatabases(databases_names) => { Request::DropDatabases(databases_names) => {
let result = let result = drop_databases(
drop_databases(databases_names, unix_user, db_connection, db_is_mariadb).await; databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropDatabases(result) Response::DropDatabases(result)
} }
Request::ListDatabases(database_names) => match database_names { Request::ListDatabases(database_names) => match database_names {
Some(database_names) => { Some(database_names) => {
let result = let result = list_databases(
list_databases(database_names, unix_user, db_connection, db_is_mariadb) database_names,
.await; unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListDatabases(result) Response::ListDatabases(result)
} }
None => { None => {
let result = let result = list_all_databases_for_user(
list_all_databases_for_user(unix_user, db_connection, db_is_mariadb).await; unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllDatabases(result) Response::ListAllDatabases(result)
} }
}, },
@@ -262,13 +292,19 @@ async fn session_handler_with_db_connection(
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
Response::ListPrivileges(privilege_data) Response::ListPrivileges(privilege_data)
} }
None => { None => {
let privilege_data = let privilege_data = get_all_database_privileges(
get_all_database_privileges(unix_user, db_connection, db_is_mariadb).await; unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllPrivileges(privilege_data) Response::ListAllPrivileges(privilege_data)
} }
}, },
@@ -278,18 +314,31 @@ async fn session_handler_with_db_connection(
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
Response::ModifyPrivileges(result) Response::ModifyPrivileges(result)
} }
Request::CreateUsers(db_users) => { Request::CreateUsers(db_users) => {
let result = let result = create_database_users(
create_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CreateUsers(result) Response::CreateUsers(result)
} }
Request::DropUsers(db_users) => { Request::DropUsers(db_users) => {
let result = let result = drop_database_users(
drop_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropUsers(result) Response::DropUsers(result)
} }
Request::PasswdUser((db_user, password)) => { Request::PasswdUser((db_user, password)) => {
@@ -299,15 +348,21 @@ async fn session_handler_with_db_connection(
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
Response::SetUserPassword(result) Response::SetUserPassword(result)
} }
Request::ListUsers(db_users) => match db_users { Request::ListUsers(db_users) => match db_users {
Some(db_users) => { Some(db_users) => {
let result = let result = list_database_users(
list_database_users(db_users, unix_user, db_connection, db_is_mariadb) db_users,
.await; unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListUsers(result) Response::ListUsers(result)
} }
None => { None => {
@@ -315,19 +370,32 @@ async fn session_handler_with_db_connection(
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist,
) )
.await; .await;
Response::ListAllUsers(result) Response::ListAllUsers(result)
} }
}, },
Request::LockUsers(db_users) => { Request::LockUsers(db_users) => {
let result = let result = lock_database_users(
lock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::LockUsers(result) Response::LockUsers(result)
} }
Request::UnlockUsers(db_users) => { Request::UnlockUsers(db_users) => {
let result = let result = unlock_database_users(
unlock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::UnlockUsers(result) Response::UnlockUsers(result)
} }
Request::Exit => { Request::Exit => {

View File

@@ -6,6 +6,7 @@ use sqlx::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::core::protocol::CompleteDatabaseNameResponse; use crate::core::protocol::CompleteDatabaseNameResponse;
use crate::core::protocol::request_validation::GroupDenylist;
use crate::core::protocol::request_validation::validate_db_or_user_request; use crate::core::protocol::request_validation::validate_db_or_user_request;
use crate::core::types::DbOrUser; use crate::core::types::DbOrUser;
use crate::core::types::MySQLDatabase; use crate::core::types::MySQLDatabase;
@@ -49,6 +50,7 @@ pub async fn complete_database_name(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> CompleteDatabaseNameResponse { ) -> CompleteDatabaseNameResponse {
let result = sqlx::query( let result = sqlx::query(
r#" r#"
@@ -59,7 +61,7 @@ pub async fn complete_database_name(
AND `SCHEMA_NAME` LIKE ? AND `SCHEMA_NAME` LIKE ?
"#, "#,
) )
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(format!("{}%", database_prefix)) .bind(format!("{}%", database_prefix))
.fetch_all(connection) .fetch_all(connection)
.await; .await;
@@ -89,13 +91,17 @@ pub async fn create_databases(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> CreateDatabasesResponse { ) -> CreateDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for database_name in database_names { for database_name in database_names {
if let Err(err) = if let Err(err) = validate_db_or_user_request(
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) &DbOrUser::Database(database_name.clone()),
.map_err(CreateDatabaseError::ValidationError) unix_user,
group_denylist,
)
.map_err(CreateDatabaseError::ValidationError)
{ {
results.insert(database_name.to_owned(), Err(err)); results.insert(database_name.to_owned(), Err(err));
continue; continue;
@@ -141,13 +147,17 @@ pub async fn drop_databases(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> DropDatabasesResponse { ) -> DropDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for database_name in database_names { for database_name in database_names {
if let Err(err) = if let Err(err) = validate_db_or_user_request(
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) &DbOrUser::Database(database_name.clone()),
.map_err(DropDatabaseError::ValidationError) unix_user,
group_denylist,
)
.map_err(DropDatabaseError::ValidationError)
{ {
results.insert(database_name.to_owned(), Err(err)); results.insert(database_name.to_owned(), Err(err));
continue; continue;
@@ -236,13 +246,17 @@ pub async fn list_databases(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListDatabasesResponse { ) -> ListDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for database_name in database_names { for database_name in database_names {
if let Err(err) = if let Err(err) = validate_db_or_user_request(
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) &DbOrUser::Database(database_name.clone()),
.map_err(ListDatabasesError::ValidationError) unix_user,
group_denylist,
)
.map_err(ListDatabasesError::ValidationError)
{ {
results.insert(database_name.to_owned(), Err(err)); results.insert(database_name.to_owned(), Err(err));
continue; continue;
@@ -296,6 +310,7 @@ pub async fn list_all_databases_for_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListAllDatabasesResponse { ) -> ListAllDatabasesResponse {
let result = sqlx::query_as::<_, DatabaseRow>( let result = sqlx::query_as::<_, DatabaseRow>(
r#" r#"
@@ -319,7 +334,7 @@ pub async fn list_all_databases_for_user(
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME` GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
"#, "#,
) )
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection) .fetch_all(connection)
.await .await
.map_err(|err| ListAllDatabasesError::MySqlError(err.to_string())); .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));

View File

@@ -31,7 +31,7 @@ use crate::{
DiffDoesNotApplyError, GetAllDatabasesPrivilegeDataError, DiffDoesNotApplyError, GetAllDatabasesPrivilegeDataError,
GetDatabasesPrivilegeDataError, ListAllPrivilegesResponse, ListPrivilegesResponse, GetDatabasesPrivilegeDataError, ListAllPrivilegesResponse, ListPrivilegesResponse,
ModifyDatabasePrivilegesError, ModifyPrivilegesResponse, ModifyDatabasePrivilegesError, ModifyPrivilegesResponse,
request_validation::validate_db_or_user_request, request_validation::{GroupDenylist, validate_db_or_user_request},
}, },
types::{DbOrUser, MySQLDatabase, MySQLUser}, types::{DbOrUser, MySQLDatabase, MySQLUser},
}, },
@@ -143,13 +143,17 @@ pub async fn get_databases_privilege_data(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListPrivilegesResponse { ) -> ListPrivilegesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for database_name in database_names.iter() { for database_name in database_names.iter() {
if let Err(err) = if let Err(err) = validate_db_or_user_request(
validate_db_or_user_request(&DbOrUser::Database(database_name.clone()), unix_user) &DbOrUser::Database(database_name.clone()),
.map_err(GetDatabasesPrivilegeDataError::ValidationError) unix_user,
group_denylist,
)
.map_err(GetDatabasesPrivilegeDataError::ValidationError)
{ {
results.insert(database_name.to_owned(), Err(err)); results.insert(database_name.to_owned(), Err(err));
continue; continue;
@@ -200,9 +204,10 @@ pub async fn get_all_database_privileges(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListAllPrivilegesResponse { ) -> ListAllPrivilegesResponse {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query()) let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query())
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection) .fetch_all(connection)
.await .await
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())); .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()));
@@ -397,6 +402,7 @@ pub async fn apply_privilege_diffs(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ModifyPrivilegesResponse { ) -> ModifyPrivilegesResponse {
let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new(); let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new();
@@ -408,6 +414,7 @@ pub async fn apply_privilege_diffs(
if let Err(err) = validate_db_or_user_request( if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(diff.get_database_name().to_owned()), &DbOrUser::Database(diff.get_database_name().to_owned()),
unix_user, unix_user,
group_denylist,
) )
.map_err(ModifyDatabasePrivilegesError::UserValidationError) .map_err(ModifyDatabasePrivilegesError::UserValidationError)
{ {
@@ -415,9 +422,12 @@ pub async fn apply_privilege_diffs(
continue; continue;
} }
if let Err(err) = if let Err(err) = validate_db_or_user_request(
validate_db_or_user_request(&DbOrUser::User(diff.get_user_name().to_owned()), unix_user) &DbOrUser::User(diff.get_user_name().to_owned()),
.map_err(ModifyDatabasePrivilegesError::UserValidationError) unix_user,
group_denylist,
)
.map_err(ModifyDatabasePrivilegesError::UserValidationError)
{ {
results.insert(key, Err(err)); results.insert(key, Err(err));
continue; continue;

View File

@@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use sqlx::prelude::*; use sqlx::prelude::*;
use crate::core::protocol::request_validation::GroupDenylist;
use crate::core::protocol::request_validation::validate_db_or_user_request; use crate::core::protocol::request_validation::validate_db_or_user_request;
use crate::core::types::DbOrUser; use crate::core::types::DbOrUser;
use crate::{ use crate::{
@@ -58,6 +59,7 @@ pub async fn complete_user_name(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> Vec<MySQLUser> { ) -> Vec<MySQLUser> {
let result = sqlx::query( let result = sqlx::query(
r#" r#"
@@ -67,7 +69,7 @@ pub async fn complete_user_name(
AND `User` LIKE ? AND `User` LIKE ?
"#, "#,
) )
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(format!("{}%", user_prefix)) .bind(format!("{}%", user_prefix))
.fetch_all(connection) .fetch_all(connection)
.await; .await;
@@ -97,12 +99,14 @@ pub async fn create_database_users(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> CreateUsersResponse { ) -> CreateUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for db_user in db_users { for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) if let Err(err) =
.map_err(CreateUserError::ValidationError) validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(CreateUserError::ValidationError)
{ {
results.insert(db_user, Err(err)); results.insert(db_user, Err(err));
continue; continue;
@@ -141,12 +145,14 @@ pub async fn drop_database_users(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> DropUsersResponse { ) -> DropUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for db_user in db_users { for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) if let Err(err) =
.map_err(DropUserError::ValidationError) validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(DropUserError::ValidationError)
{ {
results.insert(db_user, Err(err)); results.insert(db_user, Err(err));
continue; continue;
@@ -186,8 +192,9 @@ pub async fn set_password_for_database_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> SetUserPasswordResponse { ) -> SetUserPasswordResponse {
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(SetPasswordError::ValidationError)?; .map_err(SetPasswordError::ValidationError)?;
match unsafe_user_exists(db_user, &mut *connection).await { match unsafe_user_exists(db_user, &mut *connection).await {
@@ -269,12 +276,14 @@ pub async fn lock_database_users(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> LockUsersResponse { ) -> LockUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for db_user in db_users { for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) if let Err(err) =
.map_err(LockUserError::ValidationError) validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(LockUserError::ValidationError)
{ {
results.insert(db_user, Err(err)); results.insert(db_user, Err(err));
continue; continue;
@@ -327,12 +336,14 @@ pub async fn unlock_database_users(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> UnlockUsersResponse { ) -> UnlockUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for db_user in db_users { for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) if let Err(err) =
.map_err(UnlockUserError::ValidationError) validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(UnlockUserError::ValidationError)
{ {
results.insert(db_user, Err(err)); results.insert(db_user, Err(err));
continue; continue;
@@ -433,12 +444,14 @@ pub async fn list_database_users(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListUsersResponse { ) -> ListUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
for db_user in db_users { for db_user in db_users {
if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user) if let Err(err) =
.map_err(ListUsersError::ValidationError) validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(ListUsersError::ValidationError)
{ {
results.insert(db_user, Err(err)); results.insert(db_user, Err(err));
continue; continue;
@@ -477,6 +490,7 @@ pub async fn list_all_database_users_for_unix_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist,
) -> ListAllUsersResponse { ) -> ListAllUsersResponse {
let mut result = sqlx::query_as::<_, DatabaseUser>( let mut result = sqlx::query_as::<_, DatabaseUser>(
&(if db_is_mariadb { &(if db_is_mariadb {
@@ -485,7 +499,7 @@ pub async fn list_all_database_users_for_unix_user(
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ?"), } + "WHERE `user`.`User` REGEXP ?"),
) )
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await .await
.map_err(|err| ListAllUsersError::MySqlError(err.to_string())); .map_err(|err| ListAllUsersError::MySqlError(err.to_string()));

View File

@@ -17,9 +17,13 @@ use tokio::{
}; };
use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tokio_util::{sync::CancellationToken, task::TaskTracker};
use crate::server::{ use crate::{
config::{MysqlConfig, ServerConfig}, core::protocol::request_validation::GroupDenylist,
session_handler::session_handler, server::{
authorization::read_and_parse_group_denylist,
config::{MysqlConfig, ServerConfig},
session_handler::session_handler,
},
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -36,6 +40,7 @@ pub struct ReloadEvent;
pub struct Supervisor { pub struct Supervisor {
config_path: PathBuf, config_path: PathBuf,
config: Arc<Mutex<ServerConfig>>, config: Arc<Mutex<ServerConfig>>,
group_deny_list: Arc<RwLock<GroupDenylist>>,
systemd_mode: bool, systemd_mode: bool,
shutdown_cancel_token: CancellationToken, shutdown_cancel_token: CancellationToken,
@@ -66,6 +71,23 @@ impl Supervisor {
let config = ServerConfig::read_config_from_path(&config_path) let config = ServerConfig::read_config_from_path(&config_path)
.context("Failed to read server configuration")?; .context("Failed to read server configuration")?;
let group_deny_list = match &config.authorization.group_denylist_file {
Some(denylist_path) => {
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
Arc::new(RwLock::new(denylist))
}
None => {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
Arc::new(RwLock::new(GroupDenylist::new()))
}
};
let mut watchdog_duration = None; let mut watchdog_duration = None;
let mut watchdog_micro_seconds = 0; let mut watchdog_micro_seconds = 0;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
@@ -148,12 +170,14 @@ impl Supervisor {
db_connection_pool.clone(), db_connection_pool.clone(),
rx, rx,
db_is_mariadb.clone(), db_is_mariadb.clone(),
group_deny_list.clone(),
)) ))
}; };
Ok(Self { Ok(Self {
config_path, config_path,
config: Arc::new(Mutex::new(config)), config: Arc::new(Mutex::new(config)),
group_deny_list,
systemd_mode, systemd_mode,
reload_message_receiver: reload_rx, reload_message_receiver: reload_rx,
shutdown_cancel_token, shutdown_cancel_token,
@@ -196,6 +220,26 @@ impl Supervisor {
.context("Failed to read server configuration")?; .context("Failed to read server configuration")?;
let mut config = self.config.clone().lock_owned().await; let mut config = self.config.clone().lock_owned().await;
*config = new_config; *config = new_config;
let group_deny_list = match &config.authorization.group_denylist_file {
Some(denylist_path) => {
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
denylist
}
None => {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
GroupDenylist::new()
}
};
let mut group_deny_list_lock = self.group_deny_list.write().await;
*group_deny_list_lock = group_deny_list;
Ok(()) Ok(())
} }
@@ -502,6 +546,7 @@ async fn listener_task(
db_pool: Arc<RwLock<MySqlPool>>, db_pool: Arc<RwLock<MySqlPool>>,
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>, mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
db_is_mariadb: Arc<RwLock<bool>>, db_is_mariadb: Arc<RwLock<bool>>,
group_denylist: Arc<RwLock<GroupDenylist>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
@@ -539,8 +584,14 @@ async fn listener_task(
let db_pool_clone = db_pool.clone(); let db_pool_clone = db_pool.clone();
let db_is_mariadb_clone = *db_is_mariadb.read().await; let db_is_mariadb_clone = *db_is_mariadb.read().await;
let group_denylist_arc_clone = group_denylist.clone();
task_tracker.spawn(async move { task_tracker.spawn(async move {
match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await { match session_handler(
conn,
db_pool_clone,
db_is_mariadb_clone,
&*group_denylist_arc_clone.read().await,
).await {
Ok(()) => {} Ok(()) => {}
Err(e) => { Err(e) => {
tracing::error!("Failed to run server: {}", e); tracing::error!("Failed to run server: {}", e);