1 Commits

Author SHA1 Message Date
oysteikt 3f89eab11a WIP 2026-01-09 19:48:17 +09:00
43 changed files with 517 additions and 1441 deletions
+2 -8
View File
@@ -31,13 +31,7 @@ jobs:
build-deb: build-deb:
strategy: strategy:
matrix: matrix:
os: [ os: [debian-trixie, debian-bookworm, ubuntu-noble, ubuntu-jammy]
debian-trixie,
debian-bookworm,
# ubuntu-resolute,
ubuntu-noble,
ubuntu-jammy,
]
name: Build and publish for ${{ matrix.os }} name: Build and publish for ${{ matrix.os }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@@ -69,7 +63,7 @@ jobs:
- name: Upload deb package artifact - name: Upload deb package artifact
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }} name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }}.zip
path: target/debian/*.deb path: target/debian/*.deb
if-no-files-found: error if-no-files-found: error
retention-days: 30 retention-days: 30
-12
View File
@@ -1,17 +1,5 @@
# Changelog # Changelog
## v1.0.1
Patch release with some important bug fixes
### Notable changes
- `mysql.db.Host` would usually be unset when creating privileges for users, this should be fixed now.
- You might have to manually set this field for rows created with the previous version of muscl to have those privileges work properly.
- Fixed an issue where a few select server responses would refuse to serialize properly, leading to an error message: "No response from server"
- The output of various commands is now being sorted.
- Bump dependencies
## v1.0.0 - Initial Release ## v1.0.0 - Initial Release
This is the initial release of `muscl`. This is the initial release of `muscl`.
Generated
+312 -545
View File
File diff suppressed because it is too large Load Diff
+24 -22
View File
@@ -1,11 +1,12 @@
[package] [package]
name = "muscl" name = "muscl"
version = "1.0.1" version = "0.1.0"
edition = "2024" edition = "2024"
resolver = "2" resolver = "2"
license = "BSD-3-Clause" license = "BSD-3-Clause"
authors = [ authors = [
"Programvareverkstedet <projects@pvv.ntnu.no>", "oysteikt@pvv.ntnu.no",
"felixalb@pvv.ntnu.no",
] ]
homepage = "https://git.pvv.ntnu.no/Projects/muscl" homepage = "https://git.pvv.ntnu.no/Projects/muscl"
repository = "https://git.pvv.ntnu.no/Projects/muscl" repository = "https://git.pvv.ntnu.no/Projects/muscl"
@@ -18,50 +19,51 @@ autobins = false
autolib = false autolib = false
[dependencies] [dependencies]
anyhow = "1.0.102" anyhow = "1.0.100"
async-bincode = "0.8.0" async-bincode = "0.8.0"
bincode = "2.0.1" bincode = "2.0.1"
clap = { version = "4.6.1", features = ["cargo", "derive"] } clap = { version = "4.5.53", features = ["cargo", "derive"] }
clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] } clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] }
clap_complete = { version = "4.6.3", features = ["unstable-dynamic"] } clap_complete = { version = "4.5.62", features = ["unstable-dynamic"] }
clap_mangen = "0.2.31"
color-print = "0.3.7" color-print = "0.3.7"
const_format = "0.2.36" const_format = "0.2.35"
derive_more = { version = "2.1.1", features = ["display", "error"] } derive_more = { version = "2.1.1", features = ["display", "error"] }
dialoguer = "0.12.0" dialoguer = "0.12.0"
futures-util = "0.3.32" futures-util = "0.3.31"
humansize = "2.1.3" humansize = "2.1.3"
indoc = "2.0.7" indoc = "2.0.7"
itertools = "0.14.0" itertools = "0.14.0"
nix = { version = "0.31.2", features = ["fs", "process", "socket", "user"] } nix = { version = "0.30.1", features = ["fs", "process", "socket", "user"] }
num_cpus = "1.17.0" num_cpus = "1.17.0"
prettytable = "0.10.0" prettytable = "0.10.0"
rand = "0.10.1" rand = "0.9.2"
serde = "1.0.228" serde = "1.0.228"
serde_json = { version = "1.0.149", features = ["preserve_order"] } serde_json = { version = "1.0.148", features = ["preserve_order"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] } sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
thiserror = "2.0.18" thiserror = "2.0.17"
tokio = { version = "1.52.1", features = ["rt-multi-thread", "macros", "signal"] } tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros", "signal"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.18" tokio-stream = "0.1.17"
tokio-util = { version = "0.7.18", features = ["codec", "rt"] } tokio-util = { version = "0.7.17", features = ["codec", "rt"] }
toml = "1.1.2" toml = "0.9.10"
tracing = { version = "0.1.44", features = ["log"] } tracing = { version = "0.1.44", features = ["log"] }
tracing-subscriber = "0.3.23" tracing-subscriber = "0.3.22"
uuid = { version = "1.23.1", features = ["v4"] } uuid = { version = "1.19.0", features = ["v4"] }
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]
landlock = "0.4.4" landlock = "0.4.4"
sd-notify = "0.5.0" sd-notify = "0.4.5"
tracing-journald = "0.3.2" tracing-journald = "0.3.2"
[build-dependencies] [build-dependencies]
anyhow = "1.0.102" anyhow = "1.0.100"
build-info-build = "0.0.44" build-info-build = "0.0.42"
git2 = { version = "0.20.4", default-features = false } git2 = { version = "0.20.3", default-features = false }
[dev-dependencies] [dev-dependencies]
pretty_assertions = "1.4.1" pretty_assertions = "1.4.1"
regex = "1.12.3" regex = "1.12.2"
[features] [features]
default = ["mysql-admutils-compatibility"] default = ["mysql-admutils-compatibility"]
+1 -2
View File
@@ -47,8 +47,7 @@ over a IPC, which then performs the requested operations on behalf of the client
## Documentation ## Documentation
- [Installation and initial configuration](docs/installation.md) - [Installation and configuration](docs/installation.md)
- [Administration and further configuration](docs/administration.md)
- [Development and testing](docs/development.md) - [Development and testing](docs/development.md)
- [Compiling and packaging](docs/compiling.md) - [Compiling and packaging](docs/compiling.md)
- [Compatibility mode with mysql-admutils](docs/mysql-admutils-compatibility.md) - [Compatibility mode with mysql-admutils](docs/mysql-admutils-compatibility.md)
+1 -1
View File
@@ -1,5 +1,5 @@
# These are the default system groups on debian. # These are the default system groups on debian.
# You can also add groups by gid by prefixing the line with 'gid:'. # You can alos add groups by gid by prefixing the line with 'gid:'.
group:_ssh group:_ssh
group:adm group:adm
+1 -7
View File
@@ -8,9 +8,7 @@ Type=notify
ExecStart=/usr/bin/muscl-server --systemd --disable-landlock socket-activate ExecStart=/usr/bin/muscl-server --systemd --disable-landlock socket-activate
ExecReload=/usr/bin/kill -HUP $MAINPID ExecReload=/usr/bin/kill -HUP $MAINPID
WatchdogSec=3min WatchdogSec=15
Restart=always
RestartSec=10s
# Although this is a multi-instance unit, the constant `User` field is needed # Although this is a multi-instance unit, the constant `User` field is needed
# for authentication via mysql's auth_socket plugin to work. # for authentication via mysql's auth_socket plugin to work.
@@ -63,7 +61,3 @@ SystemCallFilter=@system-service
SystemCallFilter=~@privileged @resources SystemCallFilter=~@privileged @resources
UMask=0777 UMask=0777
[Install]
Also=muscl.socket
WantedBy=multi-user.target
-1
View File
@@ -3,7 +3,6 @@ Description=Muscl MySQL admin tool
[Socket] [Socket]
ListenStream=/run/muscl/muscl.sock ListenStream=/run/muscl/muscl.sock
RemoveOnStop=true
Accept=no Accept=no
PassCredentials=true PassCredentials=true
-90
View File
@@ -1,90 +0,0 @@
# Administration and further configuration
This page describes some additional configuration options and administration tasks for muscl.
## Configuring group denylists
In `/etc/muscl/muscl.conf`, you will find an option below `[authorization]` named `group_denylist_file`,
which points to `/etc/muscl/group_denylist.txt` by default.
In this file, you can add unix group names or GIDs to disallow the groups from being used as prefixes.
The deb package comes with a default denylist that disallows some common system groups.
The format of the file is one group name or GID per line. Lines starting with `#` and empty lines are ignored.
```
# Disallow using the 'root' group as a prefix
gid:0
# Disallow using the 'adm' group as a prefix
group:adm
```
> [!NOTE]
> If a user is named the same as a disallowed group, that user will still be able to use their username as a prefix.
## Configuring logging
By default, muscl logs to the systemd journal when run as a systemd service,
and also limits the log level to `info`. You can request more verbose logging
by appending `-v` flags to the `ExecStart=` line in the systemd service file.
To do this on a system where muscl was installed using a package, you can override
the service like this:
```bash
sudo systemctl edit muscl.service
```
This will open an editor where you can add the following lines:
```ini
[Service]
ExecStart=
ExecStart=/usr/bin/muscl-server -v ...
```
> [!NOTE]
> The first `ExecStart=` line is necessary to clear the previous value, as systemd
> interprets multiple `ExecStart=` lines as a list of commands to run in sequence.
You set either `-v` or `-vv` for `debug` and `trace` logging, respectively.
> [!WARNING]
> Be careful when enabling trace logging on production systems, as it might log
> passwords and credentials in plaintext.
## Querying logs in the systemd journal
Although invisible if you just run `journalctl -u muscl.service`, muscl adds a set of so-called
"fields" to its log entries to make it easier to filter and search them.
Here are some examples of how you can filter logs using `journalctl`:
```bash
# Show only logs related to a specific user
journalctl -eu muscl F_USER="<username>"
journalctl -eu muscl F_USER=johndoe
# Show only logs for a specific command types
journalctl -eu muscl F_COMMAND="<operation>"
journalctl -eu muscl F_COMMAND=create-db
# Show logs emitted for a specific session id
journalctl -eu muscl F_SESSION_ID="<session-id>"
journalctl -eu muscl F_SESSION_ID=123
# Show all of these fields together with the log message in a json format
journalctl --output json-pretty --output-fields MESSAGE,F_USER,F_COMMAND,F_SESSION_ID -eu muscl
```
See [`journalctl(1)`][journalctl_1] and [`systemd.journal-fields(7)`][systemd_journal-fields_7] for more information.
> [!NOTE]
> Please note that the commands are not 1-1 mapped to muscl subcommands.
> Rather, they are the available requests in the protocol used between the muscl client and server.
> These requests will often have the same name as the subcommands, but this is not always the case.
[journalctl_1]: https://man7.org/linux/man-pages/man1/journalctl.1.html
[systemd_journal-fields_7]: https://man7.org/linux/man-pages/man7/systemd.journal-fields.7.html
+9 -6
View File
@@ -39,12 +39,7 @@ docker stop mariadb
## Development using Nix ## Development using Nix
> [!NOTE] If you have nix installed, you can easily test your changes in a NixOS vm by running:
> We have created some nix code to generate a QEMU VM with a setup similar to a production deployment
> There is not necessarily any VMs running in a production setup, and if so then at least not this VM.
> It is mainly there for easy access to interactive testing, as well as for testing the NixOS module.
If you have nix installed, you can easily test your changes in a NixOS test VM by running:
```bash ```bash
nix run .#vm # Start a NixOS VM in QEMU with muscl and MariaDB installed nix run .#vm # Start a NixOS VM in QEMU with muscl and MariaDB installed
@@ -52,3 +47,11 @@ nix run .#vm-mysql # Start a NixOS VM in QEMU with muscl and MySQL installed
``` ```
You can configure the vm in `flake.nix` You can configure the vm in `flake.nix`
## Filter logs by user with journalctl
If you want to filter the server logs by user, you can use journalctl's built-in filtering capabilities.
```bash
journalctl -eu muscl F_USER=<username>
```
+23 -3
View File
@@ -1,11 +1,9 @@
# Installation and initial configuration # Installation and configuration
This document contains instructions for the recommended way of installing and configuring muscl. This document contains instructions for the recommended way of installing and configuring muscl.
Note that there are separate instructions for [installing on NixOS](nixos.md) and [installing with SUID/SGID mode](suid-sgid-mode.md). Note that there are separate instructions for [installing on NixOS](nixos.md) and [installing with SUID/SGID mode](suid-sgid-mode.md).
After installation, you might want to look at the [Administration and further configuration](administration.md) page.
## Installing with deb on Debian ## Installing with deb on Debian
You can install muscl by adding the [PVV apt repository][pvv-apt-repository] and installing the package: You can install muscl by adding the [PVV apt repository][pvv-apt-repository] and installing the package:
@@ -105,6 +103,28 @@ If you are using systemd, you should also create an override to unset the `Impor
ImportCredential= ImportCredential=
``` ```
## Configuring group denylists
In `/etc/muscl/muscl.conf`, you will find an option below `[authorization]` named `group_denylist_file`,
which points to `/etc/muscl/group_denylist.txt` by default.
In this file, you can add unix group names or GIDs to disallow the groups from being used as prefixes.
The deb package comes with a default denylist that disallows some common system groups.
The format of the file is one group name or GID per line. Lines starting with `#` and empty lines are ignored.
```
# Disallow using the 'root' group as a prefix
gid:0
# Disallow using the 'adm' group as a prefix
group:adm
```
> [!NOTE]
> If a user is named the same as a disallowed group, that user will still be able to use their username as a prefix.
## A note on minimum version requirements ## A note on minimum version requirements
The muscl server will work with older versions of systemd, but the recommended version is 254 or newer. The muscl server will work with older versions of systemd, but the recommended version is 254 or newer.
+2 -2
View File
@@ -4,8 +4,8 @@
> This will be deprecated in a future release, see https://git.pvv.ntnu.no/Projects/muscl/issues/101 > This will be deprecated in a future release, see https://git.pvv.ntnu.no/Projects/muscl/issues/101
> >
> We do not recommend you use this mode unless you absolutely have to. The biggest reason why `muscl` was rewritten from scratch > We do not recommend you use this mode unless you absolutely have to. The biggest reason why `muscl` was rewritten from scratch
> was to fix an architectural issue that easily caused vulnerabilities due to reliance on SUID/SGID. Although the architecture now > was to fix an architectural issue that easily caused vulnerabilites due to reliance on SUID/SGID. Althought the architecture now
> is more resistant against such vulnerabilities, it is not failsafe. > is more resistant against such vulnerabilites, it is not failsafe.
For backwards compatibility reasons, it is possible to run the program without a daemon by utilizing SUID/SGID. For backwards compatibility reasons, it is possible to run the program without a daemon by utilizing SUID/SGID.
Generated
+9 -9
View File
@@ -2,11 +2,11 @@
"nodes": { "nodes": {
"crane": { "crane": {
"locked": { "locked": {
"lastModified": 1774313767, "lastModified": 1766774972,
"narHash": "sha256-hy0XTQND6avzGEUFrJtYBBpFa/POiiaGBr2vpU6Y9tY=", "narHash": "sha256-8qxEFpj4dVmIuPn9j9z6NTbU+hrcGjBOvaxTzre5HmM=",
"owner": "ipetkov", "owner": "ipetkov",
"repo": "crane", "repo": "crane",
"rev": "3d9df76e29656c679c744968b17fbaf28f0e923d", "rev": "01bc1d404a51a0a07e9d8759cd50a7903e218c82",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -17,11 +17,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1775036866, "lastModified": 1766902085,
"narHash": "sha256-ZojAnPuCdy657PbTq5V0Y+AHKhZAIwSIT2cb8UgAz/U=", "narHash": "sha256-coBu0ONtFzlwwVBzmjacUQwj3G+lybcZ1oeNSQkgC0M=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "6201e203d09599479a3b3450ed24fa81537ebc4e", "rev": "c0b0e0fddf73fd517c3471e546c0df87a42d53f4",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -45,11 +45,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1775099554, "lastModified": 1766976750,
"narHash": "sha256-3xBsGnGDLOFtnPZ1D3j2LU19wpAlYefRKTlkv648rU0=", "narHash": "sha256-w+o3AIBI56tzfMJRqRXg9tSXnpQRN5hAT15o2t9rxYw=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "8d6387ed6d8e6e6672fd3ed4b61b59d44b124d99", "rev": "9fe44e7f05b734a64a01f92fc51ad064fb0a884f",
"type": "github" "type": "github"
}, },
"original": { "original": {
+1 -1
View File
@@ -88,7 +88,7 @@ buildFunction ({
''; '';
meta = with lib; { meta = with lib; {
license = licenses.bsd3; license = licenses.mit;
platforms = platforms.linux ++ platforms.darwin; platforms = platforms.linux ++ platforms.darwin;
inherit mainProgram; inherit mainProgram;
}; };
+4 -3
View File
@@ -155,14 +155,15 @@ in
systemd.services."muscl" = { systemd.services."muscl" = {
reloadTriggers = [ config.environment.etc."muscl/config.toml".source ]; reloadTriggers = [ config.environment.etc."muscl/config.toml".source ];
serviceConfig = { serviceConfig = {
Type = "notify-reload";
ExecStart = [ ExecStart = [
"" ""
"${lib.getExe' cfg.package "muscl-server"} ${cfg.logLevel} --systemd --disable-landlock socket-activate" "${lib.getExe' cfg.package "muscl-server"} ${cfg.logLevel} --systemd --disable-landlock socket-activate"
]; ];
ExecReload = ""; ExecReload = [
ReloadSignal = "SIGHUP"; ""
"${lib.getExe' pkgs.coreutils "kill"} -HUP $MAINPID"
];
RuntimeDirectory = "muscl/root-mnt"; RuntimeDirectory = "muscl/root-mnt";
RuntimeDirectoryMode = "0700"; RuntimeDirectoryMode = "0700";
+5 -15
View File
@@ -42,15 +42,7 @@ declare -r GIT_SHA="$2"
TMPDIR="$(mktemp -d)" TMPDIR="$(mktemp -d)"
declare -a OS_VARIANTS=( for variant in debian-bookworm debian-trixie ubuntu-jammy ubuntu-noble; do
"debian-bookworm"
"debian-trixie"
"ubuntu-jammy"
"ubuntu-noble"
# "ubuntu-resolute"
)
for variant in "${OS_VARIANTS[@]}"; do
echo "Downloading and uploading debs for variant: $variant" echo "Downloading and uploading debs for variant: $variant"
curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA.zip" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA.zip" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip"
@@ -62,13 +54,11 @@ for variant in "${OS_VARIANTS[@]}"; do
DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1) DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1)
DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1) DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1)
# echo "[DELETE] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH" curl \
# curl \ -X DELETE \
# -X DELETE \ --user "$GITEA_USER:$GITEA_TOKEN" \
# --user "$GITEA_USER:$GITEA_TOKEN" \ "https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
# "https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
echo "[PUT] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/upload"
curl \ curl \
-X PUT \ -X PUT \
--user "$GITEA_USER:$GITEA_TOKEN" \ --user "$GITEA_USER:$GITEA_TOKEN" \
+48 -3
View File
@@ -305,6 +305,10 @@ fn main() -> anyhow::Result<()> {
return Ok(()); return Ok(());
} }
if handle_manpage_command()?.is_some() {
return Ok(());
}
#[cfg(feature = "mysql-admutils-compatibility")] #[cfg(feature = "mysql-admutils-compatibility")]
if handle_mysql_admutils_command()?.is_some() { if handle_mysql_admutils_command()?.is_some() {
return Ok(()); return Ok(());
@@ -319,8 +323,7 @@ fn main() -> anyhow::Result<()> {
#[cfg(not(feature = "suid-sgid-mode"))] #[cfg(not(feature = "suid-sgid-mode"))]
None, None,
args.verbose, args.verbose,
) )?;
.context("Failed to connect to the server")?;
tokio_run_command(args.command, connection)?; tokio_run_command(args.command, connection)?;
@@ -362,6 +365,48 @@ fn handle_dynamic_completion() -> anyhow::Result<Option<()>> {
} }
} }
/// **WARNING:** This function may be run with elevated privileges.
fn handle_manpage_command() -> anyhow::Result<Option<()>> {
let argv1: Option<String> = std::env::args().nth(1);
match argv1.as_deref() {
Some("generate-manpages") => {
#[cfg(feature = "suid-sgid-mode")]
if executing_in_suid_sgid_mode()? {
use muscl_lib::core::bootstrap::drop_privs;
drop_privs()?
}
let output_dir = std::env::args().nth(2).ok_or(anyhow::anyhow!(
"Output directory argument missing for manpage generation"
))?;
let output_dir = PathBuf::from(&output_dir);
if !output_dir.is_dir() {
anyhow::bail!(
"Output directory `{:?}` does not exist or is not a directory",
output_dir,
);
}
let mut roff = clap_mangen::roff::Roff::new();
let man = clap_mangen::Man::new(Args::command());
man.render_title(&mut std::io::stdout())?;
man.render_name_section(&mut std::io::stdout())?;
man.render_synopsis_section(&mut std::io::stdout())?;
man.render_subcommands_section(&mut std::io::stdout())?;
man.render_options_section(&mut std::io::stdout())?;
roff.control("SH", ["VERSION"]);
roff.text([clap_mangen::roff::roman(AFTER_LONG_HELP)]);
roff.to_writer(&mut std::io::stdout())?;
Ok(Some(()))
}
_ => Ok(None),
}
}
/// **WARNING:** This function may be run with elevated privileges. /// **WARNING:** This function may be run with elevated privileges.
fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> { fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
let argv0 = std::env::args().next().and_then(|s| { let argv0 = std::env::args().next().and_then(|s| {
@@ -377,7 +422,7 @@ fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
} }
} }
/// Run the given command (from the client side) using Tokio. /// Run the given commmand (from the client side) using Tokio.
fn tokio_run_command( fn tokio_run_command(
command: ClientCommand, command: ClientCommand,
server_connection: StdUnixStream, server_connection: StdUnixStream,
+3 -8
View File
@@ -8,7 +8,6 @@ use clap::{Args, Parser};
use clap_complete::ArgValueCompleter; use clap_complete::ArgValueCompleter;
use dialoguer::{Confirm, Editor}; use dialoguer::{Confirm, Editor};
use futures_util::SinkExt; use futures_util::SinkExt;
use itertools::Itertools;
use nix::unistd::{User, getuid}; use nix::unistd::{User, getuid};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@@ -204,13 +203,9 @@ pub async fn edit_database_privileges(
} }
}) })
.flatten() .flatten()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
Ok(list) => list Ok(list) => list,
.into_iter()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect(),
Err(err) => { Err(err) => {
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
return Err(anyhow::anyhow!(err.to_error_message()) return Err(anyhow::anyhow!(err.to_error_message())
@@ -310,7 +305,7 @@ pub async fn edit_database_privileges(
print_modify_database_privileges_output_status(&result); print_modify_database_privileges_output_status(&result);
if result.values().flatten().any(|(_, res)| { if result.iter().any(|(_, res)| {
matches!( matches!(
res, res,
Err(ModifyDatabasePrivilegesError::UserValidationError( Err(ModifyDatabasePrivilegesError::UserValidationError(
@@ -325,7 +320,7 @@ pub async fn edit_database_privileges(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if result.values().flatten().any(|(_, res)| res.is_err()) { if result.values().any(std::result::Result::is_err) {
std::process::exit(1); std::process::exit(1);
} }
@@ -35,7 +35,7 @@ spawn the editor stored in the $EDITOR environment variable.
(pico will be used if the variable is unset) (pico will be used if the variable is unset)
The file should contain one line per user, starting with the The file should contain one line per user, starting with the
username and followed by ten Y/N-values separated by whitespace. username and followed by ten Y/N-values seperated by whitespace.
Lines starting with # are ignored. Lines starting with # are ignored.
The Y/N-values corresponds to the following mysql privileges: The Y/N-values corresponds to the following mysql privileges:
+3 -29
View File
@@ -1,6 +1,5 @@
use std::{ use std::{
fs, fs,
os::unix::fs::FileTypeExt,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::Duration, time::Duration,
@@ -8,10 +7,7 @@ use std::{
use anyhow::{Context, anyhow}; use anyhow::{Context, anyhow};
use clap_verbosity_flag::{InfoLevel, Verbosity}; use clap_verbosity_flag::{InfoLevel, Verbosity};
use nix::{ use nix::libc::{EXIT_SUCCESS, exit};
libc::{EXIT_SUCCESS, exit},
unistd::{AccessFlags, access},
};
use sqlx::mysql::MySqlPoolOptions; use sqlx::mysql::MySqlPoolOptions;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock}; use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
@@ -134,28 +130,11 @@ pub fn bootstrap_server_connection_and_drop_privileges(
} }
} }
fn socket_path_is_ok(path: &Path) -> anyhow::Result<()> {
fs::metadata(path)
.context(format!("Failed to get metadata for {:?}", path))
.and_then(|meta| {
if !meta.file_type().is_socket() {
anyhow::bail!("{:?} is not a unix socket", path);
}
access(path, AccessFlags::R_OK | AccessFlags::W_OK)
.with_context(|| format!("Socket at {:?} is not readable/writable", path))?;
Ok(())
})
}
fn connect_to_external_server( fn connect_to_external_server(
server_socket_path: Option<PathBuf>, server_socket_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> { ) -> anyhow::Result<StdUnixStream> {
// TODO: ensure this is both readable and writable
if let Some(socket_path) = server_socket_path { if let Some(socket_path) = server_socket_path {
tracing::trace!("Checking socket at {:?}", socket_path);
socket_path_is_ok(&socket_path)?;
tracing::debug!("Connecting to socket at {:?}", socket_path); tracing::debug!("Connecting to socket at {:?}", socket_path);
return match StdUnixStream::connect(socket_path) { return match StdUnixStream::connect(socket_path) {
Ok(socket) => Ok(socket), Ok(socket) => Ok(socket),
@@ -168,9 +147,6 @@ fn connect_to_external_server(
} }
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
tracing::trace!("Checking socket at {:?}", DEFAULT_SOCKET_PATH);
socket_path_is_ok(Path::new(DEFAULT_SOCKET_PATH))?;
tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH); tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) { return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
Ok(socket) => Ok(socket), Ok(socket) => Ok(socket),
@@ -182,9 +158,7 @@ fn connect_to_external_server(
}; };
} }
anyhow::bail!( anyhow::bail!("No socket path provided, and no default socket found");
"No socket path provided, and no socket found found at default location {DEFAULT_SOCKET_PATH}"
);
} }
// TODO: this function is security critical, it should be integration tested // TODO: this function is security critical, it should be integration tested
+1 -1
View File
@@ -100,7 +100,7 @@ impl UnixUser {
}) })
} }
// pub fn from_environment() -> anyhow::Result<Self> { // pub fn from_enviroment() -> anyhow::Result<Self> {
// let libc_uid = nix::unistd::getuid(); // let libc_uid = nix::unistd::getuid();
// UnixUser::from_uid(libc_uid.as_raw()) // UnixUser::from_uid(libc_uid.as_raw())
// } // }
+1 -1
View File
@@ -29,7 +29,7 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
// doesn't have any natural implementation semantics. // doesn't have any natural implementation semantics.
/// Representation of the set of privileges for a single user on a single database. /// Representation of the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Default)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow { pub struct DatabasePrivilegeRow {
// TODO: don't store the db and user here, let the type be stored in a mapping // TODO: don't store the db and user here, let the type be stored in a mapping
pub db: MySQLDatabase, pub db: MySQLDatabase,
+3 -6
View File
@@ -324,12 +324,9 @@ impl Response {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
} }
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()), Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ModifyPrivileges(res) => ResponseOkStatus::from_counts( Response::ModifyPrivileges(res) => {
res.len(), ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
res.values() }
.map(|user_map| user_map.values().filter(|v| v.is_ok()).count())
.sum(),
),
Response::CreateUsers(res) => { Response::CreateUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
@@ -67,43 +67,3 @@ impl CheckAuthorizationError {
self.0.error_type() self.0.error_type()
} }
} }
#[cfg(test)]
mod tests {
use crate::core::protocol::request_validation::NameValidationError;
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CheckAuthorizationRequest = vec![
DbOrUser::Database("test_db".into()),
DbOrUser::User("test_user".into()),
];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CheckAuthorizationRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CheckAuthorizationResponse = BTreeMap::from([
(DbOrUser::Database("test_db".into()), Ok(())),
(
DbOrUser::User("test_user".into()),
Err(CheckAuthorizationError(
ValidationError::NameValidationError(NameValidationError::TooLong),
)),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CheckAuthorizationResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,41 +87,3 @@ impl CreateDatabaseError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateDatabasesRequest =
vec!["test_db1".into(), "test_db2".into(), "test_db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateDatabasesResponse = BTreeMap::from([
("test_db1".into(), Ok(())),
(
"test_db2".into(),
Err(CreateDatabaseError::DatabaseAlreadyExists),
),
(
"test_db3".into(),
Err(CreateDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,37 +87,3 @@ impl CreateUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(CreateUserError::UserAlreadyExists)),
(
"charlie".into(),
Err(CreateUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -90,37 +90,3 @@ impl DropDatabaseError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropDatabasesRequest = vec!["db1".into(), "db2".into(), "db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropDatabasesResponse = BTreeMap::from([
("db1".into(), Ok(())),
("db2".into(), Err(DropDatabaseError::DatabaseDoesNotExist)),
(
"db3".into(),
Err(DropDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
-34
View File
@@ -87,37 +87,3 @@ impl DropUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(DropUserError::UserDoesNotExist)),
(
"charlie".into(),
Err(DropUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,36 +27,3 @@ impl ListAllDatabasesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllDatabasesResponse = Ok(vec![
DatabaseRow {
database: "db1".into(),
tables: vec!["table1".into(), "table2".into()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 1024,
},
DatabaseRow {
database: "db2".into(),
tables: vec!["table3".into(), "table4".into()],
users: vec!["user3".into(), "user4".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 2048,
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,34 +27,3 @@ impl ListAllPrivilegesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllPrivilegesResponse = Ok(vec![
DatabasePrivilegeRow {
user: "user1".into(),
db: "db1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
},
DatabasePrivilegeRow {
user: "user2".into(),
db: "db2".into(),
select_priv: false,
insert_priv: true,
..Default::default()
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,37 +27,3 @@ impl ListAllUsersError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllUsersResponse = Ok(vec![
DatabaseUser {
user: "user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
},
DatabaseUser {
user: "user2".into(),
host: "%".into(),
has_password: false,
is_locked: true,
databases: vec!["db3".into()],
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListAllUsersResponse = serde_json::from_str(&json).unwrap();
deserialized.as_mut().unwrap()[0].host = "%".into();
deserialized.as_mut().unwrap()[1].host = "%".into();
assert_eq!(response, deserialized);
}
}
+1 -42
View File
@@ -61,7 +61,7 @@ pub fn print_list_databases_output_status(
"Size" "Size"
} }
]); ]);
for db in final_database_list.iter().sorted_by_key(|db| &db.database) { for db in final_database_list {
table.add_row(row![ table.add_row(row![
db.database, db.database,
db.tables.join("\n"), db.tables.join("\n"),
@@ -137,44 +137,3 @@ impl ListDatabasesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request = Some(vec!["db1".into(), "db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListDatabasesResponse = vec![
(
"db1".into(),
Ok(DatabaseRow {
database: "db1".into(),
tables: vec!["table1".to_string(), "table2".to_string()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".to_string()),
character_set: Some("utf8mb4".to_string()),
size_bytes: 1024,
}),
),
("db2".into(), Err(ListDatabasesError::DatabaseDoesNotExist)),
]
.into_iter()
.collect();
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+18 -62
View File
@@ -64,28 +64,25 @@ pub fn print_list_privileges_output_status(output: &ListPrivilegesResponse, long
.collect(), .collect(),
)); ));
for row in final_privs_map for (_database, rows) in final_privs_map {
.values() for row in &rows {
.flatten() table.add_row(row![
.sorted_by_key(|row| (&row.db, &row.user)) row.db,
{ row.user,
table.add_row(row![ c->yn(row.select_priv),
row.db, c->yn(row.insert_priv),
row.user, c->yn(row.update_priv),
c->yn(row.select_priv), c->yn(row.delete_priv),
c->yn(row.insert_priv), c->yn(row.create_priv),
c->yn(row.update_priv), c->yn(row.drop_priv),
c->yn(row.delete_priv), c->yn(row.alter_priv),
c->yn(row.create_priv), c->yn(row.index_priv),
c->yn(row.drop_priv), c->yn(row.create_tmp_table_priv),
c->yn(row.alter_priv), c->yn(row.lock_tables_priv),
c->yn(row.index_priv), c->yn(row.references_priv),
c->yn(row.create_tmp_table_priv), ]);
c->yn(row.lock_tables_priv), }
c->yn(row.references_priv),
]);
} }
// }
table.printstd(); table.printstd();
} }
@@ -156,44 +153,3 @@ impl ListPrivilegesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListPrivilegesRequest = Some(vec!["test_db1".into(), "test_db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListPrivilegesResponse = BTreeMap::from([
(
"test_db1".into(),
Ok(vec![DatabasePrivilegeRow {
db: "test_db1".into(),
user: "user1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
}]),
),
(
"test_db2".into(),
Err(ListPrivilegesError::DatabaseDoesNotExist),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+1 -47
View File
@@ -1,6 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use itertools::Itertools;
use prettytable::Table; use prettytable::Table;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
@@ -52,7 +51,7 @@ pub fn print_list_users_output_status(output: &ListUsersResponse) {
"Locked", "Locked",
"Databases where user has privileges" "Databases where user has privileges"
]); ]);
for user in final_user_list.iter().sorted_by_key(|user| &user.user) { for user in final_user_list {
table.add_row(row![ table.add_row(row![
user.user, user.user,
user.has_password, user.has_password,
@@ -122,48 +121,3 @@ impl ListUsersError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListUsersRequest = Some(vec!["test_user1".into(), "test_user2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: ListUsersResponse = BTreeMap::from([
(
"test_user1".into(),
Ok(DatabaseUser {
user: "test_user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
}),
),
("test_user2".into(), Err(ListUsersError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListUsersResponse = serde_json::from_str(&json).unwrap();
deserialized
.get_mut(&"test_user1".into())
.unwrap()
.as_mut()
.unwrap()
.host = "%".into();
assert_eq!(response_ok, deserialized);
}
}
-30
View File
@@ -94,33 +94,3 @@ impl LockUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: LockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: LockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: LockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(LockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: LockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
@@ -12,7 +12,7 @@ use crate::core::{
pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>; pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>;
pub type ModifyPrivilegesResponse = pub type ModifyPrivilegesResponse =
BTreeMap<MySQLDatabase, BTreeMap<MySQLUser, Result<(), ModifyDatabasePrivilegesError>>>; BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>;
#[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModifyDatabasePrivilegesError { pub enum ModifyDatabasePrivilegesError {
@@ -49,11 +49,7 @@ pub enum DiffDoesNotApplyError {
} }
pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) { pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) {
for ((database_name, username), result) in output.iter().flat_map(|(db, user_map)| { for ((database_name, username), result) in output {
user_map
.iter()
.map(move |(user, result)| ((db, user), result))
}) {
match result { match result {
Ok(()) => { Ok(()) => {
println!( println!(
@@ -148,46 +144,3 @@ impl DiffDoesNotApplyError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::core::*;
#[test]
fn test_serialize_deserialize_request() {
let request =
BTreeSet::from([DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "test_db".into(),
user: "test_user".into(),
select_priv: Some(database_privileges::DatabasePrivilegeChange::NoToYes),
..Default::default()
})]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ModifyPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ModifyPrivilegesResponse = BTreeMap::from([(
"test_db".into(),
BTreeMap::from([
("test_user".into(), Ok(())),
(
"invalid_user".into(),
Err(ModifyDatabasePrivilegesError::UserDoesNotExist),
),
]),
)]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ModifyPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
-32
View File
@@ -60,35 +60,3 @@ impl SetPasswordError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: SetUserPasswordRequest = ("test_user".into(), "new_password".into());
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: SetUserPasswordRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: SetUserPasswordResponse = Ok(());
let response_err: SetUserPasswordResponse = Err(SetPasswordError::UserDoesNotExist);
let json_ok = serde_json::to_string_pretty(&response_ok).unwrap();
let json_err = serde_json::to_string_pretty(&response_err).unwrap();
println!("Serialized OK response:\n{}", json_ok);
println!("Serialized Error response:\n{}", json_err);
let deserialized_ok: SetUserPasswordResponse = serde_json::from_str(&json_ok).unwrap();
let deserialized_err: SetUserPasswordResponse = serde_json::from_str(&json_err).unwrap();
assert_eq!(response_ok, deserialized_ok);
assert_eq!(response_err, deserialized_err);
}
}
@@ -94,33 +94,3 @@ impl UnlockUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: UnlockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: UnlockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: UnlockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(UnlockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: UnlockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
+3 -34
View File
@@ -34,7 +34,7 @@ impl DerefMut for MySQLUser {
impl fmt::Display for MySQLUser { impl fmt::Display for MySQLUser {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f) write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0))
} }
} }
@@ -83,7 +83,7 @@ impl DerefMut for MySQLDatabase {
impl fmt::Display for MySQLDatabase { impl fmt::Display for MySQLDatabase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f) write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0))
} }
} }
@@ -105,43 +105,12 @@ impl From<MySQLDatabase> for OsString {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum DbOrUser { pub enum DbOrUser {
Database(MySQLDatabase), Database(MySQLDatabase),
User(MySQLUser), User(MySQLUser),
} }
impl Serialize for DbOrUser {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
DbOrUser::Database(db) => ("d:".to_string() + &db.to_string()).serialize(serializer),
DbOrUser::User(user) => ("u:".to_string() + &user.to_string()).serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for DbOrUser {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Some(rest) = s.strip_prefix("d:") {
Ok(DbOrUser::Database(MySQLDatabase(rest.to_string())))
} else if let Some(rest) = s.strip_prefix("u:") {
Ok(DbOrUser::User(MySQLUser(rest.to_string())))
} else {
Err(serde::de::Error::custom(format!(
"Invalid DbOrUser format: {}",
s
)))
}
}
}
impl DbOrUser { impl DbOrUser {
#[must_use] #[must_use]
pub fn lowercased_noun(&self) -> &'static str { pub fn lowercased_noun(&self) -> &'static str {
-4
View File
@@ -59,10 +59,6 @@ fn parse_group_denylist(denylist_path: &Path, lines: Lines) -> GroupDenylist {
} }
.trim(); .trim();
if trimmed_line.is_empty() {
continue;
}
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect(); let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
if parts.len() != 2 { if parts.len() != 2 {
tracing::warn!( tracing::warn!(
+22 -38
View File
@@ -114,7 +114,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'", "SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
@@ -234,12 +234,11 @@ async fn unsafe_apply_privilege_diff(
DatabasePrivilegesDiff::New(p) => { DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS let tables = DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.chain(&["Host"])
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","); .join(",");
let question_marks = let question_marks =
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(","); std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len()).join(",");
sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str()) sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str())
.bind(p.db.to_string()) .bind(p.db.to_string())
@@ -255,7 +254,6 @@ async fn unsafe_apply_privilege_diff(
.bind(yn(p.create_tmp_table_priv)) .bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv)) .bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv)) .bind(yn(p.references_priv))
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -280,33 +278,28 @@ async fn unsafe_apply_privilege_diff(
} }
} }
sqlx::query( sqlx::query(format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ?").as_str())
format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?") .bind(p.select_priv.map(change_to_yn))
.as_str(), .bind(p.insert_priv.map(change_to_yn))
) .bind(p.update_priv.map(change_to_yn))
.bind(p.select_priv.map(change_to_yn)) .bind(p.delete_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn)) .bind(p.create_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn)) .bind(p.drop_priv.map(change_to_yn))
.bind(p.delete_priv.map(change_to_yn)) .bind(p.alter_priv.map(change_to_yn))
.bind(p.create_priv.map(change_to_yn)) .bind(p.index_priv.map(change_to_yn))
.bind(p.drop_priv.map(change_to_yn)) .bind(p.create_tmp_table_priv.map(change_to_yn))
.bind(p.alter_priv.map(change_to_yn)) .bind(p.lock_tables_priv.map(change_to_yn))
.bind(p.index_priv.map(change_to_yn)) .bind(p.references_priv.map(change_to_yn))
.bind(p.create_tmp_table_priv.map(change_to_yn)) .bind(p.db.to_string())
.bind(p.lock_tables_priv.map(change_to_yn)) .bind(p.user.to_string())
.bind(p.references_priv.map(change_to_yn)) .execute(connection)
.bind(p.db.to_string()) .await
.bind(p.user.to_string()) .map(|_| ())
.bind("%") }
.execute(connection) DatabasePrivilegesDiff::Deleted(p) => {
.await sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ?")
.map(|_| ())
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -488,13 +481,4 @@ pub async fn apply_privilege_diffs(
} }
results results
.into_iter()
.map(|((k1, k2), v)| (k1, (k2, v)))
.into_group_map()
.into_iter()
.map(|(k1, pairs)| {
let inner = pairs.into_iter().collect::<BTreeMap<_, _>>();
(k1, inner)
})
.collect()
} }
+3 -5
View File
@@ -39,7 +39,6 @@ pub(super) async fn unsafe_user_exists(
SELECT 1 SELECT 1
FROM `mysql`.`user` FROM `mysql`.`user`
WHERE `User` = ? WHERE `User` = ?
AND `Host` = '%'
) )
", ",
) )
@@ -68,7 +67,6 @@ pub async fn complete_user_name(
FROM `mysql`.`user` FROM `mysql`.`user`
WHERE `User` REGEXP ? WHERE `User` REGEXP ?
AND `User` LIKE ? AND `User` LIKE ?
AND `Host` = '%'
", ",
) )
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
@@ -464,7 +462,7 @@ pub async fn list_database_users(
DB_USER_SELECT_STATEMENT_MARIADB.to_string() DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else { } else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'"), } + "WHERE `mysql`.`user`.`User` = ?"),
) )
.bind(db_user.as_str()) .bind(db_user.as_str())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
@@ -501,7 +499,7 @@ pub async fn list_all_database_users_for_unix_user(
DB_USER_SELECT_STATEMENT_MARIADB.to_string() DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else { } else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'"), } + "WHERE `user`.`User` REGEXP ?"),
) )
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
@@ -536,7 +534,7 @@ pub async fn set_databases_where_user_has_privileges(
r" r"
SELECT `Db` AS `database` SELECT `Db` AS `database`
FROM `db` FROM `db`
WHERE `User` = ? AND `Host` = '%' AND ({}) WHERE `User` = ? AND ({})
", ",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
+13 -14
View File
@@ -90,12 +90,14 @@ impl Supervisor {
}; };
let mut watchdog_duration = None; let mut watchdog_duration = None;
let mut watchdog_micro_seconds = 0;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let watchdog_task = let watchdog_task =
if systemd_mode && let Some(watchdog_duration_) = sd_notify::watchdog_enabled() { if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
let watchdog_duration_ = Duration::from_micros(watchdog_micro_seconds);
tracing::debug!( tracing::debug!(
"Systemd watchdog enabled with {} millisecond interval", "Systemd watchdog enabled with {} millisecond interval",
watchdog_duration_.as_millis() watchdog_micro_seconds.div_ceil(1000),
); );
watchdog_duration = Some(watchdog_duration_); watchdog_duration = Some(watchdog_duration_);
Some(spawn_watchdog_task(watchdog_duration_)) Some(spawn_watchdog_task(watchdog_duration_))
@@ -138,7 +140,7 @@ impl Supervisor {
let (tx, rx) = broadcast::channel(1); let (tx, rx) = broadcast::channel(1);
// TODO: try to detect systemd socket before using the provided socket path // TODO: try to detech systemd socket before using the provided socket path
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let listener = Arc::new(RwLock::new(match config.socket_path { let listener = Arc::new(RwLock::new(match config.socket_path {
Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?, Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
@@ -293,12 +295,7 @@ impl Supervisor {
pub async fn reload(&self) -> anyhow::Result<()> { pub async fn reload(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[ sd_notify::notify(false, &[sd_notify::NotifyState::Reloading])?;
sd_notify::NotifyState::Reloading,
sd_notify::NotifyState::monotonic_usec_now()
.expect("Failed to get monotonic time to send to systemd while reloading"),
sd_notify::NotifyState::Status("Reloading configuration"),
])?;
let previous_config = self.config.lock().await.clone(); let previous_config = self.config.lock().await.clone();
self.reload_config().await?; self.reload_config().await?;
@@ -335,14 +332,14 @@ impl Supervisor {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
Ok(()) Ok(())
} }
pub async fn shutdown(&self) -> anyhow::Result<()> { pub async fn shutdown(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[sd_notify::NotifyState::Stopping])?; sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?;
tracing::debug!("Stop accepting new connections"); tracing::debug!("Stop accepting new connections");
self.stop_receiving_new_connections()?; self.stop_receiving_new_connections()?;
@@ -412,7 +409,7 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
); );
loop { loop {
interval.tick().await; interval.tick().await;
if let Err(err) = sd_notify::notify(&[sd_notify::NotifyState::Watchdog]) { if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
tracing::warn!("Failed to notify systemd watchdog: {}", err); tracing::warn!("Failed to notify systemd watchdog: {}", err);
} }
} }
@@ -435,7 +432,9 @@ fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
"Waiting for connections".to_string() "Waiting for connections".to_string()
}; };
if let Err(e) = sd_notify::notify(&[sd_notify::NotifyState::Status(message.as_str())]) { if let Err(e) =
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())])
{
tracing::warn!("Failed to send systemd status notification: {}", e); tracing::warn!("Failed to send systemd status notification: {}", e);
} }
} }
@@ -550,7 +549,7 @@ async fn listener_task(
group_denylist: Arc<RwLock<GroupDenylist>>, group_denylist: Arc<RwLock<GroupDenylist>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
let connection_counter = AtomicU64::new(0); let connection_counter = AtomicU64::new(0);