From 807017ea70327e01d9d4f73caf95a7f8cf7437f4 Mon Sep 17 00:00:00 2001
From: h7x4 <h7x4@nani.wtf>
Date: Mon, 19 Aug 2024 02:22:18 +0200
Subject: [PATCH] add shell completion

---
 Cargo.lock                                    |  10 ++
 Cargo.toml                                    |   1 +
 nix/default.nix                               |  15 ++-
 .../mysql_dbadm.rs                            |  19 ++-
 .../mysql_useradm.rs                          |  18 ++-
 src/main.rs                                   | 121 ++++++++++++++----
 6 files changed, 143 insertions(+), 41 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index dee32b3..4e921e0 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -265,6 +265,15 @@ dependencies = [
  "strsim",
 ]
 
+[[package]]
+name = "clap_complete"
+version = "4.5.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ee158892bd7ce77aa15c208abbdb73e155d191c287a659b57abd5adb92feb03"
+dependencies = [
+ "clap",
+]
+
 [[package]]
 name = "clap_derive"
 version = "4.5.13"
@@ -1055,6 +1064,7 @@ dependencies = [
  "async-bincode",
  "bincode",
  "clap",
+ "clap_complete",
  "derive_more",
  "dialoguer",
  "env_logger",
diff --git a/Cargo.toml b/Cargo.toml
index 34877bc..4517c66 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,6 +8,7 @@ anyhow = "1.0.86"
 async-bincode = "0.7.2"
 bincode = "1.3.3"
 clap = { version = "4.5.16", features = ["derive"] }
+clap_complete = "4.5.18"
 derive_more = { version = "1.0.0", features = ["display", "error"] }
 dialoguer = "0.11.0"
 env_logger = "0.11.5"
diff --git a/nix/default.nix b/nix/default.nix
index ccc2cfe..7447158 100644
--- a/nix/default.nix
+++ b/nix/default.nix
@@ -4,8 +4,10 @@
 , cargoToml
 , cargoLock
 , src
+, installShellFiles
 }:
 let
