diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index 1d21f97..df90160 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -253,7 +253,7 @@ fn run_forked_server( .block_on(async { let socket = TokioUnixStream::from_std(server_socket)?; let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?; - session_handler::session_handler(socket, &unix_user, db_pool).await?; + session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?; Ok(()) }); diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index 24fd005..1121067 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -2,7 +2,7 @@ use std::collections::BTreeSet; use futures_util::{SinkExt, StreamExt}; use indoc::concatdoc; -use sqlx::{MySql, MySqlConnection, MySqlPool, pool::PoolConnection}; +use sqlx::{MySqlConnection, MySqlPool}; use tokio::net::UnixStream; use crate::{ @@ -30,15 +30,58 @@ use crate::{ // TODO: don't use database connection unless necessary. -pub async fn session_handler( +pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow::Result<()> { + let uid = match socket.peer_cred() { + Ok(cred) => cred.uid(), + Err(e) => { + log::error!("Failed to get peer credentials from socket: {}", e); + let mut message_stream = create_server_to_client_message_stream(socket); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get peer credentials from socket\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + anyhow::bail!("Failed to get peer credentials from socket"); + } + }; + + log::debug!("Validated peer UID: {}", uid); + + let unix_user = match UnixUser::from_uid(uid) { + Ok(user) => user, + Err(e) => { + log::error!("Failed to get username from uid: {}", e); + let mut message_stream = create_server_to_client_message_stream(socket); + message_stream + .send(Response::Error( + (concatdoc! { + "Server failed to get user data from the system\n", + "Please check the server logs or contact the system administrators" + }) + .to_string(), + )) + .await + .ok(); + anyhow::bail!("Failed to get username from uid: {}", e); + } + }; + + session_handler_with_unix_user(socket, &unix_user, db_pool).await +} + +pub async fn session_handler_with_unix_user( socket: UnixStream, unix_user: &UnixUser, db_pool: MySqlPool, ) -> anyhow::Result<()> { let mut message_stream = create_server_to_client_message_stream(socket); - log::debug!("Opening connection to database"); - + log::debug!("Requesting database connection from pool"); let mut db_connection = match db_pool.acquire().await { Ok(connection) => connection, Err(err) => { @@ -55,13 +98,12 @@ pub async fn session_handler( return Err(err.into()); } }; - - log::debug!("Successfully connected to database"); + log::debug!("Successfully acquired database connection from pool"); let result = session_handler_with_db_connection(message_stream, unix_user, &mut db_connection).await; - close_or_ignore_db_connection(db_connection).await; + log::debug!("Releasing database connection back to pool"); result } @@ -192,11 +234,3 @@ async fn session_handler_with_db_connection( Ok(()) } - -async fn close_or_ignore_db_connection(db_connection: PoolConnection) { - if let Err(e) = db_connection.close().await { - log::error!("Failed to close database connection: {}", e); - log::error!("{}", e); - log::error!("Ignoring..."); - } -} diff --git a/src/server/supervisor.rs b/src/server/supervisor.rs index 76b17c5..16555dd 100644 --- a/src/server/supervisor.rs +++ b/src/server/supervisor.rs @@ -7,22 +7,14 @@ use std::{ }; use anyhow::{Context, anyhow}; -use futures_util::SinkExt; -use indoc::concatdoc; use sqlx::MySqlPool; use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval}; use tokio_util::task::TaskTracker; // use tokio_util::sync::CancellationToken; -use crate::{ - core::{ - common::UnixUser, - protocol::{Response, create_server_to_client_message_stream}, - }, - server::{ - config::{MysqlConfig, ServerConfig}, - session_handler::session_handler, - }, +use crate::server::{ + config::{MysqlConfig, ServerConfig}, + session_handler::session_handler, }; // TODO: implement graceful shutdown and graceful restarts @@ -262,56 +254,17 @@ async fn spawn_listener_task( while let Ok((conn, _addr)) = listener.accept().await { log::debug!("Got new connection"); - let uid = match conn.peer_cred() { - Ok(cred) => cred.uid(), - Err(e) => { - log::error!("Failed to get peer credentials from socket: {}", e); - let mut message_stream = create_server_to_client_message_stream(conn); - message_stream - .send(Response::Error( - (concatdoc! { - "Server failed to get peer credentials from socket\n", - "Please check the server logs or contact the system administrators" - }) - .to_string(), - )) - .await - .ok(); - continue; - } - }; - - log::debug!("Validated peer UID: {}", uid); - + let db_pool_clone = db_pool.clone(); let _connection_counter_guard = Arc::clone(&connection_counter); - - let unix_user = match UnixUser::from_uid(uid) { - Ok(user) => user, - Err(e) => { - log::error!("Failed to get username from uid: {}", e); - let mut message_stream = create_server_to_client_message_stream(conn); - message_stream - .send(Response::Error( - (concatdoc! { - "Server failed to get user data from the system\n", - "Please check the server logs or contact the system administrators" - }) - .to_string(), - )) - .await - .ok(); - continue; + tokio::spawn(async { + let _guard = _connection_counter_guard; + match session_handler(conn, db_pool_clone).await { + Ok(()) => {} + Err(e) => { + log::error!("Failed to run server: {}", e); + } } - }; - - log::info!("Accepted connection from UNIX user: {}", unix_user.username); - - match session_handler(conn, &unix_user, db_pool.clone()).await { - Ok(()) => {} - Err(e) => { - log::error!("Failed to run server: {}", e); - } - } + }); } Ok(())