server: determine sql server variant, fix lock-user,unlock-user
All checks were successful
Build and test / check (push) Successful in 1m54s
Build and test / build (push) Successful in 3m10s
Build and test / test (push) Successful in 3m30s
Build and test / check-license (push) Successful in 7m25s
Build and test / docs (push) Successful in 5m26s

This commit is contained in:
2025-12-14 03:30:40 +09:00
parent dc7b72efe5
commit 4c82da390f
7 changed files with 194 additions and 55 deletions

View File

@@ -277,8 +277,23 @@ fn run_forked_server(
.block_on(async { .block_on(async {
let socket = TokioUnixStream::from_std(server_socket)?; let socket = TokioUnixStream::from_std(server_socket)?;
let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?; let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?;
let db_is_mariadb = {
let mut conn = db_pool.acquire().await?;
let version_row: String = sqlx::query_scalar("SELECT VERSION()")
.fetch_one(&mut *conn)
.await
.context("Failed to query MySQL version")?;
version_row.to_lowercase().contains("mariadb")
};
let db_pool = Arc::new(RwLock::new(db_pool)); let db_pool = Arc::new(RwLock::new(db_pool));
session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?; session_handler::session_handler_with_unix_user(
socket,
&unix_user,
db_pool,
db_is_mariadb,
)
.await?;
Ok(()) Ok(())
}); });

View File

@@ -102,9 +102,7 @@ impl ModifyDatabasePrivilegesError {
ModifyDatabasePrivilegesError::DatabaseDoesNotExist => { ModifyDatabasePrivilegesError::DatabaseDoesNotExist => {
"database-does-not-exist".to_string() "database-does-not-exist".to_string()
} }
ModifyDatabasePrivilegesError::UserDoesNotExist => { ModifyDatabasePrivilegesError::UserDoesNotExist => "user-does-not-exist".to_string(),
"user-does-not-exist".to_string()
}
ModifyDatabasePrivilegesError::DiffDoesNotApply(err) => { ModifyDatabasePrivilegesError::DiffDoesNotApply(err) => {
format!("diff-does-not-apply/{}", err.error_type()) format!("diff-does-not-apply/{}", err.error_type())
} }

View File

@@ -38,6 +38,7 @@ use crate::{
pub async fn session_handler( pub async fn session_handler(
socket: UnixStream, socket: UnixStream,
db_pool: Arc<RwLock<MySqlPool>>, db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let uid = match socket.peer_cred() { let uid = match socket.peer_cred() {
Ok(cred) => cred.uid(), Ok(cred) => cred.uid(),
@@ -84,7 +85,8 @@ pub async fn session_handler(
(async move { (async move {
tracing::info!("Accepted connection from user: {}", unix_user); tracing::info!("Accepted connection from user: {}", unix_user);
let result = session_handler_with_unix_user(socket, &unix_user, db_pool).await; let result =
session_handler_with_unix_user(socket, &unix_user, db_pool, db_is_mariadb).await;
tracing::info!( tracing::info!(
"Finished handling requests for connection from user: {}", "Finished handling requests for connection from user: {}",
@@ -101,6 +103,7 @@ pub async fn session_handler_with_unix_user(
socket: UnixStream, socket: UnixStream,
unix_user: &UnixUser, unix_user: &UnixUser,
db_pool: Arc<RwLock<MySqlPool>>, db_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut message_stream = create_server_to_client_message_stream(socket); let mut message_stream = create_server_to_client_message_stream(socket);
@@ -123,8 +126,13 @@ pub async fn session_handler_with_unix_user(
}; };
tracing::debug!("Successfully acquired database connection from pool"); tracing::debug!("Successfully acquired database connection from pool");
let result = let result = session_handler_with_db_connection(
session_handler_with_db_connection(message_stream, unix_user, &mut db_connection).await; message_stream,
unix_user,
&mut db_connection,
db_is_mariadb,
)
.await;
tracing::debug!("Releasing database connection back to pool"); tracing::debug!("Releasing database connection back to pool");
@@ -138,6 +146,7 @@ async fn session_handler_with_db_connection(
mut stream: ServerToClientMessageStream, mut stream: ServerToClientMessageStream,
unix_user: &UnixUser, unix_user: &UnixUser,
db_connection: &mut MySqlConnection, db_connection: &mut MySqlConnection,
db_is_mariadb: bool,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
stream.send(Response::Ready).await?; stream.send(Response::Ready).await?;
loop { loop {
@@ -180,9 +189,13 @@ async fn session_handler_with_db_connection(
{ {
Response::CompleteDatabaseName(vec![]) Response::CompleteDatabaseName(vec![])
} else { } else {
let result = let result = complete_database_name(
complete_database_name(partial_database_name, unix_user, db_connection) partial_database_name,
.await; unix_user,
db_connection,
db_is_mariadb,
)
.await;
Response::CompleteDatabaseName(result) Response::CompleteDatabaseName(result)
} }
} }
@@ -194,39 +207,54 @@ async fn session_handler_with_db_connection(
{ {
Response::CompleteUserName(vec![]) Response::CompleteUserName(vec![])
} else { } else {
let result = let result = complete_user_name(
complete_user_name(partial_user_name, unix_user, db_connection).await; partial_user_name,
unix_user,
db_connection,
db_is_mariadb,
)
.await;
Response::CompleteUserName(result) Response::CompleteUserName(result)
} }
} }
Request::CreateDatabases(databases_names) => { Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await; let result =
create_databases(databases_names, unix_user, db_connection, db_is_mariadb)
.await;
Response::CreateDatabases(result) Response::CreateDatabases(result)
} }
Request::DropDatabases(databases_names) => { Request::DropDatabases(databases_names) => {
let result = drop_databases(databases_names, unix_user, db_connection).await; let result =
drop_databases(databases_names, unix_user, db_connection, db_is_mariadb).await;
Response::DropDatabases(result) Response::DropDatabases(result)
} }
Request::ListDatabases(database_names) => match database_names { Request::ListDatabases(database_names) => match database_names {
Some(database_names) => { Some(database_names) => {
let result = list_databases(database_names, unix_user, db_connection).await; let result =
list_databases(database_names, unix_user, db_connection, db_is_mariadb)
.await;
Response::ListDatabases(result) Response::ListDatabases(result)
} }
None => { None => {
let result = list_all_databases_for_user(unix_user, db_connection).await; let result =
list_all_databases_for_user(unix_user, db_connection, db_is_mariadb).await;
Response::ListAllDatabases(result) Response::ListAllDatabases(result)
} }
}, },
Request::ListPrivileges(database_names) => match database_names { Request::ListPrivileges(database_names) => match database_names {
Some(database_names) => { Some(database_names) => {
let privilege_data = let privilege_data = get_databases_privilege_data(
get_databases_privilege_data(database_names, unix_user, db_connection) database_names,
.await; unix_user,
db_connection,
db_is_mariadb,
)
.await;
Response::ListPrivileges(privilege_data) Response::ListPrivileges(privilege_data)
} }
None => { None => {
let privilege_data = let privilege_data =
get_all_database_privileges(unix_user, db_connection).await; get_all_database_privileges(unix_user, db_connection, db_is_mariadb).await;
Response::ListAllPrivileges(privilege_data) Response::ListAllPrivileges(privilege_data)
} }
}, },
@@ -235,41 +263,57 @@ async fn session_handler_with_db_connection(
BTreeSet::from_iter(database_privilege_diffs), BTreeSet::from_iter(database_privilege_diffs),
unix_user, unix_user,
db_connection, db_connection,
db_is_mariadb,
) )
.await; .await;
Response::ModifyPrivileges(result) Response::ModifyPrivileges(result)
} }
Request::CreateUsers(db_users) => { Request::CreateUsers(db_users) => {
let result = create_database_users(db_users, unix_user, db_connection).await; let result =
create_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
Response::CreateUsers(result) Response::CreateUsers(result)
} }
Request::DropUsers(db_users) => { Request::DropUsers(db_users) => {
let result = drop_database_users(db_users, unix_user, db_connection).await; let result =
drop_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
Response::DropUsers(result) Response::DropUsers(result)
} }
Request::PasswdUser((db_user, password)) => { Request::PasswdUser((db_user, password)) => {
let result = let result = set_password_for_database_user(
set_password_for_database_user(&db_user, &password, unix_user, db_connection) &db_user,
.await; &password,
unix_user,
db_connection,
db_is_mariadb,
)
.await;
Response::SetUserPassword(result) Response::SetUserPassword(result)
} }
Request::ListUsers(db_users) => match db_users { Request::ListUsers(db_users) => match db_users {
Some(db_users) => { Some(db_users) => {
let result = list_database_users(db_users, unix_user, db_connection).await; let result =
list_database_users(db_users, unix_user, db_connection, db_is_mariadb)
.await;
Response::ListUsers(result) Response::ListUsers(result)
} }
None => { None => {
let result = let result = list_all_database_users_for_unix_user(
list_all_database_users_for_unix_user(unix_user, db_connection).await; unix_user,
db_connection,
db_is_mariadb,
)
.await;
Response::ListAllUsers(result) Response::ListAllUsers(result)
} }
}, },
Request::LockUsers(db_users) => { Request::LockUsers(db_users) => {
let result = lock_database_users(db_users, unix_user, db_connection).await; let result =
lock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
Response::LockUsers(result) Response::LockUsers(result)
} }
Request::UnlockUsers(db_users) => { Request::UnlockUsers(db_users) => {
let result = unlock_database_users(db_users, unix_user, db_connection).await; let result =
unlock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await;
Response::UnlockUsers(result) Response::UnlockUsers(result)
} }
Request::Exit => { Request::Exit => {

View File

@@ -48,6 +48,7 @@ pub async fn complete_database_name(
database_prefix: String, database_prefix: String,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> CompleteDatabaseNameResponse { ) -> CompleteDatabaseNameResponse {
let result = sqlx::query( let result = sqlx::query(
r#" r#"
@@ -87,6 +88,7 @@ pub async fn create_databases(
database_names: Vec<MySQLDatabase>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> CreateDatabasesResponse { ) -> CreateDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -146,6 +148,7 @@ pub async fn drop_databases(
database_names: Vec<MySQLDatabase>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> DropDatabasesResponse { ) -> DropDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -218,6 +221,7 @@ pub async fn list_databases(
database_names: Vec<MySQLDatabase>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> ListDatabasesResponse { ) -> ListDatabasesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -268,6 +272,7 @@ pub async fn list_databases(
pub async fn list_all_databases_for_user( pub async fn list_all_databases_for_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> ListAllDatabasesResponse { ) -> ListAllDatabasesResponse {
let result = sqlx::query_as::<_, DatabaseRow>( let result = sqlx::query_as::<_, DatabaseRow>(
r#" r#"

View File

@@ -140,6 +140,7 @@ pub async fn get_databases_privilege_data(
database_names: Vec<MySQLDatabase>, database_names: Vec<MySQLDatabase>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> ListPrivilegesResponse { ) -> ListPrivilegesResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -187,6 +188,7 @@ pub async fn get_databases_privilege_data(
pub async fn get_all_database_privileges( pub async fn get_all_database_privileges(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> ListAllPrivilegesResponse { ) -> ListAllPrivilegesResponse {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
indoc! {r#" indoc! {r#"
@@ -394,6 +396,7 @@ pub async fn apply_privilege_diffs(
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>, database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> ModifyPrivilegesResponse { ) -> ModifyPrivilegesResponse {
let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new(); let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new();
@@ -451,10 +454,7 @@ pub async fn apply_privilege_diffs(
.await .await
.unwrap() .unwrap()
{ {
results.insert( results.insert(key, Err(ModifyDatabasePrivilegesError::UserDoesNotExist));
key,
Err(ModifyDatabasePrivilegesError::UserDoesNotExist),
);
continue; continue;
} }

View File

@@ -55,6 +55,7 @@ pub async fn complete_user_name(
user_prefix: String, user_prefix: String,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> Vec<MySQLUser> { ) -> Vec<MySQLUser> {
let result = sqlx::query( let result = sqlx::query(
r#" r#"
@@ -93,6 +94,7 @@ pub async fn create_database_users(
db_users: Vec<MySQLUser>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> CreateUsersResponse { ) -> CreateUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -139,6 +141,7 @@ pub async fn drop_database_users(
db_users: Vec<MySQLUser>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> DropUsersResponse { ) -> DropUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -186,6 +189,7 @@ pub async fn set_password_for_database_user(
password: &str, password: &str,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
_db_is_mariadb: bool,
) -> SetUserPasswordResponse { ) -> SetUserPasswordResponse {
if let Err(err) = validate_name(db_user) { if let Err(err) = validate_name(db_user) {
return Err(SetPasswordError::SanitizationError(err)); return Err(SetPasswordError::SanitizationError(err));
@@ -224,26 +228,39 @@ pub async fn set_password_for_database_user(
result result
} }
const DATABASE_USER_LOCK_STATUS_QUERY_MARIADB: &str = r#"
SELECT COALESCE(
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
'false'
) != 'false'
FROM `mysql`.`global_priv`
WHERE `User` = ?
AND `Host` = '%'
"#;
const DATABASE_USER_LOCK_STATUS_QUERY_MYSQL: &str = r#"
SELECT `mysql`.`user`.`account_locked` = 'Y'
FROM `mysql`.`user`
WHERE `User` = ?
AND `Host` = '%'
"#;
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
async fn database_user_is_locked_unsafe( async fn database_user_is_locked_unsafe(
db_user: &str, db_user: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool,
) -> Result<bool, sqlx::Error> { ) -> Result<bool, sqlx::Error> {
let result = sqlx::query( let result = sqlx::query(if db_is_mariadb {
r#" DATABASE_USER_LOCK_STATUS_QUERY_MARIADB
SELECT COALESCE( } else {
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), DATABASE_USER_LOCK_STATUS_QUERY_MYSQL
'false' })
) != 'false'
FROM `mysql`.`global_priv`
WHERE `User` = ?
AND `Host` = '%'
"#,
)
.bind(db_user) .bind(db_user)
.fetch_one(connection) .fetch_one(connection)
.await .await
.map(|row| row.get::<bool, _>(0)); .map(|row| row.try_get(0))
.and_then(|res| res);
if let Err(err) = &result { if let Err(err) = &result {
tracing::error!( tracing::error!(
@@ -260,6 +277,7 @@ pub async fn lock_database_users(
db_users: Vec<MySQLUser>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool,
) -> LockUsersResponse { ) -> LockUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -286,7 +304,7 @@ pub async fn lock_database_users(
} }
} }
match database_user_is_locked_unsafe(&db_user, &mut *connection).await { match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await {
Ok(false) => {} Ok(false) => {}
Ok(true) => { Ok(true) => {
results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked)); results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked));
@@ -320,6 +338,7 @@ pub async fn unlock_database_users(
db_users: Vec<MySQLUser>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool,
) -> UnlockUsersResponse { ) -> UnlockUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -346,7 +365,7 @@ pub async fn unlock_database_users(
_ => {} _ => {}
} }
match database_user_is_locked_unsafe(&db_user, &mut *connection).await { match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await {
Ok(false) => { Ok(false) => {
results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked)); results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked));
continue; continue;
@@ -394,13 +413,13 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
user: try_get_with_binary_fallback(row, "User")?.into(), user: try_get_with_binary_fallback(row, "User")?.into(),
host: try_get_with_binary_fallback(row, "Host")?, host: try_get_with_binary_fallback(row, "Host")?,
has_password: row.try_get("has_password")?, has_password: row.try_get("has_password")?,
is_locked: row.try_get("is_locked")?, is_locked: row.try_get("account_locked")?,
databases: Vec::new(), databases: Vec::new(),
}) })
} }
} }
const DB_USER_SELECT_STATEMENT: &str = r#" const DB_USER_SELECT_STATEMENT_MARIADB: &str = r#"
SELECT SELECT
`user`.`User`, `user`.`User`,
`user`.`Host`, `user`.`Host`,
@@ -408,17 +427,27 @@ SELECT
COALESCE( COALESCE(
JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"), JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"),
'false' 'false'
) != 'false' AS `is_locked` ) != 'false' AS `account_locked`
FROM `user` FROM `user`
JOIN `global_priv` ON JOIN `global_priv` ON
`user`.`User` = `global_priv`.`User` `user`.`User` = `global_priv`.`User`
AND `user`.`Host` = `global_priv`.`Host` AND `user`.`Host` = `global_priv`.`Host`
"#; "#;
const DB_USER_SELECT_STATEMENT_MYSQL: &str = r#"
SELECT
`user`.`User`,
`user`.`Host`,
`user`.`authentication_string` != '' AS `has_password`,
`user`.`account_locked` = 'Y' AS `account_locked`
FROM `user`
"#;
pub async fn list_database_users( pub async fn list_database_users(
db_users: Vec<MySQLUser>, db_users: Vec<MySQLUser>,
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool,
) -> ListUsersResponse { ) -> ListUsersResponse {
let mut results = BTreeMap::new(); let mut results = BTreeMap::new();
@@ -434,7 +463,11 @@ pub async fn list_database_users(
} }
let mut result = sqlx::query_as::<_, DatabaseUser>( let mut result = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), &(if db_is_mariadb {
DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `mysql`.`user`.`User` = ?"),
) )
.bind(db_user.as_str()) .bind(db_user.as_str())
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
@@ -461,9 +494,14 @@ pub async fn list_database_users(
pub async fn list_all_database_users_for_unix_user( pub async fn list_all_database_users_for_unix_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
db_is_mariadb: bool,
) -> ListAllUsersResponse { ) -> ListAllUsersResponse {
let mut result = sqlx::query_as::<_, DatabaseUser>( let mut result = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `user`.`User` REGEXP ?"), &(if db_is_mariadb {
DB_USER_SELECT_STATEMENT_MARIADB.to_string()
} else {
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
} + "WHERE `user`.`User` REGEXP ?"),
) )
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user))
.fetch_all(&mut *connection) .fetch_all(&mut *connection)

View File

@@ -43,6 +43,7 @@ pub struct Supervisor {
signal_handler_task: JoinHandle<()>, signal_handler_task: JoinHandle<()>,
db_connection_pool: Arc<RwLock<MySqlPool>>, db_connection_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: Arc<RwLock<bool>>,
listener: Arc<RwLock<TokioUnixListener>>, listener: Arc<RwLock<TokioUnixListener>>,
listener_task: JoinHandle<anyhow::Result<()>>, listener_task: JoinHandle<anyhow::Result<()>>,
handler_task_tracker: TaskTracker, handler_task_tracker: TaskTracker,
@@ -83,6 +84,22 @@ impl Supervisor {
let db_connection_pool = let db_connection_pool =
Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?)); Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
let db_is_mariadb = {
let connection = db_connection_pool.read().await;
let version: String = sqlx::query_scalar("SELECT VERSION()")
.fetch_one(&*connection)
.await
.context("Failed to query database version")?;
let result = version.to_lowercase().contains("mariadb");
tracing::debug!(
"Connected to {} database server",
if result { "MariaDB" } else { "MySQL" }
);
Arc::new(RwLock::new(result))
};
let task_tracker = TaskTracker::new(); let task_tracker = TaskTracker::new();
let status_notifier_task = if systemd_mode { let status_notifier_task = if systemd_mode {
@@ -112,6 +129,7 @@ impl Supervisor {
task_tracker_clone, task_tracker_clone,
db_connection_pool.clone(), db_connection_pool.clone(),
rx, rx,
db_is_mariadb.clone(),
)) ))
}; };
@@ -123,6 +141,7 @@ impl Supervisor {
shutdown_cancel_token, shutdown_cancel_token,
signal_handler_task, signal_handler_task,
db_connection_pool, db_connection_pool,
db_is_mariadb,
listener, listener,
listener_task, listener_task,
handler_task_tracker: task_tracker, handler_task_tracker: task_tracker,
@@ -165,8 +184,26 @@ impl Supervisor {
async fn restart_db_connection_pool(&self) -> anyhow::Result<()> { async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
let config = self.config.lock().await; let config = self.config.lock().await;
let mut connection_pool = self.db_connection_pool.clone().write_owned().await; let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
let mut db_is_mariadb_lock = self.db_is_mariadb.write().await;
let new_db_pool = create_db_connection_pool(&config.mysql).await?; let new_db_pool = create_db_connection_pool(&config.mysql).await?;
let db_is_mariadb = {
let version: String = sqlx::query_scalar("SELECT VERSION()")
.fetch_one(&new_db_pool)
.await
.context("Failed to query database version")?;
let result = version.to_lowercase().contains("mariadb");
tracing::debug!(
"Connected to {} database server",
if result { "MariaDB" } else { "MySQL" }
);
result
};
*connection_pool = new_db_pool; *connection_pool = new_db_pool;
*db_is_mariadb_lock = db_is_mariadb;
Ok(()) Ok(())
} }
@@ -429,6 +466,7 @@ async fn listener_task(
task_tracker: TaskTracker, task_tracker: TaskTracker,
db_pool: Arc<RwLock<MySqlPool>>, db_pool: Arc<RwLock<MySqlPool>>,
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>, mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
db_is_mariadb: Arc<RwLock<bool>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
@@ -464,8 +502,9 @@ async fn listener_task(
tracing::debug!("Got new connection"); tracing::debug!("Got new connection");
let db_pool_clone = db_pool.clone(); let db_pool_clone = db_pool.clone();
task_tracker.spawn(async { let db_is_mariadb_clone = *db_is_mariadb.read().await;
match session_handler(conn, db_pool_clone).await { task_tracker.spawn(async move {
match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await {
Ok(()) => {} Ok(()) => {}
Err(e) => { Err(e) => {
tracing::error!("Failed to run server: {}", e); tracing::error!("Failed to run server: {}", e);