Compare commits
13 Commits
password-c
...
manpages
| Author | SHA1 | Date | |
|---|---|---|---|
|
3f89eab11a
|
|||
|
ee33c96120
|
|||
|
94996038c2
|
|||
|
beb08e1b35
|
|||
|
6a3212bde2
|
|||
|
3ce2a13711
|
|||
|
fbe594d486
|
|||
|
2ec31cd146
|
|||
|
6e648004b5
|
|||
|
cb4b8a78dc
|
|||
|
b9f11d0413
|
|||
|
9f45c2e5da
|
|||
|
107333208c
|
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -285,10 +285,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2"
|
||||
dependencies = [
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
@@ -355,6 +353,16 @@ version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
|
||||
|
||||
[[package]]
|
||||
name = "clap_mangen"
|
||||
version = "0.2.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "439ea63a92086df93893164221ad4f24142086d535b3a0957b9b9bea2dc86301"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"roff",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "color-print"
|
||||
version = "0.3.7"
|
||||
@@ -697,7 +705,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1129,7 +1137,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1336,10 +1344,10 @@ dependencies = [
|
||||
"async-bincode",
|
||||
"bincode 2.0.1",
|
||||
"build-info-build",
|
||||
"chrono",
|
||||
"clap",
|
||||
"clap-verbosity-flag",
|
||||
"clap_complete",
|
||||
"clap_mangen",
|
||||
"color-print",
|
||||
"const_format",
|
||||
"derive_more",
|
||||
@@ -1772,6 +1780,12 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "roff"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88f8660c1ff60292143c98d08fc6e2f654d722db50410e3f3797d40baaf9d8f3"
|
||||
|
||||
[[package]]
|
||||
name = "rsa"
|
||||
version = "0.9.9"
|
||||
@@ -1811,7 +1825,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1917,16 +1931,16 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.146"
|
||||
version = "1.0.148"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8"
|
||||
checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
"serde_core",
|
||||
"zmij",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2307,7 +2321,7 @@ dependencies = [
|
||||
"getrandom 0.3.4",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3222,6 +3236,12 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zmij"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f4a4e8e9dc5c62d159f04fcdbe07f4c3fb710415aab4754bf11505501e3251d"
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.3"
|
||||
|
||||
@@ -22,10 +22,10 @@ autolib = false
|
||||
anyhow = "1.0.100"
|
||||
async-bincode = "0.8.0"
|
||||
bincode = "2.0.1"
|
||||
chrono = { version = "0.4.42", features = ["serde"] }
|
||||
clap = { version = "4.5.53", features = ["cargo", "derive"] }
|
||||
clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] }
|
||||
clap_complete = { version = "4.5.62", features = ["unstable-dynamic"] }
|
||||
clap_mangen = "0.2.31"
|
||||
color-print = "0.3.7"
|
||||
const_format = "0.2.35"
|
||||
derive_more = { version = "2.1.1", features = ["display", "error"] }
|
||||
@@ -39,7 +39,7 @@ num_cpus = "1.17.0"
|
||||
prettytable = "0.10.0"
|
||||
rand = "0.9.2"
|
||||
serde = "1.0.228"
|
||||
serde_json = { version = "1.0.146", features = ["preserve_order"] }
|
||||
serde_json = { version = "1.0.148", features = ["preserve_order"] }
|
||||
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
|
||||
thiserror = "2.0.17"
|
||||
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros", "signal"] }
|
||||
@@ -77,12 +77,12 @@ path = "src/lib.rs"
|
||||
[[bin]]
|
||||
name = "muscl"
|
||||
bench = false
|
||||
path = "src/entrypoints/muscl.rs"
|
||||
path = "src/bin/muscl.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "muscl-server"
|
||||
bench = false
|
||||
path = "src/entrypoints/muscl_server.rs"
|
||||
path = "src/bin/muscl_server.rs"
|
||||
|
||||
[profile.release-lto]
|
||||
inherits = "release"
|
||||
|
||||
11
README.md
11
README.md
@@ -7,7 +7,7 @@ Dropping DBs (dumbbells) and having MySQL spasms since 2024
|
||||
|
||||
## What is this?
|
||||
|
||||
`muscl is a secure MySQL administration tool for multi-user systems.
|
||||
`muscl` is a secure MySQL administration tool for multi-user systems.
|
||||
It allows unprivileged users to manage their own databases and database users without granting them direct access to the MySQL server.
|
||||
Authorization is handled by a prefix-based model tied to Unix users and groups, making it ideal for shared hosting environments, like university servers, tilde servers, or similar.
|
||||
|
||||
@@ -53,3 +53,12 @@ over a IPC, which then performs the requested operations on behalf of the client
|
||||
- [Compatibility mode with mysql-admutils](docs/mysql-admutils-compatibility.md)
|
||||
- [Use with NixOS](docs/nixos.md)
|
||||
- [SUID/SGID mode](docs/suid-sgid-mode.md)
|
||||
|
||||
## History
|
||||
|
||||
This is a rewrite of an older piece of software called [mysql-admutils](https://git.pvv.ntnu.no/Projects/mysql-admutils).
|
||||
Programvareverkstedet used this a lot back in the day, and it was great.
|
||||
But it had some security issues inherent to the software design, particularly related to the use of SUID/SGID.
|
||||
We tried patching it multiple times, but the issue kept popping up again in different ways.
|
||||
The rewrite was intended to iron this issue out completely by splitting the software into two pieces - a client and a server.
|
||||
As far as we know, this was successful, and it is unlikely for similar issues to resurface in the future.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# These are the default system groups on debian.
|
||||
# You can alos add groups by gid by prefixing the line with 'gid:'.
|
||||
|
||||
group:_ssh
|
||||
group:adm
|
||||
group:audio
|
||||
group:avahi
|
||||
@@ -12,6 +13,7 @@ group:daemon
|
||||
group:dialout
|
||||
group:dip
|
||||
group:disk
|
||||
group:docker
|
||||
group:fax
|
||||
group:floppy
|
||||
group:games
|
||||
@@ -22,9 +24,12 @@ group:kmem
|
||||
group:kvm
|
||||
group:list
|
||||
group:lp
|
||||
group:lxd
|
||||
group:mail
|
||||
group:man
|
||||
group:messagebus
|
||||
group:mlocate
|
||||
group:mysql
|
||||
group:netdev
|
||||
group:news
|
||||
group:nogroup
|
||||
@@ -42,15 +47,18 @@ group:src
|
||||
group:staff
|
||||
group:sudo
|
||||
group:sys
|
||||
group:syslog
|
||||
group:systemd-journal
|
||||
group:systemd-network
|
||||
group:systemd-resolve
|
||||
group:systemd-timesync
|
||||
group:tape
|
||||
group:tcpdump
|
||||
group:tty
|
||||
group:users
|
||||
group:utmp
|
||||
group:uucp
|
||||
group:uuidd
|
||||
group:video
|
||||
group:voice
|
||||
group:www-data
|
||||
|
||||
18
flake.lock
generated
18
flake.lock
generated
@@ -2,11 +2,11 @@
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1766194365,
|
||||
"narHash": "sha256-4AFsUZ0kl6MXSm4BaQgItD0VGlEKR3iq7gIaL7TjBvc=",
|
||||
"lastModified": 1766774972,
|
||||
"narHash": "sha256-8qxEFpj4dVmIuPn9j9z6NTbU+hrcGjBOvaxTzre5HmM=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "7d8ec2c71771937ab99790b45e6d9b93d15d9379",
|
||||
"rev": "01bc1d404a51a0a07e9d8759cd50a7903e218c82",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -17,11 +17,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1766309749,
|
||||
"narHash": "sha256-3xY8CZ4rSnQ0NqGhMKAy5vgC+2IVK0NoVEzDoOh4DA4=",
|
||||
"lastModified": 1766902085,
|
||||
"narHash": "sha256-coBu0ONtFzlwwVBzmjacUQwj3G+lybcZ1oeNSQkgC0M=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "a6531044f6d0bef691ea18d4d4ce44d0daa6e816",
|
||||
"rev": "c0b0e0fddf73fd517c3471e546c0df87a42d53f4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -45,11 +45,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1766457837,
|
||||
"narHash": "sha256-aeBbkQ0HPFNOIsUeEsXmZHXbYq4bG8ipT9JRlCcKHgU=",
|
||||
"lastModified": 1766976750,
|
||||
"narHash": "sha256-w+o3AIBI56tzfMJRqRXg9tSXnpQRN5hAT15o2t9rxYw=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "2c7510a559416d07242621d036847152d970612b",
|
||||
"rev": "9fe44e7f05b734a64a01f92fc51ad064fb0a884f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -42,9 +42,9 @@ in
|
||||
|
||||
authorization = {
|
||||
group_denylist = lib.mkOption {
|
||||
type = with lib.types; nullOr (listOf str);
|
||||
type = with lib.types; nullOr (listOf (either str ints.unsigned));
|
||||
default = [ "wheel" ];
|
||||
description = "List of groups that are denied access";
|
||||
description = "List of groups/GIDs that can not be used as prefixes for databases/database users";
|
||||
};
|
||||
};
|
||||
|
||||
@@ -110,7 +110,32 @@ in
|
||||
];
|
||||
|
||||
environment.etc."muscl/group-denylist" = lib.mkIf (cfg.settings.authorization.group_denylist != [ ]) {
|
||||
text = lib.concatMapStringsSep "\n" (group: "group:${group}") cfg.settings.authorization.group_denylist;
|
||||
text = let
|
||||
nameToGidMapping = lib.pipe config.users.groups [
|
||||
(lib.filterAttrs (_: group: group.gid != null))
|
||||
(lib.mapAttrsToList (name: group: { name = name; value = group.gid; }))
|
||||
lib.listToAttrs
|
||||
];
|
||||
|
||||
gidToNameMapping = lib.pipe config.users.groups [
|
||||
(lib.filterAttrs (_: group: group.gid != null))
|
||||
(lib.mapAttrsToList (name: group: { name = toString group.gid; value = name; }))
|
||||
lib.listToAttrs
|
||||
];
|
||||
in lib.pipe cfg.settings.authorization.group_denylist [
|
||||
# Prefer GIDs for groups we know the GID
|
||||
(map (group: if builtins.isString group
|
||||
then (nameToGidMapping.${group} or group)
|
||||
else group))
|
||||
|
||||
# Then render back to strings
|
||||
(map (group:
|
||||
if builtins.isString group
|
||||
then "group:${group}"
|
||||
else "gid:${toString group} # ${gidToNameMapping.${toString group} or "unknown"}"))
|
||||
|
||||
(lib.concatStringsSep "\n")
|
||||
];
|
||||
};
|
||||
|
||||
services.mysql.ensureUsers = lib.mkIf cfg.createLocalDatabaseUser [
|
||||
|
||||
@@ -82,9 +82,11 @@ const EXAMPLES: &str = const_format::concatcp!(
|
||||
|
||||
# Show all databases
|
||||
muscl show-db
|
||||
muscl sd
|
||||
|
||||
# Show which users have privileges on which databases
|
||||
muscl show-privs
|
||||
muscl sp
|
||||
"#,
|
||||
);
|
||||
|
||||
@@ -169,22 +171,27 @@ const EDIT_PRIVS_EXAMPLES: &str = color_print::cstr!(
|
||||
#[command(subcommand_required = true)]
|
||||
pub enum ClientCommand {
|
||||
/// Check whether you are authorized to manage the specified databases or users.
|
||||
#[command(alias = "ca")]
|
||||
CheckAuth(CheckAuthArgs),
|
||||
|
||||
/// Create one or more databases
|
||||
#[command(alias = "cd")]
|
||||
CreateDb(CreateDbArgs),
|
||||
|
||||
/// Delete one or more databases
|
||||
#[command(alias = "dd")]
|
||||
DropDb(DropDbArgs),
|
||||
|
||||
/// Print information about one or more databases
|
||||
///
|
||||
/// If no database name is provided, all databases you have access will be shown.
|
||||
#[command(alias = "sd")]
|
||||
ShowDb(ShowDbArgs),
|
||||
|
||||
/// Print user privileges for one or more databases
|
||||
///
|
||||
/// If no database names are provided, all databases you have access to will be shown.
|
||||
#[command(alias = "sp")]
|
||||
ShowPrivs(ShowPrivsArgs),
|
||||
|
||||
/// Change user privileges for one or more databases. See `edit-privs --help` for details.
|
||||
@@ -239,27 +246,34 @@ pub enum ClientCommand {
|
||||
verbatim_doc_comment,
|
||||
override_usage = "muscl edit-privs [OPTIONS] [ -p <DB_NAME:USER_NAME:[+-]PRIVILEGES>... | <DB_NAME> <USER_NAME> <[+-]PRIVILEGES> ]",
|
||||
after_long_help = EDIT_PRIVS_EXAMPLES,
|
||||
alias = "ep",
|
||||
)]
|
||||
EditPrivs(EditPrivsArgs),
|
||||
|
||||
/// Create one or more users
|
||||
#[command(alias = "cu")]
|
||||
CreateUser(CreateUserArgs),
|
||||
|
||||
/// Delete one or more users
|
||||
#[command(alias = "du")]
|
||||
DropUser(DropUserArgs),
|
||||
|
||||
/// Change the MySQL password for a user
|
||||
#[command(alias = "pu")]
|
||||
PasswdUser(PasswdUserArgs),
|
||||
|
||||
/// Print information about one or more users
|
||||
///
|
||||
/// If no username is provided, all users you have access will be shown.
|
||||
#[command(alias = "su")]
|
||||
ShowUser(ShowUserArgs),
|
||||
|
||||
/// Lock account for one or more users
|
||||
#[command(alias = "lu")]
|
||||
LockUser(LockUserArgs),
|
||||
|
||||
/// Unlock account for one or more users
|
||||
#[command(alias = "uu")]
|
||||
UnlockUser(UnlockUserArgs),
|
||||
}
|
||||
|
||||
@@ -291,6 +305,10 @@ fn main() -> anyhow::Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if handle_manpage_command()?.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[cfg(feature = "mysql-admutils-compatibility")]
|
||||
if handle_mysql_admutils_command()?.is_some() {
|
||||
return Ok(());
|
||||
@@ -347,6 +365,48 @@ fn handle_dynamic_completion() -> anyhow::Result<Option<()>> {
|
||||
}
|
||||
}
|
||||
|
||||
/// **WARNING:** This function may be run with elevated privileges.
|
||||
fn handle_manpage_command() -> anyhow::Result<Option<()>> {
|
||||
let argv1: Option<String> = std::env::args().nth(1);
|
||||
|
||||
match argv1.as_deref() {
|
||||
Some("generate-manpages") => {
|
||||
#[cfg(feature = "suid-sgid-mode")]
|
||||
if executing_in_suid_sgid_mode()? {
|
||||
use muscl_lib::core::bootstrap::drop_privs;
|
||||
drop_privs()?
|
||||
}
|
||||
|
||||
let output_dir = std::env::args().nth(2).ok_or(anyhow::anyhow!(
|
||||
"Output directory argument missing for manpage generation"
|
||||
))?;
|
||||
|
||||
let output_dir = PathBuf::from(&output_dir);
|
||||
if !output_dir.is_dir() {
|
||||
anyhow::bail!(
|
||||
"Output directory `{:?}` does not exist or is not a directory",
|
||||
output_dir,
|
||||
);
|
||||
}
|
||||
|
||||
let mut roff = clap_mangen::roff::Roff::new();
|
||||
let man = clap_mangen::Man::new(Args::command());
|
||||
man.render_title(&mut std::io::stdout())?;
|
||||
man.render_name_section(&mut std::io::stdout())?;
|
||||
man.render_synopsis_section(&mut std::io::stdout())?;
|
||||
man.render_subcommands_section(&mut std::io::stdout())?;
|
||||
man.render_options_section(&mut std::io::stdout())?;
|
||||
|
||||
roff.control("SH", ["VERSION"]);
|
||||
roff.text([clap_mangen::roff::roman(AFTER_LONG_HELP)]);
|
||||
roff.to_writer(&mut std::io::stdout())?;
|
||||
|
||||
Ok(Some(()))
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// **WARNING:** This function may be run with elevated privileges.
|
||||
fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
|
||||
let argv0 = std::env::args().next().and_then(|s| {
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::io::IsTerminal;
|
||||
|
||||
use clap::Parser;
|
||||
use clap_complete::ArgValueCompleter;
|
||||
use dialoguer::Confirm;
|
||||
@@ -6,16 +8,15 @@ use tokio_stream::StreamExt;
|
||||
|
||||
use crate::{
|
||||
client::commands::{
|
||||
erroneous_server_response, interactive_password_dialogue_with_double_check,
|
||||
interactive_password_expiry_dialogue, print_authorization_owner_hint,
|
||||
erroneous_server_response, print_authorization_owner_hint,
|
||||
read_password_from_stdin_with_double_check,
|
||||
},
|
||||
core::{
|
||||
completion::prefix_completer,
|
||||
protocol::{
|
||||
ClientToServerMessageStream, CreateUserError, Request, Response,
|
||||
SetUserPasswordRequest, print_create_users_output_status,
|
||||
print_create_users_output_status_json, print_set_password_output_status,
|
||||
request_validation::ValidationError,
|
||||
print_create_users_output_status, print_create_users_output_status_json,
|
||||
print_set_password_output_status, request_validation::ValidationError,
|
||||
},
|
||||
types::MySQLUser,
|
||||
},
|
||||
@@ -79,6 +80,15 @@ pub async fn create_users(
|
||||
.filter_map(|(username, result)| result.as_ref().ok().map(|()| username))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !std::io::stdin().is_terminal()
|
||||
&& !args.no_password
|
||||
&& !successfully_created_users.is_empty()
|
||||
{
|
||||
anyhow::bail!(
|
||||
"Cannot prompt for passwords in non-interactive mode. Use --no-password to skip setting passwords."
|
||||
);
|
||||
}
|
||||
|
||||
for username in successfully_created_users {
|
||||
if !args.no_password
|
||||
&& Confirm::new()
|
||||
@@ -88,14 +98,8 @@ pub async fn create_users(
|
||||
.default(false)
|
||||
.interact()?
|
||||
{
|
||||
let password = interactive_password_dialogue_with_double_check(username)?;
|
||||
let expiry = interactive_password_expiry_dialogue(username)?;
|
||||
|
||||
let message = Request::PasswdUser(SetUserPasswordRequest {
|
||||
user: username.clone(),
|
||||
new_password: Some(password),
|
||||
expiry: expiry,
|
||||
});
|
||||
let password = read_password_from_stdin_with_double_check(username)?;
|
||||
let message = Request::PasswdUser((username.to_owned(), password));
|
||||
|
||||
if let Err(err) = server_connection.send(message).await {
|
||||
server_connection.close().await.ok();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::io::IsTerminal;
|
||||
|
||||
use clap::Parser;
|
||||
use clap_complete::ArgValueCompleter;
|
||||
use dialoguer::Confirm;
|
||||
@@ -41,6 +43,12 @@ pub async fn drop_databases(
|
||||
anyhow::bail!("No database names provided");
|
||||
}
|
||||
|
||||
if !std::io::stdin().is_terminal() && !args.yes {
|
||||
anyhow::bail!(
|
||||
"Cannot prompt for confirmation in non-interactive mode. Use --yes to automatically confirm."
|
||||
);
|
||||
}
|
||||
|
||||
if !args.yes {
|
||||
let confirmation = Confirm::new()
|
||||
.with_prompt(format!(
|
||||
@@ -53,7 +61,6 @@ pub async fn drop_databases(
|
||||
))
|
||||
.interact()?;
|
||||
|
||||
//
|
||||
if !confirmation {
|
||||
// TODO: should we return with an error code here?
|
||||
println!("Aborting drop operation.");
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::io::IsTerminal;
|
||||
|
||||
use clap::Parser;
|
||||
use clap_complete::ArgValueCompleter;
|
||||
use dialoguer::Confirm;
|
||||
@@ -41,6 +43,12 @@ pub async fn drop_users(
|
||||
anyhow::bail!("No usernames provided");
|
||||
}
|
||||
|
||||
if !std::io::stdin().is_terminal() && !args.yes {
|
||||
anyhow::bail!(
|
||||
"Cannot prompt for confirmation in non-interactive mode. Use --yes to automatically confirm."
|
||||
);
|
||||
}
|
||||
|
||||
if !args.yes {
|
||||
let confirmation = Confirm::new()
|
||||
.with_prompt(format!(
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
io::IsTerminal,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::{Args, Parser};
|
||||
@@ -213,6 +216,11 @@ pub async fn edit_database_privileges(
|
||||
};
|
||||
|
||||
let diffs: BTreeSet<DatabasePrivilegesDiff> = if privs.is_empty() {
|
||||
if !std::io::stdin().is_terminal() {
|
||||
anyhow::bail!(
|
||||
"Cannot launch editor in non-interactive mode. Please provide privileges via command line arguments."
|
||||
);
|
||||
}
|
||||
let privileges_to_change =
|
||||
edit_privileges_with_editor(&existing_privilege_rows, use_database.as_ref())?;
|
||||
diff_privileges(&existing_privilege_rows, &privileges_to_change)
|
||||
@@ -275,7 +283,8 @@ pub async fn edit_database_privileges(
|
||||
println!("The following changes will be made:\n");
|
||||
println!("{}", display_privilege_diffs(&diffs));
|
||||
|
||||
if !args.yes
|
||||
if std::io::stdin().is_terminal()
|
||||
&& !args.yes
|
||||
&& !Confirm::new()
|
||||
.with_prompt("Do you want to apply these changes?")
|
||||
.default(false)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::path::PathBuf;
|
||||
use std::{io::IsTerminal, path::PathBuf};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
@@ -13,8 +13,7 @@ use crate::{
|
||||
completion::mysql_user_completer,
|
||||
protocol::{
|
||||
ClientToServerMessageStream, ListUsersError, Request, Response, SetPasswordError,
|
||||
SetUserPasswordRequest, print_set_password_output_status,
|
||||
request_validation::ValidationError,
|
||||
print_set_password_output_status, request_validation::ValidationError,
|
||||
},
|
||||
types::MySQLUser,
|
||||
},
|
||||
@@ -38,21 +37,9 @@ pub struct PasswdUserArgs {
|
||||
/// Print the information as JSON
|
||||
#[arg(short, long)]
|
||||
json: bool,
|
||||
|
||||
/// Set the password to expire on the given date (YYYY-MM-DD)
|
||||
#[arg(short, long, value_name = "DATE", conflicts_with = "no-expire")]
|
||||
expire_on: Option<chrono::NaiveDate>,
|
||||
|
||||
/// Set the password to never expire
|
||||
#[arg(short, long, conflicts_with = "expire_on")]
|
||||
no_expire: bool,
|
||||
|
||||
/// Clear the password for the user instead of setting a new one
|
||||
#[arg(short, long, conflicts_with_all = &["password_file", "stdin", "expire_on", "no-expire"])]
|
||||
clear: bool,
|
||||
}
|
||||
|
||||
pub fn interactive_password_dialogue_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
|
||||
pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
|
||||
Password::new()
|
||||
.with_prompt(format!("New MySQL password for user '{username}'"))
|
||||
.with_confirmation(
|
||||
@@ -63,29 +50,6 @@ pub fn interactive_password_dialogue_with_double_check(username: &MySQLUser) ->
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
pub fn interactive_password_expiry_dialogue(username: &MySQLUser) -> anyhow::Result<Option<chrono::NaiveDate>> {
|
||||
let input = dialoguer::Input::<String>::new()
|
||||
.with_prompt(format!(
|
||||
"Enter the password expiry date for user '{username}' (YYYY-MM-DD)"
|
||||
))
|
||||
.allow_empty(true)
|
||||
.validate_with(|input: &String| {
|
||||
chrono::NaiveDate::parse_from_str(input, "%Y-%m-%d")
|
||||
.map(|_| ())
|
||||
.map_err(|_| "Invalid date format. Please use YYYY-MM-DD".to_string())
|
||||
})
|
||||
.interact_text()?;
|
||||
|
||||
if input.trim().is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let date = chrono::NaiveDate::parse_from_str(&input, "%Y-%m-%d")
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse date: {}", e))?;
|
||||
|
||||
Ok(Some(date))
|
||||
}
|
||||
|
||||
pub async fn passwd_user(
|
||||
args: PasswdUserArgs,
|
||||
mut server_connection: ClientToServerMessageStream,
|
||||
@@ -112,38 +76,27 @@ pub async fn passwd_user(
|
||||
}
|
||||
}
|
||||
|
||||
let password: Option<String> = if let Some(password_file) = args.password_file {
|
||||
Some(
|
||||
std::fs::read_to_string(password_file)
|
||||
.context("Failed to read password file")?
|
||||
.trim()
|
||||
.to_string(),
|
||||
)
|
||||
let password = if let Some(password_file) = args.password_file {
|
||||
std::fs::read_to_string(password_file)
|
||||
.context("Failed to read password file")?
|
||||
.trim()
|
||||
.to_string()
|
||||
} else if args.stdin {
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin()
|
||||
.read_line(&mut buffer)
|
||||
.context("Failed to read password from stdin")?;
|
||||
Some(buffer.trim().to_string())
|
||||
} else if args.clear {
|
||||
None
|
||||
buffer.trim().to_string()
|
||||
} else {
|
||||
Some(interactive_password_dialogue_with_double_check(&args.username)?)
|
||||
if !std::io::stdin().is_terminal() {
|
||||
anyhow::bail!(
|
||||
"Cannot prompt for password in non-interactive mode. Use --stdin or --password-file to provide the password."
|
||||
);
|
||||
}
|
||||
read_password_from_stdin_with_double_check(&args.username)?
|
||||
};
|
||||
|
||||
let expiry_date = if args.no_expire {
|
||||
None
|
||||
} else if let Some(date) = args.expire_on {
|
||||
Some(date)
|
||||
} else {
|
||||
interactive_password_expiry_dialogue(&args.username)?
|
||||
};
|
||||
|
||||
let message = Request::PasswdUser(SetUserPasswordRequest {
|
||||
user: args.username.clone(),
|
||||
new_password: password,
|
||||
expiry: expiry_date,
|
||||
});
|
||||
let message = Request::PasswdUser((args.username.clone(), password));
|
||||
|
||||
if let Err(err) = server_connection.send(message).await {
|
||||
server_connection.close().await.ok();
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio::net::UnixStream as TokioUnixStream;
|
||||
|
||||
use crate::{
|
||||
client::{
|
||||
commands::{erroneous_server_response, interactive_password_dialogue_with_double_check},
|
||||
commands::{erroneous_server_response, read_password_from_stdin_with_double_check},
|
||||
mysql_admutils_compatibility::{
|
||||
common::trim_user_name_to_32_chars,
|
||||
error_messages::{
|
||||
@@ -20,7 +20,7 @@ use crate::{
|
||||
bootstrap::bootstrap_server_connection_and_drop_privileges,
|
||||
completion::{mysql_user_completer, prefix_completer},
|
||||
protocol::{
|
||||
ClientToServerMessageStream, Request, Response, SetUserPasswordRequest, create_client_to_server_message_stream
|
||||
ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream,
|
||||
},
|
||||
types::MySQLUser,
|
||||
},
|
||||
@@ -252,12 +252,8 @@ async fn passwd_users(
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for user in users {
|
||||
let password = interactive_password_dialogue_with_double_check(&user.user)?;
|
||||
let message = Request::PasswdUser(SetUserPasswordRequest {
|
||||
user: user.user.clone(),
|
||||
new_password: Some(password),
|
||||
expiry: None,
|
||||
});
|
||||
let password = read_password_from_stdin_with_double_check(&user.user)?;
|
||||
let message = Request::PasswdUser((user.user.clone(), password));
|
||||
server_connection.send(message).await?;
|
||||
match server_connection.next().await {
|
||||
Some(Ok(Response::SetUserPassword(result))) => match result {
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::{
|
||||
authorization::read_and_parse_group_denylist,
|
||||
config::{MysqlConfig, ServerConfig},
|
||||
landlock::landlock_restrict_server,
|
||||
session_handler,
|
||||
session_handler::{self, SessionId},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -308,9 +308,11 @@ fn run_forked_server(
|
||||
version_row.to_lowercase().contains("mariadb")
|
||||
};
|
||||
|
||||
let session_id = SessionId::new(0);
|
||||
let db_pool = Arc::new(RwLock::new(db_pool));
|
||||
session_handler::session_handler_with_unix_user(
|
||||
socket,
|
||||
session_id,
|
||||
unix_user,
|
||||
db_pool,
|
||||
db_is_mariadb,
|
||||
|
||||
@@ -24,6 +24,7 @@ pub const KIND_REGARDS: &str = concat!(
|
||||
"If you experience any bugs or turbulence, please give us a heads up :)",
|
||||
);
|
||||
|
||||
/// TODO: store and display UID
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnixUser {
|
||||
pub username: String,
|
||||
|
||||
@@ -36,11 +36,16 @@ pub use modify_privileges::*;
|
||||
pub use passwd_user::*;
|
||||
pub use unlock_users::*;
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::UnixStream;
|
||||
use tokio_serde::{Framed as SerdeFramed, formats::Bincode};
|
||||
use tokio_util::codec::{Framed, LengthDelimitedCodec};
|
||||
|
||||
use crate::core::types::{MySQLDatabase, MySQLUser};
|
||||
|
||||
pub type ServerToClientMessageStream = SerdeFramed<
|
||||
Framed<UnixStream, LengthDelimitedCodec>,
|
||||
Request,
|
||||
@@ -104,6 +109,124 @@ pub enum Request {
|
||||
Exit,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
/// Get the command name associated with this request.
|
||||
pub fn command_name(&self) -> &str {
|
||||
match self {
|
||||
Request::CheckAuthorization(_) => "check-authorization",
|
||||
Request::ListValidNamePrefixes => "list-valid-name-prefixes",
|
||||
Request::CompleteDatabaseName(_) => "complete-database-name",
|
||||
Request::CompleteUserName(_) => "complete-user-name",
|
||||
Request::CreateDatabases(_) => "create-databases",
|
||||
Request::DropDatabases(_) => "drop-databases",
|
||||
Request::ListDatabases(_) => "list-databases",
|
||||
Request::ListPrivileges(_) => "list-privileges",
|
||||
Request::ModifyPrivileges(_) => "modify-privileges",
|
||||
Request::CreateUsers(_) => "create-users",
|
||||
Request::DropUsers(_) => "drop-users",
|
||||
Request::PasswdUser(_) => "passwd-user",
|
||||
Request::ListUsers(_) => "list-users",
|
||||
Request::LockUsers(_) => "lock-users",
|
||||
Request::UnlockUsers(_) => "unlock-users",
|
||||
Request::Exit => "exit",
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a short summary string representing this request for logging purposes.
|
||||
pub fn log_summary(&self) -> String {
|
||||
match self {
|
||||
Request::CheckAuthorization(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
|
||||
Request::CreateDatabases(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::DropDatabases(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::ListDatabases(req) => format!(
|
||||
"{}{}",
|
||||
self.command_name(),
|
||||
req.as_ref()
|
||||
.map_or("".to_string(), |r| format!("({})", r.len()))
|
||||
),
|
||||
Request::ListPrivileges(req) => format!(
|
||||
"{}{}",
|
||||
self.command_name(),
|
||||
req.as_ref()
|
||||
.map_or("".to_string(), |r| format!("({})", r.len()))
|
||||
),
|
||||
Request::ModifyPrivileges(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
|
||||
Request::CreateUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::DropUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::ListUsers(req) => format!(
|
||||
"{}{}",
|
||||
self.command_name(),
|
||||
req.as_ref()
|
||||
.map_or("".to_string(), |r| format!("({})", r.len()))
|
||||
),
|
||||
Request::LockUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::UnlockUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
|
||||
_ => self.command_name().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the set of users affected by this request.
|
||||
pub fn affected_users(&self) -> BTreeSet<MySQLUser> {
|
||||
match self {
|
||||
Request::CheckAuthorization(_) => Default::default(),
|
||||
Request::ListValidNamePrefixes => Default::default(),
|
||||
Request::CompleteDatabaseName(_) => Default::default(),
|
||||
Request::CompleteUserName(_) => Default::default(),
|
||||
Request::CreateDatabases(_) => Default::default(),
|
||||
Request::DropDatabases(_) => Default::default(),
|
||||
Request::ListDatabases(_) => Default::default(),
|
||||
Request::ListPrivileges(_) => Default::default(),
|
||||
Request::ModifyPrivileges(priv_diffs) => priv_diffs
|
||||
.iter()
|
||||
.map(|priv_diff| priv_diff.get_user_name().clone())
|
||||
.collect(),
|
||||
Request::CreateUsers(users) => users.iter().cloned().collect(),
|
||||
Request::DropUsers(users) => users.iter().cloned().collect(),
|
||||
Request::PasswdUser(user_passwd_req) => {
|
||||
let mut result = BTreeSet::new();
|
||||
result.insert(user_passwd_req.0.clone());
|
||||
result
|
||||
}
|
||||
Request::ListUsers(users) => users.clone().unwrap_or_default().into_iter().collect(),
|
||||
Request::LockUsers(users) => users.iter().cloned().collect(),
|
||||
Request::UnlockUsers(users) => users.iter().cloned().collect(),
|
||||
Request::Exit => Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the set of databases affected by this request.
|
||||
pub fn affected_databases(&self) -> BTreeSet<MySQLDatabase> {
|
||||
match self {
|
||||
Request::CheckAuthorization(_) => Default::default(),
|
||||
Request::ListValidNamePrefixes => Default::default(),
|
||||
Request::CompleteDatabaseName(_) => Default::default(),
|
||||
Request::CompleteUserName(_) => Default::default(),
|
||||
Request::CreateDatabases(databases) => databases.iter().cloned().collect(),
|
||||
Request::DropDatabases(databases) => databases.iter().cloned().collect(),
|
||||
Request::ListDatabases(databases) => {
|
||||
databases.clone().unwrap_or_default().into_iter().collect()
|
||||
}
|
||||
Request::ListPrivileges(databases) => {
|
||||
databases.clone().unwrap_or_default().into_iter().collect()
|
||||
}
|
||||
Request::ModifyPrivileges(priv_diffs) => priv_diffs
|
||||
.iter()
|
||||
.map(|priv_diff| priv_diff.get_database_name().clone())
|
||||
.collect(),
|
||||
Request::CreateUsers(_) => Default::default(),
|
||||
Request::DropUsers(_) => Default::default(),
|
||||
Request::PasswdUser(_) => Default::default(),
|
||||
Request::ListUsers(_) => Default::default(),
|
||||
Request::LockUsers(_) => Default::default(),
|
||||
Request::UnlockUsers(_) => Default::default(),
|
||||
Request::Exit => Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: include a generic "message" that will display a message to the user?
|
||||
|
||||
#[non_exhaustive]
|
||||
@@ -136,3 +259,95 @@ pub enum Response {
|
||||
Ready,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ResponseOkStatus {
|
||||
Success,
|
||||
PartialSuccess(usize, usize), // succeeded, total
|
||||
Error,
|
||||
}
|
||||
|
||||
impl ResponseOkStatus {
|
||||
pub fn from_counts(total: usize, succeeded: usize) -> Self {
|
||||
if succeeded == total {
|
||||
ResponseOkStatus::Success
|
||||
} else if succeeded == 0 {
|
||||
ResponseOkStatus::Error
|
||||
} else {
|
||||
ResponseOkStatus::PartialSuccess(succeeded, total)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bool(is_ok: bool) -> Self {
|
||||
if is_ok {
|
||||
ResponseOkStatus::Success
|
||||
} else {
|
||||
ResponseOkStatus::Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ResponseOkStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ResponseOkStatus::Success => write!(f, "OK"),
|
||||
ResponseOkStatus::PartialSuccess(succeeded, total) => {
|
||||
write!(f, "PARTIAL_OK({}/{})", succeeded, total)
|
||||
}
|
||||
ResponseOkStatus::Error => write!(f, "ERR"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn ok_status(&self) -> ResponseOkStatus {
|
||||
match self {
|
||||
Response::CheckAuthorization(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
|
||||
Response::ListValidNamePrefixes(_) => ResponseOkStatus::Success,
|
||||
Response::CompleteDatabaseName(_) => ResponseOkStatus::Success,
|
||||
Response::CompleteUserName(_) => ResponseOkStatus::Success,
|
||||
|
||||
Response::CreateDatabases(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::DropDatabases(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListDatabases(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListAllDatabases(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::ListPrivileges(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::ModifyPrivileges(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
|
||||
Response::CreateUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::DropUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::SetUserPassword(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::ListUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListAllUsers(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::LockUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::UnlockUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
|
||||
Response::Ready => ResponseOkStatus::Success,
|
||||
Response::Error(_) => ResponseOkStatus::Error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,7 @@ use crate::core::{
|
||||
types::{DbOrUser, MySQLUser},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SetUserPasswordRequest {
|
||||
pub user: MySQLUser,
|
||||
pub new_password: Option<String>,
|
||||
pub expiry: Option<chrono::NaiveDate>,
|
||||
}
|
||||
pub type SetUserPasswordRequest = (MySQLUser, String);
|
||||
|
||||
pub type SetUserPasswordResponse = Result<(), SetPasswordError>;
|
||||
|
||||
@@ -23,9 +18,6 @@ pub enum SetPasswordError {
|
||||
#[error("User does not exist")]
|
||||
UserDoesNotExist,
|
||||
|
||||
#[error("Cannot clear password with an expiry date set")]
|
||||
ClearPasswordWithExpiry,
|
||||
|
||||
#[error("MySQL error: {0}")]
|
||||
MySqlError(String),
|
||||
}
|
||||
@@ -52,9 +44,6 @@ impl SetPasswordError {
|
||||
SetPasswordError::UserDoesNotExist => {
|
||||
format!("User '{username}' does not exist.")
|
||||
}
|
||||
SetPasswordError::ClearPasswordWithExpiry => {
|
||||
format!("Cannot clear password for user '{username}' when an expiry date is set.")
|
||||
}
|
||||
SetPasswordError::MySqlError(err) => {
|
||||
format!("MySQL error: {err}")
|
||||
}
|
||||
@@ -67,7 +56,6 @@ impl SetPasswordError {
|
||||
match self {
|
||||
SetPasswordError::ValidationError(err) => err.error_type(),
|
||||
SetPasswordError::UserDoesNotExist => "user-does-not-exist".to_string(),
|
||||
SetPasswordError::ClearPasswordWithExpiry => "clear-password-with-expiry".to_string(),
|
||||
SetPasswordError::MySqlError(_) => "mysql-error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::HashSet, path::Path};
|
||||
use std::{collections::HashSet, path::Path, str::Lines};
|
||||
|
||||
use anyhow::Context;
|
||||
use nix::unistd::Group;
|
||||
@@ -13,23 +13,19 @@ use crate::core::{
|
||||
};
|
||||
|
||||
pub async fn check_authorization(
|
||||
dbs_or_users: Vec<DbOrUser>,
|
||||
dbs_or_users: &[DbOrUser],
|
||||
unix_user: &UnixUser,
|
||||
group_denylist: &GroupDenylist,
|
||||
) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> {
|
||||
let mut results = std::collections::BTreeMap::new();
|
||||
|
||||
for db_or_user in dbs_or_users {
|
||||
if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
|
||||
.map_err(CheckAuthorizationError)
|
||||
{
|
||||
results.insert(db_or_user.clone(), Err(err));
|
||||
continue;
|
||||
}
|
||||
results.insert(db_or_user.clone(), Ok(()));
|
||||
}
|
||||
|
||||
results
|
||||
dbs_or_users
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|db_or_user| {
|
||||
let result = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
|
||||
.map_err(CheckAuthorizationError);
|
||||
(db_or_user, result)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Reads and parses a group denylist file, returning a set of GUIDs
|
||||
@@ -45,14 +41,23 @@ pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<Gro
|
||||
let content = std::fs::read_to_string(denylist_path)
|
||||
.context(format!("Failed to read denylist file at {denylist_path:?}"))?;
|
||||
|
||||
let mut groups = HashSet::with_capacity(content.lines().count());
|
||||
let lines = content.lines();
|
||||
|
||||
for (line_number, line) in content.lines().enumerate() {
|
||||
let trimmed_line = line.trim();
|
||||
let groups = parse_group_denylist(denylist_path, lines);
|
||||
|
||||
if trimmed_line.is_empty() || trimmed_line.starts_with('#') {
|
||||
continue;
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
fn parse_group_denylist(denylist_path: &Path, lines: Lines) -> GroupDenylist {
|
||||
let mut groups = HashSet::<u32>::new();
|
||||
|
||||
for (line_number, line) in lines.enumerate() {
|
||||
let trimmed_line = if let Some(comment_start) = line.find('#') {
|
||||
&line[..comment_start]
|
||||
} else {
|
||||
line
|
||||
}
|
||||
.trim();
|
||||
|
||||
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
|
||||
if parts.len() != 2 {
|
||||
@@ -141,5 +146,32 @@ pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<Gro
|
||||
}
|
||||
}
|
||||
|
||||
Ok(groups)
|
||||
groups
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use indoc::indoc;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_group_denylist() {
|
||||
let denylist_content = indoc! {"
|
||||
# Valid entries
|
||||
gid:0 # This is usually the 'root' group
|
||||
group:root # This is also the 'root' group, should deduplicate
|
||||
|
||||
# Invalid entries
|
||||
invalid_line
|
||||
gid:not_a_number
|
||||
group:nonexistent_group
|
||||
"};
|
||||
|
||||
let lines = denylist_content.lines();
|
||||
let group_denylist = parse_group_denylist(Path::new("test_denylist"), lines);
|
||||
|
||||
assert_eq!(group_denylist.len(), 1);
|
||||
assert!(group_denylist.contains(&0));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use indoc::concatdoc;
|
||||
use itertools::Itertools;
|
||||
use sqlx::{MySqlConnection, MySqlPool};
|
||||
use tokio::{net::UnixStream, sync::RwLock};
|
||||
use tracing::Instrument;
|
||||
@@ -11,8 +12,7 @@ use crate::{
|
||||
common::UnixUser,
|
||||
protocol::{
|
||||
Request, Response, ServerToClientMessageStream, SetPasswordError,
|
||||
SetUserPasswordRequest, create_server_to_client_message_stream,
|
||||
request_validation::GroupDenylist,
|
||||
create_server_to_client_message_stream, request_validation::GroupDenylist,
|
||||
},
|
||||
},
|
||||
server::{
|
||||
@@ -35,10 +35,24 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct SessionId(u64);
|
||||
|
||||
impl SessionId {
|
||||
pub fn new(id: u64) -> Self {
|
||||
SessionId(id)
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: don't use database connection unless necessary.
|
||||
|
||||
pub async fn session_handler(
|
||||
socket: UnixStream,
|
||||
session_id: SessionId,
|
||||
db_pool: Arc<RwLock<MySqlPool>>,
|
||||
db_is_mariadb: bool,
|
||||
group_denylist: &GroupDenylist,
|
||||
@@ -62,7 +76,7 @@ pub async fn session_handler(
|
||||
}
|
||||
};
|
||||
|
||||
tracing::debug!("Validated peer UID: {}", uid);
|
||||
tracing::trace!("Validated peer UID: {}", uid);
|
||||
|
||||
let unix_user = match UnixUser::from_uid(uid) {
|
||||
Ok(user) => user,
|
||||
@@ -83,13 +97,18 @@ pub async fn session_handler(
|
||||
}
|
||||
};
|
||||
|
||||
let span = tracing::info_span!("user_session", user = %unix_user);
|
||||
let span = tracing::info_span!(
|
||||
"user_session",
|
||||
session_id = session_id.inner(),
|
||||
user = %unix_user,
|
||||
);
|
||||
|
||||
(async move {
|
||||
tracing::info!("Accepted connection from user: {}", unix_user);
|
||||
tracing::debug!("Accepted connection from user: {}", unix_user);
|
||||
|
||||
let result = session_handler_with_unix_user(
|
||||
socket,
|
||||
session_id,
|
||||
&unix_user,
|
||||
db_pool,
|
||||
db_is_mariadb,
|
||||
@@ -97,7 +116,7 @@ pub async fn session_handler(
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
tracing::debug!(
|
||||
"Finished handling requests for connection from user: {}",
|
||||
unix_user,
|
||||
);
|
||||
@@ -110,6 +129,7 @@ pub async fn session_handler(
|
||||
|
||||
pub async fn session_handler_with_unix_user(
|
||||
socket: UnixStream,
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
db_pool: Arc<RwLock<MySqlPool>>,
|
||||
db_is_mariadb: bool,
|
||||
@@ -117,7 +137,7 @@ pub async fn session_handler_with_unix_user(
|
||||
) -> anyhow::Result<()> {
|
||||
let mut message_stream = create_server_to_client_message_stream(socket);
|
||||
|
||||
tracing::debug!("Requesting database connection from pool");
|
||||
tracing::trace!("Requesting database connection from pool");
|
||||
let mut db_connection = match db_pool.read().await.acquire().await {
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
@@ -134,10 +154,11 @@ pub async fn session_handler_with_unix_user(
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
tracing::debug!("Successfully acquired database connection from pool");
|
||||
tracing::trace!("Successfully acquired database connection from pool");
|
||||
|
||||
let result = session_handler_with_db_connection(
|
||||
message_stream,
|
||||
session_id,
|
||||
unix_user,
|
||||
&mut db_connection,
|
||||
db_is_mariadb,
|
||||
@@ -145,7 +166,7 @@ pub async fn session_handler_with_unix_user(
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::debug!("Releasing database connection back to pool");
|
||||
tracing::trace!("Releasing database connection back to pool");
|
||||
|
||||
result
|
||||
}
|
||||
@@ -155,6 +176,7 @@ pub async fn session_handler_with_unix_user(
|
||||
|
||||
async fn session_handler_with_db_connection(
|
||||
mut stream: ServerToClientMessageStream,
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
db_connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -174,258 +196,311 @@ async fn session_handler_with_db_connection(
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: don't clone the request
|
||||
let request_to_display = match &request {
|
||||
Request::PasswdUser(SetUserPasswordRequest {
|
||||
user,
|
||||
new_password,
|
||||
expiry,
|
||||
}) => Request::PasswdUser(SetUserPasswordRequest {
|
||||
user: user.clone(),
|
||||
new_password: new_password.as_ref().map(|_| "<REDACTED>".to_string()),
|
||||
expiry: *expiry,
|
||||
}),
|
||||
request => request.to_owned(),
|
||||
};
|
||||
let request_span = tracing::info_span!("request", command = request.command_name());
|
||||
|
||||
if request_to_display == Request::Exit {
|
||||
tracing::debug!("Received request: {:#?}", request_to_display);
|
||||
} else {
|
||||
tracing::info!("Received request: {:#?}", request_to_display);
|
||||
if !handle_request(
|
||||
request,
|
||||
session_id,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
&mut stream,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await?
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
let response = match request {
|
||||
Request::CheckAuthorization(dbs_or_users) => {
|
||||
let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
|
||||
Response::CheckAuthorization(result)
|
||||
}
|
||||
Request::ListValidNamePrefixes => {
|
||||
let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
|
||||
result.push(unix_user.username.clone());
|
||||
|
||||
for group in get_user_filtered_groups(unix_user, group_denylist) {
|
||||
result.push(group.clone());
|
||||
}
|
||||
|
||||
Response::ListValidNamePrefixes(result)
|
||||
}
|
||||
Request::CompleteDatabaseName(partial_database_name) => {
|
||||
// TODO: more correct validation here
|
||||
if partial_database_name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
let result = complete_database_name(
|
||||
partial_database_name,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CompleteDatabaseName(result)
|
||||
} else {
|
||||
Response::CompleteDatabaseName(vec![])
|
||||
}
|
||||
}
|
||||
Request::CompleteUserName(partial_user_name) => {
|
||||
// TODO: more correct validation here
|
||||
if partial_user_name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
let result = complete_user_name(
|
||||
partial_user_name,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CompleteUserName(result)
|
||||
} else {
|
||||
Response::CompleteUserName(vec![])
|
||||
}
|
||||
}
|
||||
Request::CreateDatabases(databases_names) => {
|
||||
let result = create_databases(
|
||||
databases_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CreateDatabases(result)
|
||||
}
|
||||
Request::DropDatabases(databases_names) => {
|
||||
let result = drop_databases(
|
||||
databases_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::DropDatabases(result)
|
||||
}
|
||||
Request::ListDatabases(database_names) => {
|
||||
if let Some(database_names) = database_names {
|
||||
let result = list_databases(
|
||||
database_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListDatabases(result)
|
||||
} else {
|
||||
let result = list_all_databases_for_user(
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListAllDatabases(result)
|
||||
}
|
||||
}
|
||||
Request::ListPrivileges(database_names) => {
|
||||
if let Some(database_names) = database_names {
|
||||
let privilege_data = get_databases_privilege_data(
|
||||
database_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListPrivileges(privilege_data)
|
||||
} else {
|
||||
let privilege_data = get_all_database_privileges(
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListAllPrivileges(privilege_data)
|
||||
}
|
||||
}
|
||||
Request::ModifyPrivileges(database_privilege_diffs) => {
|
||||
let result = apply_privilege_diffs(
|
||||
BTreeSet::from_iter(database_privilege_diffs),
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ModifyPrivileges(result)
|
||||
}
|
||||
Request::CreateUsers(db_users) => {
|
||||
let result = create_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CreateUsers(result)
|
||||
}
|
||||
Request::DropUsers(db_users) => {
|
||||
let result = drop_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::DropUsers(result)
|
||||
}
|
||||
Request::PasswdUser(SetUserPasswordRequest {
|
||||
user,
|
||||
new_password,
|
||||
expiry,
|
||||
}) => {
|
||||
let result = set_password_for_database_user(
|
||||
&user,
|
||||
new_password.as_deref(),
|
||||
expiry,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::SetUserPassword(result)
|
||||
}
|
||||
Request::ListUsers(db_users) => {
|
||||
if let Some(db_users) = db_users {
|
||||
let result = list_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListUsers(result)
|
||||
} else {
|
||||
let result = list_all_database_users_for_unix_user(
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListAllUsers(result)
|
||||
}
|
||||
}
|
||||
Request::LockUsers(db_users) => {
|
||||
let result = lock_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::LockUsers(result)
|
||||
}
|
||||
Request::UnlockUsers(db_users) => {
|
||||
let result = unlock_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::UnlockUsers(result)
|
||||
}
|
||||
Request::Exit => {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let response_to_display = match &response {
|
||||
Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
|
||||
&Response::SetUserPassword(Err(SetPasswordError::MySqlError(
|
||||
"<REDACTED>".to_string(),
|
||||
)))
|
||||
}
|
||||
response => response,
|
||||
};
|
||||
tracing::debug!("Response: {:#?}", response_to_display);
|
||||
|
||||
stream.send(response).await?;
|
||||
stream.flush().await?;
|
||||
tracing::debug!("Successfully processed request");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a single request from a client.
|
||||
///
|
||||
/// If the function returns `true`, the session should continue.
|
||||
async fn handle_request(
|
||||
request: Request,
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
db_connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
group_denylist: &GroupDenylist,
|
||||
stream: &mut ServerToClientMessageStream,
|
||||
) -> anyhow::Result<bool> {
|
||||
match &request {
|
||||
Request::Exit => tracing::debug!("Request: exit"),
|
||||
Request::PasswdUser((db_user, _)) => tracing::debug!(
|
||||
"Request:\n{}",
|
||||
serde_json::to_string_pretty(&Request::PasswdUser((
|
||||
db_user.to_owned(),
|
||||
"<REDACTED>".to_string()
|
||||
)))?
|
||||
),
|
||||
request => tracing::debug!("Request:\n{}", serde_json::to_string_pretty(request)?),
|
||||
}
|
||||
|
||||
let affected_dbs = request.affected_databases();
|
||||
if !affected_dbs.is_empty() {
|
||||
tracing::trace!(
|
||||
"Affected databases: {}",
|
||||
affected_dbs.into_iter().map(|db| db.to_string()).join(", ")
|
||||
);
|
||||
}
|
||||
|
||||
let affected_users = request.affected_users();
|
||||
if !affected_users.is_empty() {
|
||||
tracing::trace!(
|
||||
"Affected users: {}",
|
||||
affected_users.into_iter().map(|u| u.to_string()).join(", "),
|
||||
);
|
||||
}
|
||||
|
||||
let response = match request {
|
||||
Request::CheckAuthorization(ref dbs_or_users) => {
|
||||
let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
|
||||
Response::CheckAuthorization(result)
|
||||
}
|
||||
Request::ListValidNamePrefixes => {
|
||||
let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
|
||||
result.push(unix_user.username.clone());
|
||||
|
||||
for group in get_user_filtered_groups(unix_user, group_denylist) {
|
||||
result.push(group.clone());
|
||||
}
|
||||
|
||||
Response::ListValidNamePrefixes(result)
|
||||
}
|
||||
Request::CompleteDatabaseName(ref partial_database_name) => {
|
||||
// TODO: more correct validation here
|
||||
if partial_database_name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
let result = complete_database_name(
|
||||
partial_database_name,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CompleteDatabaseName(result)
|
||||
} else {
|
||||
Response::CompleteDatabaseName(vec![])
|
||||
}
|
||||
}
|
||||
Request::CompleteUserName(ref partial_user_name) => {
|
||||
// TODO: more correct validation here
|
||||
if partial_user_name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
let result = complete_user_name(
|
||||
partial_user_name,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CompleteUserName(result)
|
||||
} else {
|
||||
Response::CompleteUserName(vec![])
|
||||
}
|
||||
}
|
||||
Request::CreateDatabases(ref databases_names) => {
|
||||
let result = create_databases(
|
||||
databases_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CreateDatabases(result)
|
||||
}
|
||||
Request::DropDatabases(ref databases_names) => {
|
||||
let result = drop_databases(
|
||||
databases_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::DropDatabases(result)
|
||||
}
|
||||
Request::ListDatabases(ref database_names) => {
|
||||
if let Some(database_names) = database_names {
|
||||
let result = list_databases(
|
||||
database_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListDatabases(result)
|
||||
} else {
|
||||
let result = list_all_databases_for_user(
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListAllDatabases(result)
|
||||
}
|
||||
}
|
||||
Request::ListPrivileges(ref database_names) => {
|
||||
if let Some(database_names) = database_names {
|
||||
let privilege_data = get_databases_privilege_data(
|
||||
database_names,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListPrivileges(privilege_data)
|
||||
} else {
|
||||
let privilege_data = get_all_database_privileges(
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListAllPrivileges(privilege_data)
|
||||
}
|
||||
}
|
||||
Request::ModifyPrivileges(ref database_privilege_diffs) => {
|
||||
let result = apply_privilege_diffs(
|
||||
database_privilege_diffs,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ModifyPrivileges(result)
|
||||
}
|
||||
Request::CreateUsers(ref db_users) => {
|
||||
let result = create_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::CreateUsers(result)
|
||||
}
|
||||
Request::DropUsers(ref db_users) => {
|
||||
let result = drop_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::DropUsers(result)
|
||||
}
|
||||
Request::PasswdUser((ref db_user, ref password)) => {
|
||||
let result = set_password_for_database_user(
|
||||
db_user,
|
||||
password,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::SetUserPassword(result)
|
||||
}
|
||||
Request::ListUsers(ref db_users) => {
|
||||
if let Some(db_users) = db_users {
|
||||
let result = list_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListUsers(result)
|
||||
} else {
|
||||
let result = list_all_database_users_for_unix_user(
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::ListAllUsers(result)
|
||||
}
|
||||
}
|
||||
Request::LockUsers(ref db_users) => {
|
||||
let result = lock_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::LockUsers(result)
|
||||
}
|
||||
Request::UnlockUsers(ref db_users) => {
|
||||
let result = unlock_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
group_denylist,
|
||||
)
|
||||
.await;
|
||||
Response::UnlockUsers(result)
|
||||
}
|
||||
Request::Exit => {
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let response_to_display = match &response {
|
||||
Response::SetUserPassword(Err(SetPasswordError::MySqlError(_))) => {
|
||||
&Response::SetUserPassword(Err(SetPasswordError::MySqlError("<REDACTED>".to_string())))
|
||||
}
|
||||
response => response,
|
||||
};
|
||||
tracing::debug!(
|
||||
"Response:\n{}",
|
||||
serde_json::to_string_pretty(&response_to_display)?
|
||||
);
|
||||
|
||||
log_request(session_id, unix_user, &request, &response);
|
||||
|
||||
stream.send(response).await?;
|
||||
stream.flush().await?;
|
||||
tracing::trace!("Successfully processed request");
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Log a summary of the request and its result.
|
||||
fn log_request(
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
request: &Request,
|
||||
response: &Response,
|
||||
) {
|
||||
tracing::info!(
|
||||
"[{}|session:{}|user:{unix_user}] {}",
|
||||
response.ok_status(),
|
||||
session_id.inner(),
|
||||
request.log_summary(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ pub(super) async fn unsafe_database_exists(
|
||||
}
|
||||
|
||||
pub async fn complete_database_name(
|
||||
database_prefix: String,
|
||||
database_prefix: &str,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -87,7 +87,7 @@ pub async fn complete_database_name(
|
||||
}
|
||||
|
||||
pub async fn create_databases(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -95,7 +95,7 @@ pub async fn create_databases(
|
||||
) -> CreateDatabasesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
unix_user,
|
||||
@@ -143,7 +143,7 @@ pub async fn create_databases(
|
||||
}
|
||||
|
||||
pub async fn drop_databases(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -151,7 +151,7 @@ pub async fn drop_databases(
|
||||
) -> DropDatabasesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
unix_user,
|
||||
@@ -242,7 +242,7 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
|
||||
}
|
||||
|
||||
pub async fn list_databases(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -250,7 +250,7 @@ pub async fn list_databases(
|
||||
) -> ListDatabasesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
unix_user,
|
||||
|
||||
@@ -138,7 +138,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
|
||||
}
|
||||
|
||||
pub async fn get_databases_privilege_data(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -146,19 +146,19 @@ pub async fn get_databases_privilege_data(
|
||||
) -> ListPrivilegesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in &database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
&DbOrUser::Database(database_name.to_owned()),
|
||||
unix_user,
|
||||
group_denylist,
|
||||
)
|
||||
.map_err(ListPrivilegesError::ValidationError)
|
||||
{
|
||||
results.insert(database_name.to_owned(), Err(err));
|
||||
results.insert(database_name, Err(err));
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_database_exists(database_name, connection).await {
|
||||
match unsafe_database_exists(&database_name, connection).await {
|
||||
Ok(false) => {
|
||||
results.insert(
|
||||
database_name.to_owned(),
|
||||
@@ -176,7 +176,7 @@ pub async fn get_databases_privilege_data(
|
||||
Ok(true) => {}
|
||||
}
|
||||
|
||||
let result = unsafe_get_database_privileges(database_name, connection)
|
||||
let result = unsafe_get_database_privileges(&database_name, connection)
|
||||
.await
|
||||
.map_err(|e| ListPrivilegesError::MySqlError(e.to_string()));
|
||||
|
||||
@@ -400,7 +400,7 @@ async fn validate_diff(
|
||||
|
||||
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
||||
pub async fn apply_privilege_diffs(
|
||||
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
|
||||
database_privilege_diffs: &BTreeSet<DatabasePrivilegesDiff>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -468,12 +468,12 @@ pub async fn apply_privilege_diffs(
|
||||
Ok(true) => {}
|
||||
}
|
||||
|
||||
if let Err(err) = validate_diff(&diff, connection).await {
|
||||
if let Err(err) = validate_diff(diff, connection).await {
|
||||
results.insert(key, Err(err));
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = unsafe_apply_privilege_diff(&diff, connection)
|
||||
let result = unsafe_apply_privilege_diff(diff, connection)
|
||||
.await
|
||||
.map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ pub(super) async fn unsafe_user_exists(
|
||||
}
|
||||
|
||||
pub async fn complete_user_name(
|
||||
user_prefix: String,
|
||||
user_prefix: &str,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -95,7 +95,7 @@ pub async fn complete_user_name(
|
||||
}
|
||||
|
||||
pub async fn create_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -103,7 +103,7 @@ pub async fn create_database_users(
|
||||
) -> CreateUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(CreateUserError::ValidationError)
|
||||
@@ -141,7 +141,7 @@ pub async fn create_database_users(
|
||||
}
|
||||
|
||||
pub async fn drop_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -149,7 +149,7 @@ pub async fn drop_database_users(
|
||||
) -> DropUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(DropUserError::ValidationError)
|
||||
@@ -188,8 +188,7 @@ pub async fn drop_database_users(
|
||||
|
||||
pub async fn set_password_for_database_user(
|
||||
db_user: &MySQLUser,
|
||||
password: Option<&str>,
|
||||
expiry: Option<chrono::NaiveDate>,
|
||||
password: &str,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -198,44 +197,24 @@ pub async fn set_password_for_database_user(
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(SetPasswordError::ValidationError)?;
|
||||
|
||||
if password.is_none() && expiry.is_some() {
|
||||
return Err(SetPasswordError::ClearPasswordWithExpiry);
|
||||
}
|
||||
|
||||
match unsafe_user_exists(db_user, &mut *connection).await {
|
||||
Ok(false) => return Err(SetPasswordError::UserDoesNotExist),
|
||||
Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = if let Some(password) = password {
|
||||
let mut query = format!(
|
||||
let result = sqlx::query(
|
||||
format!(
|
||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
||||
quote_literal(db_user),
|
||||
quote_literal(password).as_str(),
|
||||
);
|
||||
|
||||
if let Some(expiry_date) = expiry {
|
||||
query.push_str(&format!(" PASSWORD EXPIRE DATE '{}'", expiry_date));
|
||||
}
|
||||
|
||||
sqlx::query(query.as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
|
||||
} else {
|
||||
let query = format!(
|
||||
"ALTER USER {}@'%' IDENTIFIED WITH mysql_native_password AS ''",
|
||||
quote_literal(db_user),
|
||||
);
|
||||
|
||||
sqlx::query(query.as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
|
||||
};
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()));
|
||||
|
||||
if result.is_err() {
|
||||
tracing::error!(
|
||||
@@ -293,7 +272,7 @@ async fn database_user_is_locked_unsafe(
|
||||
}
|
||||
|
||||
pub async fn lock_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -301,7 +280,7 @@ pub async fn lock_database_users(
|
||||
) -> LockUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(LockUserError::ValidationError)
|
||||
@@ -353,7 +332,7 @@ pub async fn lock_database_users(
|
||||
}
|
||||
|
||||
pub async fn unlock_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -361,7 +340,7 @@ pub async fn unlock_database_users(
|
||||
) -> UnlockUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(UnlockUserError::ValidationError)
|
||||
@@ -461,7 +440,7 @@ FROM `user`
|
||||
";
|
||||
|
||||
pub async fn list_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -469,7 +448,7 @@ pub async fn list_database_users(
|
||||
) -> ListUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(ListUsersError::ValidationError)
|
||||
|
||||
@@ -2,7 +2,10 @@ use std::{
|
||||
fs,
|
||||
os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
@@ -22,7 +25,7 @@ use crate::{
|
||||
server::{
|
||||
authorization::read_and_parse_group_denylist,
|
||||
config::{MysqlConfig, ServerConfig},
|
||||
session_handler::session_handler,
|
||||
session_handler::{SessionId, session_handler},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -548,6 +551,8 @@ async fn listener_task(
|
||||
#[cfg(target_os = "linux")]
|
||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
|
||||
|
||||
let connection_counter = AtomicU64::new(0);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
@@ -577,28 +582,29 @@ async fn listener_task(
|
||||
} => {
|
||||
match accept_result {
|
||||
Ok((conn, _addr)) => {
|
||||
tracing::debug!("Got new connection");
|
||||
connection_counter.fetch_add(1, Ordering::Relaxed);
|
||||
let conn_id = connection_counter.load(Ordering::Relaxed);
|
||||
|
||||
tracing::debug!("Got new connection, assigned session ID {}", conn_id);
|
||||
|
||||
let session_id = SessionId::new(conn_id);
|
||||
let db_pool_clone = db_pool.clone();
|
||||
let db_is_mariadb_clone = *db_is_mariadb.read().await;
|
||||
let group_denylist_arc_clone = group_denylist.clone();
|
||||
task_tracker.spawn(async move {
|
||||
match session_handler(
|
||||
conn,
|
||||
session_id,
|
||||
db_pool_clone,
|
||||
db_is_mariadb_clone,
|
||||
&*group_denylist_arc_clone.read().await,
|
||||
).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to run server: {}", e);
|
||||
}
|
||||
Ok(()) => {},
|
||||
Err(e) => tracing::error!("Session {} failed: {}", conn_id, e),
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to accept new connection: {}", e);
|
||||
}
|
||||
},
|
||||
Err(e) => tracing::error!("Failed to accept new connection: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user