diff --git a/src/client/commands/create_db.rs b/src/client/commands/create_db.rs index b59ee26..cd08c7c 100644 --- a/src/client/commands/create_db.rs +++ b/src/client/commands/create_db.rs @@ -1,10 +1,12 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use futures_util::SinkExt; use tokio_stream::StreamExt; use crate::{ client::commands::{erroneous_server_response, print_authorization_owner_hint}, core::{ + completion::prefix_completer, protocol::{ ClientToServerMessageStream, CreateDatabaseError, Request, Response, print_create_databases_output_status, print_create_databases_output_status_json, @@ -18,6 +20,7 @@ use crate::{ pub struct CreateDbArgs { /// The MySQL database(s) to create #[arg(num_args = 1.., value_name = "DB_NAME")] + #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))] name: Vec, /// Print the information as JSON diff --git a/src/client/commands/create_user.rs b/src/client/commands/create_user.rs index 26cd14c..72a4340 100644 --- a/src/client/commands/create_user.rs +++ b/src/client/commands/create_user.rs @@ -1,4 +1,5 @@ use clap::Parser; +use clap_complete::ArgValueCompleter; use dialoguer::Confirm; use futures_util::SinkExt; use tokio_stream::StreamExt; @@ -9,6 +10,7 @@ use crate::{ read_password_from_stdin_with_double_check, }, core::{ + completion::prefix_completer, protocol::{ ClientToServerMessageStream, CreateUserError, Request, Response, print_create_users_output_status, print_create_users_output_status_json, @@ -22,6 +24,7 @@ use crate::{ pub struct CreateUserArgs { /// The MySQL user(s) to create #[arg(num_args = 1.., value_name = "USER_NAME")] + #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))] username: Vec, /// Do not ask for a password, leave it unset diff --git a/src/client/mysql_admutils_compatibility/mysql_dbadm.rs b/src/client/mysql_admutils_compatibility/mysql_dbadm.rs index 9de318e..d37e485 100644 --- a/src/client/mysql_admutils_compatibility/mysql_dbadm.rs +++ b/src/client/mysql_admutils_compatibility/mysql_dbadm.rs @@ -18,7 +18,7 @@ use crate::{ }, core::{ bootstrap::bootstrap_server_connection_and_drop_privileges, - completion::mysql_database_completer, + completion::{mysql_database_completer, prefix_completer}, database_privileges::DatabasePrivilegeRow, protocol::{ ClientToServerMessageStream, ListPrivilegesError, Request, Response, @@ -124,6 +124,7 @@ pub enum Command { pub struct CreateArgs { /// The name of the DATABASE(s) to create. #[arg(num_args = 1..)] + #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))] name: Vec, } diff --git a/src/client/mysql_admutils_compatibility/mysql_useradm.rs b/src/client/mysql_admutils_compatibility/mysql_useradm.rs index c0fa4bf..d61f52c 100644 --- a/src/client/mysql_admutils_compatibility/mysql_useradm.rs +++ b/src/client/mysql_admutils_compatibility/mysql_useradm.rs @@ -18,7 +18,7 @@ use crate::{ }, core::{ bootstrap::bootstrap_server_connection_and_drop_privileges, - completion::mysql_user_completer, + completion::{mysql_user_completer, prefix_completer}, protocol::{ ClientToServerMessageStream, Request, Response, create_client_to_server_message_stream, }, @@ -87,6 +87,7 @@ pub enum Command { pub struct CreateArgs { /// The name of the USER(s) to create. #[arg(num_args = 1..)] + #[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))] name: Vec, } diff --git a/src/core/completion.rs b/src/core/completion.rs index 2e68e28..d4dead0 100644 --- a/src/core/completion.rs +++ b/src/core/completion.rs @@ -1,5 +1,7 @@ mod mysql_database_completer; mod mysql_user_completer; +mod prefix_completer; pub use mysql_database_completer::*; pub use mysql_user_completer::*; +pub use prefix_completer::*; diff --git a/src/core/completion/prefix_completer.rs b/src/core/completion/prefix_completer.rs new file mode 100644 index 0000000..5eee537 --- /dev/null +++ b/src/core/completion/prefix_completer.rs @@ -0,0 +1,75 @@ +use clap_complete::CompletionCandidate; +use clap_verbosity_flag::Verbosity; +use futures_util::SinkExt; +use tokio::net::UnixStream as TokioUnixStream; +use tokio_stream::StreamExt; + +use crate::{ + client::commands::erroneous_server_response, + core::{ + bootstrap::bootstrap_server_connection_and_drop_privileges, + protocol::{Request, Response, create_client_to_server_message_stream}, + }, +}; + +pub fn prefix_completer(current: &std::ffi::OsStr) -> Vec { + match tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + { + Ok(runtime) => match runtime.block_on(prefix_completer_(current)) { + Ok(completions) => completions, + Err(err) => { + eprintln!("Error getting prefix completions: {}", err); + Vec::new() + } + }, + Err(err) => { + eprintln!("Error starting Tokio runtime: {}", err); + Vec::new() + } + } +} + +/// Connect to the server to get MySQL user completions. +async fn prefix_completer_(_current: &std::ffi::OsStr) -> anyhow::Result> { + let server_connection = + bootstrap_server_connection_and_drop_privileges(None, None, Verbosity::new(0, 1))?; + + let tokio_socket = TokioUnixStream::from_std(server_connection)?; + let mut server_connection = create_client_to_server_message_stream(tokio_socket); + + while let Some(Ok(message)) = server_connection.next().await { + match message { + Response::Error(err) => { + anyhow::bail!("{}", err); + } + Response::Ready => break, + message => { + eprintln!("Unexpected message from server: {:?}", message); + } + } + } + + let message = Request::ListValidNamePrefixes; + + if let Err(err) = server_connection.send(message).await { + server_connection.close().await.ok(); + anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server")); + } + + let result = match server_connection.next().await { + Some(Ok(Response::ListValidNamePrefixes(prefixes))) => prefixes, + response => return erroneous_server_response(response).map(|_| vec![]), + }; + + server_connection.send(Request::Exit).await?; + + let result = result + .into_iter() + .map(|prefix| prefix + "_") + .map(CompletionCandidate::new) + .collect(); + + Ok(result) +}