From f348e67622f67599f5f3be5ce226467222c63d12 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Mon, 1 Dec 2025 17:26:17 +0900 Subject: [PATCH] Add dynamic completion for users and databases --- Cargo.lock | 12 +++ Cargo.toml | 2 +- nix/default.nix | 10 ++- src/client/commands/drop_db.rs | 4 +- src/client/commands/drop_user.rs | 4 +- src/client/commands/edit_privs.rs | 3 + src/client/commands/lock_user.rs | 4 +- src/client/commands/passwd_user.rs | 3 + src/client/commands/show_db.rs | 4 +- src/client/commands/show_privs.rs | 4 +- src/client/commands/show_user.rs | 4 +- src/client/commands/unlock_user.rs | 4 +- src/core.rs | 1 + src/core/bootstrap.rs | 2 +- src/core/completion.rs | 5 ++ .../completion/mysql_database_completer.rs | 73 +++++++++++++++++++ src/core/completion/mysql_user_completer.rs | 73 +++++++++++++++++++ src/core/protocol/commands.rs | 10 +++ .../commands/complete_database_name.rs | 5 ++ .../protocol/commands/complete_user_name.rs | 5 ++ src/core/protocol/commands/user_exists.rs | 0 src/core/types.rs | 13 ++++ src/main.rs | 41 ++++++++++- src/server/session_handler.rs | 36 ++++++++- src/server/sql/database_operations.rs | 40 ++++++++++ src/server/sql/user_operations.rs | 38 ++++++++++ 26 files changed, 383 insertions(+), 17 deletions(-) create mode 100644 src/core/completion.rs create mode 100644 src/core/completion/mysql_database_completer.rs create mode 100644 src/core/completion/mysql_user_completer.rs create mode 100644 src/core/protocol/commands/complete_database_name.rs create mode 100644 src/core/protocol/commands/complete_user_name.rs create mode 100644 src/core/protocol/commands/user_exists.rs diff --git a/Cargo.lock b/Cargo.lock index c54dc8a..f4b658b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -242,6 +242,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39615915e2ece2550c0149addac32fb5bd312c657f43845bb9088cb9c8a7c992" dependencies = [ "clap", + "clap_lex", + "is_executable", + "shlex", ] [[package]] @@ -923,6 +926,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "is_executable" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baabb8b4867b26294d818bf3f651a454b6901431711abb96e296245888d6e8c4" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" diff --git a/Cargo.toml b/Cargo.toml index 26e91f9..2de5619 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ async-bincode = "0.8.0" bincode = "2.0.1" clap = { version = "4.5.53", features = ["derive"] } clap-verbosity-flag = { version = "3.0.4", features = [ "tracing" ] } -clap_complete = "4.5.61" +clap_complete = { version = "4.5.61", features = ["unstable-dynamic"] } derive_more = { version = "2.0.1", features = ["display", "error"] } dialoguer = "0.12.0" futures-util = "0.3.31" diff --git a/nix/default.nix b/nix/default.nix index 4f43563..6edceb8 100644 --- a/nix/default.nix +++ b/nix/default.nix @@ -24,14 +24,18 @@ buildFunction { nativeBuildInputs = [ installShellFiles ]; postInstall = let + # "$out/bin/${mainProgram}" generate-completions --shell "${shell}" --command "${command}" > "$TMP/muscl.${shell}" commands = lib.mapCartesianProduct ({ shell, command }: '' - "$out/bin/${mainProgram}" generate-completions --shell "${shell}" --command "${command}" > "$TMP/muscl.${shell}" - installShellCompletion "--${shell}" --cmd "${command}" "$TMP/muscl.${shell}" + COMPLETE=${shell} "$out/bin/${command}" > "$TMP/${command}.${shell}" + installShellCompletion "--${shell}" --cmd "${command}" "$TMP/${command}.${shell}" '') { shell = [ "bash" "zsh" "fish" ]; command = [ "muscl" "mysql-dbadm" "mysql-useradm" ]; }; - in lib.concatStringsSep "\n" commands + '' + in '' + ln -sr "$out/bin/muscl" "$out/bin/mysql-dbadm" + ln -sr "$out/bin/muscl" "$out/bin/mysql-useradm" + '' + lib.concatStringsSep "\n" commands + '' install -Dm444 assets/systemd/muscl.socket -t "$out/lib/systemd/system" install -Dm644 assets/systemd/muscl.service -t "$out/lib/systemd/system" substituteInPlace "$out/lib/systemd/system/muscl.service" \ diff --git a/src/client/commands/drop_db.rs b/src/client/commands/drop_db.rs index ae64afc..d195389 100644 --- a/src/client/commands/drop_db.rs +++ b/src/client/commands/drop_db.rs @@ -1,10 +1,12 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_database_completer, protocol::{ ClientToServerMessageStream, Request, Response, print_drop_databases_output_status, print_drop_databases_output_status_json, @@ -16,7 +18,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct DropDbArgs { /// The MySQL database(s) to drop - #[arg(num_args = 1..)] + #[arg(num_args = 1.., add = ArgValueCompleter::new(mysql_database_completer))] name: Vec, /// Print the information as JSON diff --git a/src/client/commands/drop_user.rs b/src/client/commands/drop_user.rs index 00e988d..5a698c1 100644 --- a/src/client/commands/drop_user.rs +++ b/src/client/commands/drop_user.rs @@ -1,10 +1,12 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_user_completer, protocol::{ ClientToServerMessageStream, Request, Response, print_drop_users_output_status, print_drop_users_output_status_json, @@ -16,7 +18,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct DropUserArgs { /// The MySQL user(s) to drop - #[arg(num_args = 1..)] + #[arg(num_args = 1.., add = ArgValueCompleter::new(mysql_user_completer))] username: Vec, /// Print the information as JSON diff --git a/src/client/commands/edit_privs.rs b/src/client/commands/edit_privs.rs index 751189f..853d1e0 100644 --- a/src/client/commands/edit_privs.rs +++ b/src/client/commands/edit_privs.rs @@ -2,6 +2,7 @@ use std::collections::BTreeSet; use anyhow::Context; use clap::Parser; +use clap_complete::ArgValueCompleter; use dialoguer::{Confirm, Editor}; use futures_util::SinkExt; use nix::unistd::{User, getuid}; @@ -10,6 +11,7 @@ use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_database_completer, database_privileges::{ DatabasePrivilegeEditEntry, DatabasePrivilegeRow, DatabasePrivilegeRowDiff, DatabasePrivilegesDiff, create_or_modify_privilege_rows, diff_privileges, @@ -27,6 +29,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct EditPrivsArgs { /// The MySQL database to edit privileges for + #[arg(add = ArgValueCompleter::new(mysql_database_completer))] pub name: Option, #[arg( diff --git a/src/client/commands/lock_user.rs b/src/client/commands/lock_user.rs index 50b5407..a59ccc8 100644 --- a/src/client/commands/lock_user.rs +++ b/src/client/commands/lock_user.rs @@ -1,10 +1,12 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_user_completer, protocol::{ ClientToServerMessageStream, Request, Response, print_lock_users_output_status, print_lock_users_output_status_json, @@ -16,7 +18,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct LockUserArgs { /// The MySQL user(s) to lock - #[arg(num_args = 1..)] + #[arg(num_args = 1.., add = ArgValueCompleter::new(mysql_user_completer))] username: Vec, /// Print the information as JSON diff --git a/src/client/commands/passwd_user.rs b/src/client/commands/passwd_user.rs index c53906c..50f28d7 100644 --- a/src/client/commands/passwd_user.rs +++ b/src/client/commands/passwd_user.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use anyhow::Context; use clap::Parser; +use clap_complete::ArgValueCompleter; use dialoguer::Password; use futures_util::SinkExt; use tokio_stream::StreamExt; @@ -9,6 +10,7 @@ use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_user_completer, protocol::{ ClientToServerMessageStream, ListUsersError, Request, Response, print_set_password_output_status, @@ -20,6 +22,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct PasswdUserArgs { /// The MySQL user whose password is to be changed + #[arg(add = ArgValueCompleter::new(mysql_user_completer))] username: MySQLUser, /// Read the new password from a file instead of prompting for it diff --git a/src/client/commands/show_db.rs b/src/client/commands/show_db.rs index d497f83..9a72087 100644 --- a/src/client/commands/show_db.rs +++ b/src/client/commands/show_db.rs @@ -1,4 +1,5 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use prettytable::{Cell, Row, Table}; use tokio_stream::StreamExt; @@ -6,6 +7,7 @@ use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_database_completer, protocol::{ClientToServerMessageStream, Request, Response}, types::MySQLDatabase, }, @@ -14,7 +16,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct ShowDbArgs { /// The MySQL database(s) to show - #[arg(num_args = 0..)] + #[arg(num_args = 0.., add = ArgValueCompleter::new(mysql_database_completer))] name: Vec, /// Print the information as JSON diff --git a/src/client/commands/show_privs.rs b/src/client/commands/show_privs.rs index 104c613..6210e58 100644 --- a/src/client/commands/show_privs.rs +++ b/src/client/commands/show_privs.rs @@ -1,4 +1,5 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use prettytable::{Cell, Row, Table}; use tokio_stream::StreamExt; @@ -7,6 +8,7 @@ use crate::{ client::commands::erroneous_server_response, core::{ common::yn, + completion::mysql_database_completer, database_privileges::{DATABASE_PRIVILEGE_FIELDS, db_priv_field_human_readable_name}, protocol::{ClientToServerMessageStream, Request, Response}, types::MySQLDatabase, @@ -16,7 +18,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct ShowPrivsArgs { /// The MySQL database(s) to show privileges for - #[arg(num_args = 0..)] + #[arg(num_args = 0.., add = ArgValueCompleter::new(mysql_database_completer))] name: Vec, /// Print the information as JSON diff --git a/src/client/commands/show_user.rs b/src/client/commands/show_user.rs index 1548ac7..ecdde64 100644 --- a/src/client/commands/show_user.rs +++ b/src/client/commands/show_user.rs @@ -1,11 +1,13 @@ use anyhow::Context; use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_user_completer, protocol::{ClientToServerMessageStream, Request, Response}, types::MySQLUser, }, @@ -14,7 +16,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct ShowUserArgs { /// The MySQL user(s) to show - #[arg(num_args = 0..)] + #[arg(num_args = 0.., add = ArgValueCompleter::new(mysql_user_completer))] username: Vec, /// Print the information as JSON diff --git a/src/client/commands/unlock_user.rs b/src/client/commands/unlock_user.rs index d5a6da3..6af1ac5 100644 --- a/src/client/commands/unlock_user.rs +++ b/src/client/commands/unlock_user.rs @@ -1,10 +1,12 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ client::commands::erroneous_server_response, core::{ + completion::mysql_user_completer, protocol::{ ClientToServerMessageStream, Request, Response, print_unlock_users_output_status, print_unlock_users_output_status_json, @@ -16,7 +18,7 @@ use crate::{ #[derive(Parser, Debug, Clone)] pub struct UnlockUserArgs { /// The MySQL user(s) to unlock - #[arg(num_args = 1..)] + #[arg(num_args = 1.., add = ArgValueCompleter::new(mysql_user_completer))] username: Vec, /// Print the information as JSON diff --git a/src/core.rs b/src/core.rs index 5ef5054..e91f286 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,5 +1,6 @@ pub mod bootstrap; pub mod common; +pub mod completion; pub mod database_privileges; pub mod protocol; pub mod types; diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index 37251f0..3615d18 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -159,7 +159,7 @@ fn connect_to_external_server( /// Drop privileges to the real user and group of the process. /// If the process is not running with elevated privileges, this function /// is a no-op. -fn drop_privs() -> anyhow::Result<()> { +pub fn drop_privs() -> anyhow::Result<()> { tracing::debug!("Dropping privileges"); let real_uid = nix::unistd::getuid(); let real_gid = nix::unistd::getgid(); diff --git a/src/core/completion.rs b/src/core/completion.rs new file mode 100644 index 0000000..2e68e28 --- /dev/null +++ b/src/core/completion.rs @@ -0,0 +1,5 @@ +mod mysql_database_completer; +mod mysql_user_completer; + +pub use mysql_database_completer::*; +pub use mysql_user_completer::*; diff --git a/src/core/completion/mysql_database_completer.rs b/src/core/completion/mysql_database_completer.rs new file mode 100644 index 0000000..a344bf6 --- /dev/null +++ b/src/core/completion/mysql_database_completer.rs @@ -0,0 +1,73 @@ +use clap_complete::CompletionCandidate; +use clap_verbosity_flag::Verbosity; +use futures_util::SinkExt; +use tokio::net::UnixStream as TokioUnixStream; +use tokio_stream::StreamExt; + +use crate::{ + client::commands::erroneous_server_response, + core::{ + bootstrap::bootstrap_server_connection_and_drop_privileges, + protocol::{Request, Response, create_client_to_server_message_stream}, + }, +}; + +pub fn mysql_database_completer(current: &std::ffi::OsStr) -> Vec { + match tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + { + Ok(runtime) => match runtime.block_on(mysql_database_completer_(current)) { + Ok(completions) => completions, + Err(err) => { + eprintln!("Error getting MySQL database completions: {}", err); + Vec::new() + } + }, + Err(err) => { + eprintln!("Error starting Tokio runtime: {}", err); + Vec::new() + } + } +} + +/// Connect to the server to get MySQL database completions. +async fn mysql_database_completer_( + current: &std::ffi::OsStr, +) -> anyhow::Result> { + let server_connection = + bootstrap_server_connection_and_drop_privileges(None, None, Verbosity::new(0, 1))?; + + let tokio_socket = TokioUnixStream::from_std(server_connection)?; + let mut server_connection = create_client_to_server_message_stream(tokio_socket); + + while let Some(Ok(message)) = server_connection.next().await { + match message { + Response::Error(err) => { + anyhow::bail!("{}", err); + } + Response::Ready => break, + message => { + eprintln!("Unexpected message from server: {:?}", message); + } + } + } + + let message = Request::CompleteDatabaseName(current.to_string_lossy().to_string()); + + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); + } + + let result = match server_connection.next().await { + Some(Ok(Response::CompleteDatabaseName(suggestions))) => suggestions, + response => return erroneous_server_response(response).map(|_| vec![]), + }; + + server_connection.send(Request::Exit).await?; + + let result = result.into_iter().map(CompletionCandidate::new).collect(); + + Ok(result) +} diff --git a/src/core/completion/mysql_user_completer.rs b/src/core/completion/mysql_user_completer.rs new file mode 100644 index 0000000..68dcc46 --- /dev/null +++ b/src/core/completion/mysql_user_completer.rs @@ -0,0 +1,73 @@ +use clap_complete::CompletionCandidate; +use clap_verbosity_flag::Verbosity; +use futures_util::SinkExt; +use tokio::net::UnixStream as TokioUnixStream; +use tokio_stream::StreamExt; + +use crate::{ + client::commands::erroneous_server_response, + core::{ + bootstrap::bootstrap_server_connection_and_drop_privileges, + protocol::{Request, Response, create_client_to_server_message_stream}, + }, +}; + +pub fn mysql_user_completer(current: &std::ffi::OsStr) -> Vec { + match tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + { + Ok(runtime) => match runtime.block_on(mysql_user_completer_(current)) { + Ok(completions) => completions, + Err(err) => { + eprintln!("Error getting MySQL user completions: {}", err); + Vec::new() + } + }, + Err(err) => { + eprintln!("Error starting Tokio runtime: {}", err); + Vec::new() + } + } +} + +/// Connect to the server to get MySQL user completions. +async fn mysql_user_completer_( + current: &std::ffi::OsStr, +) -> anyhow::Result> { + let server_connection = + bootstrap_server_connection_and_drop_privileges(None, None, Verbosity::new(0, 1))?; + + let tokio_socket = TokioUnixStream::from_std(server_connection)?; + let mut server_connection = create_client_to_server_message_stream(tokio_socket); + + while let Some(Ok(message)) = server_connection.next().await { + match message { + Response::Error(err) => { + anyhow::bail!("{}", err); + } + Response::Ready => break, + message => { + eprintln!("Unexpected message from server: {:?}", message); + } + } + } + + let message = Request::CompleteUserName(current.to_string_lossy().to_string()); + + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); + } + + let result = match server_connection.next().await { + Some(Ok(Response::CompleteUserName(suggestions))) => suggestions, + response => return erroneous_server_response(response).map(|_| vec![]), + }; + + server_connection.send(Request::Exit).await?; + + let result = result.into_iter().map(CompletionCandidate::new).collect(); + + Ok(result) +} diff --git a/src/core/protocol/commands.rs b/src/core/protocol/commands.rs index 719eaa9..5d0cc45 100644 --- a/src/core/protocol/commands.rs +++ b/src/core/protocol/commands.rs @@ -1,4 +1,6 @@ mod check_authorization; +mod complete_database_name; +mod complete_user_name; mod create_databases; mod create_users; mod drop_databases; @@ -15,6 +17,8 @@ mod passwd_user; mod unlock_users; pub use check_authorization::*; +pub use complete_database_name::*; +pub use complete_user_name::*; pub use create_databases::*; pub use create_users::*; pub use drop_databases::*; @@ -64,6 +68,9 @@ pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToSer pub enum Request { CheckAuthorization(CheckAuthorizationRequest), + CompleteDatabaseName(CompleteDatabaseNameRequest), + CompleteUserName(CompleteUserNameRequest), + CreateDatabases(CreateDatabasesRequest), DropDatabases(DropDatabasesRequest), ListDatabases(ListDatabasesRequest), @@ -88,6 +95,9 @@ pub enum Request { pub enum Response { CheckAuthorization(CheckAuthorizationResponse), + CompleteDatabaseName(CompleteDatabaseNameResponse), + CompleteUserName(CompleteUserNameResponse), + // Specific data for specific commands CreateDatabases(CreateDatabasesResponse), DropDatabases(DropDatabasesResponse), diff --git a/src/core/protocol/commands/complete_database_name.rs b/src/core/protocol/commands/complete_database_name.rs new file mode 100644 index 0000000..65efc2e --- /dev/null +++ b/src/core/protocol/commands/complete_database_name.rs @@ -0,0 +1,5 @@ +use crate::core::types::MySQLDatabase; + +pub type CompleteDatabaseNameRequest = String; + +pub type CompleteDatabaseNameResponse = Vec; diff --git a/src/core/protocol/commands/complete_user_name.rs b/src/core/protocol/commands/complete_user_name.rs new file mode 100644 index 0000000..2fb39b0 --- /dev/null +++ b/src/core/protocol/commands/complete_user_name.rs @@ -0,0 +1,5 @@ +use crate::core::types::MySQLUser; + +pub type CompleteUserNameRequest = String; + +pub type CompleteUserNameResponse = Vec; diff --git a/src/core/protocol/commands/user_exists.rs b/src/core/protocol/commands/user_exists.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/core/types.rs b/src/core/types.rs index d1308e3..a466881 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -1,4 +1,5 @@ use std::{ + ffi::OsString, fmt, ops::{Deref, DerefMut}, str::FromStr, @@ -49,6 +50,12 @@ impl From for MySQLUser { } } +impl From for OsString { + fn from(val: MySQLUser) -> Self { + val.0.into() + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default)] pub struct MySQLDatabase(String); @@ -92,6 +99,12 @@ impl From for MySQLDatabase { } } +impl From for OsString { + fn from(val: MySQLDatabase) -> Self { + val.0.into() + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum DbOrUser { Database(MySQLDatabase), diff --git a/src/main.rs b/src/main.rs index 1f49d3d..1e7a5f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ extern crate prettytable; use anyhow::Context; use clap::{CommandFactory, Parser, ValueEnum}; -use clap_complete::{Shell, generate}; +use clap_complete::{CompleteEnv, Shell, generate}; use clap_verbosity_flag::Verbosity; use std::path::PathBuf; @@ -103,6 +103,10 @@ enum ToplevelCommands { /// **WARNING:** This function may be run with elevated privileges. fn main() -> anyhow::Result<()> { + if handle_dynamic_completion()?.is_some() { + return Ok(()); + } + #[cfg(feature = "mysql-admutils-compatibility")] if handle_mysql_admutils_command()?.is_some() { return Ok(()); @@ -129,6 +133,41 @@ fn main() -> anyhow::Result<()> { Ok(()) } +/// **WARNING:** This function may be run with elevated privileges. +fn handle_dynamic_completion() -> anyhow::Result> { + if std::env::var_os("COMPLETE").is_some() { + #[cfg(feature = "suid-sgid-mode")] + if executable_is_suid_or_sgid()? { + use crate::core::bootstrap::drop_privs; + drop_privs()? + } + + let argv0 = std::env::args() + .next() + .and_then(|s| { + PathBuf::from(s) + .file_name() + .map(|s| s.to_string_lossy().to_string()) + }) + .ok_or(anyhow::anyhow!( + "Could not determine executable name for completion" + ))?; + + let command = match argv0.as_str() { + "muscl" => Args::command(), + "mysql-dbadm" => mysql_dbadm::Command::command(), + "mysql-useradm" => mysql_useradm::Command::command(), + command => anyhow::bail!("Unknown executable name: `{}`", command), + }; + + CompleteEnv::with_factory(move || command.clone()).complete(); + + Ok(Some(())) + } else { + Ok(None) + } +} + /// **WARNING:** This function may be run with elevated privileges. fn handle_mysql_admutils_command() -> anyhow::Result> { let argv0 = std::env::args().next().and_then(|s| { diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index c193543..ac18497 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -18,15 +18,16 @@ use crate::{ authorization::check_authorization, sql::{ database_operations::{ - create_databases, drop_databases, list_all_databases_for_user, list_databases, + complete_database_name, create_databases, drop_databases, + list_all_databases_for_user, list_databases, }, database_privilege_operations::{ apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data, }, user_operations::{ - create_database_users, drop_database_users, list_all_database_users_for_unix_user, - list_database_users, lock_database_users, set_password_for_database_user, - unlock_database_users, + complete_user_name, create_database_users, drop_database_users, + list_all_database_users_for_unix_user, list_database_users, lock_database_users, + set_password_for_database_user, unlock_database_users, }, }, }, @@ -171,6 +172,33 @@ async fn session_handler_with_db_connection( let result = check_authorization(dbs_or_users, unix_user).await; Response::CheckAuthorization(result) } + Request::CompleteDatabaseName(partial_database_name) => { + // TODO: more correct validation here + if !partial_database_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + Response::CompleteDatabaseName(vec![]) + } else { + let result = + complete_database_name(partial_database_name, unix_user, db_connection) + .await; + Response::CompleteDatabaseName(result) + } + } + Request::CompleteUserName(partial_user_name) => { + // TODO: more correct validation here + if !partial_user_name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') + { + Response::CompleteUserName(vec![]) + } else { + let result = + complete_user_name(partial_user_name, unix_user, db_connection).await; + Response::CompleteUserName(result) + } + } Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; Response::CreateDatabases(result) diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index cb8f53b..d5fe151 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -5,6 +5,7 @@ use sqlx::prelude::*; use serde::{Deserialize, Serialize}; +use crate::core::protocol::CompleteDatabaseNameResponse; use crate::core::types::MySQLDatabase; use crate::{ core::{ @@ -43,6 +44,45 @@ pub(super) async fn unsafe_database_exists( Ok(result?.is_some()) } +pub async fn complete_database_name( + database_prefix: String, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> CompleteDatabaseNameResponse { + let result = sqlx::query( + r#" + SELECT `SCHEMA_NAME` AS `database` + FROM `information_schema`.`SCHEMATA` + WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') + AND `SCHEMA_NAME` REGEXP ? + AND `SCHEMA_NAME` LIKE ? + "#, + ) + .bind(create_user_group_matching_regex(unix_user)) + .bind(format!("{}%", database_prefix)) + .fetch_all(connection) + .await; + + match result { + Ok(rows) => rows + .into_iter() + .filter_map(|row| { + let database: String = row.try_get("database").ok()?; + Some(database.into()) + }) + .collect(), + Err(err) => { + tracing::error!( + "Failed to complete database name for prefix '{}' and user '{}': {:?}", + database_prefix, + unix_user.username, + err + ); + vec![] + } + } +} + pub async fn create_databases( database_names: Vec, unix_user: &UnixUser, diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 13628bb..8262d9c 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -51,6 +51,44 @@ async fn unsafe_user_exists( result } +pub async fn complete_user_name( + user_prefix: String, + unix_user: &UnixUser, + connection: &mut MySqlConnection, +) -> Vec { + let result = sqlx::query( + r#" + SELECT `User` AS `user` + FROM `mysql`.`user` + WHERE `User` REGEXP ? + AND `User` LIKE ? + "#, + ) + .bind(create_user_group_matching_regex(unix_user)) + .bind(format!("{}%", user_prefix)) + .fetch_all(connection) + .await; + + match result { + Ok(rows) => rows + .into_iter() + .filter_map(|row| { + let user: String = try_get_with_binary_fallback(&row, "user").ok()?; + Some(user.into()) + }) + .collect(), + Err(err) => { + tracing::error!( + "Failed to complete user name for prefix '{}' and user '{}': {:?}", + user_prefix, + unix_user.username, + err + ); + vec![] + } + } +} + pub async fn create_database_users( db_users: Vec, unix_user: &UnixUser,