diff --git a/src/cli.rs b/src/cli.rs index 1b29138..65e1b12 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,5 +1,6 @@ mod common; pub mod database_command; +pub mod other_command; pub mod user_command; #[cfg(feature = "mysql-admutils-compatibility")] diff --git a/src/cli/other_command.rs b/src/cli/other_command.rs new file mode 100644 index 0000000..a3d1656 --- /dev/null +++ b/src/cli/other_command.rs @@ -0,0 +1,59 @@ +use clap::Parser; +use futures_util::{SinkExt, StreamExt}; + +use crate::core::protocol::{ + ClientToServerMessageStream, Request, Response +}; + +use super::common::erroneous_server_response; + +#[allow(clippy::enum_variant_names)] +#[derive(Parser, Debug, Clone)] +pub enum OtherCommand { + /// Check if the tool is set up correctly, and the server is running. + #[command()] + Status(StatusArgs), +} + +#[derive(Parser, Debug, Clone)] +pub struct StatusArgs { + /// Print the information as JSON + #[arg(short, long)] + json: bool, +} + +pub async fn handle_command( + command: OtherCommand, + server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + match command { + OtherCommand::Status(args) => status(args, server_connection).await, + } +} + +/// TODO: this should be moved all the way out to the main function, so that +/// we can teste the server connection before it fails to be established. +async fn status( + args: StatusArgs, + mut server_connection: ClientToServerMessageStream, +) -> anyhow::Result<()> { + if let Err(err) = server_connection.send(Request::Ping).await { + server_connection.close().await.ok(); + anyhow::bail!(err); + } + + match server_connection.next().await { + Some(Ok(Response::Pong)) => (), + response => return erroneous_server_response(response), + }; + + server_connection.send(Request::Exit).await?; + + if args.json { + // print_drop_users_output_status_json(&result); + } else { + // print_drop_users_output_status(&result); + } + + Ok(()) +} diff --git a/src/core/protocol/request_response.rs b/src/core/protocol/request_response.rs index dcda57d..7e04e5a 100644 --- a/src/core/protocol/request_response.rs +++ b/src/core/protocol/request_response.rs @@ -125,6 +125,8 @@ impl From for MySQLDatabase { #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Request { + Ping, + CreateDatabases(Vec), DropDatabases(Vec), ListDatabases(Option>), @@ -147,6 +149,8 @@ pub enum Request { #[non_exhaustive] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Response { + Pong, + // Specific data for specific commands CreateDatabases(CreateDatabasesOutput), DropDatabases(DropDatabasesOutput), diff --git a/src/main.rs b/src/main.rs index 4b7fabd..4368cf0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -82,6 +82,9 @@ enum Command { #[command(flatten)] User(cli::user_command::UserCommand), + #[command(flatten)] + Other(cli::other_command::OtherCommand), + #[command(hide = true)] Server(server::command::ServerArgs), @@ -247,6 +250,9 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh Command::Db(db_args) => { cli::database_command::handle_command(db_args, message_stream).await } + Command::Other(other_args) => { + cli::other_command::handle_command(other_args, message_stream).await + } Command::Server(_) => unreachable!(), Command::GenerateCompletions(_) => unreachable!(), } diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs index 5966f30..801402e 100644 --- a/src/server/server_loop.rs +++ b/src/server/server_loop.rs @@ -241,6 +241,8 @@ async fn handle_requests_for_single_session_with_db_connection( log::info!("Received request: {:#?}", request_to_display); let response = match request { + Request::Ping => Response::Pong, + Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; Response::CreateDatabases(result)