23 Commits

Author SHA1 Message Date
oysteikt 43e4cc45ca server/sql: great performance improvements for listing databases
Build and test / check-license (push) Successful in 53s
Build and test / check (push) Successful in 1m54s
Build and test / build (push) Successful in 3m5s
Build and test / test (push) Successful in 3m9s
Build and test / docs (push) Successful in 5m36s
2026-05-31 02:01:44 +09:00
oysteikt 62b1b66bb6 CHANGELOG.md: fix broken link
Build and test / check-license (push) Successful in 55s
Build and test / check (push) Successful in 2m38s
Build and test / build (push) Successful in 2m42s
Build and test / test (push) Successful in 3m8s
Build and test / docs (push) Successful in 6m8s
2026-05-31 00:44:52 +09:00
oysteikt f16239aceb server/sql: fixes for new sqlx crate version
Build and test / check-license (push) Successful in 49s
Build and test / check (push) Successful in 1m51s
Build and test / build (push) Successful in 2m42s
Build and test / test (push) Successful in 5m2s
Build and test / docs (push) Successful in 7m6s
2026-05-31 00:24:53 +09:00
oysteikt 8f475eced1 CHANGELOG.md: add release notes, Cargo.toml: bump version number
Build and test / check-license (push) Successful in 49s
Build and test / check (push) Failing after 1m57s
Build and test / test (push) Failing after 2m53s
Build and test / build (push) Failing after 3m11s
Build and test / docs (push) Failing after 4m15s
2026-05-31 00:09:49 +09:00
oysteikt 6849e99c11 flake.lock: bump, Cargo.{toml,lock}: update inputs 2026-05-31 00:09:40 +09:00
oysteikt 759df9ef42 server/sql: flush privileges after modification
Build and test / check-license (push) Successful in 54s
Build and test / check (push) Successful in 1m42s
Build and test / test (push) Successful in 3m34s
Build and test / build (push) Successful in 3m39s
Build and test / docs (push) Successful in 5m33s
2026-04-28 19:10:16 +09:00
oysteikt a64d1fa1bf scripts/download-and-upload-debs: fix download path
Build and test / check-license (push) Successful in 47s
Build and test / check (push) Successful in 2m20s
Build and test / build (push) Successful in 3m6s
Build and test / test (push) Successful in 3m12s
Build and test / docs (push) Successful in 6m20s
2026-04-28 18:32:28 +09:00
oysteikt 6404e5011a CHANGELOG.md: add release notes, Cargo.toml: bump version number
Build and test / check-license (push) Successful in 52s
Build and test / check (push) Successful in 1m44s
Build and test / build (push) Successful in 3m1s
Build and test / test (push) Successful in 3m12s
Build and test / docs (push) Successful in 7m24s
2026-04-28 18:14:48 +09:00
oysteikt 531fdfc2e9 Cargo.{toml,lock}: bump deps 2026-04-28 18:14:19 +09:00
oysteikt af74e8e540 .gitea/workflows/publish-deb: temporarily disable ubuntu resolute
Build and test / check-license (push) Successful in 51s
Build and test / check (push) Successful in 1m47s
Build and test / build (push) Successful in 2m44s
Build and test / test (push) Successful in 3m31s
Build and test / docs (push) Successful in 5m31s
2026-04-28 18:02:34 +09:00
oysteikt 4132fb58e8 client/various: sort output
Build and test / check (push) Successful in 1m54s
Build and test / check-license (push) Successful in 2m11s
Build and test / build (push) Successful in 2m47s
Build and test / test (push) Successful in 3m13s
Build and test / docs (push) Successful in 6m12s
2026-04-28 17:58:17 +09:00
oysteikt 40c7a935b3 assets/systemd: always restart service when it dies
Build and test / check-license (push) Successful in 46s
Build and test / check (push) Successful in 1m53s
Build and test / build (push) Successful in 3m20s
Build and test / test (push) Successful in 3m33s
Build and test / docs (push) Successful in 5m41s
2026-04-28 17:34:25 +09:00
oysteikt 5aca2314c4 core/protocol: make ModifyPrivileges response serializable
Build and test / check-license (push) Successful in 46s
Build and test / check (push) Successful in 2m25s
Build and test / build (push) Successful in 3m0s
Build and test / test (push) Successful in 3m34s
Build and test / docs (push) Successful in 5m55s
2026-04-28 17:27:40 +09:00
oysteikt 7a9b233611 core/types: add custom de/serialization for DbOrUser
Build and test / check (push) Successful in 1m46s
Build and test / check-license (push) Successful in 2m5s
Build and test / test (push) Failing after 2m54s
Build and test / build (push) Successful in 3m0s
Build and test / docs (push) Successful in 5m20s
2026-04-28 07:45:46 +09:00
oysteikt 5444ab46ca core/protocol: test de/serialization of all protocol messages
Build and test / check (push) Successful in 1m42s
Build and test / check-license (push) Successful in 57s
Build and test / build (push) Successful in 2m40s
Build and test / test (push) Failing after 3m45s
Build and test / docs (push) Successful in 5m23s
2026-04-28 07:30:08 +09:00
oysteikt b12acbf3b4 .gitea/workflows/publish-deb: remove double file extension 2026-04-28 07:14:07 +09:00
oysteikt 8e2aace9d4 server: specify Host for all relevant sql queries 2026-04-28 07:14:06 +09:00
oysteikt 913aad5758 .gitea/workflows/publish-deb: build for ubuntu resolute
Build and test / check-license (push) Successful in 1m26s
Build and test / check (push) Successful in 2m27s
Build and test / build (push) Successful in 3m33s
Build and test / test (push) Successful in 3m15s
Build and test / docs (push) Successful in 6m4s
2026-04-24 05:14:46 +09:00
oysteikt 1d4a19c299 core/types: better fmt::Display implementation for newtypes
Build and test / check-license (push) Successful in 58s
Build and test / check (push) Successful in 2m0s
Build and test / build (push) Successful in 2m50s
Build and test / test (push) Successful in 3m24s
Build and test / docs (push) Successful in 6m53s
2026-04-15 05:09:49 +09:00
oysteikt 9b279a4956 flake.lock: bump, Cargo.{toml,lock}: update inputs
Build and test / check-license (push) Successful in 1m7s
Build and test / check (push) Successful in 1m47s
Build and test / build (push) Successful in 2m48s
Build and test / test (push) Successful in 3m48s
Build and test / docs (push) Successful in 6m46s
2026-04-02 14:00:30 +09:00
oysteikt 124cf9e69e nix/package: fix license meta field
Build and test / check-license (push) Successful in 55s
Build and test / check (push) Successful in 1m53s
Build and test / build (push) Successful in 3m7s
Build and test / test (push) Successful in 4m55s
Build and test / docs (push) Successful in 5m42s
2026-02-12 11:28:14 +09:00
oysteikt 3fe6a3edea flake.lock: bump, Cargo.{toml,lock}: update inputs
Build and test / check (push) Successful in 1m50s
Build and test / check-license (push) Successful in 2m4s
Build and test / build (push) Successful in 3m8s
Build and test / test (push) Successful in 4m7s
Build and test / docs (push) Successful in 6m14s
2026-01-31 12:22:53 +09:00
oysteikt 65e02192dd scripts/download-and-publish-debs: comment out DEL request
Build and test / check-license (push) Successful in 57s
Build and test / check (push) Successful in 2m2s
Build and test / build (push) Successful in 3m43s
Build and test / test (push) Successful in 3m21s
Build and test / docs (push) Successful in 6m22s
This is not a good thing to do now that we have published a stable
version.
2026-01-14 00:48:51 +09:00
32 changed files with 1597 additions and 860 deletions
+8 -2
View File
@@ -31,7 +31,13 @@ jobs:
build-deb: build-deb:
strategy: strategy:
matrix: matrix:
os: [debian-trixie, debian-bookworm, ubuntu-noble, ubuntu-jammy] os: [
debian-trixie,
debian-bookworm,
# ubuntu-resolute,
ubuntu-noble,
ubuntu-jammy,
]
name: Build and publish for ${{ matrix.os }} name: Build and publish for ${{ matrix.os }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@@ -63,7 +69,7 @@ jobs:
- name: Upload deb package artifact - name: Upload deb package artifact
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }}.zip name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }}
path: target/debian/*.deb path: target/debian/*.deb
if-no-files-found: error if-no-files-found: error
retention-days: 30 retention-days: 30
+23 -1
View File
@@ -1,5 +1,27 @@
# Changelog # Changelog
## v1.0.2
Patch release with an important bug fix
### Notable changes
- Run `FLUSH PRIVILEGES` on the server whenever users modify privileges.
- You will have to grant `RELOAD` for the muscl admin user on all databases, see the [installation docs](./docs/installation.md) for details.
- Bump dependencies
## v1.0.1
Patch release with some important bug fixes
### Notable changes
- `mysql.db.Host` would usually be unset when creating privileges for users, this should be fixed now.
- You might have to manually set this field for rows created with the previous version of muscl to have those privileges work properly.
- Fixed an issue where a few select server responses would refuse to serialize properly, leading to an error message: "No response from server"
- The output of various commands is now being sorted.
- Bump dependencies
## v1.0.0 - Initial Release ## v1.0.0 - Initial Release
This is the initial release of `muscl`. This is the initial release of `muscl`.
@@ -54,7 +76,7 @@ This is the initial release of `muscl`.
interactive tool, there shouldn't have been any scripts relying on the old formatting. interactive tool, there shouldn't have been any scripts relying on the old formatting.
- The configuration file is shared for all variants of the program, and `muscl` will use - The configuration file is shared for all variants of the program, and `muscl` will use
its new logic to look for and parse this file. See the example config and its new logic to look for and parse this file. See the example config and
[installation instructions][installation-instructions] for more information about how to [installation instructions](./docs/installation.md) for more information about how to
configure the software. configure the software.
- The order in which input is validated might be differ from the original - The order in which input is validated might be differ from the original
(e.g. database ownership checks, invalid character checks, existence checks, ...). (e.g. database ownership checks, invalid character checks, existence checks, ...).
Generated
+665 -639
View File
File diff suppressed because it is too large Load Diff
+21 -21
View File
@@ -1,6 +1,6 @@
[package] [package]
name = "muscl" name = "muscl"
version = "1.0.0" version = "1.0.2"
edition = "2024" edition = "2024"
resolver = "2" resolver = "2"
license = "BSD-3-Clause" license = "BSD-3-Clause"
@@ -18,50 +18,50 @@ autobins = false
autolib = false autolib = false
[dependencies] [dependencies]
anyhow = "1.0.100" anyhow = "1.0.102"
async-bincode = "0.8.0" async-bincode = "0.8.0"
bincode = "2.0.1" bincode = "2.0.1"
clap = { version = "4.5.54", features = ["cargo", "derive"] } clap = { version = "4.6.1", features = ["cargo", "derive"] }
clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] } clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] }
clap_complete = { version = "4.5.65", features = ["unstable-dynamic"] } clap_complete = { version = "4.6.5", features = ["unstable-dynamic"] }
color-print = "0.3.7" color-print = "0.3.7"
const_format = "0.2.35" const_format = "0.2.36"
derive_more = { version = "2.1.1", features = ["display", "error"] } derive_more = { version = "2.1.1", features = ["display", "error"] }
dialoguer = "0.12.0" dialoguer = "0.12.0"
futures-util = "0.3.31" futures-util = "0.3.32"
humansize = "2.1.3" humansize = "2.1.3"
indoc = "2.0.7" indoc = "2.0.7"
itertools = "0.14.0" itertools = "0.14.0"
nix = { version = "0.30.1", features = ["fs", "process", "socket", "user"] } nix = { version = "0.31.3", features = ["fs", "process", "socket", "user"] }
num_cpus = "1.17.0" num_cpus = "1.17.0"
prettytable = "0.10.0" prettytable = "0.10.0"
rand = "0.9.2" rand = "0.10.1"
serde = "1.0.228" serde = "1.0.228"
serde_json = { version = "1.0.149", features = ["preserve_order"] } serde_json = { version = "1.0.150", features = ["preserve_order"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] } sqlx = { version = "0.9.0", features = ["runtime-tokio", "mysql", "tls-rustls"] }
thiserror = "2.0.17" thiserror = "2.0.18"
tokio = { version = "1.49.0", features = ["rt-multi-thread", "macros", "signal"] } tokio = { version = "1.52.3", features = ["rt-multi-thread", "macros", "signal"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.18" tokio-stream = "0.1.18"
tokio-util = { version = "0.7.18", features = ["codec", "rt"] } tokio-util = { version = "0.7.18", features = ["codec", "rt"] }
toml = "0.9.11" toml = "1.1.2"
tracing = { version = "0.1.44", features = ["log"] } tracing = { version = "0.1.44", features = ["log"] }
tracing-subscriber = "0.3.22" tracing-subscriber = "0.3.23"
uuid = { version = "1.19.0", features = ["v4"] } uuid = { version = "1.23.2", features = ["v4"] }
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]
landlock = "0.4.4" landlock = "0.4.5"
sd-notify = "0.4.5" sd-notify = "0.5.0"
tracing-journald = "0.3.2" tracing-journald = "0.3.2"
[build-dependencies] [build-dependencies]
anyhow = "1.0.100" anyhow = "1.0.102"
build-info-build = "0.0.42" build-info-build = "0.0.44"
git2 = { version = "0.20.3", default-features = false } git2 = { version = "0.21.0", default-features = false }
[dev-dependencies] [dev-dependencies]
pretty_assertions = "1.4.1" pretty_assertions = "1.4.1"
regex = "1.12.2" regex = "1.12.3"
[features] [features]
default = ["mysql-admutils-compatibility"] default = ["mysql-admutils-compatibility"]
+2
View File
@@ -9,6 +9,8 @@ ExecStart=/usr/bin/muscl-server --systemd --disable-landlock socket-activate
ExecReload=/usr/bin/kill -HUP $MAINPID ExecReload=/usr/bin/kill -HUP $MAINPID
WatchdogSec=3min WatchdogSec=3min
Restart=always
RestartSec=10s
# Although this is a multi-instance unit, the constant `User` field is needed # Although this is a multi-instance unit, the constant `User` field is needed
# for authentication via mysql's auth_socket plugin to work. # for authentication via mysql's auth_socket plugin to work.
+1 -1
View File
@@ -42,7 +42,7 @@ on the MySQL server as the admin user (or another user with sufficient privilege
```sql ```sql
CREATE USER `muscl`@`localhost` IDENTIFIED BY '<strong_password_here>'; CREATE USER `muscl`@`localhost` IDENTIFIED BY '<strong_password_here>';
GRANT SELECT, INSERT, UPDATE, DELETE ON `mysql`.* TO `muscl`@`localhost`; GRANT SELECT, INSERT, UPDATE, DELETE ON `mysql`.* TO `muscl`@`localhost`;
GRANT GRANT OPTION, CREATE, DROP ON *.* TO `muscl`@`localhost`; GRANT GRANT OPTION, CREATE, DROP, RELOAD ON *.* TO `muscl`@`localhost`;
FLUSH PRIVILEGES; FLUSH PRIVILEGES;
``` ```
Generated
+9 -9
View File
@@ -2,11 +2,11 @@
"nodes": { "nodes": {
"crane": { "crane": {
"locked": { "locked": {
"lastModified": 1767744144, "lastModified": 1780099841,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=", "narHash": "sha256-EVZd2RsbpreRUDSi9rBwPY+ZxoyMaiEBbZxxhljbaS4=",
"owner": "ipetkov", "owner": "ipetkov",
"repo": "crane", "repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e", "rev": "0532eb17955225173906d671fb36306bdeb1e2dc",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -17,11 +17,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1768127708, "lastModified": 1779560665,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=", "narHash": "sha256-tpyBcxPpcQb8ukyNF7DoCwfSY3VPsxHoYwj00Cayv5o=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38", "rev": "64c08a7ca051951c8eae34e3e3cb1e202fe36786",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -45,11 +45,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1768186348, "lastModified": 1780110990,
"narHash": "sha256-nkpIe3zkpeoFuOl8xBpexulECsHLQ9Ljg1gW3bPCjSI=", "narHash": "sha256-6QBThUi7SuK+dgA+DCaEkQGZN4kYx6DpXmK45+MG9zI=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "af69e497567a5945a64057717bc9b17c8478097e", "rev": "85570ef134d92a8702de6afd1f6f0209c863fa91",
"type": "github" "type": "github"
}, },
"original": { "original": {
+1 -1
View File
@@ -88,7 +88,7 @@ buildFunction ({
''; '';
meta = with lib; { meta = with lib; {
license = licenses.mit; license = licenses.bsd3;
platforms = platforms.linux ++ platforms.darwin; platforms = platforms.linux ++ platforms.darwin;
inherit mainProgram; inherit mainProgram;
}; };
+15 -7
View File
@@ -42,9 +42,17 @@ declare -r GIT_SHA="$2"
TMPDIR="$(mktemp -d)" TMPDIR="$(mktemp -d)"
for variant in debian-bookworm debian-trixie ubuntu-jammy ubuntu-noble; do declare -a OS_VARIANTS=(
"debian-bookworm"
"debian-trixie"
"ubuntu-jammy"
"ubuntu-noble"
# "ubuntu-resolute"
)
for variant in "${OS_VARIANTS[@]}"; do
echo "Downloading and uploading debs for variant: $variant" echo "Downloading and uploading debs for variant: $variant"
curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA.zip" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip"
unzip "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" -d "$TMPDIR/muscl-deb-$variant-$GIT_SHA" unzip "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" -d "$TMPDIR/muscl-deb-$variant-$GIT_SHA"
@@ -54,11 +62,11 @@ for variant in debian-bookworm debian-trixie ubuntu-jammy ubuntu-noble; do
DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1) DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1)
DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1) DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1)
echo "[DELETE] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH" # echo "[DELETE] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
curl \ # curl \
-X DELETE \ # -X DELETE \
--user "$GITEA_USER:$GITEA_TOKEN" \ # --user "$GITEA_USER:$GITEA_TOKEN" \
"https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH" # "https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
echo "[PUT] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/upload" echo "[PUT] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/upload"
curl \ curl \
+8 -3
View File
@@ -8,6 +8,7 @@ use clap::{Args, Parser};
use clap_complete::ArgValueCompleter; use clap_complete::ArgValueCompleter;
use dialoguer::{Confirm, Editor}; use dialoguer::{Confirm, Editor};
use futures_util::SinkExt; use futures_util::SinkExt;
use itertools::Itertools;
use nix::unistd::{User, getuid}; use nix::unistd::{User, getuid};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@@ -203,9 +204,13 @@ pub async fn edit_database_privileges(
} }
}) })
.flatten() .flatten()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
Ok(list) => list, Ok(list) => list
.into_iter()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect(),
Err(err) => { Err(err) => {
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
return Err(anyhow::anyhow!(err.to_error_message()) return Err(anyhow::anyhow!(err.to_error_message())
@@ -305,7 +310,7 @@ pub async fn edit_database_privileges(
print_modify_database_privileges_output_status(&result); print_modify_database_privileges_output_status(&result);
if result.iter().any(|(_, res)| { if result.values().flatten().any(|(_, res)| {
matches!( matches!(
res, res,
Err(ModifyDatabasePrivilegesError::UserValidationError( Err(ModifyDatabasePrivilegesError::UserValidationError(
@@ -320,7 +325,7 @@ pub async fn edit_database_privileges(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if result.values().any(std::result::Result::is_err) { if result.values().flatten().any(|(_, res)| res.is_err()) {
std::process::exit(1); std::process::exit(1);
} }
+1 -1
View File
@@ -29,7 +29,7 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
// doesn't have any natural implementation semantics. // doesn't have any natural implementation semantics.
/// Representation of the set of privileges for a single user on a single database. /// Representation of the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Default)]
pub struct DatabasePrivilegeRow { pub struct DatabasePrivilegeRow {
// TODO: don't store the db and user here, let the type be stored in a mapping // TODO: don't store the db and user here, let the type be stored in a mapping
pub db: MySQLDatabase, pub db: MySQLDatabase,
+6 -3
View File
@@ -324,9 +324,12 @@ impl Response {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
} }
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()), Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ModifyPrivileges(res) => { Response::ModifyPrivileges(res) => ResponseOkStatus::from_counts(
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) res.len(),
} res.values()
.map(|user_map| user_map.values().filter(|v| v.is_ok()).count())
.sum(),
),
Response::CreateUsers(res) => { Response::CreateUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
@@ -67,3 +67,43 @@ impl CheckAuthorizationError {
self.0.error_type() self.0.error_type()
} }
} }
#[cfg(test)]
mod tests {
use crate::core::protocol::request_validation::NameValidationError;
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CheckAuthorizationRequest = vec![
DbOrUser::Database("test_db".into()),
DbOrUser::User("test_user".into()),
];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CheckAuthorizationRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CheckAuthorizationResponse = BTreeMap::from([
(DbOrUser::Database("test_db".into()), Ok(())),
(
DbOrUser::User("test_user".into()),
Err(CheckAuthorizationError(
ValidationError::NameValidationError(NameValidationError::TooLong),
)),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CheckAuthorizationResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,3 +87,41 @@ impl CreateDatabaseError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateDatabasesRequest =
vec!["test_db1".into(), "test_db2".into(), "test_db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateDatabasesResponse = BTreeMap::from([
("test_db1".into(), Ok(())),
(
"test_db2".into(),
Err(CreateDatabaseError::DatabaseAlreadyExists),
),
(
"test_db3".into(),
Err(CreateDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,3 +87,37 @@ impl CreateUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(CreateUserError::UserAlreadyExists)),
(
"charlie".into(),
Err(CreateUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -90,3 +90,37 @@ impl DropDatabaseError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropDatabasesRequest = vec!["db1".into(), "db2".into(), "db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropDatabasesResponse = BTreeMap::from([
("db1".into(), Ok(())),
("db2".into(), Err(DropDatabaseError::DatabaseDoesNotExist)),
(
"db3".into(),
Err(DropDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+34
View File
@@ -87,3 +87,37 @@ impl DropUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(DropUserError::UserDoesNotExist)),
(
"charlie".into(),
Err(DropUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,3 +27,36 @@ impl ListAllDatabasesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllDatabasesResponse = Ok(vec![
DatabaseRow {
database: "db1".into(),
tables: vec!["table1".into(), "table2".into()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 1024,
},
DatabaseRow {
database: "db2".into(),
tables: vec!["table3".into(), "table4".into()],
users: vec!["user3".into(), "user4".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 2048,
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,3 +27,34 @@ impl ListAllPrivilegesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllPrivilegesResponse = Ok(vec![
DatabasePrivilegeRow {
user: "user1".into(),
db: "db1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
},
DatabasePrivilegeRow {
user: "user2".into(),
db: "db2".into(),
select_priv: false,
insert_priv: true,
..Default::default()
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,3 +27,37 @@ impl ListAllUsersError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllUsersResponse = Ok(vec![
DatabaseUser {
user: "user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
},
DatabaseUser {
user: "user2".into(),
host: "%".into(),
has_password: false,
is_locked: true,
databases: vec!["db3".into()],
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListAllUsersResponse = serde_json::from_str(&json).unwrap();
deserialized.as_mut().unwrap()[0].host = "%".into();
deserialized.as_mut().unwrap()[1].host = "%".into();
assert_eq!(response, deserialized);
}
}
+42 -1
View File
@@ -61,7 +61,7 @@ pub fn print_list_databases_output_status(
"Size" "Size"
} }
]); ]);
for db in final_database_list { for db in final_database_list.iter().sorted_by_key(|db| &db.database) {
table.add_row(row![ table.add_row(row![
db.database, db.database,
db.tables.join("\n"), db.tables.join("\n"),
@@ -137,3 +137,44 @@ impl ListDatabasesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request = Some(vec!["db1".into(), "db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListDatabasesResponse = vec![
(
"db1".into(),
Ok(DatabaseRow {
database: "db1".into(),
tables: vec!["table1".to_string(), "table2".to_string()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".to_string()),
character_set: Some("utf8mb4".to_string()),
size_bytes: 1024,
}),
),
("db2".into(), Err(ListDatabasesError::DatabaseDoesNotExist)),
]
.into_iter()
.collect();
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+47 -3
View File
@@ -64,8 +64,11 @@ pub fn print_list_privileges_output_status(output: &ListPrivilegesResponse, long
.collect(), .collect(),
)); ));
for (_database, rows) in final_privs_map { for row in final_privs_map
for row in &rows { .values()
.flatten()
.sorted_by_key(|row| (&row.db, &row.user))
{
table.add_row(row![ table.add_row(row![
row.db, row.db,
row.user, row.user,
@@ -82,7 +85,7 @@ pub fn print_list_privileges_output_status(output: &ListPrivilegesResponse, long
c->yn(row.references_priv), c->yn(row.references_priv),
]); ]);
} }
} // }
table.printstd(); table.printstd();
} }
@@ -153,3 +156,44 @@ impl ListPrivilegesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListPrivilegesRequest = Some(vec!["test_db1".into(), "test_db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListPrivilegesResponse = BTreeMap::from([
(
"test_db1".into(),
Ok(vec![DatabasePrivilegeRow {
db: "test_db1".into(),
user: "user1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
}]),
),
(
"test_db2".into(),
Err(ListPrivilegesError::DatabaseDoesNotExist),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+47 -1
View File
@@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use itertools::Itertools;
use prettytable::Table; use prettytable::Table;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
@@ -51,7 +52,7 @@ pub fn print_list_users_output_status(output: &ListUsersResponse) {
"Locked", "Locked",
"Databases where user has privileges" "Databases where user has privileges"
]); ]);
for user in final_user_list { for user in final_user_list.iter().sorted_by_key(|user| &user.user) {
table.add_row(row![ table.add_row(row![
user.user, user.user,
user.has_password, user.has_password,
@@ -121,3 +122,48 @@ impl ListUsersError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListUsersRequest = Some(vec!["test_user1".into(), "test_user2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: ListUsersResponse = BTreeMap::from([
(
"test_user1".into(),
Ok(DatabaseUser {
user: "test_user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
}),
),
("test_user2".into(), Err(ListUsersError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListUsersResponse = serde_json::from_str(&json).unwrap();
deserialized
.get_mut(&"test_user1".into())
.unwrap()
.as_mut()
.unwrap()
.host = "%".into();
assert_eq!(response_ok, deserialized);
}
}
+30
View File
@@ -94,3 +94,33 @@ impl LockUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: LockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: LockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: LockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(LockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: LockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
@@ -12,7 +12,7 @@ use crate::core::{
pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>; pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>;
pub type ModifyPrivilegesResponse = pub type ModifyPrivilegesResponse =
BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>; BTreeMap<MySQLDatabase, BTreeMap<MySQLUser, Result<(), ModifyDatabasePrivilegesError>>>;
#[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModifyDatabasePrivilegesError { pub enum ModifyDatabasePrivilegesError {
@@ -49,7 +49,11 @@ pub enum DiffDoesNotApplyError {
} }
pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) { pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) {
for ((database_name, username), result) in output { for ((database_name, username), result) in output.iter().flat_map(|(db, user_map)| {
user_map
.iter()
.map(move |(user, result)| ((db, user), result))
}) {
match result { match result {
Ok(()) => { Ok(()) => {
println!( println!(
@@ -144,3 +148,46 @@ impl DiffDoesNotApplyError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::core::*;
#[test]
fn test_serialize_deserialize_request() {
let request =
BTreeSet::from([DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "test_db".into(),
user: "test_user".into(),
select_priv: Some(database_privileges::DatabasePrivilegeChange::NoToYes),
..Default::default()
})]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ModifyPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ModifyPrivilegesResponse = BTreeMap::from([(
"test_db".into(),
BTreeMap::from([
("test_user".into(), Ok(())),
(
"invalid_user".into(),
Err(ModifyDatabasePrivilegesError::UserDoesNotExist),
),
]),
)]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ModifyPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+32
View File
@@ -60,3 +60,35 @@ impl SetPasswordError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: SetUserPasswordRequest = ("test_user".into(), "new_password".into());
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: SetUserPasswordRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: SetUserPasswordResponse = Ok(());
let response_err: SetUserPasswordResponse = Err(SetPasswordError::UserDoesNotExist);
let json_ok = serde_json::to_string_pretty(&response_ok).unwrap();
let json_err = serde_json::to_string_pretty(&response_err).unwrap();
println!("Serialized OK response:\n{}", json_ok);
println!("Serialized Error response:\n{}", json_err);
let deserialized_ok: SetUserPasswordResponse = serde_json::from_str(&json_ok).unwrap();
let deserialized_err: SetUserPasswordResponse = serde_json::from_str(&json_err).unwrap();
assert_eq!(response_ok, deserialized_ok);
assert_eq!(response_err, deserialized_err);
}
}
@@ -94,3 +94,33 @@ impl UnlockUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: UnlockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: UnlockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: UnlockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(UnlockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: UnlockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
+34 -3
View File
@@ -34,7 +34,7 @@ impl DerefMut for MySQLUser {
impl fmt::Display for MySQLUser { impl fmt::Display for MySQLUser {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0)) self.0.fmt(f)
} }
} }
@@ -83,7 +83,7 @@ impl DerefMut for MySQLDatabase {
impl fmt::Display for MySQLDatabase { impl fmt::Display for MySQLDatabase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0)) self.0.fmt(f)
} }
} }
@@ -105,12 +105,43 @@ impl From<MySQLDatabase> for OsString {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum DbOrUser { pub enum DbOrUser {
Database(MySQLDatabase), Database(MySQLDatabase),
User(MySQLUser), User(MySQLUser),
} }
impl Serialize for DbOrUser {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
DbOrUser::Database(db) => ("d:".to_string() + &db.to_string()).serialize(serializer),
DbOrUser::User(user) => ("u:".to_string() + &user.to_string()).serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for DbOrUser {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Some(rest) = s.strip_prefix("d:") {
Ok(DbOrUser::Database(MySQLDatabase(rest.to_string())))
} else if let Some(rest) = s.strip_prefix("u:") {
Ok(DbOrUser::User(MySQLUser(rest.to_string())))
} else {
Err(serde::de::Error::custom(format!(
"Invalid DbOrUser format: {}",
s
)))
}
}
}
impl DbOrUser { impl DbOrUser {
#[must_use] #[must_use]
pub fn lowercased_noun(&self) -> &'static str { pub fn lowercased_noun(&self) -> &'static str {
+99 -38
View File
@@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use sqlx::AssertSqlSafe;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use sqlx::prelude::*; use sqlx::prelude::*;
@@ -125,8 +126,11 @@ pub async fn create_databases(
_ => {} _ => {}
} }
let result = let statement = AssertSqlSafe(format!(
sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str()) "CREATE DATABASE {}",
quote_identifier(&database_name)
));
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -181,8 +185,11 @@ pub async fn drop_databases(
_ => {} _ => {}
} }
let result = let statement = AssertSqlSafe(format!(
sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str()) "DROP DATABASE {}",
quote_identifier(&database_name)
));
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -265,26 +272,49 @@ pub async fn list_databases(
let result = sqlx::query_as::<_, DatabaseRow>( let result = sqlx::query_as::<_, DatabaseRow>(
r" r"
SELECT SELECT
CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`, CAST(s.SCHEMA_NAME AS CHAR(64)) AS `database`,
GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`, t.tables,
GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`, u.users,
MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`, s.DEFAULT_COLLATION_NAME AS `collation`,
MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`, s.DEFAULT_CHARACTER_SET_NAME AS `character_set`,
CAST(IFNULL( CAST(COALESCE(t.size_bytes, 0) AS UNSIGNED) AS `size_bytes`
SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`), FROM information_schema.SCHEMATA s
0
) AS UNSIGNED INTEGER) AS `size_bytes`
FROM `information_schema`.`SCHEMATA`
LEFT OUTER JOIN `information_schema`.`TABLES`
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA`
LEFT OUTER JOIN `mysql`.`db`
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB`
WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = ?
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
",
LEFT JOIN (
SELECT
TABLE_SCHEMA,
GROUP_CONCAT(
DISTINCT CAST(TABLE_NAME AS CHAR(64))
ORDER BY TABLE_NAME
SEPARATOR ','
) AS tables,
SUM(DATA_LENGTH + INDEX_LENGTH) AS size_bytes
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = ?
GROUP BY TABLE_SCHEMA
) t
ON t.TABLE_SCHEMA = s.SCHEMA_NAME
LEFT JOIN (
SELECT
DB,
GROUP_CONCAT(
DISTINCT CAST(User AS CHAR(64))
ORDER BY User
SEPARATOR ','
) AS users
FROM mysql.db
WHERE DB = ?
GROUP BY DB
) u
ON u.DB = s.SCHEMA_NAME
WHERE s.SCHEMA_NAME = ?;
",
) )
.bind(database_name.to_string()) .bind(database_name.to_string())
.bind(database_name.to_string())
.bind(database_name.to_string())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await .await
.map_err(|err| ListDatabasesError::MySqlError(err.to_string())) .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
@@ -313,26 +343,57 @@ pub async fn list_all_databases_for_user(
let result = sqlx::query_as::<_, DatabaseRow>( let result = sqlx::query_as::<_, DatabaseRow>(
r" r"
SELECT SELECT
CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`, CAST(s.SCHEMA_NAME AS CHAR(64)) AS `database`,
GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`, t.tables,
GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`, u.users,
MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`, s.DEFAULT_COLLATION_NAME AS collation,
MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`, s.DEFAULT_CHARACTER_SET_NAME AS character_set,
CAST(IFNULL( CAST(COALESCE(t.size_bytes, 0) AS UNSIGNED) AS size_bytes
SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`), FROM information_schema.SCHEMATA s
0
) AS UNSIGNED INTEGER) AS `size_bytes` LEFT JOIN (
FROM `information_schema`.`SCHEMATA` SELECT
LEFT OUTER JOIN `information_schema`.`TABLES` TABLE_SCHEMA,
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA` GROUP_CONCAT(
LEFT OUTER JOIN `mysql`.`db` DISTINCT CAST(TABLE_NAME AS CHAR(64))
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB` ORDER BY TABLE_NAME
WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') SEPARATOR ','
AND `information_schema`.`SCHEMATA`.`SCHEMA_NAME` REGEXP ? ) AS tables,
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME` SUM(DATA_LENGTH + INDEX_LENGTH) AS size_bytes
FROM information_schema.TABLES
WHERE TABLE_SCHEMA REGEXP ?
GROUP BY TABLE_SCHEMA
) t
ON t.TABLE_SCHEMA = s.SCHEMA_NAME
LEFT JOIN (
SELECT
DB,
GROUP_CONCAT(
DISTINCT CAST(User AS CHAR(64))
ORDER BY User
SEPARATOR ','
) AS users
FROM mysql.db
WHERE DB REGEXP ?
GROUP BY DB
) u
ON u.DB = s.SCHEMA_NAME
WHERE s.SCHEMA_NAME REGEXP ?
AND s.SCHEMA_NAME NOT IN (
'information_schema',
'performance_schema',
'mysql',
'sys'
)
ORDER BY s.SCHEMA_NAME
", ",
) )
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection) .fetch_all(connection)
.await .await
.map_err(|err| ListAllDatabasesError::MySqlError(err.to_string())); .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
+39 -14
View File
@@ -18,7 +18,7 @@ use std::collections::{BTreeMap, BTreeSet};
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use sqlx::{MySqlConnection, mysql::MySqlRow, prelude::*}; use sqlx::{AssertSqlSafe, MySqlConnection, mysql::MySqlRow, prelude::*};
use crate::{ use crate::{
core::{ core::{
@@ -84,13 +84,14 @@ async fn unsafe_get_database_privileges(
database_name: &str, database_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let statement = AssertSqlSafe(format!(
"SELECT {} FROM `db` WHERE `Db` = ?", "SELECT {} FROM `db` WHERE `Db` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)) ));
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement)
.bind(database_name) .bind(database_name)
.fetch_all(connection) .fetch_all(connection)
.await; .await;
@@ -113,13 +114,14 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
user_name: &MySQLUser, user_name: &MySQLUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let statement = AssertSqlSafe(format!(
"SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ?", "SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)) ));
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement)
.bind(database_name.as_str()) .bind(database_name.as_str())
.bind(user_name.as_str()) .bind(user_name.as_str())
.fetch_optional(connection) .fetch_optional(connection)
@@ -189,8 +191,8 @@ pub async fn get_databases_privilege_data(
} }
/// TODO: make this constant /// TODO: make this constant
fn get_all_db_privs_query() -> String { fn get_all_db_privs_query() -> AssertSqlSafe<String> {
format!( AssertSqlSafe(format!(
indoc! {r" indoc! {r"
SELECT {} FROM `db` WHERE `db` IN SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database` (SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
@@ -202,7 +204,7 @@ fn get_all_db_privs_query() -> String {
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
) ))
} }
/// Get all database + user + privileges pairs that are owned by the current user. /// Get all database + user + privileges pairs that are owned by the current user.
@@ -212,7 +214,7 @@ pub async fn get_all_database_privileges(
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist, group_denylist: &GroupDenylist,
) -> ListAllPrivilegesResponse { ) -> ListAllPrivilegesResponse {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query()) let result = sqlx::query_as::<_, DatabasePrivilegeRow>(get_all_db_privs_query())
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection) .fetch_all(connection)
.await .await
@@ -234,13 +236,17 @@ async fn unsafe_apply_privilege_diff(
DatabasePrivilegesDiff::New(p) => { DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS let tables = DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.chain(&["Host"])
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","); .join(",");
let question_marks = let question_marks =
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len()).join(","); std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(",");
sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str()) let statement = AssertSqlSafe(format!(
"INSERT INTO `db` ({tables}) VALUES ({question_marks})"
));
sqlx::query(statement)
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind(yn(p.select_priv)) .bind(yn(p.select_priv))
@@ -254,6 +260,7 @@ async fn unsafe_apply_privilege_diff(
.bind(yn(p.create_tmp_table_priv)) .bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv)) .bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv)) .bind(yn(p.references_priv))
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -278,7 +285,10 @@ async fn unsafe_apply_privilege_diff(
} }
} }
sqlx::query(format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ?").as_str()) let statement = AssertSqlSafe(format!(
"UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?"
));
sqlx::query(statement)
.bind(p.select_priv.map(change_to_yn)) .bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn)) .bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn)) .bind(p.update_priv.map(change_to_yn))
@@ -292,14 +302,16 @@ async fn unsafe_apply_privilege_diff(
.bind(p.references_priv.map(change_to_yn)) .bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
} }
DatabasePrivilegesDiff::Deleted(p) => { DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ?") sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -480,5 +492,18 @@ pub async fn apply_privilege_diffs(
results.insert(key, result); results.insert(key, result);
} }
if let Err(err) = connection.execute("FLUSH PRIVILEGES").await {
tracing::error!("Failed to flush privileges: {}", err);
}
results results
.into_iter()
.map(|((k1, k2), v)| (k1, (k2, v)))
.into_group_map()
.into_iter()
.map(|(k1, pairs)| {
let inner = pairs.into_iter().collect::<BTreeMap<_, _>>();
(k1, inner)
})
.collect()
} }
+34 -27
View File
@@ -1,5 +1,6 @@
use indoc::formatdoc; use indoc::formatdoc;
use itertools::Itertools; use itertools::Itertools;
use sqlx::AssertSqlSafe;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -39,6 +40,7 @@ pub(super) async fn unsafe_user_exists(
SELECT 1 SELECT 1
FROM `mysql`.`user` FROM `mysql`.`user`
WHERE `User` = ? WHERE `User` = ?
AND `Host` = '%'
) )
", ",
) )
@@ -67,6 +69,7 @@ pub async fn complete_user_name(
FROM `mysql`.`user` FROM `mysql`.`user`
WHERE `User` REGEXP ? WHERE `User` REGEXP ?
AND `User` LIKE ? AND `User` LIKE ?
AND `Host` = '%'
", ",
) )
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
@@ -124,7 +127,8 @@ pub async fn create_database_users(
_ => {} _ => {}
} }
let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str()) let statement = AssertSqlSafe(format!("CREATE USER {}@'%'", quote_literal(&db_user),));
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -170,7 +174,8 @@ pub async fn drop_database_users(
_ => {} _ => {}
} }
let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str()) let statement = AssertSqlSafe(format!("DROP USER {}@'%'", quote_literal(&db_user),));
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -203,14 +208,12 @@ pub async fn set_password_for_database_user(
_ => {} _ => {}
} }
let result = sqlx::query( let statement = AssertSqlSafe(format!(
format!(
"ALTER USER {}@'%' IDENTIFIED BY {}", "ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user), quote_literal(db_user),
quote_literal(password).as_str(), quote_literal(password).as_str(),
) ));
.as_str(), let result = sqlx::query(statement)
)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -313,9 +316,11 @@ pub async fn lock_database_users(
} }
} }
let result = sqlx::query( let statement = AssertSqlSafe(format!(
format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(), "ALTER USER {}@'%' ACCOUNT LOCK",
) quote_literal(&db_user),
));
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -373,9 +378,11 @@ pub async fn unlock_database_users(
_ => {} _ => {}
} }
let result = sqlx::query( let statement = AssertSqlSafe(format!(
format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(), "ALTER USER {}@'%' ACCOUNT UNLOCK",
) quote_literal(&db_user),
));
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -457,13 +464,14 @@ pub async fn list_database_users(
continue; continue;
} }
let mut result = sqlx::query_as::<_, DatabaseUser>( let statement = AssertSqlSafe(
&(if db_is_mariadb { if db_is_mariadb {
DB_USER_SELECT_STATEMENT_MARIADB.to_string() DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else { } else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `mysql`.`user`.`User` = ?"), } + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'",
) );
let mut result = sqlx::query_as::<_, DatabaseUser>(statement)
.bind(db_user.as_str()) .bind(db_user.as_str())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await; .await;
@@ -494,13 +502,14 @@ pub async fn list_all_database_users_for_unix_user(
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist, group_denylist: &GroupDenylist,
) -> ListAllUsersResponse { ) -> ListAllUsersResponse {
let mut result = sqlx::query_as::<_, DatabaseUser>( let statement = AssertSqlSafe(
&(if db_is_mariadb { if db_is_mariadb {
DB_USER_SELECT_STATEMENT_MARIADB.to_string() DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else { } else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ?"), } + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'",
) );
let mut result = sqlx::query_as::<_, DatabaseUser>(statement)
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await .await
@@ -529,20 +538,18 @@ pub async fn set_databases_where_user_has_privileges(
db_user: &mut DatabaseUser, db_user: &mut DatabaseUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<(), sqlx::Error> { ) -> Result<(), sqlx::Error> {
let database_list = sqlx::query( let statement = AssertSqlSafe(formatdoc!(
formatdoc!(
r" r"
SELECT `Db` AS `database` SELECT `Db` AS `database`
FROM `db` FROM `db`
WHERE `User` = ? AND ({}) WHERE `User` = ? AND `Host` = '%' AND ({})
", ",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| format!("`{field}` = 'Y'")) .map(|field| format!("`{field}` = 'Y'"))
.join(" OR "), .join(" OR "),
) ));
.as_str(), let database_list = sqlx::query(statement)
)
.bind(db_user.user.as_str()) .bind(db_user.user.as_str())
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await; .await;
+9 -16
View File
@@ -90,14 +90,12 @@ impl Supervisor {
}; };
let mut watchdog_duration = None; let mut watchdog_duration = None;
let mut watchdog_micro_seconds = 0;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let watchdog_task = let watchdog_task =
if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) { if systemd_mode && let Some(watchdog_duration_) = sd_notify::watchdog_enabled() {
let watchdog_duration_ = Duration::from_micros(watchdog_micro_seconds);
tracing::debug!( tracing::debug!(
"Systemd watchdog enabled with {} millisecond interval", "Systemd watchdog enabled with {} millisecond interval",
watchdog_micro_seconds.div_ceil(1000), watchdog_duration_.as_millis()
); );
watchdog_duration = Some(watchdog_duration_); watchdog_duration = Some(watchdog_duration_);
Some(spawn_watchdog_task(watchdog_duration_)) Some(spawn_watchdog_task(watchdog_duration_))
@@ -295,15 +293,12 @@ impl Supervisor {
pub async fn reload(&self) -> anyhow::Result<()> { pub async fn reload(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify( sd_notify::notify(&[
false,
&[
sd_notify::NotifyState::Reloading, sd_notify::NotifyState::Reloading,
sd_notify::NotifyState::monotonic_usec_now() sd_notify::NotifyState::monotonic_usec_now()
.expect("Failed to get monotonic time to send to systemd while reloading"), .expect("Failed to get monotonic time to send to systemd while reloading"),
sd_notify::NotifyState::Status("Reloading configuration"), sd_notify::NotifyState::Status("Reloading configuration"),
], ])?;
)?;
let previous_config = self.config.lock().await.clone(); let previous_config = self.config.lock().await.clone();
self.reload_config().await?; self.reload_config().await?;
@@ -340,14 +335,14 @@ impl Supervisor {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; sd_notify::notify(&[sd_notify::NotifyState::Ready])?;
Ok(()) Ok(())
} }
pub async fn shutdown(&self) -> anyhow::Result<()> { pub async fn shutdown(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?; sd_notify::notify(&[sd_notify::NotifyState::Stopping])?;
tracing::debug!("Stop accepting new connections"); tracing::debug!("Stop accepting new connections");
self.stop_receiving_new_connections()?; self.stop_receiving_new_connections()?;
@@ -417,7 +412,7 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
); );
loop { loop {
interval.tick().await; interval.tick().await;
if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) { if let Err(err) = sd_notify::notify(&[sd_notify::NotifyState::Watchdog]) {
tracing::warn!("Failed to notify systemd watchdog: {}", err); tracing::warn!("Failed to notify systemd watchdog: {}", err);
} }
} }
@@ -440,9 +435,7 @@ fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
"Waiting for connections".to_string() "Waiting for connections".to_string()
}; };
if let Err(e) = if let Err(e) = sd_notify::notify(&[sd_notify::NotifyState::Status(message.as_str())]) {
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())])
{
tracing::warn!("Failed to send systemd status notification: {}", e); tracing::warn!("Failed to send systemd status notification: {}", e);
} }
} }
@@ -557,7 +550,7 @@ async fn listener_task(
group_denylist: Arc<RwLock<GroupDenylist>>, group_denylist: Arc<RwLock<GroupDenylist>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; sd_notify::notify(&[sd_notify::NotifyState::Ready])?;
let connection_counter = AtomicU64::new(0); let connection_counter = AtomicU64::new(0);