diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index 4286987..7f5ed7f 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -22,7 +22,7 @@ use crate::{ authorization::read_and_parse_group_denylist, config::{MysqlConfig, ServerConfig}, landlock::landlock_restrict_server, - session_handler, + session_handler::{self, SessionId}, }, }; @@ -308,9 +308,11 @@ fn run_forked_server( version_row.to_lowercase().contains("mariadb") }; + let session_id = SessionId::new(0); let db_pool = Arc::new(RwLock::new(db_pool)); session_handler::session_handler_with_unix_user( socket, + session_id, unix_user, db_pool, db_is_mariadb, diff --git a/src/core/common.rs b/src/core/common.rs index f090aab..3e1a4bd 100644 --- a/src/core/common.rs +++ b/src/core/common.rs @@ -24,6 +24,7 @@ pub const KIND_REGARDS: &str = concat!( "If you experience any bugs or turbulence, please give us a heads up :)", ); +/// TODO: store and display UID #[derive(Debug, Clone)] pub struct UnixUser { pub username: String, diff --git a/src/core/protocol/commands.rs b/src/core/protocol/commands.rs index e5bb77e..35ff9a1 100644 --- a/src/core/protocol/commands.rs +++ b/src/core/protocol/commands.rs @@ -17,8 +17,6 @@ mod modify_privileges; mod passwd_user; mod unlock_users; -use std::collections::BTreeSet; - pub use check_authorization::*; pub use complete_database_name::*; pub use complete_user_name::*; @@ -38,6 +36,9 @@ pub use modify_privileges::*; pub use passwd_user::*; pub use unlock_users::*; +use std::collections::BTreeSet; +use std::fmt; + use serde::{Deserialize, Serialize}; use tokio::net::UnixStream; use tokio_serde::{Framed as SerdeFramed, formats::Bincode}; @@ -109,6 +110,7 @@ pub enum Request { } impl Request { + /// Get the command name associated with this request. pub fn command_name(&self) -> &str { match self { Request::CheckAuthorization(_) => "check-authorization", @@ -130,6 +132,43 @@ impl Request { } } + /// Generate a short summary string representing this request for logging purposes. + pub fn log_summary(&self) -> String { + match self { + Request::CheckAuthorization(req) => format!("{}({})", self.command_name(), req.len()), + + Request::CreateDatabases(req) => format!("{}({})", self.command_name(), req.len()), + Request::DropDatabases(req) => format!("{}({})", self.command_name(), req.len()), + Request::ListDatabases(req) => format!( + "{}{}", + self.command_name(), + req.as_ref() + .map_or("".to_string(), |r| format!("({})", r.len())) + ), + Request::ListPrivileges(req) => format!( + "{}{}", + self.command_name(), + req.as_ref() + .map_or("".to_string(), |r| format!("({})", r.len())) + ), + Request::ModifyPrivileges(req) => format!("{}({})", self.command_name(), req.len()), + + Request::CreateUsers(req) => format!("{}({})", self.command_name(), req.len()), + Request::DropUsers(req) => format!("{}({})", self.command_name(), req.len()), + Request::ListUsers(req) => format!( + "{}{}", + self.command_name(), + req.as_ref() + .map_or("".to_string(), |r| format!("({})", r.len())) + ), + Request::LockUsers(req) => format!("{}({})", self.command_name(), req.len()), + Request::UnlockUsers(req) => format!("{}({})", self.command_name(), req.len()), + + _ => self.command_name().to_string(), + } + } + + /// Get the set of users affected by this request. pub fn affected_users(&self) -> BTreeSet { match self { Request::CheckAuthorization(_) => Default::default(), @@ -158,6 +197,7 @@ impl Request { } } + /// Get the set of databases affected by this request. pub fn affected_databases(&self) -> BTreeSet { match self { Request::CheckAuthorization(_) => Default::default(), @@ -219,3 +259,95 @@ pub enum Response { Ready, Error(String), } + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum ResponseOkStatus { + Success, + PartialSuccess(usize, usize), // succeeded, total + Error, +} + +impl ResponseOkStatus { + pub fn from_counts(total: usize, succeeded: usize) -> Self { + if succeeded == total { + ResponseOkStatus::Success + } else if succeeded == 0 { + ResponseOkStatus::Error + } else { + ResponseOkStatus::PartialSuccess(succeeded, total) + } + } + + pub fn from_bool(is_ok: bool) -> Self { + if is_ok { + ResponseOkStatus::Success + } else { + ResponseOkStatus::Error + } + } +} + +impl fmt::Display for ResponseOkStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ResponseOkStatus::Success => write!(f, "OK"), + ResponseOkStatus::PartialSuccess(succeeded, total) => { + write!(f, "PARTIAL_OK({}/{})", succeeded, total) + } + ResponseOkStatus::Error => write!(f, "ERR"), + } + } +} + +impl Response { + pub fn ok_status(&self) -> ResponseOkStatus { + match self { + Response::CheckAuthorization(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + + Response::ListValidNamePrefixes(_) => ResponseOkStatus::Success, + Response::CompleteDatabaseName(_) => ResponseOkStatus::Success, + Response::CompleteUserName(_) => ResponseOkStatus::Success, + + Response::CreateDatabases(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::DropDatabases(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::ListDatabases(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::ListAllDatabases(res) => ResponseOkStatus::from_bool(res.is_ok()), + Response::ListPrivileges(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()), + Response::ModifyPrivileges(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + + Response::CreateUsers(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::DropUsers(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::SetUserPassword(res) => ResponseOkStatus::from_bool(res.is_ok()), + Response::ListUsers(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::ListAllUsers(res) => ResponseOkStatus::from_bool(res.is_ok()), + Response::LockUsers(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + Response::UnlockUsers(res) => { + ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count()) + } + + Response::Ready => ResponseOkStatus::Success, + Response::Error(_) => ResponseOkStatus::Error, + } + } +} diff --git a/src/server/authorization.rs b/src/server/authorization.rs index 1433099..c791cc7 100644 --- a/src/server/authorization.rs +++ b/src/server/authorization.rs @@ -13,23 +13,19 @@ use crate::core::{ }; pub async fn check_authorization( - dbs_or_users: Vec, + dbs_or_users: &[DbOrUser], unix_user: &UnixUser, group_denylist: &GroupDenylist, ) -> std::collections::BTreeMap> { - let mut results = std::collections::BTreeMap::new(); - - for db_or_user in dbs_or_users { - if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist) - .map_err(CheckAuthorizationError) - { - results.insert(db_or_user.clone(), Err(err)); - continue; - } - results.insert(db_or_user.clone(), Ok(())); - } - - results + dbs_or_users + .iter() + .cloned() + .map(|db_or_user| { + let result = validate_db_or_user_request(&db_or_user, unix_user, group_denylist) + .map_err(CheckAuthorizationError); + (db_or_user, result) + }) + .collect() } /// Reads and parses a group denylist file, returning a set of GUIDs diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index a31f739..d2a174e 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeSet, sync::Arc}; +use std::sync::Arc; use futures_util::{SinkExt, StreamExt}; use indoc::concatdoc; @@ -35,10 +35,24 @@ use crate::{ }, }; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct SessionId(u64); + +impl SessionId { + pub fn new(id: u64) -> Self { + SessionId(id) + } + + pub fn inner(&self) -> u64 { + self.0 + } +} + // TODO: don't use database connection unless necessary. pub async fn session_handler( socket: UnixStream, + session_id: SessionId, db_pool: Arc>, db_is_mariadb: bool, group_denylist: &GroupDenylist, @@ -83,13 +97,18 @@ pub async fn session_handler( } }; - let span = tracing::info_span!("user_session", user = %unix_user); + let span = tracing::info_span!( + "user_session", + session_id = session_id.inner(), + user = %unix_user, + ); (async move { - tracing::info!("Accepted connection from user: {}", unix_user); + tracing::debug!("Accepted connection from user: {}", unix_user); let result = session_handler_with_unix_user( socket, + session_id, &unix_user, db_pool, db_is_mariadb, @@ -97,7 +116,7 @@ pub async fn session_handler( ) .await; - tracing::info!( + tracing::debug!( "Finished handling requests for connection from user: {}", unix_user, ); @@ -110,6 +129,7 @@ pub async fn session_handler( pub async fn session_handler_with_unix_user( socket: UnixStream, + session_id: SessionId, unix_user: &UnixUser, db_pool: Arc>, db_is_mariadb: bool, @@ -138,6 +158,7 @@ pub async fn session_handler_with_unix_user( let result = session_handler_with_db_connection( message_stream, + session_id, unix_user, &mut db_connection, db_is_mariadb, @@ -155,6 +176,7 @@ pub async fn session_handler_with_unix_user( async fn session_handler_with_db_connection( mut stream: ServerToClientMessageStream, + session_id: SessionId, unix_user: &UnixUser, db_connection: &mut MySqlConnection, db_is_mariadb: bool, @@ -178,6 +200,7 @@ async fn session_handler_with_db_connection( if !handle_request( request, + session_id, unix_user, db_connection, db_is_mariadb, @@ -199,6 +222,7 @@ async fn session_handler_with_db_connection( /// If the function returns `true`, the session should continue. async fn handle_request( request: Request, + session_id: SessionId, unix_user: &UnixUser, db_connection: &mut MySqlConnection, db_is_mariadb: bool, @@ -207,11 +231,11 @@ async fn handle_request( ) -> anyhow::Result { match &request { Request::Exit => tracing::debug!("Received request: {:#?}", request), - Request::PasswdUser((db_user, _)) => tracing::info!( + Request::PasswdUser((db_user, _)) => tracing::debug!( "Received request: {:#?}", Request::PasswdUser((db_user.to_owned(), "".to_string())) ), - request => tracing::info!("Received request: {:#?}", request), + request => tracing::debug!("Request:\n{}", serde_json::to_string_pretty(request)?), } let affected_dbs = request.affected_databases(); @@ -231,7 +255,7 @@ async fn handle_request( } let response = match request { - Request::CheckAuthorization(dbs_or_users) => { + Request::CheckAuthorization(ref dbs_or_users) => { let result = check_authorization(dbs_or_users, unix_user, group_denylist).await; Response::CheckAuthorization(result) } @@ -245,7 +269,7 @@ async fn handle_request( Response::ListValidNamePrefixes(result) } - Request::CompleteDatabaseName(partial_database_name) => { + Request::CompleteDatabaseName(ref partial_database_name) => { // TODO: more correct validation here if partial_database_name .chars() @@ -264,7 +288,7 @@ async fn handle_request( Response::CompleteDatabaseName(vec![]) } } - Request::CompleteUserName(partial_user_name) => { + Request::CompleteUserName(ref partial_user_name) => { // TODO: more correct validation here if partial_user_name .chars() @@ -283,7 +307,7 @@ async fn handle_request( Response::CompleteUserName(vec![]) } } - Request::CreateDatabases(databases_names) => { + Request::CreateDatabases(ref databases_names) => { let result = create_databases( databases_names, unix_user, @@ -294,7 +318,7 @@ async fn handle_request( .await; Response::CreateDatabases(result) } - Request::DropDatabases(databases_names) => { + Request::DropDatabases(ref databases_names) => { let result = drop_databases( databases_names, unix_user, @@ -305,7 +329,7 @@ async fn handle_request( .await; Response::DropDatabases(result) } - Request::ListDatabases(database_names) => { + Request::ListDatabases(ref database_names) => { if let Some(database_names) = database_names { let result = list_databases( database_names, @@ -327,7 +351,7 @@ async fn handle_request( Response::ListAllDatabases(result) } } - Request::ListPrivileges(database_names) => { + Request::ListPrivileges(ref database_names) => { if let Some(database_names) = database_names { let privilege_data = get_databases_privilege_data( database_names, @@ -349,9 +373,9 @@ async fn handle_request( Response::ListAllPrivileges(privilege_data) } } - Request::ModifyPrivileges(database_privilege_diffs) => { + Request::ModifyPrivileges(ref database_privilege_diffs) => { let result = apply_privilege_diffs( - BTreeSet::from_iter(database_privilege_diffs), + database_privilege_diffs, unix_user, db_connection, db_is_mariadb, @@ -360,7 +384,7 @@ async fn handle_request( .await; Response::ModifyPrivileges(result) } - Request::CreateUsers(db_users) => { + Request::CreateUsers(ref db_users) => { let result = create_database_users( db_users, unix_user, @@ -371,7 +395,7 @@ async fn handle_request( .await; Response::CreateUsers(result) } - Request::DropUsers(db_users) => { + Request::DropUsers(ref db_users) => { let result = drop_database_users( db_users, unix_user, @@ -382,10 +406,10 @@ async fn handle_request( .await; Response::DropUsers(result) } - Request::PasswdUser((db_user, password)) => { + Request::PasswdUser((ref db_user, ref password)) => { let result = set_password_for_database_user( - &db_user, - &password, + db_user, + password, unix_user, db_connection, db_is_mariadb, @@ -394,7 +418,7 @@ async fn handle_request( .await; Response::SetUserPassword(result) } - Request::ListUsers(db_users) => { + Request::ListUsers(ref db_users) => { if let Some(db_users) = db_users { let result = list_database_users( db_users, @@ -416,7 +440,7 @@ async fn handle_request( Response::ListAllUsers(result) } } - Request::LockUsers(db_users) => { + Request::LockUsers(ref db_users) => { let result = lock_database_users( db_users, unix_user, @@ -427,7 +451,7 @@ async fn handle_request( .await; Response::LockUsers(result) } - Request::UnlockUsers(db_users) => { + Request::UnlockUsers(ref db_users) => { let result = unlock_database_users( db_users, unix_user, @@ -449,7 +473,12 @@ async fn handle_request( } response => response, }; - tracing::debug!("Response: {:#?}", response_to_display); + tracing::debug!( + "Response:\n{}", + serde_json::to_string_pretty(&response_to_display)? + ); + + log_request(session_id, unix_user, &request, &response); stream.send(response).await?; stream.flush().await?; @@ -457,3 +486,18 @@ async fn handle_request( Ok(true) } + +/// Log a summary of the request and its result. +fn log_request( + session_id: SessionId, + unix_user: &UnixUser, + request: &Request, + response: &Response, +) { + tracing::info!( + "[{}|session:{}|user:{unix_user}] {}", + response.ok_status(), + session_id.inner(), + request.log_summary(), + ); +} diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index d064395..53fbbe5 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -46,7 +46,7 @@ pub(super) async fn unsafe_database_exists( } pub async fn complete_database_name( - database_prefix: String, + database_prefix: &str, unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -87,7 +87,7 @@ pub async fn complete_database_name( } pub async fn create_databases( - database_names: Vec, + database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -95,7 +95,7 @@ pub async fn create_databases( ) -> CreateDatabasesResponse { let mut results = BTreeMap::new(); - for database_name in database_names { + for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(database_name.clone()), unix_user, @@ -143,7 +143,7 @@ pub async fn create_databases( } pub async fn drop_databases( - database_names: Vec, + database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -151,7 +151,7 @@ pub async fn drop_databases( ) -> DropDatabasesResponse { let mut results = BTreeMap::new(); - for database_name in database_names { + for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(database_name.clone()), unix_user, @@ -242,7 +242,7 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow { } pub async fn list_databases( - database_names: Vec, + database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -250,7 +250,7 @@ pub async fn list_databases( ) -> ListDatabasesResponse { let mut results = BTreeMap::new(); - for database_name in database_names { + for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(database_name.clone()), unix_user, diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index 2cf54f3..afb1088 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -138,7 +138,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair( } pub async fn get_databases_privilege_data( - database_names: Vec, + database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -146,19 +146,19 @@ pub async fn get_databases_privilege_data( ) -> ListPrivilegesResponse { let mut results = BTreeMap::new(); - for database_name in &database_names { + for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( - &DbOrUser::Database(database_name.clone()), + &DbOrUser::Database(database_name.to_owned()), unix_user, group_denylist, ) .map_err(ListPrivilegesError::ValidationError) { - results.insert(database_name.to_owned(), Err(err)); + results.insert(database_name, Err(err)); continue; } - match unsafe_database_exists(database_name, connection).await { + match unsafe_database_exists(&database_name, connection).await { Ok(false) => { results.insert( database_name.to_owned(), @@ -176,7 +176,7 @@ pub async fn get_databases_privilege_data( Ok(true) => {} } - let result = unsafe_get_database_privileges(database_name, connection) + let result = unsafe_get_database_privileges(&database_name, connection) .await .map_err(|e| ListPrivilegesError::MySqlError(e.to_string())); @@ -400,7 +400,7 @@ async fn validate_diff( /// Uses the result of [`diff_privileges`] to modify privileges in the database. pub async fn apply_privilege_diffs( - database_privilege_diffs: BTreeSet, + database_privilege_diffs: &BTreeSet, unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -468,12 +468,12 @@ pub async fn apply_privilege_diffs( Ok(true) => {} } - if let Err(err) = validate_diff(&diff, connection).await { + if let Err(err) = validate_diff(diff, connection).await { results.insert(key, Err(err)); continue; } - let result = unsafe_apply_privilege_diff(&diff, connection) + let result = unsafe_apply_privilege_diff(diff, connection) .await .map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string())); diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 2c059ed..b420a2b 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -55,7 +55,7 @@ pub(super) async fn unsafe_user_exists( } pub async fn complete_user_name( - user_prefix: String, + user_prefix: &str, unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -95,7 +95,7 @@ pub async fn complete_user_name( } pub async fn create_database_users( - db_users: Vec, + db_users: &[MySQLUser], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -103,7 +103,7 @@ pub async fn create_database_users( ) -> CreateUsersResponse { let mut results = BTreeMap::new(); - for db_user in db_users { + for db_user in db_users.iter().cloned() { if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) .map_err(CreateUserError::ValidationError) @@ -141,7 +141,7 @@ pub async fn create_database_users( } pub async fn drop_database_users( - db_users: Vec, + db_users: &[MySQLUser], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, @@ -149,7 +149,7 @@ pub async fn drop_database_users( ) -> DropUsersResponse { let mut results = BTreeMap::new(); - for db_user in db_users { + for db_user in db_users.iter().cloned() { if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) .map_err(DropUserError::ValidationError) @@ -272,7 +272,7 @@ async fn database_user_is_locked_unsafe( } pub async fn lock_database_users( - db_users: Vec, + db_users: &[MySQLUser], unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, @@ -280,7 +280,7 @@ pub async fn lock_database_users( ) -> LockUsersResponse { let mut results = BTreeMap::new(); - for db_user in db_users { + for db_user in db_users.iter().cloned() { if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) .map_err(LockUserError::ValidationError) @@ -332,7 +332,7 @@ pub async fn lock_database_users( } pub async fn unlock_database_users( - db_users: Vec, + db_users: &[MySQLUser], unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, @@ -340,7 +340,7 @@ pub async fn unlock_database_users( ) -> UnlockUsersResponse { let mut results = BTreeMap::new(); - for db_user in db_users { + for db_user in db_users.iter().cloned() { if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) .map_err(UnlockUserError::ValidationError) @@ -440,7 +440,7 @@ FROM `user` "; pub async fn list_database_users( - db_users: Vec, + db_users: &[MySQLUser], unix_user: &UnixUser, connection: &mut MySqlConnection, db_is_mariadb: bool, @@ -448,7 +448,7 @@ pub async fn list_database_users( ) -> ListUsersResponse { let mut results = BTreeMap::new(); - for db_user in db_users { + for db_user in db_users.iter().cloned() { if let Err(err) = validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist) .map_err(ListUsersError::ValidationError) diff --git a/src/server/supervisor.rs b/src/server/supervisor.rs index be5d30d..e94d20a 100644 --- a/src/server/supervisor.rs +++ b/src/server/supervisor.rs @@ -2,7 +2,10 @@ use std::{ fs, os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener}, path::PathBuf, - sync::Arc, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, time::Duration, }; @@ -22,7 +25,7 @@ use crate::{ server::{ authorization::read_and_parse_group_denylist, config::{MysqlConfig, ServerConfig}, - session_handler::session_handler, + session_handler::{SessionId, session_handler}, }, }; @@ -548,6 +551,8 @@ async fn listener_task( #[cfg(target_os = "linux")] sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; + let connection_counter = AtomicU64::new(0); + loop { tokio::select! { biased; @@ -577,14 +582,20 @@ async fn listener_task( } => { match accept_result { Ok((conn, _addr)) => { - tracing::debug!("Got new connection"); + connection_counter.fetch_add(1, Ordering::Relaxed); + let conn_id = connection_counter.load(Ordering::Relaxed); + + tracing::debug!("Got new connection, assigned session ID {}", conn_id); + + let session_id = SessionId::new(conn_id); let db_pool_clone = db_pool.clone(); let db_is_mariadb_clone = *db_is_mariadb.read().await; let group_denylist_arc_clone = group_denylist.clone(); task_tracker.spawn(async move { match session_handler( conn, + session_id, db_pool_clone, db_is_mariadb_clone, &*group_denylist_arc_clone.read().await,