server: determine sql server variant, fix lock-user,unlock-user
Build and test / check (push) Successful in 1m54s
Build and test / build (push) Successful in 3m10s
Build and test / test (push) Successful in 3m30s
Build and test / check-license (push) Successful in 7m25s
Build and test / docs (push) Successful in 5m26s

This commit is contained in:
2025-12-14 03:30:40 +09:00
parent dc7b72efe5
commit 4c82da390f
7 changed files with 194 additions and 55 deletions
+41 -2
View File
@@ -43,6 +43,7 @@ pub struct Supervisor {
signal_handler_task: JoinHandle<()>,
db_connection_pool: Arc<RwLock<MySqlPool>>,
db_is_mariadb: Arc<RwLock<bool>>,
listener: Arc<RwLock<TokioUnixListener>>,
listener_task: JoinHandle<anyhow::Result<()>>,
handler_task_tracker: TaskTracker,
@@ -83,6 +84,22 @@ impl Supervisor {
let db_connection_pool =
Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
let db_is_mariadb = {
let connection = db_connection_pool.read().await;
let version: String = sqlx::query_scalar("SELECT VERSION()")
.fetch_one(&*connection)
.await
.context("Failed to query database version")?;
let result = version.to_lowercase().contains("mariadb");
tracing::debug!(
"Connected to {} database server",
if result { "MariaDB" } else { "MySQL" }
);
Arc::new(RwLock::new(result))
};
let task_tracker = TaskTracker::new();
let status_notifier_task = if systemd_mode {
@@ -112,6 +129,7 @@ impl Supervisor {
task_tracker_clone,
db_connection_pool.clone(),
rx,
db_is_mariadb.clone(),
))
};
@@ -123,6 +141,7 @@ impl Supervisor {
shutdown_cancel_token,
signal_handler_task,
db_connection_pool,
db_is_mariadb,
listener,
listener_task,
handler_task_tracker: task_tracker,
@@ -165,8 +184,26 @@ impl Supervisor {
async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
let config = self.config.lock().await;
let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
let mut db_is_mariadb_lock = self.db_is_mariadb.write().await;
let new_db_pool = create_db_connection_pool(&config.mysql).await?;
let db_is_mariadb = {
let version: String = sqlx::query_scalar("SELECT VERSION()")
.fetch_one(&new_db_pool)
.await
.context("Failed to query database version")?;
let result = version.to_lowercase().contains("mariadb");
tracing::debug!(
"Connected to {} database server",
if result { "MariaDB" } else { "MySQL" }
);
result
};
*connection_pool = new_db_pool;
*db_is_mariadb_lock = db_is_mariadb;
Ok(())
}
@@ -429,6 +466,7 @@ async fn listener_task(
task_tracker: TaskTracker,
db_pool: Arc<RwLock<MySqlPool>>,
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
db_is_mariadb: Arc<RwLock<bool>>,
) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
@@ -464,8 +502,9 @@ async fn listener_task(
tracing::debug!("Got new connection");
let db_pool_clone = db_pool.clone();
task_tracker.spawn(async {
match session_handler(conn, db_pool_clone).await {
let db_is_mariadb_clone = *db_is_mariadb.read().await;
task_tracker.spawn(async move {
match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await {
Ok(()) => {}
Err(e) => {
tracing::error!("Failed to run server: {}", e);