1 Commits

Author SHA1 Message Date
oysteikt 7f45c49a79 WIP 2026-01-12 16:32:30 +09:00
46 changed files with 1161 additions and 1775 deletions
+2 -8
View File
@@ -31,13 +31,7 @@ jobs:
build-deb: build-deb:
strategy: strategy:
matrix: matrix:
os: [ os: [debian-trixie, debian-bookworm, ubuntu-noble, ubuntu-jammy]
debian-trixie,
debian-bookworm,
# ubuntu-resolute,
ubuntu-noble,
ubuntu-jammy,
]
name: Build and publish for ${{ matrix.os }} name: Build and publish for ${{ matrix.os }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@@ -69,7 +63,7 @@ jobs:
- name: Upload deb package artifact - name: Upload deb package artifact
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
with: with:
name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }} name: muscl-deb-${{ matrix.os }}-${{ gitea.sha }}.zip
path: target/debian/*.deb path: target/debian/*.deb
if-no-files-found: error if-no-files-found: error
retention-days: 30 retention-days: 30
+1 -23
View File
@@ -1,27 +1,5 @@
# Changelog # Changelog
## v1.0.2
Patch release with an important bug fix
### Notable changes
- Run `FLUSH PRIVILEGES` on the server whenever users modify privileges.
- You will have to grant `RELOAD` for the muscl admin user on all databases, see the [installation docs](./docs/installation.md) for details.
- Bump dependencies
## v1.0.1
Patch release with some important bug fixes
### Notable changes
- `mysql.db.Host` would usually be unset when creating privileges for users, this should be fixed now.
- You might have to manually set this field for rows created with the previous version of muscl to have those privileges work properly.
- Fixed an issue where a few select server responses would refuse to serialize properly, leading to an error message: "No response from server"
- The output of various commands is now being sorted.
- Bump dependencies
## v1.0.0 - Initial Release ## v1.0.0 - Initial Release
This is the initial release of `muscl`. This is the initial release of `muscl`.
@@ -76,7 +54,7 @@ This is the initial release of `muscl`.
interactive tool, there shouldn't have been any scripts relying on the old formatting. interactive tool, there shouldn't have been any scripts relying on the old formatting.
- The configuration file is shared for all variants of the program, and `muscl` will use - The configuration file is shared for all variants of the program, and `muscl` will use
its new logic to look for and parse this file. See the example config and its new logic to look for and parse this file. See the example config and
[installation instructions](./docs/installation.md) for more information about how to [installation instructions][installation-instructions] for more information about how to
configure the software. configure the software.
- The order in which input is validated might be differ from the original - The order in which input is validated might be differ from the original
(e.g. database ownership checks, invalid character checks, existence checks, ...). (e.g. database ownership checks, invalid character checks, existence checks, ...).
Generated
+636 -662
View File
File diff suppressed because it is too large Load Diff
+23 -22
View File
@@ -1,11 +1,12 @@
[package] [package]
name = "muscl" name = "muscl"
version = "1.0.2" version = "0.1.0"
edition = "2024" edition = "2024"
resolver = "2" resolver = "2"
license = "BSD-3-Clause" license = "BSD-3-Clause"
authors = [ authors = [
"Programvareverkstedet <projects@pvv.ntnu.no>", "oysteikt@pvv.ntnu.no",
"felixalb@pvv.ntnu.no",
] ]
homepage = "https://git.pvv.ntnu.no/Projects/muscl" homepage = "https://git.pvv.ntnu.no/Projects/muscl"
repository = "https://git.pvv.ntnu.no/Projects/muscl" repository = "https://git.pvv.ntnu.no/Projects/muscl"
@@ -18,50 +19,50 @@ autobins = false
autolib = false autolib = false
[dependencies] [dependencies]
anyhow = "1.0.102" anyhow = "1.0.100"
async-bincode = "0.8.0" async-bincode = "0.8.0"
bincode = "2.0.1" bincode = "2.0.1"
clap = { version = "4.6.1", features = ["cargo", "derive"] } clap = { version = "4.5.54", features = ["cargo", "derive"] }
clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] } clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] }
clap_complete = { version = "4.6.5", features = ["unstable-dynamic"] } clap_complete = { version = "4.5.65", features = ["unstable-dynamic"] }
color-print = "0.3.7" color-print = "0.3.7"
const_format = "0.2.36" const_format = "0.2.35"
derive_more = { version = "2.1.1", features = ["display", "error"] } derive_more = { version = "2.1.1", features = ["display", "error"] }
dialoguer = "0.12.0" dialoguer = "0.12.0"
futures-util = "0.3.32" futures-util = "0.3.31"
humansize = "2.1.3" humansize = "2.1.3"
indoc = "2.0.7" indoc = "2.0.7"
itertools = "0.14.0" itertools = "0.14.0"
nix = { version = "0.31.3", features = ["fs", "process", "socket", "user"] } nix = { version = "0.30.1", features = ["fs", "process", "socket", "user"] }
num_cpus = "1.17.0" num_cpus = "1.17.0"
prettytable = "0.10.0" prettytable = "0.10.0"
rand = "0.10.1" rand = "0.9.2"
serde = "1.0.228" serde = "1.0.228"
serde_json = { version = "1.0.150", features = ["preserve_order"] } serde_json = { version = "1.0.149", features = ["preserve_order"] }
sqlx = { version = "0.9.0", features = ["runtime-tokio", "mysql", "tls-rustls"] } sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
thiserror = "2.0.18" thiserror = "2.0.17"
tokio = { version = "1.52.3", features = ["rt-multi-thread", "macros", "signal"] } tokio = { version = "1.49.0", features = ["rt-multi-thread", "macros", "signal"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.18" tokio-stream = "0.1.18"
tokio-util = { version = "0.7.18", features = ["codec", "rt"] } tokio-util = { version = "0.7.18", features = ["codec", "rt"] }
toml = "1.1.2" toml = "0.9.11"
tracing = { version = "0.1.44", features = ["log"] } tracing = { version = "0.1.44", features = ["log"] }
tracing-subscriber = "0.3.23" tracing-subscriber = "0.3.22"
uuid = { version = "1.23.2", features = ["v4"] } uuid = { version = "1.19.0", features = ["v4"] }
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]
landlock = "0.4.5" landlock = "0.4.4"
sd-notify = "0.5.0" sd-notify = "0.4.5"
tracing-journald = "0.3.2" tracing-journald = "0.3.2"
[build-dependencies] [build-dependencies]
anyhow = "1.0.102" anyhow = "1.0.100"
build-info-build = "0.0.44" build-info-build = "0.0.42"
git2 = { version = "0.21.0", default-features = false } git2 = { version = "0.20.3", default-features = false }
[dev-dependencies] [dev-dependencies]
pretty_assertions = "1.4.1" pretty_assertions = "1.4.1"
regex = "1.12.3" regex = "1.12.2"
[features] [features]
default = ["mysql-admutils-compatibility"] default = ["mysql-admutils-compatibility"]
-2
View File
@@ -9,8 +9,6 @@ ExecStart=/usr/bin/muscl-server --systemd --disable-landlock socket-activate
ExecReload=/usr/bin/kill -HUP $MAINPID ExecReload=/usr/bin/kill -HUP $MAINPID
WatchdogSec=3min WatchdogSec=3min
Restart=always
RestartSec=10s
# Although this is a multi-instance unit, the constant `User` field is needed # Although this is a multi-instance unit, the constant `User` field is needed
# for authentication via mysql's auth_socket plugin to work. # for authentication via mysql's auth_socket plugin to work.
-1
View File
@@ -3,7 +3,6 @@ Description=Muscl MySQL admin tool
[Socket] [Socket]
ListenStream=/run/muscl/muscl.sock ListenStream=/run/muscl/muscl.sock
RemoveOnStop=true
Accept=no Accept=no
PassCredentials=true PassCredentials=true
+1 -1
View File
@@ -42,7 +42,7 @@ on the MySQL server as the admin user (or another user with sufficient privilege
```sql ```sql
CREATE USER `muscl`@`localhost` IDENTIFIED BY '<strong_password_here>'; CREATE USER `muscl`@`localhost` IDENTIFIED BY '<strong_password_here>';
GRANT SELECT, INSERT, UPDATE, DELETE ON `mysql`.* TO `muscl`@`localhost`; GRANT SELECT, INSERT, UPDATE, DELETE ON `mysql`.* TO `muscl`@`localhost`;
GRANT GRANT OPTION, CREATE, DROP, RELOAD ON *.* TO `muscl`@`localhost`; GRANT GRANT OPTION, CREATE, DROP ON *.* TO `muscl`@`localhost`;
FLUSH PRIVILEGES; FLUSH PRIVILEGES;
``` ```
@@ -0,0 +1,53 @@
[Unit]
Description=Authorization daemon for Muscl
[Service]
Type=notify
ExecStart=/usr/local/bin/muscl_auth_daemon.py
# WatchdogSec=15
User=muscl
Group=muscl
DynamicUser=yes
; ConfigurationDirectory=muscl
; RuntimeDirectory=muscl
; # This is required to read unix user/group details.
; PrivateUsers=false
; # Needed to communicate with MySQL.
; PrivateNetwork=false
; PrivateIPC=false
; AmbientCapabilities=
; CapabilityBoundingSet=
; DeviceAllow=
; DevicePolicy=closed
; LockPersonality=true
; MemoryDenyWriteExecute=true
; NoNewPrivileges=true
; PrivateDevices=true
; PrivateMounts=true
; PrivateTmp=yes
; ProcSubset=pid
; ProtectClock=true
; ProtectControlGroups=strict
; ProtectHome=true
; ProtectHostname=true
; ProtectKernelLogs=true
; ProtectKernelModules=true
; ProtectKernelTunables=true
; ProtectProc=invisible
; ProtectSystem=strict
; RemoveIPC=true
; RestrictAddressFamilies=AF_UNIX AF_INET AF_INET6
; RestrictNamespaces=true
; RestrictRealtime=true
; RestrictSUIDSGID=true
; SocketBindDeny=any
; SystemCallArchitectures=native
; SystemCallFilter=@system-service
; SystemCallFilter=~@privileged @resources
; UMask=0777
@@ -0,0 +1,8 @@
[Unit]
Description=Authorization daemon for Muscl
WantedBy=sockets.target
[Socket]
ListenStream=/run/muscl/muscl-auth-daemon.socket
Accept=no
SocketMode=0660
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# TODO: create pool of workers to handle requests concurrently
# the socket should be a listener socket and each worker should accept connections from it
# the socket should accept requests as newline-separated JSON objects
# there should be a watchdog to monitor worker health and restart them if they die
# graceful shutdown should be implemented for the workers
# optional logging of requests and responses
# use systemd notify to signal readiness and amount of connections handled
import json
import os
from socket import AF_UNIX, SOCK_DGRAM, SOCK_STREAM, fromfd, socket
from multiprocessing import Pool
def get_listener_from_systemd() -> socket:
listen_fds = int(os.getenv("LISTEN_FDS", "0"))
listen_pid = int(os.getenv("LISTEN_PID", "0"))
if listen_fds != 1 or listen_pid != os.getpid():
raise RuntimeError("No socket passed from systemd")
assert listen_fds == 1
sock = fromfd(3, AF_UNIX, SOCK_STREAM)
sock.setblocking(False)
return sock
def get_notify_socket_from_systemd() -> socket:
notify_socket_path = os.getenv("NOTIFY_SOCKET")
if not notify_socket_path:
raise RuntimeError("No notify socket path found in environment")
sock = socket(AF_UNIX, SOCK_DGRAM)
sock.connect(notify_socket_path)
return sock
def run_auth_daemon(sock: socket):
sock.listen()
print("Auth daemon is running and listening for connections...")
with Pool() as worker_pool:
with get_notify_socket_from_systemd() as notify_socket:
notify_socket.sendall(b"READY=1\n")
while True:
conn, _ = sock.accept()
worker_pool.apply_async(session_handler, args=(conn,))
def session_handler(sock: socket):
buffer = ""
while True:
data = sock.recv(4096).decode("utf-8")
if not data:
print("Connection closed by client")
break
buffer += data
if buffer.endswith("\n"):
requests = buffer.strip().split("\n")
buffer = ""
for request in requests:
try:
req_json = json.loads(request)
username = req_json.get("username", "")
groups = req_json.get("groups", [])
resource_type = req_json.get("resource_type", "")
resource = req_json.get("resource", "")
allowed = process_request(username, groups, resource_type, resource)
response = {"allowed": allowed}
except json.JSONDecodeError:
response = {"error": "Invalid JSON"}
sock.sendall((json.dumps(response) + "\n").encode("utf-8"))
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
...
if __name__ == "__main__":
listener_socket = get_listener_from_systemd()
run_auth_daemon(listener_socket)
Generated
+9 -9
View File
@@ -2,11 +2,11 @@
"nodes": { "nodes": {
"crane": { "crane": {
"locked": { "locked": {
"lastModified": 1780099841, "lastModified": 1767744144,
"narHash": "sha256-EVZd2RsbpreRUDSi9rBwPY+ZxoyMaiEBbZxxhljbaS4=", "narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"owner": "ipetkov", "owner": "ipetkov",
"repo": "crane", "repo": "crane",
"rev": "0532eb17955225173906d671fb36306bdeb1e2dc", "rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -17,11 +17,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1779560665, "lastModified": 1768127708,
"narHash": "sha256-tpyBcxPpcQb8ukyNF7DoCwfSY3VPsxHoYwj00Cayv5o=", "narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "64c08a7ca051951c8eae34e3e3cb1e202fe36786", "rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github" "type": "github"
}, },
"original": { "original": {
@@ -45,11 +45,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1780110990, "lastModified": 1768186348,
"narHash": "sha256-6QBThUi7SuK+dgA+DCaEkQGZN4kYx6DpXmK45+MG9zI=", "narHash": "sha256-nkpIe3zkpeoFuOl8xBpexulECsHLQ9Ljg1gW3bPCjSI=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "85570ef134d92a8702de6afd1f6f0209c863fa91", "rev": "af69e497567a5945a64057717bc9b17c8478097e",
"type": "github" "type": "github"
}, },
"original": { "original": {
+1
View File
@@ -105,6 +105,7 @@
fileset = lib.fileset.unions [ fileset = lib.fileset.unions [
(craneLib.fileset.commonCargoSources ./.) (craneLib.fileset.commonCargoSources ./.)
./assets ./assets
./examples
]; ];
}; };
in { in {
+4 -1
View File
@@ -85,10 +85,13 @@ buildFunction ({
install -Dm644 assets/systemd/muscl.service -t "$out/lib/systemd/system" install -Dm644 assets/systemd/muscl.service -t "$out/lib/systemd/system"
substituteInPlace "$out/lib/systemd/system/muscl.service" \ substituteInPlace "$out/lib/systemd/system/muscl.service" \
--replace-fail '/usr/bin/muscl-server' "$out/bin/muscl-server" --replace-fail '/usr/bin/muscl-server' "$out/bin/muscl-server"
mkdir -p "$out/share/muscl"
cp -r examples "$out/share/muscl"
''; '';
meta = with lib; { meta = with lib; {
license = licenses.bsd3; license = licenses.mit;
platforms = platforms.linux ++ platforms.darwin; platforms = platforms.linux ++ platforms.darwin;
inherit mainProgram; inherit mainProgram;
}; };
+92
View File
@@ -27,6 +27,31 @@ in
}.${level}; }.${level};
}; };
authHandler = lib.mkOption {
type = with lib.types; nullOr lines;
default = null;
description = "Custom authentication handler, written in python";
example = ''
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
if resource_type == "database":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
elif resource_type == "user":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
return False
'';
};
settings = lib.mkOption { settings = lib.mkOption {
default = { }; default = { };
type = lib.types.submodule { type = lib.types.submodule {
@@ -191,5 +216,72 @@ in
++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]); ++ (lib.optionals (cfg.settings.mysql.host != null) [ "AF_INET" "AF_INET6" ]);
}; };
}; };
systemd.sockets."muscl-auth-daemon" = lib.mkIf (cfg.authHandler != null) {
description = "Authorization daemon for Muscl";
wantedBy = [ "sockets.target" ];
socketConfig = {
ListenStream = "/run/muscl/muscl-auth-daemon.sock";
Accept = "no";
};
};
systemd.services."muscl-auth-daemon" = lib.mkIf (cfg.authHandler != null) {
description = "Authorization daemon for Muscl";
requires = [ "muscl-auth-daemon.socket" ];
serviceConfig = {
Type = "notify";
ExecStart = let
authScript = lib.pipe ../examples/auth_daemon_python/muscl_auth_daemon.py [
lib.fileContents
(lib.replaceString ''
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
...
'' cfg.authHandler)
(pkgs.writers.writePyPy3Bin "muscl-auth-handler.py" { })
];
in lib.getExe authScript;
User = "muscl-auth-daemon";
Group = "muscl-auth-daemon";
DynamicUser = true;
AmbientCapabilities = [ "" ];
CapabilityBoundingSet = [ "" ];
DeviceAllow = [ "" ];
LockPersonality = true;
NoNewPrivileges = true;
PrivateDevices = true;
PrivateMounts = true;
PrivateTmp = "yes";
ProcSubset = "pid";
ProtectClock = true;
ProtectControlGroups = "strict";
ProtectHome = true;
ProtectHostname = true;
ProtectKernelLogs = true;
ProtectKernelModules = true;
ProtectKernelTunables = true;
ProtectProc = "invisible";
ProtectSystem = "strict";
RemoveIPC = true;
UMask = "0777";
RestrictNamespaces = true;
RestrictRealtime = true;
RestrictSUIDSGID = true;
SystemCallArchitectures = "native";
SocketBindDeny = [ "any" ];
SystemCallFilter = [
"@system-service"
"~@privileged"
"~@resources"
];
};
};
}; };
} }
+19
View File
@@ -56,6 +56,25 @@ nixpkgs.lib.nixosSystem {
enable = true; enable = true;
logLevel = "trace"; logLevel = "trace";
createLocalDatabaseUser = true; createLocalDatabaseUser = true;
authHandler = ''
def process_request(
username: str,
groups: list[str],
resource_type: str,
resource: str,
) -> bool:
if resource_type == "database":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
elif resource_type == "user":
if resource.startswith(username) or any(
resource.startswith(group) for group in groups
):
return True
return False
'';
}; };
programs.vim = { programs.vim = {
+6 -16
View File
@@ -42,17 +42,9 @@ declare -r GIT_SHA="$2"
TMPDIR="$(mktemp -d)" TMPDIR="$(mktemp -d)"
declare -a OS_VARIANTS=( for variant in debian-bookworm debian-trixie ubuntu-jammy ubuntu-noble; do
"debian-bookworm"
"debian-trixie"
"ubuntu-jammy"
"ubuntu-noble"
# "ubuntu-resolute"
)
for variant in "${OS_VARIANTS[@]}"; do
echo "Downloading and uploading debs for variant: $variant" echo "Downloading and uploading debs for variant: $variant"
curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" curl "https://git.pvv.ntnu.no/Projects/muscl/actions/runs/$RUN_NUMBER/artifacts/muscl-deb-$variant-$GIT_SHA.zip" --output "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip"
unzip "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" -d "$TMPDIR/muscl-deb-$variant-$GIT_SHA" unzip "$TMPDIR/muscl-deb-$variant-$GIT_SHA.zip" -d "$TMPDIR/muscl-deb-$variant-$GIT_SHA"
@@ -62,13 +54,11 @@ for variant in "${OS_VARIANTS[@]}"; do
DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1) DEB_VERSION=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f2 | head -n1)
DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1) DEB_ARCH=$(find "$TMPDIR/muscl-deb-$variant-$GIT_SHA"/*.deb -print0 | xargs -0 -n1 basename | cut -d'_' -f3 | cut -d'.' -f1 | head -n1)
# echo "[DELETE] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH" curl \
# curl \ -X DELETE \
# -X DELETE \ --user "$GITEA_USER:$GITEA_TOKEN" \
# --user "$GITEA_USER:$GITEA_TOKEN" \ "https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
# "https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/$DEB_NAME/$DEB_VERSION/$DEB_ARCH"
echo "[PUT] https://git.pvv.ntnu.no/api/packages/Projects/debian/pool/$DISTRO_VERSION_NAME/main/upload"
curl \ curl \
-X PUT \ -X PUT \
--user "$GITEA_USER:$GITEA_TOKEN" \ --user "$GITEA_USER:$GITEA_TOKEN" \
+1 -2
View File
@@ -319,8 +319,7 @@ fn main() -> anyhow::Result<()> {
#[cfg(not(feature = "suid-sgid-mode"))] #[cfg(not(feature = "suid-sgid-mode"))]
None, None,
args.verbose, args.verbose,
) )?;
.context("Failed to connect to the server")?;
tokio_run_command(args.command, connection)?; tokio_run_command(args.command, connection)?;
+5 -10
View File
@@ -8,7 +8,6 @@ use clap::{Args, Parser};
use clap_complete::ArgValueCompleter; use clap_complete::ArgValueCompleter;
use dialoguer::{Confirm, Editor}; use dialoguer::{Confirm, Editor};
use futures_util::SinkExt; use futures_util::SinkExt;
use itertools::Itertools;
use nix::unistd::{User, getuid}; use nix::unistd::{User, getuid};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
@@ -23,7 +22,7 @@ use crate::{
parse_privilege_data_from_editor_content, reduce_privilege_diffs, parse_privilege_data_from_editor_content, reduce_privilege_diffs,
}, },
protocol::{ protocol::{
ClientToServerMessageStream, ListDatabasesError, ListDatabasesRequest, ListUsersError, ClientToServerMessageStream, ListDatabasesError, ListUsersError,
ModifyDatabasePrivilegesError, Request, Response, ModifyDatabasePrivilegesError, Request, Response,
print_modify_database_privileges_output_status, request_validation::ValidationError, print_modify_database_privileges_output_status, request_validation::ValidationError,
}, },
@@ -132,7 +131,7 @@ async fn databases_exist(
.map(|diff| diff.get_database_name().clone()) .map(|diff| diff.get_database_name().clone())
.collect(); .collect();
let message = Request::ListDatabases(ListDatabasesRequest::new(Some(database_list), false)); let message = Request::ListDatabases(Some(database_list));
server_connection.send(message).await?; server_connection.send(message).await?;
let result = match server_connection.next().await { let result = match server_connection.next().await {
@@ -204,13 +203,9 @@ pub async fn edit_database_privileges(
} }
}) })
.flatten() .flatten()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows { Some(Ok(Response::ListAllPrivileges(privilege_rows))) => match privilege_rows {
Ok(list) => list Ok(list) => list,
.into_iter()
.sorted_by_key(|row| (row.db.clone(), row.user.clone()))
.collect(),
Err(err) => { Err(err) => {
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
return Err(anyhow::anyhow!(err.to_error_message()) return Err(anyhow::anyhow!(err.to_error_message())
@@ -310,7 +305,7 @@ pub async fn edit_database_privileges(
print_modify_database_privileges_output_status(&result); print_modify_database_privileges_output_status(&result);
if result.values().flatten().any(|(_, res)| { if result.iter().any(|(_, res)| {
matches!( matches!(
res, res,
Err(ModifyDatabasePrivilegesError::UserValidationError( Err(ModifyDatabasePrivilegesError::UserValidationError(
@@ -325,7 +320,7 @@ pub async fn edit_database_privileges(
server_connection.send(Request::Exit).await?; server_connection.send(Request::Exit).await?;
if result.values().flatten().any(|(_, res)| res.is_err()) { if result.values().any(std::result::Result::is_err) {
std::process::exit(1); std::process::exit(1);
} }
+7 -14
View File
@@ -8,8 +8,8 @@ use crate::{
core::{ core::{
completion::mysql_database_completer, completion::mysql_database_completer,
protocol::{ protocol::{
ClientToServerMessageStream, ListDatabasesError, ListDatabasesRequest, Request, ClientToServerMessageStream, ListDatabasesError, Request, Response,
Response, print_list_databases_output_status, print_list_databases_output_status_json, print_list_databases_output_status, print_list_databases_output_status_json,
request_validation::ValidationError, request_validation::ValidationError,
}, },
types::MySQLDatabase, types::MySQLDatabase,
@@ -27,10 +27,6 @@ pub struct ShowDbArgs {
#[arg(short, long)] #[arg(short, long)]
json: bool, json: bool,
/// Show all tables and users for each database
#[arg(short = 'a', long)]
all: bool,
/// Show sizes in bytes instead of human-readable format /// Show sizes in bytes instead of human-readable format
#[arg(short, long)] #[arg(short, long)]
bytes: bool, bytes: bool,
@@ -40,14 +36,11 @@ pub async fn show_databases(
args: ShowDbArgs, args: ShowDbArgs,
mut server_connection: ClientToServerMessageStream, mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let message = Request::ListDatabases(ListDatabasesRequest::new( let message = if args.name.is_empty() {
if args.name.is_empty() { Request::ListDatabases(None)
None } else {
} else { Request::ListDatabases(Some(args.name.clone()))
Some(args.name.clone()) };
},
args.all || args.json,
));
server_connection.send(message).await?; server_connection.send(message).await?;
@@ -22,8 +22,8 @@ use crate::{
completion::{mysql_database_completer, prefix_completer}, completion::{mysql_database_completer, prefix_completer},
database_privileges::DatabasePrivilegeRow, database_privileges::DatabasePrivilegeRow,
protocol::{ protocol::{
ClientToServerMessageStream, ListDatabasesRequest, ListPrivilegesError, Request, ClientToServerMessageStream, ListPrivilegesError, Request, Response,
Response, create_client_to_server_message_stream, create_client_to_server_message_stream,
}, },
types::MySQLDatabase, types::MySQLDatabase,
}, },
@@ -285,7 +285,7 @@ async fn show_databases(
args.name.iter().map(trim_db_name_to_32_chars).collect(); args.name.iter().map(trim_db_name_to_32_chars).collect();
let message = if database_names.is_empty() { let message = if database_names.is_empty() {
let message = Request::ListDatabases(ListDatabasesRequest::new(None, false)); let message = Request::ListDatabases(None);
server_connection.send(message).await?; server_connection.send(message).await?;
let response = server_connection.next().await; let response = server_connection.next().await;
let databases = match response { let databases = match response {
+3 -29
View File
@@ -1,6 +1,5 @@
use std::{ use std::{
fs, fs,
os::unix::fs::FileTypeExt,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::Duration, time::Duration,
@@ -8,10 +7,7 @@ use std::{
use anyhow::{Context, anyhow}; use anyhow::{Context, anyhow};
use clap_verbosity_flag::{InfoLevel, Verbosity}; use clap_verbosity_flag::{InfoLevel, Verbosity};
use nix::{ use nix::libc::{EXIT_SUCCESS, exit};
libc::{EXIT_SUCCESS, exit},
unistd::{AccessFlags, access},
};
use sqlx::mysql::MySqlPoolOptions; use sqlx::mysql::MySqlPoolOptions;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock}; use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
@@ -134,28 +130,11 @@ pub fn bootstrap_server_connection_and_drop_privileges(
} }
} }
fn socket_path_is_ok(path: &Path) -> anyhow::Result<()> {
fs::metadata(path)
.context(format!("Failed to get metadata for {:?}", path))
.and_then(|meta| {
if !meta.file_type().is_socket() {
anyhow::bail!("{:?} is not a unix socket", path);
}
access(path, AccessFlags::R_OK | AccessFlags::W_OK)
.with_context(|| format!("Socket at {:?} is not readable/writable", path))?;
Ok(())
})
}
fn connect_to_external_server( fn connect_to_external_server(
server_socket_path: Option<PathBuf>, server_socket_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> { ) -> anyhow::Result<StdUnixStream> {
// TODO: ensure this is both readable and writable
if let Some(socket_path) = server_socket_path { if let Some(socket_path) = server_socket_path {
tracing::trace!("Checking socket at {:?}", socket_path);
socket_path_is_ok(&socket_path)?;
tracing::debug!("Connecting to socket at {:?}", socket_path); tracing::debug!("Connecting to socket at {:?}", socket_path);
return match StdUnixStream::connect(socket_path) { return match StdUnixStream::connect(socket_path) {
Ok(socket) => Ok(socket), Ok(socket) => Ok(socket),
@@ -168,9 +147,6 @@ fn connect_to_external_server(
} }
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
tracing::trace!("Checking socket at {:?}", DEFAULT_SOCKET_PATH);
socket_path_is_ok(Path::new(DEFAULT_SOCKET_PATH))?;
tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH); tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) { return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
Ok(socket) => Ok(socket), Ok(socket) => Ok(socket),
@@ -182,9 +158,7 @@ fn connect_to_external_server(
}; };
} }
anyhow::bail!( anyhow::bail!("No socket path provided, and no default socket found");
"No socket path provided, and no socket found found at default location {DEFAULT_SOCKET_PATH}"
);
} }
// TODO: this function is security critical, it should be integration tested // TODO: this function is security critical, it should be integration tested
+1 -1
View File
@@ -29,7 +29,7 @@ pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
// doesn't have any natural implementation semantics. // doesn't have any natural implementation semantics.
/// Representation of the set of privileges for a single user on a single database. /// Representation of the set of privileges for a single user on a single database.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, Default)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
pub struct DatabasePrivilegeRow { pub struct DatabasePrivilegeRow {
// TODO: don't store the db and user here, let the type be stored in a mapping // TODO: don't store the db and user here, let the type be stored in a mapping
pub db: MySQLDatabase, pub db: MySQLDatabase,
+7 -14
View File
@@ -142,8 +142,7 @@ impl Request {
Request::ListDatabases(req) => format!( Request::ListDatabases(req) => format!(
"{}{}", "{}{}",
self.command_name(), self.command_name(),
req.names req.as_ref()
.as_ref()
.map_or("".to_string(), |r| format!("({})", r.len())) .map_or("".to_string(), |r| format!("({})", r.len()))
), ),
Request::ListPrivileges(req) => format!( Request::ListPrivileges(req) => format!(
@@ -207,12 +206,9 @@ impl Request {
Request::CompleteUserName(_) => Default::default(), Request::CompleteUserName(_) => Default::default(),
Request::CreateDatabases(databases) => databases.iter().cloned().collect(), Request::CreateDatabases(databases) => databases.iter().cloned().collect(),
Request::DropDatabases(databases) => databases.iter().cloned().collect(), Request::DropDatabases(databases) => databases.iter().cloned().collect(),
Request::ListDatabases(request) => request Request::ListDatabases(databases) => {
.names databases.clone().unwrap_or_default().into_iter().collect()
.clone() }
.unwrap_or_default()
.into_iter()
.collect(),
Request::ListPrivileges(databases) => { Request::ListPrivileges(databases) => {
databases.clone().unwrap_or_default().into_iter().collect() databases.clone().unwrap_or_default().into_iter().collect()
} }
@@ -328,12 +324,9 @@ impl Response {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
} }
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()), Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
Response::ModifyPrivileges(res) => ResponseOkStatus::from_counts( Response::ModifyPrivileges(res) => {
res.len(), ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
res.values() }
.map(|user_map| user_map.values().filter(|v| v.is_ok()).count())
.sum(),
),
Response::CreateUsers(res) => { Response::CreateUsers(res) => {
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
@@ -67,43 +67,3 @@ impl CheckAuthorizationError {
self.0.error_type() self.0.error_type()
} }
} }
#[cfg(test)]
mod tests {
use crate::core::protocol::request_validation::NameValidationError;
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CheckAuthorizationRequest = vec![
DbOrUser::Database("test_db".into()),
DbOrUser::User("test_user".into()),
];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CheckAuthorizationRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CheckAuthorizationResponse = BTreeMap::from([
(DbOrUser::Database("test_db".into()), Ok(())),
(
DbOrUser::User("test_user".into()),
Err(CheckAuthorizationError(
ValidationError::NameValidationError(NameValidationError::TooLong),
)),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CheckAuthorizationResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,41 +87,3 @@ impl CreateDatabaseError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateDatabasesRequest =
vec!["test_db1".into(), "test_db2".into(), "test_db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateDatabasesResponse = BTreeMap::from([
("test_db1".into(), Ok(())),
(
"test_db2".into(),
Err(CreateDatabaseError::DatabaseAlreadyExists),
),
(
"test_db3".into(),
Err(CreateDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -87,37 +87,3 @@ impl CreateUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: CreateUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: CreateUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: CreateUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(CreateUserError::UserAlreadyExists)),
(
"charlie".into(),
Err(CreateUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: CreateUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -90,37 +90,3 @@ impl DropDatabaseError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropDatabasesRequest = vec!["db1".into(), "db2".into(), "db3".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropDatabasesResponse = BTreeMap::from([
("db1".into(), Ok(())),
("db2".into(), Err(DropDatabaseError::DatabaseDoesNotExist)),
(
"db3".into(),
Err(DropDatabaseError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
-34
View File
@@ -87,37 +87,3 @@ impl DropUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: DropUsersRequest = vec!["alice".into(), "bob".into(), "charlie".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: DropUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: DropUsersResponse = BTreeMap::from([
("alice".into(), Ok(())),
("bob".into(), Err(DropUserError::UserDoesNotExist)),
(
"charlie".into(),
Err(DropUserError::MySqlError("Some MySQL error".into())),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: DropUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,36 +27,3 @@ impl ListAllDatabasesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllDatabasesResponse = Ok(vec![
DatabaseRow {
database: "db1".into(),
tables: vec!["table1".into(), "table2".into()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 1024,
},
DatabaseRow {
database: "db2".into(),
tables: vec!["table3".into(), "table4".into()],
users: vec!["user3".into(), "user4".into()],
collation: Some("utf8mb4_general_ci".into()),
character_set: Some("utf8mb4".into()),
size_bytes: 2048,
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,34 +27,3 @@ impl ListAllPrivilegesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllPrivilegesResponse = Ok(vec![
DatabasePrivilegeRow {
user: "user1".into(),
db: "db1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
},
DatabasePrivilegeRow {
user: "user2".into(),
db: "db2".into(),
select_priv: false,
insert_priv: true,
..Default::default()
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListAllPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
@@ -27,37 +27,3 @@ impl ListAllUsersError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_response() {
let response: ListAllUsersResponse = Ok(vec![
DatabaseUser {
user: "user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
},
DatabaseUser {
user: "user2".into(),
host: "%".into(),
has_password: false,
is_locked: true,
databases: vec!["db3".into()],
},
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListAllUsersResponse = serde_json::from_str(&json).unwrap();
deserialized.as_mut().unwrap()[0].host = "%".into();
deserialized.as_mut().unwrap()[1].host = "%".into();
assert_eq!(response, deserialized);
}
}
+2 -71
View File
@@ -14,21 +14,7 @@ use crate::{
server::sql::database_operations::DatabaseRow, server::sql::database_operations::DatabaseRow,
}; };
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub type ListDatabasesRequest = Option<Vec<MySQLDatabase>>;
pub struct ListDatabasesRequest {
pub names: Option<Vec<MySQLDatabase>>,
#[serde(default)]
pub include_all_tables_and_users: bool,
}
impl ListDatabasesRequest {
pub fn new(names: Option<Vec<MySQLDatabase>>, include_all_tables_and_users: bool) -> Self {
Self {
names,
include_all_tables_and_users,
}
}
}
pub type ListDatabasesResponse = BTreeMap<MySQLDatabase, Result<DatabaseRow, ListDatabasesError>>; pub type ListDatabasesResponse = BTreeMap<MySQLDatabase, Result<DatabaseRow, ListDatabasesError>>;
@@ -75,7 +61,7 @@ pub fn print_list_databases_output_status(
"Size" "Size"
} }
]); ]);
for db in final_database_list.iter().sorted_by_key(|db| &db.database) { for db in final_database_list {
table.add_row(row![ table.add_row(row![
db.database, db.database,
db.tables.join("\n"), db.tables.join("\n"),
@@ -151,58 +137,3 @@ impl ListDatabasesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request = ListDatabasesRequest::new(Some(vec!["db1".into(), "db2".into()]), true);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_deserialize_request_without_include_all_tables_and_users_defaults_to_false() {
let json = serde_json::json!({
"names": ["db1", "db2"]
})
.to_string();
let deserialized: ListDatabasesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(
deserialized,
ListDatabasesRequest::new(Some(vec!["db1".into(), "db2".into()]), false)
);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListDatabasesResponse = vec![
(
"db1".into(),
Ok(DatabaseRow {
database: "db1".into(),
tables: vec!["table1".to_string(), "table2".to_string()],
users: vec!["user1".into(), "user2".into()],
collation: Some("utf8mb4_general_ci".to_string()),
character_set: Some("utf8mb4".to_string()),
size_bytes: 1024,
}),
),
("db2".into(), Err(ListDatabasesError::DatabaseDoesNotExist)),
]
.into_iter()
.collect();
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListDatabasesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+18 -62
View File
@@ -64,28 +64,25 @@ pub fn print_list_privileges_output_status(output: &ListPrivilegesResponse, long
.collect(), .collect(),
)); ));
for row in final_privs_map for (_database, rows) in final_privs_map {
.values() for row in &rows {
.flatten() table.add_row(row![
.sorted_by_key(|row| (&row.db, &row.user)) row.db,
{ row.user,
table.add_row(row![ c->yn(row.select_priv),
row.db, c->yn(row.insert_priv),
row.user, c->yn(row.update_priv),
c->yn(row.select_priv), c->yn(row.delete_priv),
c->yn(row.insert_priv), c->yn(row.create_priv),
c->yn(row.update_priv), c->yn(row.drop_priv),
c->yn(row.delete_priv), c->yn(row.alter_priv),
c->yn(row.create_priv), c->yn(row.index_priv),
c->yn(row.drop_priv), c->yn(row.create_tmp_table_priv),
c->yn(row.alter_priv), c->yn(row.lock_tables_priv),
c->yn(row.index_priv), c->yn(row.references_priv),
c->yn(row.create_tmp_table_priv), ]);
c->yn(row.lock_tables_priv), }
c->yn(row.references_priv),
]);
} }
// }
table.printstd(); table.printstd();
} }
@@ -156,44 +153,3 @@ impl ListPrivilegesError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListPrivilegesRequest = Some(vec!["test_db1".into(), "test_db2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ListPrivilegesResponse = BTreeMap::from([
(
"test_db1".into(),
Ok(vec![DatabasePrivilegeRow {
db: "test_db1".into(),
user: "user1".into(),
select_priv: true,
insert_priv: false,
..Default::default()
}]),
),
(
"test_db2".into(),
Err(ListPrivilegesError::DatabaseDoesNotExist),
),
]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ListPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
+1 -47
View File
@@ -1,6 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use itertools::Itertools;
use prettytable::Table; use prettytable::Table;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
@@ -52,7 +51,7 @@ pub fn print_list_users_output_status(output: &ListUsersResponse) {
"Locked", "Locked",
"Databases where user has privileges" "Databases where user has privileges"
]); ]);
for user in final_user_list.iter().sorted_by_key(|user| &user.user) { for user in final_user_list {
table.add_row(row![ table.add_row(row![
user.user, user.user,
user.has_password, user.has_password,
@@ -122,48 +121,3 @@ impl ListUsersError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: ListUsersRequest = Some(vec!["test_user1".into(), "test_user2".into()]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ListUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: ListUsersResponse = BTreeMap::from([
(
"test_user1".into(),
Ok(DatabaseUser {
user: "test_user1".into(),
host: "%".into(),
has_password: true,
is_locked: false,
databases: vec!["db1".into(), "db2".into()],
}),
),
("test_user2".into(), Err(ListUsersError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let mut deserialized: ListUsersResponse = serde_json::from_str(&json).unwrap();
deserialized
.get_mut(&"test_user1".into())
.unwrap()
.as_mut()
.unwrap()
.host = "%".into();
assert_eq!(response_ok, deserialized);
}
}
-30
View File
@@ -94,33 +94,3 @@ impl LockUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: LockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: LockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: LockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(LockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: LockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
@@ -12,7 +12,7 @@ use crate::core::{
pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>; pub type ModifyPrivilegesRequest = BTreeSet<DatabasePrivilegesDiff>;
pub type ModifyPrivilegesResponse = pub type ModifyPrivilegesResponse =
BTreeMap<MySQLDatabase, BTreeMap<MySQLUser, Result<(), ModifyDatabasePrivilegesError>>>; BTreeMap<(MySQLDatabase, MySQLUser), Result<(), ModifyDatabasePrivilegesError>>;
#[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Error, Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModifyDatabasePrivilegesError { pub enum ModifyDatabasePrivilegesError {
@@ -49,11 +49,7 @@ pub enum DiffDoesNotApplyError {
} }
pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) { pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesResponse) {
for ((database_name, username), result) in output.iter().flat_map(|(db, user_map)| { for ((database_name, username), result) in output {
user_map
.iter()
.map(move |(user, result)| ((db, user), result))
}) {
match result { match result {
Ok(()) => { Ok(()) => {
println!( println!(
@@ -148,46 +144,3 @@ impl DiffDoesNotApplyError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::core::*;
#[test]
fn test_serialize_deserialize_request() {
let request =
BTreeSet::from([DatabasePrivilegesDiff::Modified(DatabasePrivilegeRowDiff {
db: "test_db".into(),
user: "test_user".into(),
select_priv: Some(database_privileges::DatabasePrivilegeChange::NoToYes),
..Default::default()
})]);
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: ModifyPrivilegesRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response: ModifyPrivilegesResponse = BTreeMap::from([(
"test_db".into(),
BTreeMap::from([
("test_user".into(), Ok(())),
(
"invalid_user".into(),
Err(ModifyDatabasePrivilegesError::UserDoesNotExist),
),
]),
)]);
let json = serde_json::to_string_pretty(&response).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: ModifyPrivilegesResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response, deserialized);
}
}
-32
View File
@@ -60,35 +60,3 @@ impl SetPasswordError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: SetUserPasswordRequest = ("test_user".into(), "new_password".into());
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: SetUserPasswordRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: SetUserPasswordResponse = Ok(());
let response_err: SetUserPasswordResponse = Err(SetPasswordError::UserDoesNotExist);
let json_ok = serde_json::to_string_pretty(&response_ok).unwrap();
let json_err = serde_json::to_string_pretty(&response_err).unwrap();
println!("Serialized OK response:\n{}", json_ok);
println!("Serialized Error response:\n{}", json_err);
let deserialized_ok: SetUserPasswordResponse = serde_json::from_str(&json_ok).unwrap();
let deserialized_err: SetUserPasswordResponse = serde_json::from_str(&json_err).unwrap();
assert_eq!(response_ok, deserialized_ok);
assert_eq!(response_err, deserialized_err);
}
}
@@ -94,33 +94,3 @@ impl UnlockUserError {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize_deserialize_request() {
let request: UnlockUsersRequest = vec!["test_user1".into(), "test_user2".into()];
let json = serde_json::to_string_pretty(&request).unwrap();
println!("Serialized request:\n{}", json);
let deserialized: UnlockUsersRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request, deserialized);
}
#[test]
fn test_serialize_deserialize_response() {
let response_ok: UnlockUsersResponse = BTreeMap::from([
("test_user1".into(), Ok(())),
("test_user2".into(), Err(UnlockUserError::UserDoesNotExist)),
]);
let json = serde_json::to_string_pretty(&response_ok).unwrap();
println!("Serialized response:\n{}", json);
let deserialized: UnlockUsersResponse = serde_json::from_str(&json).unwrap();
assert_eq!(response_ok, deserialized);
}
}
+3 -34
View File
@@ -34,7 +34,7 @@ impl DerefMut for MySQLUser {
impl fmt::Display for MySQLUser { impl fmt::Display for MySQLUser {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f) write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0))
} }
} }
@@ -83,7 +83,7 @@ impl DerefMut for MySQLDatabase {
impl fmt::Display for MySQLDatabase { impl fmt::Display for MySQLDatabase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f) write!(f, "{:<width$}", self.0, width = f.width().unwrap_or(0))
} }
} }
@@ -105,43 +105,12 @@ impl From<MySQLDatabase> for OsString {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum DbOrUser { pub enum DbOrUser {
Database(MySQLDatabase), Database(MySQLDatabase),
User(MySQLUser), User(MySQLUser),
} }
impl Serialize for DbOrUser {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
DbOrUser::Database(db) => ("d:".to_string() + &db.to_string()).serialize(serializer),
DbOrUser::User(user) => ("u:".to_string() + &user.to_string()).serialize(serializer),
}
}
}
impl<'de> Deserialize<'de> for DbOrUser {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Some(rest) = s.strip_prefix("d:") {
Ok(DbOrUser::Database(MySQLDatabase(rest.to_string())))
} else if let Some(rest) = s.strip_prefix("u:") {
Ok(DbOrUser::User(MySQLUser(rest.to_string())))
} else {
Err(serde::de::Error::custom(format!(
"Invalid DbOrUser format: {}",
s
)))
}
}
}
impl DbOrUser { impl DbOrUser {
#[must_use] #[must_use]
pub fn lowercased_noun(&self) -> &'static str { pub fn lowercased_noun(&self) -> &'static str {
-4
View File
@@ -59,10 +59,6 @@ fn parse_group_denylist(denylist_path: &Path, lines: Lines) -> GroupDenylist {
} }
.trim(); .trim();
if trimmed_line.is_empty() {
continue;
}
let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect(); let parts: Vec<&str> = trimmed_line.splitn(2, ':').collect();
if parts.len() != 2 { if parts.len() != 2 {
tracing::warn!( tracing::warn!(
+2 -4
View File
@@ -332,15 +332,14 @@ async fn handle_request(
.await; .await;
Response::DropDatabases(result) Response::DropDatabases(result)
} }
Request::ListDatabases(ref request) => { Request::ListDatabases(ref database_names) => {
if let Some(database_names) = &request.names { if let Some(database_names) = database_names {
let result = list_databases( let result = list_databases(
database_names, database_names,
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist, group_denylist,
request.include_all_tables_and_users,
) )
.await; .await;
Response::ListDatabases(result) Response::ListDatabases(result)
@@ -350,7 +349,6 @@ async fn handle_request(
db_connection, db_connection,
db_is_mariadb, db_is_mariadb,
group_denylist, group_denylist,
request.include_all_tables_and_users,
) )
.await; .await;
Response::ListAllDatabases(result) Response::ListAllDatabases(result)
+66 -190
View File
@@ -1,6 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use sqlx::AssertSqlSafe;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use sqlx::prelude::*; use sqlx::prelude::*;
@@ -24,8 +23,6 @@ use crate::{
server::{common::create_user_group_matching_regex, sql::quote_identifier}, server::{common::create_user_group_matching_regex, sql::quote_identifier},
}; };
const MAX_SHOW_DB_RELATED_ITEMS: usize = 5;
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
pub(super) async fn unsafe_database_exists( pub(super) async fn unsafe_database_exists(
database_name: &str, database_name: &str,
@@ -128,15 +125,12 @@ pub async fn create_databases(
_ => {} _ => {}
} }
let statement = AssertSqlSafe(format!( let result =
"CREATE DATABASE {}", sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
quote_identifier(&database_name) .execute(&mut *connection)
)); .await
let result = sqlx::query(statement) .map(|_| ())
.execute(&mut *connection) .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
.await
.map(|_| ())
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to create database '{}': {:?}", &database_name, err); tracing::error!("Failed to create database '{}': {:?}", &database_name, err);
@@ -187,15 +181,12 @@ pub async fn drop_databases(
_ => {} _ => {}
} }
let statement = AssertSqlSafe(format!( let result =
"DROP DATABASE {}", sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
quote_identifier(&database_name) .execute(&mut *connection)
)); .await
let result = sqlx::query(statement) .map(|_| ())
.execute(&mut *connection) .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
.await
.map(|_| ())
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to drop database '{}': {:?}", &database_name, err); tracing::error!("Failed to drop database '{}': {:?}", &database_name, err);
@@ -250,84 +241,12 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
} }
} }
fn list_database_query(include_all_tables_and_users: bool) -> AssertSqlSafe<String> {
let limit_clause = if include_all_tables_and_users {
"".to_string()
} else {
format!(" LIMIT {}", MAX_SHOW_DB_RELATED_ITEMS)
};
AssertSqlSafe(format!(
r"
SELECT
CAST(s.SCHEMA_NAME AS CHAR(64)) AS `database`,
t.tables,
u.users,
s.DEFAULT_COLLATION_NAME AS `collation`,
s.DEFAULT_CHARACTER_SET_NAME AS `character_set`,
CAST(COALESCE(sz.size_bytes, 0) AS UNSIGNED) AS size_bytes
FROM information_schema.SCHEMATA s
LEFT JOIN (
SELECT
x.TABLE_SCHEMA,
GROUP_CONCAT(x.TABLE_NAME ORDER BY x.TABLE_NAME SEPARATOR ',') AS tables
FROM (
SELECT
TABLE_SCHEMA,
TABLE_NAME
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = ?
ORDER BY TABLE_NAME{limit_clause}
) x
GROUP BY x.TABLE_SCHEMA
) t
ON t.TABLE_SCHEMA = s.SCHEMA_NAME
LEFT JOIN (
SELECT
x.DB,
GROUP_CONCAT(DISTINCT x.User ORDER BY x.User SEPARATOR ',') AS users
FROM (
SELECT
DB,
User
FROM mysql.db
WHERE DB = ?
ORDER BY User{limit_clause}
) x
GROUP BY x.DB
) u
ON u.DB = s.SCHEMA_NAME
LEFT JOIN (
SELECT
TABLE_SCHEMA,
SUM(DATA_LENGTH + INDEX_LENGTH) AS size_bytes
FROM information_schema.TABLES
WHERE TABLE_SCHEMA = ?
GROUP BY TABLE_SCHEMA
) sz
ON sz.TABLE_SCHEMA = s.SCHEMA_NAME
WHERE s.SCHEMA_NAME REGEXP ?
AND s.SCHEMA_NAME NOT IN (
'information_schema',
'performance_schema',
'mysql',
'sys'
)
"
))
}
pub async fn list_databases( pub async fn list_databases(
database_names: &[MySQLDatabase], database_names: &[MySQLDatabase],
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist, group_denylist: &GroupDenylist,
include_all_tables_and_users: bool,
) -> ListDatabasesResponse { ) -> ListDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -343,19 +262,35 @@ pub async fn list_databases(
continue; continue;
} }
let query = list_database_query(include_all_tables_and_users); let result = sqlx::query_as::<_, DatabaseRow>(
r"
SELECT
CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`,
GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`,
GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`,
MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`,
MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`,
CAST(IFNULL(
SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`),
0
) AS UNSIGNED INTEGER) AS `size_bytes`
FROM `information_schema`.`SCHEMATA`
LEFT OUTER JOIN `information_schema`.`TABLES`
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA`
LEFT OUTER JOIN `mysql`.`db`
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB`
WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = ?
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
",
let result = sqlx::query_as::<_, DatabaseRow>(query) )
.bind(database_name.to_string()) .bind(database_name.to_string())
.bind(database_name.to_string()) .fetch_optional(&mut *connection)
.bind(database_name.to_string()) .await
.bind(database_name.to_string()) .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
.fetch_optional(&mut *connection) .and_then(|database| {
.await database.map_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist), Ok)
.map_err(|err| ListDatabasesError::MySqlError(err.to_string())) });
.and_then(|database| {
database.map_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist), Ok)
});
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to list database '{}': {:?}", &database_name, err); tracing::error!("Failed to list database '{}': {:?}", &database_name, err);
@@ -369,97 +304,38 @@ pub async fn list_databases(
results results
} }
fn list_all_databases_for_user_query(include_all_tables_and_users: bool) -> AssertSqlSafe<String> {
let limit_clause = if include_all_tables_and_users {
"".to_string()
} else {
format!(" LIMIT {}", MAX_SHOW_DB_RELATED_ITEMS)
};
AssertSqlSafe(format!(
r"
SELECT
CAST(s.SCHEMA_NAME AS CHAR(64)) AS `database`,
t.tables,
u.users,
s.DEFAULT_COLLATION_NAME AS `collation`,
s.DEFAULT_CHARACTER_SET_NAME AS `character_set`,
CAST(COALESCE(sz.size_bytes, 0) AS UNSIGNED) AS size_bytes
FROM information_schema.SCHEMATA s
LEFT JOIN (
SELECT
x.TABLE_SCHEMA,
GROUP_CONCAT(x.TABLE_NAME ORDER BY x.TABLE_NAME SEPARATOR ',') AS tables
FROM (
SELECT
TABLE_SCHEMA,
TABLE_NAME
FROM information_schema.TABLES
WHERE TABLE_SCHEMA REGEXP ?
ORDER BY TABLE_NAME{limit_clause}
) x
GROUP BY x.TABLE_SCHEMA
) t
ON t.TABLE_SCHEMA = s.SCHEMA_NAME
LEFT JOIN (
SELECT
x.DB,
GROUP_CONCAT(DISTINCT x.User ORDER BY x.User SEPARATOR ',') AS users
FROM (
SELECT
DB,
User
FROM mysql.db
WHERE DB REGEXP ?
ORDER BY User{limit_clause}
) x
GROUP BY x.DB
) u
ON u.DB = s.SCHEMA_NAME
LEFT JOIN (
SELECT
TABLE_SCHEMA,
SUM(DATA_LENGTH + INDEX_LENGTH) AS size_bytes
FROM information_schema.TABLES
WHERE TABLE_SCHEMA REGEXP ?
GROUP BY TABLE_SCHEMA
) sz
ON sz.TABLE_SCHEMA = s.SCHEMA_NAME
WHERE s.SCHEMA_NAME REGEXP ?
AND s.SCHEMA_NAME NOT IN (
'information_schema',
'performance_schema',
'mysql',
'sys'
)
ORDER BY s.SCHEMA_NAME
"
))
}
pub async fn list_all_databases_for_user( pub async fn list_all_databases_for_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist, group_denylist: &GroupDenylist,
include_all_tables_and_users: bool,
) -> ListAllDatabasesResponse { ) -> ListAllDatabasesResponse {
let query = list_all_databases_for_user_query(include_all_tables_and_users); let result = sqlx::query_as::<_, DatabaseRow>(
let user_group_regex = create_user_group_matching_regex(unix_user, group_denylist); r"
SELECT
let result = sqlx::query_as::<_, DatabaseRow>(query) CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`,
.bind(&user_group_regex) GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`,
.bind(&user_group_regex) GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`,
.bind(&user_group_regex) MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`,
.bind(&user_group_regex) MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`,
.fetch_all(connection) CAST(IFNULL(
.await SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`),
.map_err(|err| ListAllDatabasesError::MySqlError(err.to_string())); 0
) AS UNSIGNED INTEGER) AS `size_bytes`
FROM `information_schema`.`SCHEMATA`
LEFT OUTER JOIN `information_schema`.`TABLES`
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA`
LEFT OUTER JOIN `mysql`.`db`
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB`
WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `information_schema`.`SCHEMATA`.`SCHEMA_NAME` REGEXP ?
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
",
)
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection)
.await
.map_err(|err| ListAllDatabasesError::MySqlError(err.to_string()));
// TODO: should we assert that the users are also owned by the unix_user from the request? // TODO: should we assert that the users are also owned by the unix_user from the request?
+21 -46
View File
@@ -18,7 +18,7 @@ use std::collections::{BTreeMap, BTreeSet};
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use sqlx::{AssertSqlSafe, MySqlConnection, mysql::MySqlRow, prelude::*}; use sqlx::{MySqlConnection, mysql::MySqlRow, prelude::*};
use crate::{ use crate::{
core::{ core::{
@@ -84,17 +84,16 @@ async fn unsafe_get_database_privileges(
database_name: &str, database_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
let statement = AssertSqlSafe(format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `Db` = ?", "SELECT {} FROM `db` WHERE `Db` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)); ))
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement) .bind(database_name)
.bind(database_name) .fetch_all(connection)
.fetch_all(connection) .await;
.await;
if let Err(e) = &result { if let Err(e) = &result {
tracing::error!( tracing::error!(
@@ -114,18 +113,17 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
user_name: &MySQLUser, user_name: &MySQLUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let statement = AssertSqlSafe(format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'", "SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)); ))
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement) .bind(database_name.as_str())
.bind(database_name.as_str()) .bind(user_name.as_str())
.bind(user_name.as_str()) .fetch_optional(connection)
.fetch_optional(connection) .await;
.await;
if let Err(e) = &result { if let Err(e) = &result {
tracing::error!( tracing::error!(
@@ -191,8 +189,8 @@ pub async fn get_databases_privilege_data(
} }
/// TODO: make this constant /// TODO: make this constant
fn get_all_db_privs_query() -> AssertSqlSafe<String> { fn get_all_db_privs_query() -> String {
AssertSqlSafe(format!( format!(
indoc! {r" indoc! {r"
SELECT {} FROM `db` WHERE `db` IN SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database` (SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
@@ -204,7 +202,7 @@ fn get_all_db_privs_query() -> AssertSqlSafe<String> {
.iter() .iter()
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","), .join(","),
)) )
} }
/// Get all database + user + privileges pairs that are owned by the current user. /// Get all database + user + privileges pairs that are owned by the current user.
@@ -214,7 +212,7 @@ pub async fn get_all_database_privileges(
_db_is_mariadb: bool, _db_is_mariadb: bool,
group_denylist: &GroupDenylist, group_denylist: &GroupDenylist,
) -> ListAllPrivilegesResponse { ) -> ListAllPrivilegesResponse {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(get_all_db_privs_query()) let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query())
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection) .fetch_all(connection)
.await .await
@@ -236,17 +234,13 @@ async fn unsafe_apply_privilege_diff(
DatabasePrivilegesDiff::New(p) => { DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS let tables = DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
.chain(&["Host"])
.map(|field| quote_identifier(field)) .map(|field| quote_identifier(field))
.join(","); .join(",");
let question_marks = let question_marks =
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(","); std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len()).join(",");
let statement = AssertSqlSafe(format!( sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str())
"INSERT INTO `db` ({tables}) VALUES ({question_marks})"
));
sqlx::query(statement)
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind(yn(p.select_priv)) .bind(yn(p.select_priv))
@@ -260,7 +254,6 @@ async fn unsafe_apply_privilege_diff(
.bind(yn(p.create_tmp_table_priv)) .bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv)) .bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv)) .bind(yn(p.references_priv))
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -285,10 +278,7 @@ async fn unsafe_apply_privilege_diff(
} }
} }
let statement = AssertSqlSafe(format!( sqlx::query(format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ?").as_str())
"UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?"
));
sqlx::query(statement)
.bind(p.select_priv.map(change_to_yn)) .bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn)) .bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn)) .bind(p.update_priv.map(change_to_yn))
@@ -302,16 +292,14 @@ async fn unsafe_apply_privilege_diff(
.bind(p.references_priv.map(change_to_yn)) .bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
} }
DatabasePrivilegesDiff::Deleted(p) => { DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?") sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ?")
.bind(p.db.to_string()) .bind(p.db.to_string())
.bind(p.user.to_string()) .bind(p.user.to_string())
.bind("%")
.execute(connection) .execute(connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -492,18 +480,5 @@ pub async fn apply_privilege_diffs(
results.insert(key, result); results.insert(key, result);
} }
if let Err(err) = connection.execute("FLUSH PRIVILEGES").await {
tracing::error!("Failed to flush privileges: {}", err);
}
results results
.into_iter()
.map(|((k1, k2), v)| (k1, (k2, v)))
.into_group_map()
.into_iter()
.map(|(k1, pairs)| {
let inner = pairs.into_iter().collect::<BTreeMap<_, _>>();
(k1, inner)
})
.collect()
} }
+60 -67
View File
@@ -1,6 +1,5 @@
use indoc::formatdoc; use indoc::formatdoc;
use itertools::Itertools; use itertools::Itertools;
use sqlx::AssertSqlSafe;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -40,7 +39,6 @@ pub(super) async fn unsafe_user_exists(
SELECT 1 SELECT 1
FROM `mysql`.`user` FROM `mysql`.`user`
WHERE `User` = ? WHERE `User` = ?
AND `Host` = '%'
) )
", ",
) )
@@ -69,7 +67,6 @@ pub async fn complete_user_name(
FROM `mysql`.`user` FROM `mysql`.`user`
WHERE `User` REGEXP ? WHERE `User` REGEXP ?
AND `User` LIKE ? AND `User` LIKE ?
AND `Host` = '%'
", ",
) )
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(create_user_group_matching_regex(unix_user, group_denylist))
@@ -127,8 +124,7 @@ pub async fn create_database_users(
_ => {} _ => {}
} }
let statement = AssertSqlSafe(format!("CREATE USER {}@'%'", quote_literal(&db_user),)); let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str())
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -174,8 +170,7 @@ pub async fn drop_database_users(
_ => {} _ => {}
} }
let statement = AssertSqlSafe(format!("DROP USER {}@'%'", quote_literal(&db_user),)); let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str())
let result = sqlx::query(statement)
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
@@ -208,16 +203,18 @@ pub async fn set_password_for_database_user(
_ => {} _ => {}
} }
let statement = AssertSqlSafe(format!( let result = sqlx::query(
"ALTER USER {}@'%' IDENTIFIED BY {}", format!(
quote_literal(db_user), "ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(password).as_str(), quote_literal(db_user),
)); quote_literal(password).as_str(),
let result = sqlx::query(statement) )
.execute(&mut *connection) .as_str(),
.await )
.map(|_| ()) .execute(&mut *connection)
.map_err(|err| SetPasswordError::MySqlError(err.to_string())); .await
.map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string()));
if result.is_err() { if result.is_err() {
tracing::error!( tracing::error!(
@@ -316,15 +313,13 @@ pub async fn lock_database_users(
} }
} }
let statement = AssertSqlSafe(format!( let result = sqlx::query(
"ALTER USER {}@'%' ACCOUNT LOCK", format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(),
quote_literal(&db_user), )
)); .execute(&mut *connection)
let result = sqlx::query(statement) .await
.execute(&mut *connection) .map(|_| ())
.await .map_err(|err| LockUserError::MySqlError(err.to_string()));
.map(|_| ())
.map_err(|err| LockUserError::MySqlError(err.to_string()));
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to lock database user '{}': {:?}", &db_user, err); tracing::error!("Failed to lock database user '{}': {:?}", &db_user, err);
@@ -378,15 +373,13 @@ pub async fn unlock_database_users(
_ => {} _ => {}
} }
let statement = AssertSqlSafe(format!( let result = sqlx::query(
"ALTER USER {}@'%' ACCOUNT UNLOCK", format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(),
quote_literal(&db_user), )
)); .execute(&mut *connection)
let result = sqlx::query(statement) .await
.execute(&mut *connection) .map(|_| ())
.await .map_err(|err| UnlockUserError::MySqlError(err.to_string()));
.map(|_| ())
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to unlock database user '{}': {:?}", &db_user, err); tracing::error!("Failed to unlock database user '{}': {:?}", &db_user, err);
@@ -464,17 +457,16 @@ pub async fn list_database_users(
continue; continue;
} }
let statement = AssertSqlSafe( let mut result = sqlx::query_as::<_, DatabaseUser>(
if db_is_mariadb { &(if db_is_mariadb {
DB_USER_SELECT_STATEMENT_MARIADB.to_string() DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else { } else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'", } + "WHERE `mysql`.`user`.`User` = ?"),
); )
let mut result = sqlx::query_as::<_, DatabaseUser>(statement) .bind(db_user.as_str())
.bind(db_user.as_str()) .fetch_optional(&mut *connection)
.fetch_optional(&mut *connection) .await;
.await;
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to list database user '{}': {:?}", &db_user, err); tracing::error!("Failed to list database user '{}': {:?}", &db_user, err);
@@ -502,18 +494,17 @@ pub async fn list_all_database_users_for_unix_user(
db_is_mariadb: bool, db_is_mariadb: bool,
group_denylist: &GroupDenylist, group_denylist: &GroupDenylist,
) -> ListAllUsersResponse { ) -> ListAllUsersResponse {
let statement = AssertSqlSafe( let mut result = sqlx::query_as::<_, DatabaseUser>(
if db_is_mariadb { &(if db_is_mariadb {
DB_USER_SELECT_STATEMENT_MARIADB.to_string() DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else { } else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string() DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'", } + "WHERE `user`.`User` REGEXP ?"),
); )
let mut result = sqlx::query_as::<_, DatabaseUser>(statement) .bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(create_user_group_matching_regex(unix_user, group_denylist)) .fetch_all(&mut *connection)
.fetch_all(&mut *connection) .await
.await .map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!("Failed to list all database users: {:?}", err); tracing::error!("Failed to list all database users: {:?}", err);
@@ -538,21 +529,23 @@ pub async fn set_databases_where_user_has_privileges(
db_user: &mut DatabaseUser, db_user: &mut DatabaseUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<(), sqlx::Error> { ) -> Result<(), sqlx::Error> {
let statement = AssertSqlSafe(formatdoc!( let database_list = sqlx::query(
r" formatdoc!(
SELECT `Db` AS `database` r"
FROM `db` SELECT `Db` AS `database`
WHERE `User` = ? AND `Host` = '%' AND ({}) FROM `db`
", WHERE `User` = ? AND ({})
DATABASE_PRIVILEGE_FIELDS ",
.iter() DATABASE_PRIVILEGE_FIELDS
.map(|field| format!("`{field}` = 'Y'")) .iter()
.join(" OR "), .map(|field| format!("`{field}` = 'Y'"))
)); .join(" OR "),
let database_list = sqlx::query(statement) )
.bind(db_user.user.as_str()) .as_str(),
.fetch_all(&mut *connection) )
.await; .bind(db_user.user.as_str())
.fetch_all(&mut *connection)
.await;
if let Err(err) = &database_list { if let Err(err) = &database_list {
tracing::error!( tracing::error!(
+20 -13
View File
@@ -90,12 +90,14 @@ impl Supervisor {
}; };
let mut watchdog_duration = None; let mut watchdog_duration = None;
let mut watchdog_micro_seconds = 0;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let watchdog_task = let watchdog_task =
if systemd_mode && let Some(watchdog_duration_) = sd_notify::watchdog_enabled() { if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
let watchdog_duration_ = Duration::from_micros(watchdog_micro_seconds);
tracing::debug!( tracing::debug!(
"Systemd watchdog enabled with {} millisecond interval", "Systemd watchdog enabled with {} millisecond interval",
watchdog_duration_.as_millis() watchdog_micro_seconds.div_ceil(1000),
); );
watchdog_duration = Some(watchdog_duration_); watchdog_duration = Some(watchdog_duration_);
Some(spawn_watchdog_task(watchdog_duration_)) Some(spawn_watchdog_task(watchdog_duration_))
@@ -293,12 +295,15 @@ impl Supervisor {
pub async fn reload(&self) -> anyhow::Result<()> { pub async fn reload(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[ sd_notify::notify(
sd_notify::NotifyState::Reloading, false,
sd_notify::NotifyState::monotonic_usec_now() &[
.expect("Failed to get monotonic time to send to systemd while reloading"), sd_notify::NotifyState::Reloading,
sd_notify::NotifyState::Status("Reloading configuration"), sd_notify::NotifyState::monotonic_usec_now()
])?; .expect("Failed to get monotonic time to send to systemd while reloading"),
sd_notify::NotifyState::Status("Reloading configuration"),
],
)?;
let previous_config = self.config.lock().await.clone(); let previous_config = self.config.lock().await.clone();
self.reload_config().await?; self.reload_config().await?;
@@ -335,14 +340,14 @@ impl Supervisor {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
Ok(()) Ok(())
} }
pub async fn shutdown(&self) -> anyhow::Result<()> { pub async fn shutdown(&self) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[sd_notify::NotifyState::Stopping])?; sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?;
tracing::debug!("Stop accepting new connections"); tracing::debug!("Stop accepting new connections");
self.stop_receiving_new_connections()?; self.stop_receiving_new_connections()?;
@@ -412,7 +417,7 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
); );
loop { loop {
interval.tick().await; interval.tick().await;
if let Err(err) = sd_notify::notify(&[sd_notify::NotifyState::Watchdog]) { if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
tracing::warn!("Failed to notify systemd watchdog: {}", err); tracing::warn!("Failed to notify systemd watchdog: {}", err);
} }
} }
@@ -435,7 +440,9 @@ fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
"Waiting for connections".to_string() "Waiting for connections".to_string()
}; };
if let Err(e) = sd_notify::notify(&[sd_notify::NotifyState::Status(message.as_str())]) { if let Err(e) =
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())])
{
tracing::warn!("Failed to send systemd status notification: {}", e); tracing::warn!("Failed to send systemd status notification: {}", e);
} }
} }
@@ -550,7 +557,7 @@ async fn listener_task(
group_denylist: Arc<RwLock<GroupDenylist>>, group_denylist: Arc<RwLock<GroupDenylist>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
sd_notify::notify(&[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
let connection_counter = AtomicU64::new(0); let connection_counter = AtomicU64::new(0);