1 Commits

Author SHA1 Message Date
77ad080f83 passwd-user: allow clearing, allow setting expiry 2025-12-23 15:14:13 +09:00
24 changed files with 450 additions and 847 deletions

40
Cargo.lock generated
View File

@@ -285,8 +285,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
dependencies = [
"iana-time-zone",
"js-sys",
"num-traits",
"serde",
"wasm-bindgen",
"windows-link",
]
@@ -353,16 +355,6 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "clap_mangen"
version = "0.2.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ea63a92086df93893164221ad4f24142086d535b3a0957b9b9bea2dc86301"
dependencies = [
"clap",
"roff",
]
[[package]]
name = "color-print"
version = "0.3.7"
@@ -705,7 +697,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -1137,7 +1129,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -1344,10 +1336,10 @@ dependencies = [
"async-bincode",
"bincode 2.0.1",
"build-info-build",
"chrono",
"clap",
"clap-verbosity-flag",
"clap_complete",
"clap_mangen",
"color-print",
"const_format",
"derive_more",
@@ -1780,12 +1772,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "roff"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3"
[[package]]
name = "rsa"
version = "0.9.9"
@@ -1825,7 +1811,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -1931,16 +1917,16 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.148"
version = "1.0.146"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da"
checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8"
dependencies = [
"indexmap",
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
"zmij",
]
[[package]]
@@ -2321,7 +2307,7 @@ dependencies = [
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -3236,12 +3222,6 @@ dependencies = [
"syn",
]
[[package]]
name = "zmij"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f4a4e8e9dc5c62d159f04fcdbe07f4c3fb710415aab4754bf11505501e3251d"
[[package]]
name = "zstd"
version = "0.13.3"

View File

@@ -22,10 +22,10 @@ autolib = false
anyhow = "1.0.100"
async-bincode = "0.8.0"
bincode = "2.0.1"
chrono = { version = "0.4.42", features = ["serde"] }
clap = { version = "4.5.53", features = ["cargo", "derive"] }
clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] }
clap_complete = { version = "4.5.62", features = ["unstable-dynamic"] }
clap_mangen = "0.2.31"
color-print = "0.3.7"
const_format = "0.2.35"
derive_more = { version = "2.1.1", features = ["display", "error"] }
@@ -39,7 +39,7 @@ num_cpus = "1.17.0"
prettytable = "0.10.0"
rand = "0.9.2"
serde = "1.0.228"
serde_json = { version = "1.0.148", features = ["preserve_order"] }
serde_json = { version = "1.0.146", features = ["preserve_order"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
thiserror = "2.0.17"
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros", "signal"] }
@@ -77,12 +77,12 @@ path = "src/lib.rs"
[[bin]]
name = "muscl"
bench = false
path = "src/bin/muscl.rs"
path = "src/entrypoints/muscl.rs"
[[bin]]
name = "muscl-server"
bench = false
path = "src/bin/muscl_server.rs"
path = "src/entrypoints/muscl_server.rs"
[profile.release-lto]
inherits = "release"

View File

@@ -7,7 +7,7 @@ Dropping DBs (dumbbells) and having MySQL spasms since 2024
## What is this?
`muscl` is a secure MySQL administration tool for multi-user systems.
`muscl is a secure MySQL administration tool for multi-user systems.
It allows unprivileged users to manage their own databases and database users without granting them direct access to the MySQL server.
Authorization is handled by a prefix-based model tied to Unix users and groups, making it ideal for shared hosting environments, like university servers, tilde servers, or similar.
@@ -53,12 +53,3 @@ over a IPC, which then performs the requested operations on behalf of the client
- [Compatibility mode with mysql-admutils](docs/mysql-admutils-compatibility.md)
- [Use with NixOS](docs/nixos.md)
- [SUID/SGID mode](docs/suid-sgid-mode.md)
## History
This is a rewrite of an older piece of software called [mysql-admutils](https://git.pvv.ntnu.no/Projects/mysql-admutils).
Programvareverkstedet used this a lot back in the day, and it was great.
But it had some security issues inherent to the software design, particularly related to the use of SUID/SGID.
We tried patching it multiple times, but the issue kept popping up again in different ways.
The rewrite was intended to iron this issue out completely by splitting the software into two pieces - a client and a server.
As far as we know, this was successful, and it is unlikely for similar issues to resurface in the future.

View File

@@ -1,7 +1,6 @@
# These are the default system groups on debian.
# You can alos add groups by gid by prefixing the line with 'gid:'.
group:_ssh
group:adm
group:audio
group:avahi
@@ -13,7 +12,6 @@ group:daemon
group:dialout
group:dip
group:disk
group:docker
group:fax
group:floppy
group:games
@@ -24,12 +22,9 @@ group:kmem
group:kvm
group:list
group:lp
group:lxd
group:mail
group:man
group:messagebus
group:mlocate
group:mysql
group:netdev
group:news
group:nogroup
@@ -47,18 +42,15 @@ group:src
group:staff
group:sudo
group:sys
group:syslog
group:systemd-journal
group:systemd-network
group:systemd-resolve
group:systemd-timesync
group:tape
group:tcpdump
group:tty
group:users
group:utmp
group:uucp
group:uuidd
group:video
group:voice
group:www-data

18
flake.lock generated
View File

@@ -2,11 +2,11 @@
"nodes": {
"crane": {
"locked": {
"lastModified": 1766774972,
"narHash": "sha256-8qxEFpj4dVmIuPn9j9z6NTbU+hrcGjBOvaxTzre5HmM=",
"lastModified": 1766194365,
"narHash": "sha256-4AFsUZ0kl6MXSm4BaQgItD0VGlEKR3iq7gIaL7TjBvc=",
"owner": "ipetkov",
"repo": "crane",
"rev": "01bc1d404a51a0a07e9d8759cd50a7903e218c82",
"rev": "7d8ec2c71771937ab99790b45e6d9b93d15d9379",
"type": "github"
},
"original": {
@@ -17,11 +17,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1766902085,
"narHash": "sha256-coBu0ONtFzlwwVBzmjacUQwj3G+lybcZ1oeNSQkgC0M=",
"lastModified": 1766309749,
"narHash": "sha256-3xY8CZ4rSnQ0NqGhMKAy5vgC+2IVK0NoVEzDoOh4DA4=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "c0b0e0fddf73fd517c3471e546c0df87a42d53f4",
"rev": "a6531044f6d0bef691ea18d4d4ce44d0daa6e816",
"type": "github"
},
"original": {
@@ -45,11 +45,11 @@
]
},
"locked": {
"lastModified": 1766976750,
"narHash": "sha256-w+o3AIBI56tzfMJRqRXg9tSXnpQRN5hAT15o2t9rxYw=",
"lastModified": 1766457837,
"narHash": "sha256-aeBbkQ0HPFNOIsUeEsXmZHXbYq4bG8ipT9JRlCcKHgU=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "9fe44e7f05b734a64a01f92fc51ad064fb0a884f",
"rev": "2c7510a559416d07242621d036847152d970612b",
"type": "github"
},
"original": {

View File

@@ -42,9 +42,9 @@ in
authorization = {
group_denylist = lib.mkOption {
type = with lib.types; nullOr (listOf (either str ints.unsigned));
type = with lib.types; nullOr (listOf str);
default = [ "wheel" ];
description = "List of groups/GIDs that can not be used as prefixes for databases/database users";
description = "List of groups that are denied access";
};
};
@@ -110,32 +110,7 @@ in
];
environment.etc."muscl/group-denylist" = lib.mkIf (cfg.settings.authorization.group_denylist != [ ]) {
text = let
nameToGidMapping = lib.pipe config.users.groups [
(lib.filterAttrs (_: group: group.gid != null))
(lib.mapAttrsToList (name: group: { name = name; value = group.gid; }))
lib.listToAttrs
];
gidToNameMapping = lib.pipe config.users.groups [
(lib.filterAttrs (_: group: group.gid != null))
(lib.mapAttrsToList (name: group: { name = toString group.gid; value = name; }))
lib.listToAttrs
];
in lib.pipe cfg.settings.authorization.group_denylist [
# Prefer GIDs for groups we know the GID
(map (group: if builtins.isString group
then (nameToGidMapping.${group} or group)
else group))
# Then render back to strings
(map (group:
if builtins.isString group
then "group:${group}"
else "gid:${toString group} # ${gidToNameMapping.${toString group} or "unknown"}"))
(lib.concatStringsSep "\n")
];
text = lib.concatMapStringsSep "\n" (group: "group:${group}") cfg.settings.authorization.group_denylist;
};
services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [

View File

@@ -1,5 +1,3 @@
use std::io::IsTerminal;
use clap::Parser;
use clap_complete::ArgValueCompleter;
use dialoguer::Confirm;
@@ -8,15 +6,16 @@ use tokio_stream::StreamExt;
use crate::{
client::commands::{
erroneous_server_response, print_authorization_owner_hint,
read_password_from_stdin_with_double_check,
erroneous_server_response, interactive_password_dialogue_with_double_check,
interactive_password_expiry_dialogue, print_authorization_owner_hint,
},
core::{
completion::prefix_completer,
protocol::{
ClientToServerMessageStream, CreateUserError, Request, Response,
print_create_users_output_status, print_create_users_output_status_json,
print_set_password_output_status, request_validation::ValidationError,
SetUserPasswordRequest, print_create_users_output_status,
print_create_users_output_status_json, print_set_password_output_status,
request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -80,15 +79,6 @@ pub async fn create_users(
.filter_map(|(username, result)| result.as_ref().ok().map(|()| username))
.collect::<Vec<_>>();
if !std::io::stdin().is_terminal()
&& !args.no_password
&& !successfully_created_users.is_empty()
{
anyhow::bail!(
"Cannot prompt for passwords in non-interactive mode. Use --no-password to skip setting passwords."
);
}
for username in successfully_created_users {
if !args.no_password
&& Confirm::new()
@@ -98,8 +88,14 @@ pub async fn create_users(
.default(false)
.interact()?
{
let password = read_password_from_stdin_with_double_check(username)?;
let message = Request::PasswdUser((username.to_owned(), password));
let password = interactive_password_dialogue_with_double_check(username)?;
let expiry = interactive_password_expiry_dialogue(username)?;
let message = Request::PasswdUser(SetUserPasswordRequest {
user: username.clone(),
new_password: Some(password),
expiry: expiry,
});
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();

View File

@@ -1,5 +1,3 @@
use std::io::IsTerminal;
use clap::Parser;
use clap_complete::ArgValueCompleter;
use dialoguer::Confirm;
@@ -43,12 +41,6 @@ pub async fn drop_databases(
anyhow::bail!("No database names provided");
}
if !std::io::stdin().is_terminal() && !args.yes {
anyhow::bail!(
"Cannot prompt for confirmation in non-interactive mode. Use --yes to automatically confirm."
);
}
if !args.yes {
let confirmation = Confirm::new()
.with_prompt(format!(
@@ -61,6 +53,7 @@ pub async fn drop_databases(
))
.interact()?;
//
if !confirmation {
// TODO: should we return with an error code here?
println!("Aborting drop operation.");

View File

@@ -1,5 +1,3 @@
use std::io::IsTerminal;
use clap::Parser;
use clap_complete::ArgValueCompleter;
use dialoguer::Confirm;
@@ -43,12 +41,6 @@ pub async fn drop_users(
anyhow::bail!("No usernames provided");
}
if !std::io::stdin().is_terminal() && !args.yes {
anyhow::bail!(
"Cannot prompt for confirmation in non-interactive mode. Use --yes to automatically confirm."
);
}
if !args.yes {
let confirmation = Confirm::new()
.with_prompt(format!(

View File

@@ -1,7 +1,4 @@
use std::{
collections::{BTreeMap, BTreeSet},
io::IsTerminal,
};
use std::collections::{BTreeMap, BTreeSet};
use anyhow::Context;
use clap::{Args, Parser};
@@ -216,11 +213,6 @@ pub async fn edit_database_privileges(
};
let diffs: BTreeSet<DatabasePrivilegesDiff> = if privs.is_empty() {
if !std::io::stdin().is_terminal() {
anyhow::bail!(
"Cannot launch editor in non-interactive mode. Please provide privileges via command line arguments."
);
}
let privileges_to_change =
edit_privileges_with_editor(&existing_privilege_rows, use_database.as_ref())?;
diff_privileges(&existing_privilege_rows, &privileges_to_change)
@@ -283,8 +275,7 @@ pub async fn edit_database_privileges(
println!("The following changes will be made:\n");
println!("{}", display_privilege_diffs(&diffs));
if std::io::stdin().is_terminal()
&& !args.yes
if !args.yes
&& !Confirm::new()
.with_prompt("Do you want to apply these changes?")
.default(false)

View File

@@ -1,4 +1,4 @@
use std::{io::IsTerminal, path::PathBuf};
use std::path::PathBuf;
use anyhow::Context;
use clap::Parser;
@@ -13,7 +13,8 @@ use crate::{
completion::mysql_user_completer,
protocol::{
ClientToServerMessageStream, ListUsersError, Request, Response, SetPasswordError,
print_set_password_output_status, request_validation::ValidationError,
SetUserPasswordRequest, print_set_password_output_status,
request_validation::ValidationError,
},
types::MySQLUser,
},
@@ -37,9 +38,21 @@ pub struct PasswdUserArgs {
/// Print the information as JSON
#[arg(short, long)]
json: bool,
/// Set the password to expire on the given date (YYYY-MM-DD)
#[arg(short, long, value_name = "DATE", conflicts_with = "no-expire")]
expire_on: Option<chrono::NaiveDate>,
/// Set the password to never expire
#[arg(short, long, conflicts_with = "expire_on")]
no_expire: bool,
/// Clear the password for the user instead of setting a new one
#[arg(short, long, conflicts_with_all = &["password_file", "stdin", "expire_on", "no-expire"])]
clear: bool,
}
pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
pub fn interactive_password_dialogue_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
Password::new()
.with_prompt(format!("New MySQL password for user '{username}'"))
.with_confirmation(
@@ -50,6 +63,29 @@ pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyho
.map_err(Into::into)
}
pub fn interactive_password_expiry_dialogue(username: &MySQLUser) -> anyhow::Result<Option<chrono::NaiveDate>> {
let input = dialoguer::Input::<String>::new()
.with_prompt(format!(
"Enter the password expiry date for user '{username}' (YYYY-MM-DD)"
))
.allow_empty(true)
.validate_with(|input: &String| {
chrono::NaiveDate::parse_from_str(input, "%Y-%m-%d")
.map(|_| ())
.map_err(|_| "Invalid date format. Please use YYYY-MM-DD".to_string())
})
.interact_text()?;
if input.trim().is_empty() {
return Ok(None);
}
let date = chrono::NaiveDate::parse_from_str(&input, "%Y-%m-%d")
.map_err(|e| anyhow::anyhow!("Failed to parse date: {}", e))?;
Ok(Some(date))
}
pub async fn passwd_user(
args: PasswdUserArgs,
mut server_connection: ClientToServerMessageStream,
@@ -76,27 +112,38 @@ pub async fn passwd_user(
}
}
let password = if let Some(password_file) = args.password_file {
std::fs::read_to_string(password_file)
.context("Failed to read password file")?
.trim()
.to_string()
let password: Option<String> = if let Some(password_file) = args.password_file {
Some(
std::fs::read_to_string(password_file)
.context("Failed to read password file")?
.trim()
.to_string(),
)
} else if args.stdin {
let mut buffer = String::new();
std::io::stdin()
.read_line(&mut buffer)
.context("Failed to read password from stdin")?;
buffer.trim().to_string()
Some(buffer.trim().to_string())
} else if args.clear {
None
} else {
if !std::io::stdin().is_terminal() {
anyhow::bail!(
"Cannot prompt for password in non-interactive mode. Use --stdin or --password-file to provide the password."
);
}
read_password_from_stdin_with_double_check(&args.username)?
Some(interactive_password_dialogue_with_double_check(&args.username)?)
};
let message = Request::PasswdUser((args.username.clone(), password));
let expiry_date = if args.no_expire {
None
} else if let Some(date) = args.expire_on {
Some(date)
} else {
interactive_password_expiry_dialogue(&args.username)?
};
let message = Request::PasswdUser(SetUserPasswordRequest {
user: args.username.clone(),
new_password: password,
expiry: expiry_date,
});
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();

View File

@@ -8,7 +8,7 @@ use tokio::net::UnixStream as TokioUnixStream;
use crate::{
client::{
commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
commands::{erroneous_server_response, interactive_password_dialogue_with_double_check},
mysql_admutils_compatibility::{
common::trim_user_name_to_32_chars,
error_messages::{
@@ -20,7 +20,7 @@ use crate::{
bootstrap::bootstrap_server_connection_and_drop_privileges,
completion::{mysql_user_completer, prefix_completer},
protocol::{
ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream,
ClientToServerMessageStream, Request, Response, SetUserPasswordRequest, create_client_to_server_message_stream
},
types::MySQLUser,
},
@@ -252,8 +252,12 @@ async fn passwd_users(
.collect::<Vec<_>>();
for user in users {
let password = read_password_from_stdin_with_double_check(&user.user)?;
let message = Request::PasswdUser((user.user.clone(), password));
let password = interactive_password_dialogue_with_double_check(&user.user)?;
let message = Request::PasswdUser(SetUserPasswordRequest {
user: user.user.clone(),
new_password: Some(password),
expiry: None,
});
server_connection.send(message).await?;
match server_connection.next().await {
Some(Ok(Response::SetUserPassword(result))) => match result {

View File

@@ -22,7 +22,7 @@ use crate::{
authorization::read_and_parse_group_denylist,
config::{MysqlConfig, ServerConfig},
landlock::landlock_restrict_server,
session_handler::{self, SessionId},
session_handler,
},
};
@@ -308,11 +308,9 @@ fn run_forked_server(
version_row.to_lowercase().contains("mariadb")
};
let session_id = SessionId::new(0);
let db_pool = Arc::new(RwLock::new(db_pool));
session_handler::session_handler_with_unix_user(
socket,
session_id,
unix_user,
db_pool,
db_is_mariadb,

View File

@@ -24,7 +24,6 @@ pub const KIND_REGARDS: &str = concat!(
"If you experience any bugs or turbulence, please give us a heads up :)",
);
/// TODO: store and display UID
#[derive(Debug, Clone)]
pub struct UnixUser {
pub username: String,

View File

@@ -36,16 +36,11 @@ pub use modify_privileges::*;
pub use passwd_user::*;
pub use unlock_users::*;
use std::collections::BTreeSet;
use std::fmt;
use serde::{Deserialize, Serialize};
use tokio::net::UnixStream;
use tokio_serde::{Framed as SerdeFramed, formats::Bincode};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use crate::core::types::{MySQLDatabase, MySQLUser};
pub type ServerToClientMessageStream = SerdeFramed<
Framed<UnixStream, LengthDelimitedCodec>,
Request,
@@ -109,124 +104,6 @@ pub enum Request {
Exit,
}
impl Request {
/// Get the command name associated with this request.
pub fn command_name(&self) -> &str {
match self {
Request::CheckAuthorization(_) => "check-authorization",
Request::ListValidNamePrefixes => "list-valid-name-prefixes",
Request::CompleteDatabaseName(_) => "complete-database-name",
Request::CompleteUserName(_) => "complete-user-name",
Request::CreateDatabases(_) => "create-databases",
Request::DropDatabases(_) => "drop-databases",
Request::ListDatabases(_) => "list-databases",
Request::ListPrivileges(_) => "list-privileges",
Request::ModifyPrivileges(_) => "modify-privileges",
Request::CreateUsers(_) => "create-users",
Request::DropUsers(_) => "drop-users",
Request::PasswdUser(_) => "passwd-user",
Request::ListUsers(_) => "list-users",
Request::LockUsers(_) => "lock-users",
Request::UnlockUsers(_) => "unlock-users",
Request::Exit => "exit",
}
}
/// Generate a short summary string representing this request for logging purposes.
pub fn log_summary(&self) -> String {
match self {
Request::CheckAuthorization(req) => format!("{}({})", self.command_name(), req.len()),
Request::CreateDatabases(req) => format!("{}({})", self.command_name(), req.len()),
Request::DropDatabases(req) => format!("{}({})", self.command_name(), req.len()),
Request::ListDatabases(req) => format!(
"{}{}",
self.command_name(),
req.as_ref()
.map_or("".to_string(), |r| format!("({})", r.len()))
),
Request::ListPrivileges(req) => format!(
"{}{}",
self.command_name(),
req.as_ref()
.map_or("".to_string(), |r| format!("({})", r.len()))
),
Request::ModifyPrivileges(req) => format!("{}({})", self.command_name(), req.len()),
Request::CreateUsers(req) => format!("{}({})", self.command_name(), req.len()),
Request::DropUsers(req) => format!("{}({})", self.command_name(), req.len()),
Request::ListUsers(req) => format!(
"{}{}",
self.command_name(),
req.as_ref()
.map_or("".to_string(), |r| format!("({})", r.len()))
),
Request::LockUsers(req) => format!("{}({})", self.command_name(), req.len()),
Request::UnlockUsers(req) => format!("{}({})", self.command_name(), req.len()),
_ => self.command_name().to_string(),
}
}
/// Get the set of users affected by this request.
pub fn affected_users(&self) -> BTreeSet<MySQLUser> {
match self {
Request::CheckAuthorization(_) => Default::default(),
Request::ListValidNamePrefixes => Default::default(),
Request::CompleteDatabaseName(_) => Default::default(),
Request::CompleteUserName(_) => Default::default(),
Request::CreateDatabases(_) => Default::default(),
Request::DropDatabases(_) => Default::default(),
Request::ListDatabases(_) => Default::default(),
Request::ListPrivileges(_) => Default::default(),
Request::ModifyPrivileges(priv_diffs) => priv_diffs
.iter()
.map(|priv_diff| priv_diff.get_user_name().clone())
.collect(),
Request::CreateUsers(users) => users.iter().cloned().collect(),
Request::DropUsers(users) => users.iter().cloned().collect(),
Request::PasswdUser(user_passwd_req) => {
let mut result = BTreeSet::new();
result.insert(user_passwd_req.0.clone());
result
}
Request::ListUsers(users) => users.clone().unwrap_or_default().into_iter().collect(),
Request::LockUsers(users) => users.iter().cloned().collect(),
Request::UnlockUsers(users) => users.iter().cloned().collect(),
Request::Exit => Default::default(),
}
}
/// Get the set of databases affected by this request.
pub fn affected_databases(&self) -> BTreeSet<MySQLDatabase> {
match self {
Request::CheckAuthorization(_) => Default::default(),
Request::ListValidNamePrefixes => Default::default(),
Request::CompleteDatabaseName(_) => Default::default(),
Request::CompleteUserName(_) => Default::default(),
Request::CreateDatabases(databases) => databases.iter().cloned().collect(),
Request::DropDatabases(databases) => databases.iter().cloned().collect(),
Request::ListDatabases(databases) => {
databases.clone().unwrap_or_default().into_iter().collect()
}
Request::ListPrivileges(databases) => {
databases.clone().unwrap_or_default().into_iter().collect()
}
Request::ModifyPrivileges(priv_diffs) => priv_diffs
.iter()
.map(|priv_diff| priv_diff.get_database_name().clone())
.collect(),
Request::CreateUsers(_) => Default::default(),
Request::DropUsers(_) => Default::default(),
Request::PasswdUser(_) => Default::default(),
Request::ListUsers(_) => Default::default(),
Request::LockUsers(_) => Default::default(),
Request::UnlockUsers(_) => Default::default(),
Request::Exit => Default::default(),
}
}
}
// TODO: include a generic "message" that will display a message to the user?
#[non_exhaustive]
@@ -259,95 +136,3 @@ pub enum Response {
Ready,
Error(String),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ResponseOkStatus {
Success,
PartialSuccess(usize, usize), // succeeded, total
Error,
}
impl ResponseOkStatus {
pub fn from_counts(total: usize, succeeded: usize) -> Self {
if succeeded == total {
ResponseOkStatus::Success
} else if succeeded == 0 {
ResponseOkStatus::Error
} else {
ResponseOkStatus::PartialSuccess(succeeded, total)
}
}
pub fn from_bool(is_ok: bool) -> Self {
if is_ok {
ResponseOkStatus::Success
} else {
ResponseOkStatus::Error
}
}
}
impl fmt::Display for ResponseOkStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ResponseOkStatus::Success => write!(f, "OK"),
ResponseOkStatus::PartialSuccess(succeeded, total) => {
write!(f, "PARTIAL_OK({}/{})", succeeded, total)
}
ResponseOkStatus::Error => write!(f, "ERR"),
}
}
}
impl Response {
pub fn ok_status(&self) -> ResponseOkStatus {
match self {
Response::CheckAuthorization(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ListValidNamePrefixes(_) => ResponseOkStatus::Success,
Response::CompleteDatabaseName(_) => ResponseOkStatus::Success,
Response::CompleteUserName(_) => ResponseOkStatus::Success,
Response::CreateDatabases(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::DropDatabases(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ListDatabases(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ListAllDatabases(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ListPrivileges(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ModifyPrivileges(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::CreateUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::DropUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::SetUserPassword(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ListUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ListAllUsers(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::LockUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::UnlockUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::Ready => ResponseOkStatus::Success,
Response::Error(_) => ResponseOkStatus::Error,
}
}
}

View File

@@ -6,7 +6,12 @@ use crate::core::{
types::{DbOrUser, MySQLUser},
};
pub type SetUserPasswordRequest = (MySQLUser, String);
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SetUserPasswordRequest {
pub user: MySQLUser,
pub new_password: Option<String>,
pub expiry: Option<chrono::NaiveDate>,
}
pub type SetUserPasswordResponse = Result<(), SetPasswordError>;
@@ -18,6 +23,9 @@ pub enum SetPasswordError {
#[error("User does not exist")]
UserDoesNotExist,
#[error("Cannot clear password with an expiry date set")]
ClearPasswordWithExpiry,
#[error("MySQL error: {0}")]
MySqlError(String),
}
@@ -44,6 +52,9 @@ impl SetPasswordError {
SetPasswordError::UserDoesNotExist => {
format!("User '{username}' does not exist.")
}
SetPasswordError::ClearPasswordWithExpiry => {
format!("Cannot clear password for user '{username}' when an expiry date is set.")
}
SetPasswordError::MySqlError(err) => {
format!("MySQL error: {err}")
}
@@ -56,6 +67,7 @@ impl SetPasswordError {
match self {
SetPasswordError::ValidationError(err) => err.error_type(),
SetPasswordError::UserDoesNotExist => "user-does-not-exist".to_string(),
SetPasswordError::ClearPasswordWithExpiry => "clear-password-with-expiry".to_string(),
SetPasswordError::MySqlError(_) => "mysql-error".to_string(),
}
}

View File

@@ -82,11 +82,9 @@ const EXAMPLES: &str = const_format::concatcp!(
# Show all databases
muscl show-db
muscl sd
# Show which users have privileges on which databases
muscl show-privs
muscl sp
"#,
);
@@ -171,27 +169,22 @@ const EDIT_PRIVS_EXAMPLES: &str = color_print::cstr!(
#[command(subcommand_required = true)]
pub enum ClientCommand {
/// Check whether you are authorized to manage the specified databases or users.
#[command(alias = "ca")]
CheckAuth(CheckAuthArgs),
/// Create one or more databases
#[command(alias = "cd")]
CreateDb(CreateDbArgs),
/// Delete one or more databases
#[command(alias = "dd")]
DropDb(DropDbArgs),
/// Print information about one or more databases
///
/// If no database name is provided, all databases you have access will be shown.
#[command(alias = "sd")]
ShowDb(ShowDbArgs),
/// Print user privileges for one or more databases
///
/// If no database names are provided, all databases you have access to will be shown.
#[command(alias = "sp")]
ShowPrivs(ShowPrivsArgs),
/// Change user privileges for one or more databases. See `edit-privs --help` for details.
@@ -246,34 +239,27 @@ pub enum ClientCommand {
verbatim_doc_comment,
override_usage = "muscl edit-privs [OPTIONS] [ -p <DB_NAME:USER_NAME:[+-]PRIVILEGES>... | <DB_NAME> <USER_NAME> <[+-]PRIVILEGES> ]",
after_long_help = EDIT_PRIVS_EXAMPLES,
alias = "ep",
)]
EditPrivs(EditPrivsArgs),
/// Create one or more users
#[command(alias = "cu")]
CreateUser(CreateUserArgs),
/// Delete one or more users
#[command(alias = "du")]
DropUser(DropUserArgs),
/// Change the MySQL password for a user
#[command(alias = "pu")]
PasswdUser(PasswdUserArgs),
/// Print information about one or more users
///
/// If no username is provided, all users you have access will be shown.
#[command(alias = "su")]
ShowUser(ShowUserArgs),
/// Lock account for one or more users
#[command(alias = "lu")]
LockUser(LockUserArgs),
/// Unlock account for one or more users
#[command(alias = "uu")]
UnlockUser(UnlockUserArgs),
}
@@ -305,10 +291,6 @@ fn main() -> anyhow::Result<()> {
return Ok(());
}
if handle_manpage_command()?.is_some() {
return Ok(());
}
#[cfg(feature = "mysql-admutils-compatibility")]
if handle_mysql_admutils_command()?.is_some() {
return Ok(());
@@ -365,48 +347,6 @@ fn handle_dynamic_completion() -> anyhow::Result<Option<()>> {
}
}
/// **WARNING:** This function may be run with elevated privileges.
fn handle_manpage_command() -> anyhow::Result<Option<()>> {
let argv1: Option<String> = std::env::args().nth(1);
match argv1.as_deref() {
Some("generate-manpages") => {
#[cfg(feature = "suid-sgid-mode")]
if executing_in_suid_sgid_mode()? {
use muscl_lib::core::bootstrap::drop_privs;
drop_privs()?
}
let output_dir = std::env::args().nth(2).ok_or(anyhow::anyhow!(
"Output directory argument missing for manpage generation"
))?;
let output_dir = PathBuf::from(&output_dir);
if !output_dir.is_dir() {
anyhow::bail!(
"Output directory `{:?}` does not exist or is not a directory",
output_dir,
);
}
let mut roff = clap_mangen::roff::Roff::new();
let man = clap_mangen::Man::new(Args::command());
man.render_title(&mut std::io::stdout())?;
man.render_name_section(&mut std::io::stdout())?;
man.render_synopsis_section(&mut std::io::stdout())?;
man.render_subcommands_section(&mut std::io::stdout())?;
man.render_options_section(&mut std::io::stdout())?;
roff.control("SH", ["VERSION"]);
roff.text([clap_mangen::roff::roman(AFTER_LONG_HELP)]);
roff.to_writer(&mut std::io::stdout())?;
Ok(Some(()))
}
_ => Ok(None),
}
}
/// **WARNING:** This function may be run with elevated privileges.
fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
let argv0 = std::env::args().next().and_then(|s| {

View File

@@ -1,4 +1,4 @@
use std::{collections::HashSet, path::Path, str::Lines};
use std::{collections::HashSet, path::Path};
use anyhow::Context;
use nix::unistd::Group;
@@ -13,19 +13,23 @@ use crate::core::{
};
pub async fn check_authorization(
dbs_or_users: &[DbOrUser],
dbs_or_users: Vec<DbOrUser>,
unix_user: &UnixUser,
group_denylist: &GroupDenylist,
) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> {
dbs_or_users
.iter()
.cloned()
.map(|db_or_user| {
let result = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
.map_err(CheckAuthorizationError);
(db_or_user, result)
})
.collect()
let mut results = std::collections::BTreeMap::new();
for db_or_user in dbs_or_users {
if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
.map_err(CheckAuthorizationError)
{
results.insert(db_or_user.clone(), Err(err));
continue;
}
results.insert(db_or_user.clone(), Ok(()));
}
results
}
/// Reads and parses a group denylist file, returning a set of GUIDs
@@ -41,23 +45,14 @@ pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<Gro
let content = std::fs::read_to_string(denylist_path)
.context(format!("Failed to read denylist file at {denylist_path:?}"))?;
let lines = content.lines();
let mut groups = HashSet::with_capacity(content.lines().count());
let groups = parse_group_denylist(denylist_path, lines);
for (line_number, line) in content.lines().enumerate() {
let trimmed_line = line.trim();
Ok(groups)
}
fn parse_group_denylist(denylist_path: &Path, lines: Lines) -> GroupDenylist {
let mut groups = HashSet::<u32>::new();
for (line_number, line) in lines.enumerate() {
let trimmed_line = if let Some(comment_start) = line.find('#') {
&line[..comment_start]
} else {
line
if trimmed_line.is_empty() || trimmed_line.starts_with('#') {
continue;
}
.trim();
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
if parts.len() != 2 {
@@ -146,32 +141,5 @@ fn parse_group_denylist(denylist_path: &Path, lines: Lines) -> GroupDenylist {
}
}
groups
}
#[cfg(test)]
mod tests {
use indoc::indoc;
use super::*;
#[test]
fn test_parse_group_denylist() {
let denylist_content = indoc! {"
# Valid entries
gid:0 # This is usually the 'root' group
group:root # This is also the 'root' group, should deduplicate
# Invalid entries
invalid_line
gid:not_a_number
group:nonexistent_group
"};
let lines = denylist_content.lines();
let group_denylist = parse_group_denylist(Path::new("test_denylist"), lines);
assert_eq!(group_denylist.len(), 1);
assert!(group_denylist.contains(&0));
}
Ok(groups)
}

View File

@@ -1,8 +1,7 @@
use std::sync::Arc;
use std::{collections::BTreeSet, sync::Arc};
use futures_util::{SinkExt, StreamExt};
use indoc::concatdoc;
use itertools::Itertools;
use sqlx::{MySqlConnection, MySqlPool};
use tokio::{net::UnixStream, sync::RwLock};
use tracing::Instrument;
@@ -12,7 +11,8 @@ use crate::{
common::UnixUser,
protocol::{
Request, Response, ServerToClientMessageStream, SetPasswordError,
create_server_to_client_message_stream, request_validation::GroupDenylist,
SetUserPasswordRequest, create_server_to_client_message_stream,
request_validation::GroupDenylist,
},
},
server::{
@@ -35,24 +35,10 @@ use crate::{
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SessionId(u64);
impl SessionId {
pub fn new(id: u64) -> Self {
SessionId(id)
}
pub fn inner(&self) -> u64 {
self.0
}
}
// TODO: don't use database connection unless necessary.
pub async fn session_handler(
socket: UnixStream,
session_id: SessionId,
db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
@@ -76,7 +62,7 @@ pub async fn session_handler(
}
};
tracing::trace!("Validated peer UID: {}", uid);
tracing::debug!("Validated peer UID: {}", uid);
let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user,
@@ -97,18 +83,13 @@ pub async fn session_handler(
}
};
let span = tracing::info_span!(
"user_session",
session_id = session_id.inner(),
user = %unix_user,
);
let span = tracing::info_span!("user_session", user = %unix_user);
(async move {
tracing::debug!("Accepted connection from user: {}", unix_user);
tracing::info!("Accepted connection from user: {}", unix_user);
let result = session_handler_with_unix_user(
socket,
session_id,
&unix_user,
db_pool,
db_is_mariadb,
@@ -116,7 +97,7 @@ pub async fn session_handler(
)
.await;
tracing::debug!(
tracing::info!(
"Finished handling requests for connection from user: {}",
unix_user,
);
@@ -129,7 +110,6 @@ pub async fn session_handler(
pub async fn session_handler_with_unix_user(
socket: UnixStream,
session_id: SessionId,
unix_user: &UnixUser,
db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool,
@@ -137,7 +117,7 @@ pub async fn session_handler_with_unix_user(
) -> anyhow::Result<()> {
let mut message_stream = create_server_to_client_message_stream(socket);
tracing::trace!("Requesting database connection from pool");
tracing::debug!("Requesting database connection from pool");
let mut db_connection = match db_pool.read().await.acquire().await {
Ok(connection) => connection,
Err(err) => {
@@ -154,11 +134,10 @@ pub async fn session_handler_with_unix_user(
return Err(err.into());
}
};
tracing::trace!("Successfully acquired database connection from pool");
tracing::debug!("Successfully acquired database connection from pool");
let result = session_handler_with_db_connection(
message_stream,
session_id,
unix_user,
&mut db_connection,
db_is_mariadb,
@@ -166,7 +145,7 @@ pub async fn session_handler_with_unix_user(
)
.await;
tracing::trace!("Releasing database connection back to pool");
tracing::debug!("Releasing database connection back to pool");
result
}
@@ -176,7 +155,6 @@ pub async fn session_handler_with_unix_user(
async fn session_handler_with_db_connection(
mut stream: ServerToClientMessageStream,
session_id: SessionId,
unix_user: &UnixUser,
db_connection: &mut MySqlConnection,
db_is_mariadb: bool,
@@ -196,234 +174,158 @@ async fn session_handler_with_db_connection(
}
};
let request_span = tracing::info_span!("request", command = request.command_name());
// TODO: don't clone the request
let request_to_display = match &request {
Request::PasswdUser(SetUserPasswordRequest {
user,
new_password,
expiry,
}) => Request::PasswdUser(SetUserPasswordRequest {
user: user.clone(),
new_password: new_password.as_ref().map(|_| "<REDACTED>".to_string()),
expiry: *expiry,
}),
request => request.to_owned(),
};
if !handle_request(
request,
session_id,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
&mut stream,
)
.instrument(request_span)
.await?
{
break;
if request_to_display == Request::Exit {
tracing::debug!("Received request: {:#?}", request_to_display);
} else {
tracing::info!("Received request: {:#?}", request_to_display);
}
}
Ok(())
}
/// Handle a single request from a client.
///
/// If the function returns `true`, the session should continue.
async fn handle_request(
request: Request,
session_id: SessionId,
unix_user: &UnixUser,
db_connection: &mut MySqlConnection,
db_is_mariadb: bool,
group_denylist: &GroupDenylist,
stream: &mut ServerToClientMessageStream,
) -> anyhow::Result<bool> {
match &request {
Request::Exit => tracing::debug!("Request: exit"),
Request::PasswdUser((db_user, _)) => tracing::debug!(
"Request:\n{}",
serde_json::to_string_pretty(&Request::PasswdUser((
db_user.to_owned(),
"<REDACTED>".to_string()
)))?
),
request => tracing::debug!("Request:\n{}", serde_json::to_string_pretty(request)?),
}
let affected_dbs = request.affected_databases();
if !affected_dbs.is_empty() {
tracing::trace!(
"Affected databases: {}",
affected_dbs.into_iter().map(|db| db.to_string()).join(", ")
);
}
let affected_users = request.affected_users();
if !affected_users.is_empty() {
tracing::trace!(
"Affected users: {}",
affected_users.into_iter().map(|u| u.to_string()).join(", "),
);
}
let response = match request {
Request::CheckAuthorization(ref dbs_or_users) => {
let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
Response::CheckAuthorization(result)
}
Request::ListValidNamePrefixes => {
let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
result.push(unix_user.username.clone());
for group in get_user_filtered_groups(unix_user, group_denylist) {
result.push(group.clone());
let response = match request {
Request::CheckAuthorization(dbs_or_users) => {
let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
Response::CheckAuthorization(result)
}
Request::ListValidNamePrefixes => {
let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
result.push(unix_user.username.clone());
Response::ListValidNamePrefixes(result)
}
Request::CompleteDatabaseName(ref partial_database_name) => {
// TODO: more correct validation here
if partial_database_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
let result = complete_database_name(
partial_database_name,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CompleteDatabaseName(result)
} else {
Response::CompleteDatabaseName(vec![])
for group in get_user_filtered_groups(unix_user, group_denylist) {
result.push(group.clone());
}
Response::ListValidNamePrefixes(result)
}
}
Request::CompleteUserName(ref partial_user_name) => {
// TODO: more correct validation here
if partial_user_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
let result = complete_user_name(
partial_user_name,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CompleteUserName(result)
} else {
Response::CompleteUserName(vec![])
Request::CompleteDatabaseName(partial_database_name) => {
// TODO: more correct validation here
if partial_database_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
let result = complete_database_name(
partial_database_name,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CompleteDatabaseName(result)
} else {
Response::CompleteDatabaseName(vec![])
}
}
}
Request::CreateDatabases(ref databases_names) => {
let result = create_databases(
databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CreateDatabases(result)
}
Request::DropDatabases(ref databases_names) => {
let result = drop_databases(
databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropDatabases(result)
}
Request::ListDatabases(ref database_names) => {
if let Some(database_names) = database_names {
let result = list_databases(
database_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListDatabases(result)
} else {
let result = list_all_databases_for_user(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllDatabases(result)
Request::CompleteUserName(partial_user_name) => {
// TODO: more correct validation here
if partial_user_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
let result = complete_user_name(
partial_user_name,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CompleteUserName(result)
} else {
Response::CompleteUserName(vec![])
}
}
}
Request::ListPrivileges(ref database_names) => {
if let Some(database_names) = database_names {
let privilege_data = get_databases_privilege_data(
database_names,
Request::CreateDatabases(databases_names) => {
let result = create_databases(
databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListPrivileges(privilege_data)
} else {
let privilege_data = get_all_database_privileges(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllPrivileges(privilege_data)
Response::CreateDatabases(result)
}
}
Request::ModifyPrivileges(ref database_privilege_diffs) => {
let result = apply_privilege_diffs(
database_privilege_diffs,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ModifyPrivileges(result)
}
Request::CreateUsers(ref db_users) => {
let result = create_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::CreateUsers(result)
}
Request::DropUsers(ref db_users) => {
let result = drop_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropUsers(result)
}
Request::PasswdUser((ref db_user, ref password)) => {
let result = set_password_for_database_user(
db_user,
password,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::SetUserPassword(result)
}
Request::ListUsers(ref db_users) => {
if let Some(db_users) = db_users {
let result = list_database_users(
Request::DropDatabases(databases_names) => {
let result = drop_databases(
databases_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::DropDatabases(result)
}
Request::ListDatabases(database_names) => {
if let Some(database_names) = database_names {
let result = list_databases(
database_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListDatabases(result)
} else {
let result = list_all_databases_for_user(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllDatabases(result)
}
}
Request::ListPrivileges(database_names) => {
if let Some(database_names) = database_names {
let privilege_data = get_databases_privilege_data(
database_names,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListPrivileges(privilege_data)
} else {
let privilege_data = get_all_database_privileges(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllPrivileges(privilege_data)
}
}
Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs(
BTreeSet::from_iter(database_privilege_diffs),
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ModifyPrivileges(result)
}
Request::CreateUsers(db_users) => {
let result = create_database_users(
db_users,
unix_user,
db_connection,
@@ -431,76 +333,99 @@ async fn handle_request(
group_denylist,
)
.await;
Response::ListUsers(result)
} else {
let result = list_all_database_users_for_unix_user(
Response::CreateUsers(result)
}
Request::DropUsers(db_users) => {
let result = drop_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllUsers(result)
Response::DropUsers(result)
}
}
Request::LockUsers(ref db_users) => {
let result = lock_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::LockUsers(result)
}
Request::UnlockUsers(ref db_users) => {
let result = unlock_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::UnlockUsers(result)
}
Request::Exit => {
return Ok(false);
}
};
Request::PasswdUser(SetUserPasswordRequest {
user,
new_password,
expiry,
}) => {
let result = set_password_for_database_user(
&user,
new_password.as_deref(),
expiry,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::SetUserPassword(result)
}
Request::ListUsers(db_users) => {
if let Some(db_users) = db_users {
let result = list_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListUsers(result)
} else {
let result = list_all_database_users_for_unix_user(
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::ListAllUsers(result)
}
}
Request::LockUsers(db_users) => {
let result = lock_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::LockUsers(result)
}
Request::UnlockUsers(db_users) => {
let result = unlock_database_users(
db_users,
unix_user,
db_connection,
db_is_mariadb,
group_denylist,
)
.await;
Response::UnlockUsers(result)
}
Request::Exit => {
break;
}
};
let response_to_display = match &response {
Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
&Response::SetUserPassword(Err(SetPasswordError::MySqlError("<REDACTED>".to_string())))
}
response => response,
};
tracing::debug!(
"Response:\n{}",
serde_json::to_string_pretty(&response_to_display)?
);
let response_to_display = match &response {
Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
&Response::SetUserPassword(Err(SetPasswordError::MySqlError(
"<REDACTED>".to_string(),
)))
}
response => response,
};
tracing::debug!("Response: {:#?}", response_to_display);
log_request(session_id, unix_user, &request, &response);
stream.send(response).await?;
stream.flush().await?;
tracing::debug!("Successfully processed request");
}
stream.send(response).await?;
stream.flush().await?;
tracing::trace!("Successfully processed request");
Ok(true)
}
/// Log a summary of the request and its result.
fn log_request(
session_id: SessionId,
unix_user: &UnixUser,
request: &Request,
response: &Response,
) {
tracing::info!(
"[{}|session:{}|user:{unix_user}] {}",
response.ok_status(),
session_id.inner(),
request.log_summary(),
);
Ok(())
}

View File

@@ -46,7 +46,7 @@ pub(super) async fn unsafe_database_exists(
}
pub async fn complete_database_name(
database_prefix: &str,
database_prefix: String,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -87,7 +87,7 @@ pub async fn complete_database_name(
}
pub async fn create_databases(
database_names: &[MySQLDatabase],
database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -95,7 +95,7 @@ pub async fn create_databases(
) -> CreateDatabasesResponse {
let mut results = BTreeMap::new();
for database_name in database_names.iter().cloned() {
for database_name in database_names {
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
@@ -143,7 +143,7 @@ pub async fn create_databases(
}
pub async fn drop_databases(
database_names: &[MySQLDatabase],
database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -151,7 +151,7 @@ pub async fn drop_databases(
) -> DropDatabasesResponse {
let mut results = BTreeMap::new();
for database_name in database_names.iter().cloned() {
for database_name in database_names {
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
@@ -242,7 +242,7 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
}
pub async fn list_databases(
database_names: &[MySQLDatabase],
database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -250,7 +250,7 @@ pub async fn list_databases(
) -> ListDatabasesResponse {
let mut results = BTreeMap::new();
for database_name in database_names.iter().cloned() {
for database_name in database_names {
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,

View File

@@ -138,7 +138,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
}
pub async fn get_databases_privilege_data(
database_names: &[MySQLDatabase],
database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -146,19 +146,19 @@ pub async fn get_databases_privilege_data(
) -> ListPrivilegesResponse {
let mut results = BTreeMap::new();
for database_name in database_names.iter().cloned() {
for database_name in &database_names {
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.to_owned()),
&DbOrUser::Database(database_name.clone()),
unix_user,
group_denylist,
)
.map_err(ListPrivilegesError::ValidationError)
{
results.insert(database_name, Err(err));
results.insert(database_name.to_owned(), Err(err));
continue;
}
match unsafe_database_exists(&database_name, connection).await {
match unsafe_database_exists(database_name, connection).await {
Ok(false) => {
results.insert(
database_name.to_owned(),
@@ -176,7 +176,7 @@ pub async fn get_databases_privilege_data(
Ok(true) => {}
}
let result = unsafe_get_database_privileges(&database_name, connection)
let result = unsafe_get_database_privileges(database_name, connection)
.await
.map_err(|e| ListPrivilegesError::MySqlError(e.to_string()));
@@ -400,7 +400,7 @@ async fn validate_diff(
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
pub async fn apply_privilege_diffs(
database_privilege_diffs: &BTreeSet<DatabasePrivilegesDiff>,
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -468,12 +468,12 @@ pub async fn apply_privilege_diffs(
Ok(true) => {}
}
if let Err(err) = validate_diff(diff, connection).await {
if let Err(err) = validate_diff(&diff, connection).await {
results.insert(key, Err(err));
continue;
}
let result = unsafe_apply_privilege_diff(diff, connection)
let result = unsafe_apply_privilege_diff(&diff, connection)
.await
.map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));

View File

@@ -55,7 +55,7 @@ pub(super) async fn unsafe_user_exists(
}
pub async fn complete_user_name(
user_prefix: &str,
user_prefix: String,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -95,7 +95,7 @@ pub async fn complete_user_name(
}
pub async fn create_database_users(
db_users: &[MySQLUser],
db_users: Vec<MySQLUser>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -103,7 +103,7 @@ pub async fn create_database_users(
) -> CreateUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users.iter().cloned() {
for db_user in db_users {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(CreateUserError::ValidationError)
@@ -141,7 +141,7 @@ pub async fn create_database_users(
}
pub async fn drop_database_users(
db_users: &[MySQLUser],
db_users: Vec<MySQLUser>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -149,7 +149,7 @@ pub async fn drop_database_users(
) -> DropUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users.iter().cloned() {
for db_user in db_users {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(DropUserError::ValidationError)
@@ -188,7 +188,8 @@ pub async fn drop_database_users(
pub async fn set_password_for_database_user(
db_user: &MySQLUser,
password: &str,
password: Option<&str>,
expiry: Option<chrono::NaiveDate>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
_db_is_mariadb: bool,
@@ -197,24 +198,44 @@ pub async fn set_password_for_database_user(
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(SetPasswordError::ValidationError)?;
if password.is_none() && expiry.is_some() {
return Err(SetPasswordError::ClearPasswordWithExpiry);
}
match unsafe_user_exists(db_user, &mut *connection).await {
Ok(false) => return Err(SetPasswordError::UserDoesNotExist),
Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())),
_ => {}
}
let result = sqlx::query(
format!(
let result = if let Some(password) = password {
let mut query = format!(
"ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user),
quote_literal(password).as_str(),
)
.as_str(),
)
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string()));
);
if let Some(expiry_date) = expiry {
query.push_str(&format!(" PASSWORD EXPIRE DATE '{}'", expiry_date));
}
sqlx::query(query.as_str())
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
} else {
let query = format!(
"ALTER USER {}@'%' IDENTIFIED WITH mysql_native_password AS ''",
quote_literal(db_user),
);
sqlx::query(query.as_str())
.execute(&mut *connection)
.await
.map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
};
if result.is_err() {
tracing::error!(
@@ -272,7 +293,7 @@ async fn database_user_is_locked_unsafe(
}
pub async fn lock_database_users(
db_users: &[MySQLUser],
db_users: Vec<MySQLUser>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
@@ -280,7 +301,7 @@ pub async fn lock_database_users(
) -> LockUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users.iter().cloned() {
for db_user in db_users {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(LockUserError::ValidationError)
@@ -332,7 +353,7 @@ pub async fn lock_database_users(
}
pub async fn unlock_database_users(
db_users: &[MySQLUser],
db_users: Vec<MySQLUser>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
@@ -340,7 +361,7 @@ pub async fn unlock_database_users(
) -> UnlockUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users.iter().cloned() {
for db_user in db_users {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(UnlockUserError::ValidationError)
@@ -440,7 +461,7 @@ FROM `user`
";
pub async fn list_database_users(
db_users: &[MySQLUser],
db_users: Vec<MySQLUser>,
unix_user: &UnixUser,
connection: &mut MySqlConnection,
db_is_mariadb: bool,
@@ -448,7 +469,7 @@ pub async fn list_database_users(
) -> ListUsersResponse {
let mut results = BTreeMap::new();
for db_user in db_users.iter().cloned() {
for db_user in db_users {
if let Err(err) =
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
.map_err(ListUsersError::ValidationError)

View File

@@ -2,10 +2,7 @@ use std::{
fs,
os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
path::PathBuf,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
sync::Arc,
time::Duration,
};
@@ -25,7 +22,7 @@ use crate::{
server::{
authorization::read_and_parse_group_denylist,
config::{MysqlConfig, ServerConfig},
session_handler::{SessionId, session_handler},
session_handler::session_handler,
},
};
@@ -551,8 +548,6 @@ async fn listener_task(
#[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
let connection_counter = AtomicU64::new(0);
loop {
tokio::select! {
biased;
@@ -582,29 +577,28 @@ async fn listener_task(
} => {
match accept_result {
Ok((conn, _addr)) => {
connection_counter.fetch_add(1, Ordering::Relaxed);
let conn_id = connection_counter.load(Ordering::Relaxed);
tracing::debug!("Got new connection");
tracing::debug!("Got new connection, assigned session ID {}", conn_id);
let session_id = SessionId::new(conn_id);
let db_pool_clone = db_pool.clone();
let db_is_mariadb_clone = *db_is_mariadb.read().await;
let group_denylist_arc_clone = group_denylist.clone();
task_tracker.spawn(async move {
match session_handler(
conn,
session_id,
db_pool_clone,
db_is_mariadb_clone,
&*group_denylist_arc_clone.read().await,
).await {
Ok(()) => {},
Err(e) => tracing::error!("Session {} failed: {}", conn_id, e),
Ok(()) => {}
Err(e) => {
tracing::error!("Failed to run server: {}", e);
}
}
});
},
Err(e) => tracing::error!("Failed to accept new connection: {}", e),
}
Err(e) => {
tracing::error!("Failed to accept new connection: {}", e);
}
}
}
}