22 Commits

Author SHA1 Message Date
oysteikt 6404e5011a CHANGELOG.md: add release notes, Cargo.toml: bump version number
Build and test / check-license (push) Successful in 52s
Build and test / check (push) Successful in 1m44s
Build and test / build (push) Successful in 3m1s
Build and test / test (push) Successful in 3m12s
Build and test / docs (push) Successful in 7m24s
2026-04-28 18:14:48 +09:00
oysteikt 531fdfc2e9 Cargo.{toml,lock}: bump deps 2026-04-28 18:14:19 +09:00
oysteikt af74e8e540 .gitea/workflows/publish-deb: temporarily disable ubuntu resolute
Build and test / check-license (push) Successful in 51s
Build and test / check (push) Successful in 1m47s
Build and test / build (push) Successful in 2m44s
Build and test / test (push) Successful in 3m31s
Build and test / docs (push) Successful in 5m31s
2026-04-28 18:02:34 +09:00
oysteikt 4132fb58e8 client/various: sort output
Build and test / check (push) Successful in 1m54s
Build and test / check-license (push) Successful in 2m11s
Build and test / build (push) Successful in 2m47s
Build and test / test (push) Successful in 3m13s
Build and test / docs (push) Successful in 6m12s
2026-04-28 17:58:17 +09:00
oysteikt 40c7a935b3 assets/systemd: always restart service when it dies
Build and test / check-license (push) Successful in 46s
Build and test / check (push) Successful in 1m53s
Build and test / build (push) Successful in 3m20s
Build and test / test (push) Successful in 3m33s
Build and test / docs (push) Successful in 5m41s
2026-04-28 17:34:25 +09:00
oysteikt 5aca2314c4 core/protocol: make ModifyPrivileges response serializable
Build and test / check-license (push) Successful in 46s
Build and test / check (push) Successful in 2m25s
Build and test / build (push) Successful in 3m0s
Build and test / test (push) Successful in 3m34s
Build and test / docs (push) Successful in 5m55s
2026-04-28 17:27:40 +09:00
oysteikt 7a9b233611 core/types: add custom de/serialization for DbOrUser
Build and test / check (push) Successful in 1m46s
Build and test / check-license (push) Successful in 2m5s
Build and test / test (push) Failing after 2m54s
Build and test / build (push) Successful in 3m0s
Build and test / docs (push) Successful in 5m20s
2026-04-28 07:45:46 +09:00
oysteikt 5444ab46ca core/protocol: test de/serialization of all protocol messages
Build and test / check (push) Successful in 1m42s
Build and test / check-license (push) Successful in 57s
Build and test / build (push) Successful in 2m40s
Build and test / test (push) Failing after 3m45s
Build and test / docs (push) Successful in 5m23s
2026-04-28 07:30:08 +09:00
oysteikt b12acbf3b4 .gitea/workflows/publish-deb: remove double file extension 2026-04-28 07:14:07 +09:00
oysteikt 8e2aace9d4 server: specify Host for all relevant sql queries 2026-04-28 07:14:06 +09:00
oysteikt 913aad5758 .gitea/workflows/publish-deb: build for ubuntu resolute
Build and test / check-license (push) Successful in 1m26s
Build and test / check (push) Successful in 2m27s
Build and test / build (push) Successful in 3m33s
Build and test / test (push) Successful in 3m15s
Build and test / docs (push) Successful in 6m4s
2026-04-24 05:14:46 +09:00
oysteikt 1d4a19c299 core/types: better fmt::Display implementation for newtypes
Build and test / check-license (push) Successful in 58s
Build and test / check (push) Successful in 2m0s
Build and test / build (push) Successful in 2m50s
Build and test / test (push) Successful in 3m24s
Build and test / docs (push) Successful in 6m53s
2026-04-15 05:09:49 +09:00
oysteikt 9b279a4956 flake.lock: bump, Cargo.{toml,lock}: update inputs
Build and test / check-license (push) Successful in 1m7s
Build and test / check (push) Successful in 1m47s
Build and test / build (push) Successful in 2m48s
Build and test / test (push) Successful in 3m48s
Build and test / docs (push) Successful in 6m46s
2026-04-02 14:00:30 +09:00
oysteikt 124cf9e69e nix/package: fix license meta field
Build and test / check-license (push) Successful in 55s
Build and test / check (push) Successful in 1m53s
Build and test / build (push) Successful in 3m7s
Build and test / test (push) Successful in 4m55s
Build and test / docs (push) Successful in 5m42s
2026-02-12 11:28:14 +09:00
oysteikt 3fe6a3edea flake.lock: bump, Cargo.{toml,lock}: update inputs
Build and test / check (push) Successful in 1m50s
Build and test / check-license (push) Successful in 2m4s
Build and test / build (push) Successful in 3m8s
Build and test / test (push) Successful in 4m7s
Build and test / docs (push) Successful in 6m14s
2026-01-31 12:22:53 +09:00
oysteikt 65e02192dd scripts/download-and-publish-debs: comment out DEL request
Build and test / check-license (push) Successful in 57s
Build and test / check (push) Successful in 2m2s
Build and test / build (push) Successful in 3m43s
Build and test / test (push) Successful in 3m21s
Build and test / docs (push) Successful in 6m22s
This is not a good thing to do now that we have published a stable
version.
2026-01-14 00:48:51 +09:00
oysteikt b2d9400f0e Cargo.toml: 0.1.0 -> 1.0.0
Build and test / check-license (push) Successful in 1m3s
Build and test / check (push) Successful in 2m3s
Build and test / build (push) Successful in 3m35s
Build and test / test (push) Successful in 3m19s
Build and test / docs (push) Successful in 6m41s
2026-01-14 00:30:18 +09:00
oysteikt 2838c584d3 Cargo.toml: state Programvareverkstedet as author 2026-01-14 00:29:46 +09:00
oysteikt ce75aa509d client: add better error messages on failed server connection
Build and test / check-license (push) Successful in 53s
Build and test / check (push) Successful in 2m28s
Build and test / build (push) Successful in 3m11s
Build and test / test (push) Successful in 3m18s
Build and test / docs (push) Successful in 8m0s
2026-01-12 21:12:18 +09:00
oysteikt 87ef63b680 assets/debian/systemd: remove socket on stop 2026-01-12 21:06:25 +09:00
oysteikt 6686b3bbe7 scripts/download-and-upload-debs: add extra logging
Build and test / check-license (push) Successful in 1m54s
Build and test / check (push) Successful in 2m2s
Build and test / build (push) Successful in 2m51s
Build and test / test (push) Successful in 4m49s
Build and test / docs (push) Successful in 6m25s
2026-01-12 16:50:05 +09:00
oysteikt f75b34f40c server: don't warn on empty/comment only lines in group denylists
Build and test / check-license (push) Successful in 57s
Build and test / docs (push) Has been cancelled
Build and test / build (push) Has been cancelled
Build and test / check (push) Has been cancelled
Build and test / test (push) Has been cancelled
2026-01-12 16:48:58 +09:00
41 changed files with 1312 additions and 664 deletions
+8 -2
View File
@@ -31,7 +31,13 @@ jobs:
build-deb:
strategy:
matrix:
os: [debian-trixie, debian-bookworm, ubuntu-noble, ubuntu-jammy]
os: [
debian-trixie,
debian-bookworm,
# ubuntu-resolute,
ubuntu-noble,
ubuntu-jammy,
]
name: Build and publish for ${{ matrix.os }}
runs-on: ${{ matrix.os }}
steps:
@@ -63,7 +69,7 @@ jobs:
- name: Upload deb package artifact
uses: actions/upload-artifact@v3
with:
name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }}.zip
name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }}
path: target/debian/*.deb
if-no-files-found: error
retention-days: 30
+12
View File
@@ -1,5 +1,17 @@
# Changelog
## v1.0.1
Patch release with some important bug fixes
### Notable changes
- `mysql.db.Host` would usually be unset when creating privileges for users, this should be fixed now.
- You might have to manually set this field for rows created with the previous version of muscl to have those privileges work properly.
- Fixed an issue where a few select server responses would refuse to serialize properly, leading to an error message: "No response from server"
- The output of various commands is now being sorted.
- Bump dependencies
## v1.0.0 - Initial Release
This is the initial release of `muscl`.
Generated
+537 -288
View File
File diff suppressed because it is too large Load Diff
+19 -20
View File
@@ -1,12 +1,11 @@
[package]
name = "muscl"
version = "0.1.0"
version = "1.0.1"
edition = "2024"
resolver = "2"
license = "BSD-3-Clause"
authors = [
"oysteikt@pvv.ntnu.no",
"felixalb@pvv.ntnu.no",
"Programvareverkstedet <projects@pvv.ntnu.no>",
]
homepage = "https://git.pvv.ntnu.no/Projects/muscl"
repository = "https://git.pvv.ntnu.no/Projects/muscl"
@@ -19,50 +18,50 @@ autobins = false
autolib = false
[dependencies]
anyhow = "1.0.100"
anyhow = "1.0.102"
async-bincode = "0.8.0"
bincode = "2.0.1"
clap = { version = "4.5.54", features = ["cargo", "derive"] }
clap = { version = "4.6.1", features = ["cargo", "derive"] }
clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] }
clap_complete = { version = "4.5.65", features = ["unstable-dynamic"] }
clap_complete = { version = "4.6.3", features = ["unstable-dynamic"] }
color-print = "0.3.7"
const_format = "0.2.35"
const_format = "0.2.36"
derive_more = { version = "2.1.1", features = ["display", "error"] }
dialoguer = "0.12.0"
futures-util = "0.3.31"
futures-util = "0.3.32"
humansize = "2.1.3"
indoc = "2.0.7"
itertools = "0.14.0"
nix = { version = "0.30.1", features = ["fs", "process", "socket", "user"] }
nix = { version = "0.31.2", features = ["fs", "process", "socket", "user"] }
num_cpus = "1.17.0"
prettytable = "0.10.0"
rand = "0.9.2"
rand = "0.10.1"
serde = "1.0.228"
serde_json = { version = "1.0.149", features = ["preserve_order"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
thiserror = "2.0.17"
tokio = { version = "1.49.0", features = ["rt-multi-thread", "macros", "signal"] }
thiserror = "2.0.18"
tokio = { version = "1.52.1", features = ["rt-multi-thread", "macros", "signal"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.18"
tokio-util = { version = "0.7.18", features = ["codec", "rt"] }
toml = "0.9.11"
toml = "1.1.2"
tracing = { version = "0.1.44", features = ["log"] }
tracing-subscriber = "0.3.22"
uuid = { version = "1.19.0", features = ["v4"] }
tracing-subscriber = "0.3.23"
uuid = { version = "1.23.1", features = ["v4"] }
[target.'cfg(target_os = "linux")'.dependencies]
landlock = "0.4.4"
sd-notify = "0.4.5"
sd-notify = "0.5.0"
tracing-journald = "0.3.2"
[build-dependencies]
anyhow = "1.0.100"
build-info-build = "0.0.42"
git2 = { version = "0.20.3", default-features = false }
anyhow = "1.0.102"
build-info-build = "0.0.44"
git2 = { version = "0.20.4", default-features = false }
[dev-dependencies]
pretty_assertions = "1.4.1"
regex = "1.12.2"
regex = "1.12.3"
[features]
default = ["mysql-admutils-compatibility"]
+2
View File
@@ -9,6 +9,8 @@ ExecStart=/usr/bin/muscl-server --systemd --disable-landlock socket-activate
ExecReload=/usr/bin/kill -HUP $MAINPID
WatchdogSec=3min
Restart=always
RestartSec=10s
# Although this is a multi-instance unit, the constant `User` field is needed
# for authentication via mysql's auth_socket plugin to work.
+1
View File
@@ -3,6 +3,7 @@ Description=Muscl MySQL admin tool
[Socket]
ListenStream=/run/muscl/muscl.sock
RemoveOnStop=true
Accept=no
PassCredentials=true
@@ -1,53 +0,0 @@
[Unit]
Description=Authorization daemon for Muscl
[Service]
Type=notify
ExecStart=/usr/local/bin/muscl_auth_daemon.py
# WatchdogSec=15
User=muscl
Group=muscl
DynamicUser=yes
; ConfigurationDirectory=muscl
; RuntimeDirectory=muscl
; # This is required to read unix user/group details.
; PrivateUsers=false
; # Needed to communicate with MySQL.
; PrivateNetwork=false
; PrivateIPC=false
; AmbientCapabilities=
; CapabilityBoundingSet=
; DeviceAllow=
; DevicePolicy=closed
; LockPersonality=true
; MemoryDenyWriteExecute=true
; NoNewPrivileges=true
; PrivateDevices=true
; PrivateMounts=true
; PrivateTmp=yes
; ProcSubset=pid
; ProtectClock=true
; ProtectControlGroups=strict
; ProtectHome=true
; ProtectHostname=true
; ProtectKernelLogs=true
; ProtectKernelModules=true
; ProtectKernelTunables=true
; ProtectProc=invisible
; ProtectSystem=strict
; RemoveIPC=true
; RestrictAddressFamilies=AF_UNIX AF_INET AF_INET6
; RestrictNamespaces=true
; RestrictRealtime=true
; RestrictSUIDSGID=true
; SocketBindDeny=any
; SystemCallArchitectures=native
; SystemCallFilter=@system-service
; SystemCallFilter=~@privileged @resources
; UMask=0777
@@ -1,8 +0,0 @@
[Unit]
Description=Authorization daemon for Muscl
WantedBy=sockets.target
[Socket]
ListenStream=/run/muscl/muscl-auth-daemon.socket
Accept=no
SocketMode=0660
@@ -1,84 +0,0 @@
#!/usr/bin/env python3
# TODO: create pool of workers to handle requests concurrently
# the socket should be a listener socket and each worker should accept connections from it
# the socket should accept requests as newline-separated JSON objects
# there should be a watchdog to monitor worker health and restart them if they die
# graceful shutdown should be implemented for the workers
# optional logging of requests and responses
# use systemd notify to signal readiness and amount of connections handled
import json
import os
from socket import AF_UNIX, SOCK_DGRAM, SOCK_STREAM, fromfd, socket
from multiprocessing import Pool
def get_listener_from_systemd() -> socket:
listen_fds = int(os.getenv("LISTEN_FDS", "0"))
listen_pid = int(os.getenv("LISTEN_PID", "0"))
if listen_fds != 1 or listen_pid != os.getpid():
raise RuntimeError("No socket passed from systemd")
assert listen_fds == 1
sock = fromfd(3, AF_UNIX, SOCK_STREAM)
sock.setblocking(False)
return sock
def get_notify_socket_from_systemd() -> socket:
notify_socket_path = os.getenv("NOTIFY_SOCKET")
if not notify_socket_path:
raise RuntimeError("No notify socket path found in environment")
sock = socket(AF_UNIX, SOCK_DGRAM)
sock.connect(notify_socket_path)
return sock
def run_auth_daemon(sock: socket):
sock.listen()
print("Auth daemon is running and listening for connections...")
with Pool() as worker_pool:
with get_notify_socket_from_systemd() as notify_socket:
notify_socket.sendall(b"READY=1\n")
while True:
conn, _ = sock.accept()
worker_pool.apply_async(session_handler, args=(conn,))
def session_handler(sock: socket):
buffer = ""
while True:
data = sock.recv(4096).decode("utf-8")
if not data:
print("Connection closed by client")
break
buffer += data
if buffer.endswith("\n"):
requests = buffer.strip().split("\n")
buffer = ""
for request in requests:
try:
req_json = json.loads(request)
username = req_json.get("username", "")
groups = req_json.get("groups", [])
resource_type = req_json.get("resource_type", "")
resource = req_json.get("resource", "")
allowed = process_request(username, groups, resource_type, resource)
response = {"allowed": allowed}
except json.JSONDecodeError:
response = {"error": "Invalid JSON"}
sock.sendall((json.dumps(response) + "\n").encode("utf-8"))
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
...
if __name__ == "__main__":
listener_socket = get_listener_from_systemd()
run_auth_daemon(listener_socket)
Generated
+9 -9
View File
@@ -2,11 +2,11 @@
"nodes": {
"crane": {
"locked": {
"lastModified": 1767744144,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"lastModified": 1774313767,
"narHash": "sha256-hy0XTQND6avzGEUFrJtYBBpFa/POiiaGBr2vpU6Y9tY=",
"owner": "ipetkov",
"repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"rev": "3d9df76e29656c679c744968b17fbaf28f0e923d",
"type": "github"
},
"original": {
@@ -17,11 +17,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"lastModified": 1775036866,
"narHash": "sha256-ZojAnPuCdy657PbTq5V0Y+AHKhZAIwSIT2cb8UgAz/U=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"rev": "6201e203d09599479a3b3450ed24fa81537ebc4e",
"type": "github"
},
"original": {
@@ -45,11 +45,11 @@
]
},
"locked": {
"lastModified": 1768186348,
"narHash": "sha256-nkpIe3zkpeoFuOl8xBpexulECsHLQ9Ljg1gW3bPCjSI=",
"lastModified": 1775099554,
"narHash": "sha256-3xBsGnGDLOFtnPZ1D3j2LU19wpAlYefRKTlkv648rU0=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "af69e497567a5945a64057717bc9b17c8478097e",
"rev": "8d6387ed6d8e6e6672fd3ed4b61b59d44b124d99",
"type": "github"
},
"original": {
-1
View File
@@ -105,7 +105,6 @@
fileset = lib.fileset.unions [
(craneLib.fileset.commonCargoSources ./.)
./assets
./examples
];
};
in {
+1 -4
View File
@@ -85,13 +85,10 @@ buildFunction ({
install -Dm644 assets/systemd/muscl.service -t "$out/lib/systemd/system"
substituteInPlace "$out/lib/systemd/system/muscl.service" \
--replace-fail '/usr/bin/muscl-server' "$out/bin/muscl-server"
mkdir -p "$out/share/muscl"
cp -r examples "$out/share/muscl"
'';
meta = with lib; {
license = licenses.mit;
license = licenses.bsd3;
platforms = platforms.linux ++ platforms.darwin;
inherit mainProgram;
};
-92
View File
@@ -27,31 +27,6 @@ in
}.${level};
};
authHandler = lib.mkOption {
type = with lib.types; nullOr lines;
default = null;
description = "Custom authentication handler, written in python";
example = ''
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
if resource_type == "database":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
elif resource_type == "user":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
return False
'';
};
settings = lib.mkOption {
default = { };
type = lib.types.submodule {
@@ -216,72 +191,5 @@ in
++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]);
};
};
systemd.sockets."muscl-auth-daemon" = lib.mkIf (cfg.authHandler != null) {
description = "Authorization daemon for Muscl";
wantedBy = [ "sockets.target" ];
socketConfig = {
ListenStream = "/run/muscl/muscl-auth-daemon.sock";
Accept = "no";
};
};
systemd.services."muscl-auth-daemon" = lib.mkIf (cfg.authHandler != null) {
description = "Authorization daemon for Muscl";
requires = [ "muscl-auth-daemon.socket" ];
serviceConfig = {
Type = "notify";
ExecStart = let
authScript = lib.pipe ../examples/auth_daemon_python/muscl_auth_daemon.py [
lib.fileContents
(lib.replaceString ''
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
...
'' cfg.authHandler)
(pkgs.writers.writePyPy3Bin "muscl-auth-handler.py" { })
];
in lib.getExe authScript;
User = "muscl-auth-daemon";
Group = "muscl-auth-daemon";
DynamicUser = true;
AmbientCapabilities = [ "" ];
CapabilityBoundingSet = [ "" ];
DeviceAllow = [ "" ];
LockPersonality = true;
NoNewPrivileges = true;
PrivateDevices = true;
PrivateMounts = true;
PrivateTmp = "yes";
ProcSubset = "pid";
ProtectClock = true;
ProtectControlGroups = "strict";
ProtectHome = true;
ProtectHostname = true;
ProtectKernelLogs = true;
ProtectKernelModules = true;
ProtectKernelTunables = true;
ProtectProc = "invisible";
ProtectSystem = "strict";
RemoveIPC = true;
UMask = "0777";
RestrictNamespaces = true;
RestrictRealtime = true;
RestrictSUIDSGID = true;
SystemCallArchitectures = "native";
SocketBindDeny = [ "any" ];
SystemCallFilter = [
"@system-service"
"~@privileged"
"~@resources"
];
};
};
};
}
-19
View File
@@ -56,25 +56,6 @@ nixpkgs.lib.nixosSystem {
enable = true;
logLevel = "trace";
createLocalDatabaseUser = true;
authHandler = ''
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
if resource_type == "database":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
elif resource_type == "user":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
return False
'';
};
programs.vim = {
+15 -5
View File
@@ -42,7 +42,15 @@ declare -r GIT_SHA="$2"
TMPDIR="$(mktemp -d)"
for variant in debian-bookworm debian-trixie ubuntu-jammy ubuntu-noble; do
declare -a OS_VARIANTS=(
"debian-bookworm"
"debian-trixie"
"ubuntu-jammy"
"ubuntu-noble"
# "ubuntu-resolute"
)
for variant in "${OS_VARIANTS[@]}"; do
echo "Downloading and uploading debs for variant: $variant"
curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA.zip" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip"
@@ -54,11 +62,13 @@ for variant in debian-bookworm debian-trixie ubuntu-jammy ubuntu-noble; do
DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1)
DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1)
curl \
-X DELETE \
--user "$GITEA_USER:$GITEA_TOKEN" \
"https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
# echo "[DELETE] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
# curl \
# -X DELETE \
# --user "$GITEA_USER:$GITEA_TOKEN" \
# "https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
echo "[PUT] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/upload"
curl \
-X PUT \
--user "$GITEA_USER:$GITEA_TOKEN" \
+2 -1
View File
@@ -319,7 +319,8 @@ fn main() -> anyhow::Result<()> {
#[cfg(not(feature = "suid-sgid-mode"))]
None,
args.verbose,
)?;
)
.context("Failed to connect to the server")?;
tokio_run_command(args.command, connection)?;
+8 -3
View File
@@ -8,6 +8,7 @@ use clap::{Args, Parser};
use clap_complete::ArgValueCompleter;
use dialoguer::{Confirm, Editor};
use futures_util::SinkExt;
use itertools::Itertools;
use nix::unistd::{User, getuid};
use tokio_stream::StreamExt;
@@ -203,9 +204,13 @@ pub async fn edit_database_privileges(
}
})
.flatten()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect::<Vec<_>>(),
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
Ok(list) => list,
Ok(list) => list
.into_iter()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect(),
Err(err) => {
server_connection.send(Request::Exit).await?;
return Err(anyhow::anyhow!(err.to_error_message())
@@ -305,7 +310,7 @@ pub async fn edit_database_privileges(
print_modify_database_privileges_output_status(&result);
if result.iter().any(|(_, res)| {
if result.values().flatten().any(|(_, res)| {
matches!(
res,
Err(ModifyDatabasePrivilegesError::UserValidationError(
@@ -320,7 +325,7 @@ pub async fn edit_database_privileges(
server_connection.send(Request::Exit).await?;
if result.values().any(std::result::Result::is_err) {
if result.values().flatten().any(|(_, res)| res.is_err()) {
std::process::exit(1);
}
+29 -3
View File
@@ -1,5 +1,6 @@
use std::{
fs,
os::unix::fs::FileTypeExt,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
@@ -7,7 +8,10 @@ use std::{
use anyhow::{Context, anyhow};
use clap_verbosity_flag::{InfoLevel, Verbosity};
use nix::libc::{EXIT_SUCCESS, exit};
use nix::{
libc::{EXIT_SUCCESS, exit},
unistd::{AccessFlags, access},
};
use sqlx::mysql::MySqlPoolOptions;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
@@ -130,11 +134,28 @@ pub fn bootstrap_server_connection_and_drop_privileges(
}
}
fn socket_path_is_ok(path: &Path) -> anyhow::Result<()> {
fs::metadata(path)
.context(format!("Failed to get metadata for {:?}", path))
.and_then(|meta| {
if !meta.file_type().is_socket() {
anyhow::bail!("{:?} is not a unix socket", path);
}
access(path, AccessFlags::R_OK | AccessFlags::W_OK)
.with_context(|| format!("Socket at {:?} is not readable/writable", path))?;
Ok(())
})
}
fn connect_to_external_server(
server_socket_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> {
// TODO: ensure this is both readable and writable
if let Some(socket_path) = server_socket_path {
tracing::trace!("Checking socket at {:?}", socket_path);
socket_path_is_ok(&socket_path)?;
tracing::debug!("Connecting to socket at {:?}", socket_path);
return match StdUnixStream::connect(socket_path) {
Ok(socket) => Ok(socket),
@@ -147,6 +168,9 @@ fn connect_to_external_server(
}
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
tracing::trace!("Checking socket at {:?}", DEFAULT_SOCKET_PATH);
socket_path_is_ok(Path::new(DEFAULT_SOCKET_PATH))?;
tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
Ok(socket) => Ok(socket),
@@ -158,7 +182,9 @@ fn connect_to_external_server(
};
}
anyhow::bail!("No socket path provided, and no default socket found");
anyhow::bail!(
"No socket path provided, and no socket found found at default location {DEFAULT_SOCKET_PATH}"
);
}
// TODO: this function is security critical, it should be integration tested
+1 -1
View File
@@ -29,7 +29,7 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
// doesn't have any natural implementation semantics.
/// Representation of the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Default)]
pub struct DatabasePrivilegeRow {
// TODO: don't store the db and user here, let the type be stored in a mapping
pub db: MySQLDatabase,
+6 -3
View File
@@ -324,9 +324,12 @@ impl Response {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ModifyPrivileges(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
}
Response::ModifyPrivileges(res) => ResponseOkStatus::from_counts(
res.len(),
res.values()
.map(|user_map| user_map.values().filter(|v| v.is_ok()).count())
.sum(),
),
Response::CreateUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
@@ -67,3 +67,43 @@ impl CheckAuthorizationError {
self.0.error_type()
}
}
#[cfg(test)]
mod tests {
use crate::core::protocol::request_validation::NameValidationError;
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CheckAuthorizationRequest = vec![
DbOrUser::Database("test_db".into()),
DbOrUser::User("test_user".into()),
];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CheckAuthorizationRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CheckAuthorizationResponse = BTreeMap::from([
(DbOrUser::Database("test_db".into()), Ok(())),
(
DbOrUser::User("test_user".into()),
Err(CheckAuthorizationError(
ValidationError::NameValidationError(NameValidationError::TooLong),
)),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CheckAuthorizationResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,3 +87,41 @@ impl CreateDatabaseError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateDatabasesRequest =
vec!["test_db1".into(), "test_db2".into(), "test_db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateDatabasesResponse = BTreeMap::from([
("test_db1".into(), Ok(())),
(
"test_db2".into(),
Err(CreateDatabaseError::DatabaseAlreadyExists),
),
(
"test_db3".into(),
Err(CreateDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,3 +87,37 @@ impl CreateUserError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(CreateUserError::UserAlreadyExists)),
(
"charlie".into(),
Err(CreateUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -90,3 +90,37 @@ impl DropDatabaseError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropDatabasesRequest = vec!["db1".into(), "db2".into(), "db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropDatabasesResponse = BTreeMap::from([
("db1".into(), Ok(())),
("db2".into(), Err(DropDatabaseError::DatabaseDoesNotExist)),
(
"db3".into(),
Err(DropDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+34
View File
@@ -87,3 +87,37 @@ impl DropUserError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(DropUserError::UserDoesNotExist)),
(
"charlie".into(),
Err(DropUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,3 +27,36 @@ impl ListAllDatabasesError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllDatabasesResponse = Ok(vec![
DatabaseRow {
database: "db1".into(),
tables: vec!["table1".into(), "table2".into()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 1024,
},
DatabaseRow {
database: "db2".into(),
tables: vec!["table3".into(), "table4".into()],
users: vec!["user3".into(), "user4".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 2048,
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,3 +27,34 @@ impl ListAllPrivilegesError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllPrivilegesResponse = Ok(vec![
DatabasePrivilegeRow {
user: "user1".into(),
db: "db1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
},
DatabasePrivilegeRow {
user: "user2".into(),
db: "db2".into(),
select_priv: false,
insert_priv: true,
..Default::default()
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,3 +27,37 @@ impl ListAllUsersError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllUsersResponse = Ok(vec![
DatabaseUser {
user: "user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
},
DatabaseUser {
user: "user2".into(),
host: "%".into(),
has_password: false,
is_locked: true,
databases: vec!["db3".into()],
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListAllUsersResponse = serde_json::from_str(&json).unwrap();
deserialized.as_mut().unwrap()[0].host = "%".into();
deserialized.as_mut().unwrap()[1].host = "%".into();
assert_eq!(response, deserialized);
}
}
+42 -1
View File
@@ -61,7 +61,7 @@ pub fn print_list_databases_output_status(
"Size"
}
]);
for db in final_database_list {
for db in final_database_list.iter().sorted_by_key(|db| &db.database) {
table.add_row(row![
db.database,
db.tables.join("\n"),
@@ -137,3 +137,44 @@ impl ListDatabasesError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request = Some(vec!["db1".into(), "db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListDatabasesResponse = vec![
(
"db1".into(),
Ok(DatabaseRow {
database: "db1".into(),
tables: vec!["table1".to_string(), "table2".to_string()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".to_string()),
character_set: Some("utf8mb4".to_string()),
size_bytes: 1024,
}),
),
("db2".into(), Err(ListDatabasesError::DatabaseDoesNotExist)),
]
.into_iter()
.collect();
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+62 -18
View File
@@ -64,25 +64,28 @@ pub fn print_list_privileges_output_status(output: &ListPrivilegesResponse, long
.collect(),
));
for (_database, rows) in final_privs_map {
for row in &rows {
table.add_row(row![
row.db,
row.user,
c->yn(row.select_priv),
c->yn(row.insert_priv),
c->yn(row.update_priv),
c->yn(row.delete_priv),
c->yn(row.create_priv),
c->yn(row.drop_priv),
c->yn(row.alter_priv),
c->yn(row.index_priv),
c->yn(row.create_tmp_table_priv),
c->yn(row.lock_tables_priv),
c->yn(row.references_priv),
]);
}
for row in final_privs_map
.values()
.flatten()
.sorted_by_key(|row| (&row.db, &row.user))
{
table.add_row(row![
row.db,
row.user,
c->yn(row.select_priv),
c->yn(row.insert_priv),
c->yn(row.update_priv),
c->yn(row.delete_priv),
c->yn(row.create_priv),
c->yn(row.drop_priv),
c->yn(row.alter_priv),
c->yn(row.index_priv),
c->yn(row.create_tmp_table_priv),
c->yn(row.lock_tables_priv),
c->yn(row.references_priv),
]);
}
// }
table.printstd();
}
@@ -153,3 +156,44 @@ impl ListPrivilegesError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListPrivilegesRequest = Some(vec!["test_db1".into(), "test_db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListPrivilegesResponse = BTreeMap::from([
(
"test_db1".into(),
Ok(vec![DatabasePrivilegeRow {
db: "test_db1".into(),
user: "user1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
}]),
),
(
"test_db2".into(),
Err(ListPrivilegesError::DatabaseDoesNotExist),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+47 -1
View File
@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use itertools::Itertools;
use prettytable::Table;
use serde::{Deserialize, Serialize};
use serde_json::json;
@@ -51,7 +52,7 @@ pub fn print_list_users_output_status(output: &ListUsersResponse) {
"Locked",
"Databases where user has privileges"
]);
for user in final_user_list {
for user in final_user_list.iter().sorted_by_key(|user| &user.user) {
table.add_row(row![
user.user,
user.has_password,
@@ -121,3 +122,48 @@ impl ListUsersError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListUsersRequest = Some(vec!["test_user1".into(), "test_user2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: ListUsersResponse = BTreeMap::from([
(
"test_user1".into(),
Ok(DatabaseUser {
user: "test_user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
}),
),
("test_user2".into(), Err(ListUsersError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListUsersResponse = serde_json::from_str(&json).unwrap();
deserialized
.get_mut(&"test_user1".into())
.unwrap()
.as_mut()
.unwrap()
.host = "%".into();
assert_eq!(response_ok, deserialized);
}
}
+30
View File
@@ -94,3 +94,33 @@ impl LockUserError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: LockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: LockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: LockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(LockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: LockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
@@ -12,7 +12,7 @@ use crate::core::{
pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>;
pub type ModifyPrivilegesResponse =
BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>;
BTreeMap<MySQLDatabase, BTreeMap<MySQLUser, Result<(), ModifyDatabasePrivilegesError>>>;
#[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModifyDatabasePrivilegesError {
@@ -49,7 +49,11 @@ pub enum DiffDoesNotApplyError {
}
pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) {
for ((database_name, username), result) in output {
for ((database_name, username), result) in output.iter().flat_map(|(db, user_map)| {
user_map
.iter()
.map(move |(user, result)| ((db, user), result))
}) {
match result {
Ok(()) => {
println!(
@@ -144,3 +148,46 @@ impl DiffDoesNotApplyError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::*;
#[test]
fn test_serialize_deserialize_request() {
let request =
BTreeSet::from([DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "test_db".into(),
user: "test_user".into(),
select_priv: Some(database_privileges::DatabasePrivilegeChange::NoToYes),
..Default::default()
})]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ModifyPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ModifyPrivilegesResponse = BTreeMap::from([(
"test_db".into(),
BTreeMap::from([
("test_user".into(), Ok(())),
(
"invalid_user".into(),
Err(ModifyDatabasePrivilegesError::UserDoesNotExist),
),
]),
)]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ModifyPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+32
View File
@@ -60,3 +60,35 @@ impl SetPasswordError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: SetUserPasswordRequest = ("test_user".into(), "new_password".into());
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: SetUserPasswordRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: SetUserPasswordResponse = Ok(());
let response_err: SetUserPasswordResponse = Err(SetPasswordError::UserDoesNotExist);
let json_ok = serde_json::to_string_pretty(&response_ok).unwrap();
let json_err = serde_json::to_string_pretty(&response_err).unwrap();
println!("Serialized OK response:\n{}", json_ok);
println!("Serialized Error response:\n{}", json_err);
let deserialized_ok: SetUserPasswordResponse = serde_json::from_str(&json_ok).unwrap();
let deserialized_err: SetUserPasswordResponse = serde_json::from_str(&json_err).unwrap();
assert_eq!(response_ok, deserialized_ok);
assert_eq!(response_err, deserialized_err);
}
}
@@ -94,3 +94,33 @@ impl UnlockUserError {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: UnlockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: UnlockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: UnlockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(UnlockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: UnlockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
+34 -3
View File
@@ -34,7 +34,7 @@ impl DerefMut for MySQLUser {
impl fmt::Display for MySQLUser {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0))
self.0.fmt(f)
}
}
@@ -83,7 +83,7 @@ impl DerefMut for MySQLDatabase {
impl fmt::Display for MySQLDatabase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0))
self.0.fmt(f)
}
}
@@ -105,12 +105,43 @@ impl From<MySQLDatabase> for OsString {
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum DbOrUser {
Database(MySQLDatabase),
User(MySQLUser),
}
impl Serialize for DbOrUser {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
DbOrUser::Database(db) => ("d:".to_string() + &db.to_string()).serialize(serializer),
DbOrUser::User(user) => ("u:".to_string() + &user.to_string()).serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for DbOrUser {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Some(rest) = s.strip_prefix("d:") {
Ok(DbOrUser::Database(MySQLDatabase(rest.to_string())))
} else if let Some(rest) = s.strip_prefix("u:") {
Ok(DbOrUser::User(MySQLUser(rest.to_string())))
} else {
Err(serde::de::Error::custom(format!(
"Invalid DbOrUser format: {}",
s
)))
}
}
}
impl DbOrUser {
#[must_use]
pub fn lowercased_noun(&self) -> &'static str {
+4
View File
@@ -59,6 +59,10 @@ fn parse_group_denylist(denylist_path: &Path, lines: Lines) -> GroupDenylist {
}
.trim();
if trimmed_line.is_empty() {
continue;
}
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
if parts.len() != 2 {
tracing::warn!(
+36 -20
View File
@@ -114,7 +114,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ?",
"SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
@@ -234,11 +234,12 @@ async fn unsafe_apply_privilege_diff(
DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS
.iter()
.chain(&["Host"])
.map(|field| quote_identifier(field))
.join(",");
let question_marks =
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len()).join(",");
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(",");
sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str())
.bind(p.db.to_string())
@@ -254,6 +255,7 @@ async fn unsafe_apply_privilege_diff(
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.bind("%")
.execute(connection)
.await
.map(|_| ())
@@ -278,28 +280,33 @@ async fn unsafe_apply_privilege_diff(
}
}
sqlx::query(format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ?").as_str())
.bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn))
.bind(p.delete_priv.map(change_to_yn))
.bind(p.create_priv.map(change_to_yn))
.bind(p.drop_priv.map(change_to_yn))
.bind(p.alter_priv.map(change_to_yn))
.bind(p.index_priv.map(change_to_yn))
.bind(p.create_tmp_table_priv.map(change_to_yn))
.bind(p.lock_tables_priv.map(change_to_yn))
.bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
.await
.map(|_| ())
sqlx::query(
format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
.as_str(),
)
.bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn))
.bind(p.delete_priv.map(change_to_yn))
.bind(p.create_priv.map(change_to_yn))
.bind(p.drop_priv.map(change_to_yn))
.bind(p.alter_priv.map(change_to_yn))
.bind(p.index_priv.map(change_to_yn))
.bind(p.create_tmp_table_priv.map(change_to_yn))
.bind(p.lock_tables_priv.map(change_to_yn))
.bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string())
.bind(p.user.to_string())
.bind("%")
.execute(connection)
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ?")
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
.bind(p.db.to_string())
.bind(p.user.to_string())
.bind("%")
.execute(connection)
.await
.map(|_| ())
@@ -481,4 +488,13 @@ pub async fn apply_privilege_diffs(
}
results
.into_iter()
.map(|((k1, k2), v)| (k1, (k2, v)))
.into_group_map()
.into_iter()
.map(|(k1, pairs)| {
let inner = pairs.into_iter().collect::<BTreeMap<_, _>>();
(k1, inner)
})
.collect()
}
+5 -3
View File
@@ -39,6 +39,7 @@ pub(super) async fn unsafe_user_exists(
SELECT 1
FROM `mysql`.`user`
WHERE `User` = ?
AND `Host` = '%'
)
",
)
@@ -67,6 +68,7 @@ pub async fn complete_user_name(
FROM `mysql`.`user`
WHERE `User` REGEXP ?
AND `User` LIKE ?
AND `Host` = '%'
",
)
.bind(create_user_group_matching_regex(unix_user, group_denylist))
@@ -462,7 +464,7 @@ pub async fn list_database_users(
DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `mysql`.`user`.`User` = ?"),
} + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'"),
)
.bind(db_user.as_str())
.fetch_optional(&mut *connection)
@@ -499,7 +501,7 @@ pub async fn list_all_database_users_for_unix_user(
DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ?"),
} + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'"),
)
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(&mut *connection)
@@ -534,7 +536,7 @@ pub async fn set_databases_where_user_has_privileges(
r"
SELECT `Db` AS `database`
FROM `db`
WHERE `User` = ? AND ({})
WHERE `User` = ? AND `Host` = '%' AND ({})
",
DATABASE_PRIVILEGE_FIELDS
.iter()
+13 -20
View File
@@ -90,14 +90,12 @@ impl Supervisor {
};
let mut watchdog_duration = None;
let mut watchdog_micro_seconds = 0;
#[cfg(target_os = "linux")]
let watchdog_task =
if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
let watchdog_duration_ = Duration::from_micros(watchdog_micro_seconds);
if systemd_mode && let Some(watchdog_duration_) = sd_notify::watchdog_enabled() {
tracing::debug!(
"Systemd watchdog enabled with {} millisecond interval",
watchdog_micro_seconds.div_ceil(1000),
watchdog_duration_.as_millis()
);
watchdog_duration = Some(watchdog_duration_);
Some(spawn_watchdog_task(watchdog_duration_))
@@ -295,15 +293,12 @@ impl Supervisor {
pub async fn reload(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")]
sd_notify::notify(
false,
&[
sd_notify::NotifyState::Reloading,
sd_notify::NotifyState::monotonic_usec_now()
.expect("Failed to get monotonic time to send to systemd while reloading"),
sd_notify::NotifyState::Status("Reloading configuration"),
],
)?;
sd_notify::notify(&[
sd_notify::NotifyState::Reloading,
sd_notify::NotifyState::monotonic_usec_now()
.expect("Failed to get monotonic time to send to systemd while reloading"),
sd_notify::NotifyState::Status("Reloading configuration"),
])?;
let previous_config = self.config.lock().await.clone();
self.reload_config().await?;
@@ -340,14 +335,14 @@ impl Supervisor {
}
#[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
sd_notify::notify(&[sd_notify::NotifyState::Ready])?;
Ok(())
}
pub async fn shutdown(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?;
sd_notify::notify(&[sd_notify::NotifyState::Stopping])?;
tracing::debug!("Stop accepting new connections");
self.stop_receiving_new_connections()?;
@@ -417,7 +412,7 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
);
loop {
interval.tick().await;
if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
if let Err(err) = sd_notify::notify(&[sd_notify::NotifyState::Watchdog]) {
tracing::warn!("Failed to notify systemd watchdog: {}", err);
}
}
@@ -440,9 +435,7 @@ fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
"Waiting for connections".to_string()
};
if let Err(e) =
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())])
{
if let Err(e) = sd_notify::notify(&[sd_notify::NotifyState::Status(message.as_str())]) {
tracing::warn!("Failed to send systemd status notification: {}", e);
}
}
@@ -557,7 +550,7 @@ async fn listener_task(
group_denylist: Arc<RwLock<GroupDenylist>>,
) -> anyhow::Result<()> {
#[cfg(target_os = "linux")]
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
sd_notify::notify(&[sd_notify::NotifyState::Ready])?;
let connection_counter = AtomicU64::new(0);