+  mainProgram = (lib.head cargoToml.bin).name;
 in
 rustPlatform.buildRustPackage {
   pname = cargoToml.package.name;
@@ -14,9 +16,20 @@ rustPlatform.buildRustPackage {
 
   cargoLock.lockFile = cargoLock;
 
+  nativeBuildInputs = [ installShellFiles ];
+  postInstall = let
+    commands = lib.mapCartesianProduct ({ shell, command }: ''
+      "$out/bin/${mainProgram}" generate-completions --shell "${shell}" --command "${command}" > "$TMP/mysqladm.${shell}"
+      installShellCompletion "--${shell}" --cmd "${command}" "$TMP/mysqladm.${shell}"
+    '') {
+      shell = [ "bash" "zsh" "fish" ];
+      command = [ "mysqladm" "mysql-dbadm" "mysql-useradm" ];
+    };
+  in lib.concatStringsSep "\n" commands;
+
   meta = with lib; {
     license = licenses.mit;
     platforms = platforms.linux ++ platforms.darwin;
-    mainProgram = (lib.head cargoToml.bin).name;
+    inherit mainProgram;
   };
 }
diff --git a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs
index 63d6d3a..1e05a2c 100644
--- a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs
+++ b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs
@@ -49,7 +49,19 @@ The Y/N-values corresponds to the following mysql privileges:
   References - Enables use of REFERENCES
 "#;
 
+/// Create, drop or edit permissions for the DATABASE(s),
+/// as determined by the COMMAND.
+///
+/// This is a compatibility layer for the mysql-dbadm command.
+/// Please consider using the newer mysqladm command instead.
 #[derive(Parser)]
+#[command(
+  bin_name = "mysql-dbadm",
+  version,
+  about,
+  disable_help_subcommand = true,
+  verbatim_doc_comment,
+)]
 pub struct Args {
     #[command(subcommand)]
     pub command: Option<Command>,
@@ -82,14 +94,7 @@ pub struct Args {
 // NOTE: mysql-dbadm explicitly calls privileges "permissions".
 //       This is something we're trying to move away from.
 //       See https://git.pvv.ntnu.no/Projects/mysqladm-rs/issues/29
-
-/// Create, drop or edit permissions for the DATABASE(s),
-/// as determined by the COMMAND.
-///
-/// This is a compatibility layer for the mysql-dbadm command.
-/// Please consider using the newer mysqladm command instead.
 #[derive(Parser)]
-#[command(version, about, disable_help_subcommand = true, verbatim_doc_comment)]
 pub enum Command {
     /// create the DATABASE(s).
     Create(CreateArgs),
diff --git a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
index 5f17494..5de17f0 100644
--- a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
+++ b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
@@ -25,7 +25,19 @@ use crate::{
     server::sql::user_operations::DatabaseUser,
 };
 
+/// Create, delete or change password for the USER(s),
+/// as determined by the COMMAND.
+///
+/// This is a compatibility layer for the mysql-useradm command.
+/// Please consider using the newer mysqladm command instead.
 #[derive(Parser)]
+#[command(
+  bin_name = "mysql-useradm",
+  version,
+  about,
+  disable_help_subcommand = true,
+  verbatim_doc_comment,
+)]
 pub struct Args {
     #[command(subcommand)]
     pub command: Option<Command>,
@@ -51,13 +63,7 @@ pub struct Args {
     config: Option<PathBuf>,
 }
 
-/// Create, delete or change password for the USER(s),
-/// as determined by the COMMAND.
-///
-/// This is a compatibility layer for the mysql-useradm command.
-/// Please consider using the newer mysqladm command instead.
 #[derive(Parser)]
-#[command(version, about, disable_help_subcommand = true, verbatim_doc_comment)]
 pub enum Command {
     /// create the USER(s).
     Create(CreateArgs),
diff --git a/src/main.rs b/src/main.rs
index 0ca5515..165e46e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,7 +1,8 @@
 #[macro_use]
 extern crate prettytable;
 
-use clap::Parser;
+use clap::{CommandFactory, Parser, ValueEnum};
+use clap_complete::{generate, Shell};
 
 use std::path::PathBuf;
 
@@ -27,7 +28,14 @@ mod core;
 #[cfg(feature = "tui")]
 mod tui;
 
+/// Database administration tool for non-admin users to manage their own MySQL databases and users.
+///
+/// This tool allows you to manage users and databases in MySQL.
+///
+/// You are only allowed to manage databases and users that are prefixed with
+/// either your username, or a group that you are a member of.
 #[derive(Parser, Debug)]
+#[command(bin_name = "mysqladm", version, about, disable_help_subcommand = true)]
 struct Args {
     #[command(subcommand)]
     command: Command,
@@ -57,14 +65,7 @@ struct Args {
     interactive: bool,
 }
 
-// Database administration tool for non-admin users to manage their own MySQL databases and users.
-//
-// This tool allows you to manage users and databases in MySQL.
-//
-// You are only allowed to manage databases and users that are prefixed with
-// either your username, or a group that you are a member of.
 #[derive(Parser, Debug, Clone)]
-#[command(version, about, disable_help_subcommand = true)]
 enum Command {
     #[command(flatten)]
     Db(cli::database_command::DatabaseCommand),
@@ -74,6 +75,26 @@ enum Command {
 
     #[command(hide = true)]
     Server(server::command::ServerArgs),
+
+    #[command(hide = true)]
+    GenerateCompletions(GenerateCompletionArgs),
+}
+
+#[derive(Parser, Debug, Clone)]
+struct GenerateCompletionArgs {
+    #[arg(long, default_value = "bash")]
+    shell: Shell,
+
+    #[arg(long, default_value = "mysqladm")]
+    command: ToplevelCommands,
+}
+
+#[cfg(feature = "mysql-admutils-compatibility")]
+#[derive(ValueEnum, Debug, Clone)]
+enum ToplevelCommands {
+    Mysqladm,
+    MysqlDbadm,
+    MysqlUseradm,
 }
 
 // TODO: tag all functions that are run with elevated privileges with
@@ -86,28 +107,18 @@ fn main() -> anyhow::Result<()> {
     env_logger::init();
 
     #[cfg(feature = "mysql-admutils-compatibility")]
-    {
-        let argv0 = std::env::args().next().and_then(|s| {
-            PathBuf::from(s)
-                .file_name()
-                .map(|s| s.to_string_lossy().to_string())
-        });
-
-        match argv0.as_deref() {
-            Some("mysql-dbadm") => return mysql_dbadm::main(),
-            Some("mysql-useradm") => return mysql_useradm::main(),
-            _ => { /* fall through */ }
-        }
+    if let Some(_) = handle_mysql_admutils_command()? {
+        return Ok(());
     }
 
     let args: Args = Args::parse();
-    match args.command {
-        Command::Server(ref command) => {
-            drop_privs()?;
-            tokio_start_server(args.server_socket_path, args.config, command.clone())?;
-            return Ok(());
-        }
-        _ => { /* fall through */ }
+
+    if let Some(_) = handle_server_command(&args)? {
+        return Ok(());
+    }
+
+    if let Some(_) = handle_generate_completions_command(&args)? {
+        return Ok(());
     }
 
     let server_connection =
@@ -118,6 +129,61 @@ fn main() -> anyhow::Result<()> {
     Ok(())
 }
 
+fn handle_mysql_admutils_command() -> anyhow::Result<Option<()>> {
+    let argv0 = std::env::args().next().and_then(|s| {
+        PathBuf::from(s)
+            .file_name()
+            .map(|s| s.to_string_lossy().to_string())
+    });
+
+    match argv0.as_deref() {
+        Some("mysql-dbadm") => mysql_dbadm::main().map(|result| Some(result)),
+        Some("mysql-useradm") => mysql_useradm::main().map(|result| Some(result)),
+        _ => Ok(None),
+    }
+}
+
+fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
+    match args.command {
+        Command::Server(ref command) => {
+            drop_privs()?;
+            tokio_start_server(
+                args.server_socket_path.clone(),
+                args.config.clone(),
+                command.clone(),
+            )?;
+            Ok(Some(()))
+        }
+        _ => Ok(None),
+    }
+}
+
+fn handle_generate_completions_command(args: &Args) -> anyhow::Result<Option<()>> {
+    match args.command {
+        Command::GenerateCompletions(ref completion_args) => {
+            let mut cmd = match completion_args.command {
+                ToplevelCommands::Mysqladm => Args::command(),
+                #[cfg(feature = "mysql-admutils-compatibility")]
+                ToplevelCommands::MysqlDbadm => mysql_dbadm::Args::command(),
+                #[cfg(feature = "mysql-admutils-compatibility")]
+                ToplevelCommands::MysqlUseradm => mysql_useradm::Args::command(),
+            };
+
+            let binary_name = cmd.get_bin_name().unwrap().to_owned();
+
+            generate(
+                completion_args.shell,
+                &mut cmd,
+                binary_name,
+                &mut std::io::stdout(),
+            );
+
+            Ok(Some(()))
+        }
+        _ => Ok(None),
+    }
+}
+
 fn tokio_start_server(
     server_socket_path: Option<PathBuf>,
     config_path: Option<PathBuf>,
@@ -148,6 +214,7 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh
                     cli::database_command::handle_command(db_args, message_stream).await
                 }
                 Command::Server(_) => unreachable!(),
+                Command::GenerateCompletions(_) => unreachable!(),
             }
         })
 }