server: determine sql server variant, fix lock-user,unlock-user
This commit is contained in:
@@ -43,6 +43,7 @@ pub struct Supervisor {
|
||||
signal_handler_task: JoinHandle<()>,
|
||||
|
||||
db_connection_pool: Arc<RwLock<MySqlPool>>,
|
||||
db_is_mariadb: Arc<RwLock<bool>>,
|
||||
listener: Arc<RwLock<TokioUnixListener>>,
|
||||
listener_task: JoinHandle<anyhow::Result<()>>,
|
||||
handler_task_tracker: TaskTracker,
|
||||
@@ -83,6 +84,22 @@ impl Supervisor {
|
||||
let db_connection_pool =
|
||||
Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
|
||||
|
||||
let db_is_mariadb = {
|
||||
let connection = db_connection_pool.read().await;
|
||||
let version: String = sqlx::query_scalar("SELECT VERSION()")
|
||||
.fetch_one(&*connection)
|
||||
.await
|
||||
.context("Failed to query database version")?;
|
||||
|
||||
let result = version.to_lowercase().contains("mariadb");
|
||||
tracing::debug!(
|
||||
"Connected to {} database server",
|
||||
if result { "MariaDB" } else { "MySQL" }
|
||||
);
|
||||
|
||||
Arc::new(RwLock::new(result))
|
||||
};
|
||||
|
||||
let task_tracker = TaskTracker::new();
|
||||
|
||||
let status_notifier_task = if systemd_mode {
|
||||
@@ -112,6 +129,7 @@ impl Supervisor {
|
||||
task_tracker_clone,
|
||||
db_connection_pool.clone(),
|
||||
rx,
|
||||
db_is_mariadb.clone(),
|
||||
))
|
||||
};
|
||||
|
||||
@@ -123,6 +141,7 @@ impl Supervisor {
|
||||
shutdown_cancel_token,
|
||||
signal_handler_task,
|
||||
db_connection_pool,
|
||||
db_is_mariadb,
|
||||
listener,
|
||||
listener_task,
|
||||
handler_task_tracker: task_tracker,
|
||||
@@ -165,8 +184,26 @@ impl Supervisor {
|
||||
async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
|
||||
let config = self.config.lock().await;
|
||||
let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
|
||||
let mut db_is_mariadb_lock = self.db_is_mariadb.write().await;
|
||||
|
||||
let new_db_pool = create_db_connection_pool(&config.mysql).await?;
|
||||
let db_is_mariadb = {
|
||||
let version: String = sqlx::query_scalar("SELECT VERSION()")
|
||||
.fetch_one(&new_db_pool)
|
||||
.await
|
||||
.context("Failed to query database version")?;
|
||||
|
||||
let result = version.to_lowercase().contains("mariadb");
|
||||
tracing::debug!(
|
||||
"Connected to {} database server",
|
||||
if result { "MariaDB" } else { "MySQL" }
|
||||
);
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
*connection_pool = new_db_pool;
|
||||
*db_is_mariadb_lock = db_is_mariadb;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -429,6 +466,7 @@ async fn listener_task(
|
||||
task_tracker: TaskTracker,
|
||||
db_pool: Arc<RwLock<MySqlPool>>,
|
||||
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
|
||||
db_is_mariadb: Arc<RwLock<bool>>,
|
||||
) -> anyhow::Result<()> {
|
||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
|
||||
|
||||
@@ -464,8 +502,9 @@ async fn listener_task(
|
||||
tracing::debug!("Got new connection");
|
||||
|
||||
let db_pool_clone = db_pool.clone();
|
||||
task_tracker.spawn(async {
|
||||
match session_handler(conn, db_pool_clone).await {
|
||||
let db_is_mariadb_clone = *db_is_mariadb.read().await;
|
||||
task_tracker.spawn(async move {
|
||||
match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to run server: {}", e);
|
||||
|
||||
Reference in New Issue
Block a user