diff --git a/src/core/protocol/request_response.rs b/src/core/protocol/request_response.rs index 13bc013..94725cb 100644 --- a/src/core/protocol/request_response.rs +++ b/src/core/protocol/request_response.rs @@ -73,7 +73,6 @@ pub enum Response { UnlockUsers(UnlockUsersOutput), // Generic responses - OperationAborted, + Ready, Error(String), - Exit, } diff --git a/src/main.rs b/src/main.rs index 165e46e..a539de3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,10 +9,12 @@ use std::path::PathBuf; use std::os::unix::net::UnixStream as StdUnixStream; use tokio::net::UnixStream as TokioUnixStream; +use futures::StreamExt; + use crate::{ core::{ bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs}, - protocol::create_client_to_server_message_stream, + protocol::{create_client_to_server_message_stream, Response}, }, server::command::ServerArgs, }; @@ -205,7 +207,20 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh .unwrap() .block_on(async { let tokio_socket = TokioUnixStream::from_std(server_connection)?; - let message_stream = create_client_to_server_message_stream(tokio_socket); + let mut message_stream = create_client_to_server_message_stream(tokio_socket); + + while let Some(Ok(message)) = message_stream.next().await { + match message { + Response::Error(err) => { + anyhow::bail!("{}", err); + } + Response::Ready => break, + message => { + eprintln!("Unexpected message from server: {:?}", message); + } + } + } + match command { Command::User(user_args) => { cli::user_command::handle_command(user_args, message_stream).await diff --git a/src/server/server_loop.rs b/src/server/server_loop.rs index f768568..1fe6f30 100644 --- a/src/server/server_loop.rs +++ b/src/server/server_loop.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeSet, fs, path::PathBuf}; use futures_util::{SinkExt, StreamExt}; -use tokio::io::AsyncWriteExt; +use indoc::concatdoc; use tokio::net::{UnixListener, UnixStream}; use sqlx::prelude::*; @@ -57,15 +57,43 @@ pub async fn listen_for_incoming_connections( sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); - while let Ok((mut conn, _addr)) = listener.accept().await { - let uid = conn.peer_cred()?.uid(); + while let Ok((conn, _addr)) = listener.accept().await { + let uid = match conn.peer_cred() { + Ok(cred) => cred.uid(), + Err(e) => { + log::error!("Failed to get peer credentials from socket: {}", e); + let mut message_stream = create_server_to_client_message_stream(conn); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get peer credentials from socket\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + continue; + } + }; + log::trace!("Accepted connection from uid {}", uid); let unix_user = match UnixUser::from_uid(uid) { Ok(user) => user, Err(e) => { - eprintln!("Failed to get UnixUser from uid: {}", e); - conn.shutdown().await?; + log::error!("Failed to get username from uid: {}", e); + let mut message_stream = create_server_to_client_message_stream(conn); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get user data from the system\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); continue; } }; @@ -73,9 +101,9 @@ pub async fn listen_for_incoming_connections( log::info!("Accepted connection from {}", unix_user.username); match handle_requests_for_single_session(conn, &unix_user, &config).await { - Ok(_) => {} + Ok(()) => {} Err(e) => { - eprintln!("Failed to run server: {}", e); + log::error!("Failed to run server: {}", e); } } } @@ -88,8 +116,24 @@ pub async fn handle_requests_for_single_session( unix_user: &UnixUser, config: &ServerConfig, ) -> anyhow::Result<()> { - let message_stream = create_server_to_client_message_stream(socket); - let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; + let mut message_stream = create_server_to_client_message_stream(socket); + let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await { + Ok(connection) => connection, + Err(err) => { + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to connect to database\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await?; + message_stream.flush().await?; + return Err(err); + } + }; + log::debug!("Successfully connected to database"); let result = handle_requests_for_single_session_with_db_connection( @@ -100,9 +144,9 @@ pub async fn handle_requests_for_single_session( .await; if let Err(e) = db_connection.close().await { - eprintln!("Failed to close database connection: {}", e); - eprintln!("{}", e); - eprintln!("Ignoring..."); + log::error!("Failed to close database connection: {}", e); + log::error!("{}", e); + log::error!("Ignoring..."); } result @@ -116,6 +160,7 @@ pub async fn handle_requests_for_single_session_with_db_connection( unix_user: &UnixUser, db_connection: &mut MySqlConnection, ) -> anyhow::Result<()> { + stream.send(Response::Ready).await?; loop { // TODO: better error handling let request = match stream.next().await { @@ -133,17 +178,14 @@ pub async fn handle_requests_for_single_session_with_db_connection( Request::CreateDatabases(databases_names) => { let result = create_databases(databases_names, unix_user, db_connection).await; stream.send(Response::CreateDatabases(result)).await?; - stream.flush().await?; } Request::DropDatabases(databases_names) => { let result = drop_databases(databases_names, unix_user, db_connection).await; stream.send(Response::DropDatabases(result)).await?; - stream.flush().await?; } Request::ListDatabases => { let result = list_databases_for_user(unix_user, db_connection).await; stream.send(Response::ListAllDatabases(result)).await?; - stream.flush().await?; } Request::ListPrivileges(database_names) => { let response = match database_names { @@ -161,7 +203,6 @@ pub async fn handle_requests_for_single_session_with_db_connection( }; stream.send(response).await?; - stream.flush().await?; } Request::ModifyPrivileges(database_privilege_diffs) => { let result = apply_privilege_diffs( @@ -171,24 +212,20 @@ pub async fn handle_requests_for_single_session_with_db_connection( ) .await; stream.send(Response::ModifyPrivileges(result)).await?; - stream.flush().await?; } Request::CreateUsers(db_users) => { let result = create_database_users(db_users, unix_user, db_connection).await; stream.send(Response::CreateUsers(result)).await?; - stream.flush().await?; } Request::DropUsers(db_users) => { let result = drop_database_users(db_users, unix_user, db_connection).await; stream.send(Response::DropUsers(result)).await?; - stream.flush().await?; } Request::PasswdUser(db_user, password) => { let result = set_password_for_database_user(&db_user, &password, unix_user, db_connection) .await; stream.send(Response::PasswdUser(result)).await?; - stream.flush().await?; } Request::ListUsers(db_users) => { let response = match db_users { @@ -203,22 +240,21 @@ pub async fn handle_requests_for_single_session_with_db_connection( } }; stream.send(response).await?; - stream.flush().await?; } Request::LockUsers(db_users) => { let result = lock_database_users(db_users, unix_user, db_connection).await; stream.send(Response::LockUsers(result)).await?; - stream.flush().await?; } Request::UnlockUsers(db_users) => { let result = unlock_database_users(db_users, unix_user, db_connection).await; stream.send(Response::UnlockUsers(result)).await?; - stream.flush().await?; } Request::Exit => { break; } } + + stream.flush().await?; } Ok(())