Have server notify the client about db connection errors

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-19 16:46:12 +02:00
parent 8fdfe457ac
commit 48240489a7
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
3 changed files with 77 additions and 27 deletions

View File

@ -73,7 +73,6 @@ pub enum Response {
UnlockUsers(UnlockUsersOutput),
// Generic responses
OperationAborted,
Ready,
Error(String),
Exit,
}

View File

@ -9,10 +9,12 @@ use std::path::PathBuf;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream;
use futures::StreamExt;
use crate::{
core::{
bootstrap::{bootstrap_server_connection_and_drop_privileges, drop_privs},
protocol::create_client_to_server_message_stream,
protocol::{create_client_to_server_message_stream, Response},
},
server::command::ServerArgs,
};
@ -205,7 +207,20 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh
.unwrap()
.block_on(async {
let tokio_socket = TokioUnixStream::from_std(server_connection)?;
let message_stream = create_client_to_server_message_stream(tokio_socket);
let mut message_stream = create_client_to_server_message_stream(tokio_socket);
while let Some(Ok(message)) = message_stream.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
}
}
}
match command {
Command::User(user_args) => {
cli::user_command::handle_command(user_args, message_stream).await

View File

@ -1,7 +1,7 @@
use std::{collections::BTreeSet, fs, path::PathBuf};
use futures_util::{SinkExt, StreamExt};
use tokio::io::AsyncWriteExt;
use indoc::concatdoc;
use tokio::net::{UnixListener, UnixStream};
use sqlx::prelude::*;
@ -57,15 +57,43 @@ pub async fn listen_for_incoming_connections(
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
while let Ok((mut conn, _addr)) = listener.accept().await {
let uid = conn.peer_cred()?.uid();
while let Ok((conn, _addr)) = listener.accept().await {
let uid = match conn.peer_cred() {
Ok(cred) => cred.uid(),
Err(e) => {
log::error!("Failed to get peer credentials from socket: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get peer credentials from socket\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue;
}
};
log::trace!("Accepted connection from uid {}", uid);
let unix_user = match UnixUser::from_uid(uid) {
Ok(user) => user,
Err(e) => {
eprintln!("Failed to get UnixUser from uid: {}", e);
conn.shutdown().await?;
log::error!("Failed to get username from uid: {}", e);
let mut message_stream = create_server_to_client_message_stream(conn);
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to get user data from the system\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await
.ok();
continue;
}
};
@ -73,9 +101,9 @@ pub async fn listen_for_incoming_connections(
log::info!("Accepted connection from {}", unix_user.username);
match handle_requests_for_single_session(conn, &unix_user, &config).await {
Ok(_) => {}
Ok(()) => {}
Err(e) => {
eprintln!("Failed to run server: {}", e);
log::error!("Failed to run server: {}", e);
}
}
}
@ -88,8 +116,24 @@ pub async fn handle_requests_for_single_session(
unix_user: &UnixUser,
config: &ServerConfig,
) -> anyhow::Result<()> {
let message_stream = create_server_to_client_message_stream(socket);
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?;
let mut message_stream = create_server_to_client_message_stream(socket);
let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
Ok(connection) => connection,
Err(err) => {
message_stream
.send(Response::Error(
(concatdoc! {
"Server failed to connect to database\n",
"Please check the server logs or contact the system administrators"
})
.to_string(),
))
.await?;
message_stream.flush().await?;
return Err(err);
}
};
log::debug!("Successfully connected to database");
let result = handle_requests_for_single_session_with_db_connection(
@ -100,9 +144,9 @@ pub async fn handle_requests_for_single_session(
.await;
if let Err(e) = db_connection.close().await {
eprintln!("Failed to close database connection: {}", e);
eprintln!("{}", e);
eprintln!("Ignoring...");
log::error!("Failed to close database connection: {}", e);
log::error!("{}", e);
log::error!("Ignoring...");
}
result
@ -116,6 +160,7 @@ pub async fn handle_requests_for_single_session_with_db_connection(
unix_user: &UnixUser,
db_connection: &mut MySqlConnection,
) -> anyhow::Result<()> {
stream.send(Response::Ready).await?;
loop {
// TODO: better error handling
let request = match stream.next().await {
@ -133,17 +178,14 @@ pub async fn handle_requests_for_single_session_with_db_connection(
Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::CreateDatabases(result)).await?;
stream.flush().await?;
}
Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await;
stream.send(Response::DropDatabases(result)).await?;
stream.flush().await?;
}
Request::ListDatabases => {
let result = list_databases_for_user(unix_user, db_connection).await;
stream.send(Response::ListAllDatabases(result)).await?;
stream.flush().await?;
}
Request::ListPrivileges(database_names) => {
let response = match database_names {
@ -161,7 +203,6 @@ pub async fn handle_requests_for_single_session_with_db_connection(
};
stream.send(response).await?;
stream.flush().await?;
}
Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs(
@ -171,24 +212,20 @@ pub async fn handle_requests_for_single_session_with_db_connection(
)
.await;
stream.send(Response::ModifyPrivileges(result)).await?;
stream.flush().await?;
}
Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::CreateUsers(result)).await?;
stream.flush().await?;
}
Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::DropUsers(result)).await?;
stream.flush().await?;
}
Request::PasswdUser(db_user, password) => {
let result =
set_password_for_database_user(&db_user, &password, unix_user, db_connection)
.await;
stream.send(Response::PasswdUser(result)).await?;
stream.flush().await?;
}
Request::ListUsers(db_users) => {
let response = match db_users {
@ -203,22 +240,21 @@ pub async fn handle_requests_for_single_session_with_db_connection(
}
};
stream.send(response).await?;
stream.flush().await?;
}
Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::LockUsers(result)).await?;
stream.flush().await?;
}
Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await;
stream.send(Response::UnlockUsers(result)).await?;
stream.flush().await?;
}
Request::Exit => {
break;
}
}
stream.flush().await?;
}
Ok(())