Have server notify the client about db connection errors

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-19 16:46:12 +02:00
parent 8fdfe457ac
commit 48240489a7
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
3 changed files with 77 additions and 27 deletions

View File

@ -73,7 +73,6 @@ pub enum Response {
UnlockUsers(UnlockUsersOutput), UnlockUsers(UnlockUsersOutput),
// Generic responses // Generic responses
OperationAborted, Ready,
Error(String), Error(String),
Exit,
} }

View File

@ -9,10 +9,12 @@ use std::path::PathBuf;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream; use tokio::net::UnixStream as TokioUnixStream;
use futures::StreamExt;
use crate::{ use crate::{
core::{ core::{
bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs}, bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs},
protocol::create_client_to_server_message_stream, protocol::{create_client_to_server_message_stream, Response},
}, },
server::command::ServerArgs, server::command::ServerArgs,
}; };
@ -205,7 +207,20 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh
.unwrap() .unwrap()
.block_on(async { .block_on(async {
let tokio_socket = TokioUnixStream::from_std(server_connection)?; let tokio_socket = TokioUnixStream::from_std(server_connection)?;
let message_stream = create_client_to_server_message_stream(tokio_socket); let mut message_stream = create_client_to_server_message_stream(tokio_socket);
while let Some(Ok(message)) = message_stream.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
}
}
}
match command { match command {
Command::User(user_args) => { Command::User(user_args) => {
cli::user_command::handle_command(user_args, message_stream).await cli::user_command::handle_command(user_args, message_stream).await

View File

@ -1,7 +1,7 @@
use std::{collections::BTreeSet, fs, path::PathBuf}; use std::{collections::BTreeSet, fs, path::PathBuf};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio::io::AsyncWriteExt; use indoc::concatdoc;
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use sqlx::prelude::*; use sqlx::prelude::*;
@ -57,15 +57,43 @@ pub async fn listen_for_incoming_connections(
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok(); sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
while let Ok((mut conn, _addr)) = listener.accept().await { while let Ok((conn, _addr)) = listener.accept().await {
let uid = conn.peer_cred()?.uid(); let uid = match conn.peer_cred() {
Ok(cred) => cred.uid(),
Err(e) => {
log::error!("Failed to get peer credentials from socket: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get peer credentials from socket\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue;
}
};
log::trace!("Accepted connection from uid {}", uid); log::trace!("Accepted connection from uid {}", uid);
let unix_user = match UnixUser::from_uid(uid) { let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user, Ok(user) => user,
Err(e) => { Err(e) => {
eprintln!("Failed to get UnixUser from uid: {}", e); log::error!("Failed to get username from uid: {}", e);
conn.shutdown().await?; let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get user data from the system\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue; continue;
} }
}; };
@ -73,9 +101,9 @@ pub async fn listen_for_incoming_connections(
log::info!("Accepted connection from {}", unix_user.username); log::info!("Accepted connection from {}", unix_user.username);
match handle_requests_for_single_session(conn, &unix_user, &config).await { match handle_requests_for_single_session(conn, &unix_user, &config).await {
Ok(_) => {} Ok(()) => {}
Err(e) => { Err(e) => {
eprintln!("Failed to run server: {}", e); log::error!("Failed to run server: {}", e);
} }
} }
} }
@ -88,8 +116,24 @@ pub async fn handle_requests_for_single_session(
unix_user: &UnixUser, unix_user: &UnixUser,
config: &ServerConfig, config: &ServerConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let message_stream = create_server_to_client_message_stream(socket); let mut message_stream = create_server_to_client_message_stream(socket);
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
Ok(connection) => connection,
Err(err) => {
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to connect to database\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await?;
message_stream.flush().await?;
return Err(err);
}
};
log::debug!("Successfully connected to database"); log::debug!("Successfully connected to database");
let result = handle_requests_for_single_session_with_db_connection( let result = handle_requests_for_single_session_with_db_connection(
@ -100,9 +144,9 @@ pub async fn handle_requests_for_single_session(
.await; .await;
if let Err(e) = db_connection.close().await { if let Err(e) = db_connection.close().await {
eprintln!("Failed to close database connection: {}", e); log::error!("Failed to close database connection: {}", e);
eprintln!("{}", e); log::error!("{}", e);
eprintln!("Ignoring..."); log::error!("Ignoring...");
} }
result result
@ -116,6 +160,7 @@ pub async fn handle_requests_for_single_session_with_db_connection(
unix_user: &UnixUser, unix_user: &UnixUser,
db_connection: &mut MySqlConnection, db_connection: &mut MySqlConnection,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
stream.send(Response::Ready).await?;
loop { loop {
// TODO: better error handling // TODO: better error handling
let request = match stream.next().await { let request = match stream.next().await {
@ -133,17 +178,14 @@ pub async fn handle_requests_for_single_session_with_db_connection(
Request::CreateDatabases(databases_names) => { Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await; let result = create_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::CreateDatabases(result)).await?; stream.send(Response::CreateDatabases(result)).await?;
stream.flush().await?;
} }
Request::DropDatabases(databases_names) => { Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await; let result = drop_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::DropDatabases(result)).await?; stream.send(Response::DropDatabases(result)).await?;
stream.flush().await?;
} }
Request::ListDatabases => { Request::ListDatabases => {
let result = list_databases_for_user(unix_user, db_connection).await; let result = list_databases_for_user(unix_user, db_connection).await;
stream.send(Response::ListAllDatabases(result)).await?; stream.send(Response::ListAllDatabases(result)).await?;
stream.flush().await?;
} }
Request::ListPrivileges(database_names) => { Request::ListPrivileges(database_names) => {
let response = match database_names { let response = match database_names {
@ -161,7 +203,6 @@ pub async fn handle_requests_for_single_session_with_db_connection(
}; };
stream.send(response).await?; stream.send(response).await?;
stream.flush().await?;
} }
Request::ModifyPrivileges(database_privilege_diffs) => { Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs( let result = apply_privilege_diffs(
@ -171,24 +212,20 @@ pub async fn handle_requests_for_single_session_with_db_connection(
) )
.await; .await;
stream.send(Response::ModifyPrivileges(result)).await?; stream.send(Response::ModifyPrivileges(result)).await?;
stream.flush().await?;
} }
Request::CreateUsers(db_users) => { Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await; let result = create_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::CreateUsers(result)).await?; stream.send(Response::CreateUsers(result)).await?;
stream.flush().await?;
} }
Request::DropUsers(db_users) => { Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await; let result = drop_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::DropUsers(result)).await?; stream.send(Response::DropUsers(result)).await?;
stream.flush().await?;
} }
Request::PasswdUser(db_user, password) => { Request::PasswdUser(db_user, password) => {
let result = let result =
set_password_for_database_user(&db_user, &password, unix_user, db_connection) set_password_for_database_user(&db_user, &password, unix_user, db_connection)
.await; .await;
stream.send(Response::PasswdUser(result)).await?; stream.send(Response::PasswdUser(result)).await?;
stream.flush().await?;
} }
Request::ListUsers(db_users) => { Request::ListUsers(db_users) => {
let response = match db_users { let response = match db_users {
@ -203,22 +240,21 @@ pub async fn handle_requests_for_single_session_with_db_connection(
} }
}; };
stream.send(response).await?; stream.send(response).await?;
stream.flush().await?;
} }
Request::LockUsers(db_users) => { Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await; let result = lock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::LockUsers(result)).await?; stream.send(Response::LockUsers(result)).await?;
stream.flush().await?;
} }
Request::UnlockUsers(db_users) => { Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await; let result = unlock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::UnlockUsers(result)).await?; stream.send(Response::UnlockUsers(result)).await?;
stream.flush().await?;
} }
Request::Exit => { Request::Exit => {
break; break;
} }
} }
stream.flush().await?;
} }
Ok(()) Ok(())