server: implement graceful shutdown and reloads

This commit is contained in:
2025-11-29 21:57:25 +09:00
parent 4a6e49110a
commit 1fe08b59a3
8 changed files with 313 additions and 116 deletions

10
Cargo.lock generated
View File

@@ -1690,6 +1690,15 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "signature" name = "signature"
version = "2.2.0" version = "2.2.0"
@@ -2094,6 +2103,7 @@ dependencies = [
"libc", "libc",
"mio", "mio",
"pin-project-lite", "pin-project-lite",
"signal-hook-registry",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
"windows-sys 0.61.2", "windows-sys 0.61.2",

View File

@@ -39,7 +39,7 @@ serde = "1.0.228"
serde_json = { version = "1.0.145", features = ["preserve_order"] } serde_json = { version = "1.0.145", features = ["preserve_order"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] } sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
systemd-journal-logger = "2.2.2" systemd-journal-logger = "2.2.2"
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] } tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros", "signal"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.17" tokio-stream = "0.1.17"
tokio-util = { version = "0.7.17", features = ["codec", "rt"] } tokio-util = { version = "0.7.17", features = ["codec", "rt"] }

View File

@@ -5,6 +5,7 @@ Requires=muscl.socket
[Service] [Service]
Type=notify Type=notify
ExecStart=/usr/bin/muscl server --systemd socket-activate ExecStart=/usr/bin/muscl server --systemd socket-activate
ExecReload=/usr/bin/kill -HUP $MAINPID
WatchdogSec=15 WatchdogSec=15
@@ -15,7 +16,7 @@ Group=muscl
DynamicUser=yes DynamicUser=yes
ConfigurationDirectory=muscl ConfigurationDirectory=muscl
RuntimeDirectory=muscl # RuntimeDirectory=muscl
# This is required to read unix user/group details. # This is required to read unix user/group details.
PrivateUsers=false PrivateUsers=false

View File

@@ -101,13 +101,18 @@ in
systemd.sockets."muscl".wantedBy = [ "sockets.target" ]; systemd.sockets."muscl".wantedBy = [ "sockets.target" ];
systemd.services."muscl" = { systemd.services."muscl" = {
restartTriggers = [ config.environment.etc."muscl/config.toml".source ]; reloadTriggers = [ config.environment.etc."muscl/config.toml".source ];
serviceConfig = { serviceConfig = {
ExecStart = [ ExecStart = [
"" ""
"${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate" "${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate"
]; ];
ExecReload = [
""
"${lib.getExe' pkgs.coreutils "kill"} -HUP $MAINPID"
];
IPAddressDeny = "any"; IPAddressDeny = "any";
IPAddressAllow = [ IPAddressAllow = [
"127.0.0.0/8" "127.0.0.0/8"

View File

@@ -1,11 +1,11 @@
use std::{fs, path::PathBuf, time::Duration}; use std::{fs, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{Context, anyhow}; use anyhow::{Context, anyhow};
use clap_verbosity_flag::Verbosity; use clap_verbosity_flag::Verbosity;
use nix::libc::{EXIT_SUCCESS, exit}; use nix::libc::{EXIT_SUCCESS, exit};
use sqlx::mysql::MySqlPoolOptions; use sqlx::mysql::MySqlPoolOptions;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream; use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
use crate::{ use crate::{
core::common::{ core::common::{
@@ -254,6 +254,7 @@ fn run_forked_server(
.block_on(async { .block_on(async {
let socket = TokioUnixStream::from_std(server_socket)?; let socket = TokioUnixStream::from_std(server_socket)?;
let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?; let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?;
let db_pool = Arc::new(RwLock::new(db_pool));
session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?; session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?;
Ok(()) Ok(())
}); });

View File

@@ -17,7 +17,7 @@ fn default_mysql_timeout() -> u64 {
DEFAULT_TIMEOUT DEFAULT_TIMEOUT
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename = "mysql")] #[serde(rename = "mysql")]
pub struct MysqlConfig { pub struct MysqlConfig {
pub socket_path: Option<PathBuf>, pub socket_path: Option<PathBuf>,
@@ -70,7 +70,7 @@ impl MysqlConfig {
} }
} }
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct ServerConfig { pub struct ServerConfig {
pub socket_path: Option<PathBuf>, pub socket_path: Option<PathBuf>,
pub mysql: MysqlConfig, pub mysql: MysqlConfig,

View File

@@ -1,9 +1,9 @@
use std::collections::BTreeSet; use std::{collections::BTreeSet, sync::Arc};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use indoc::concatdoc; use indoc::concatdoc;
use sqlx::{MySqlConnection, MySqlPool}; use sqlx::{MySqlConnection, MySqlPool};
use tokio::net::UnixStream; use tokio::{net::UnixStream, sync::RwLock};
use crate::{ use crate::{
core::{ core::{
@@ -33,7 +33,10 @@ use crate::{
// TODO: don't use database connection unless necessary. // TODO: don't use database connection unless necessary.
pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow::Result<()> { pub async fn session_handler(
socket: UnixStream,
db_pool: Arc<RwLock<MySqlPool>>,
) -> anyhow::Result<()> {
let uid = match socket.peer_cred() { let uid = match socket.peer_cred() {
Ok(cred) => cred.uid(), Ok(cred) => cred.uid(),
Err(e) => { Err(e) => {
@@ -80,12 +83,12 @@ pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow::
pub async fn session_handler_with_unix_user( pub async fn session_handler_with_unix_user(
socket: UnixStream, socket: UnixStream,
unix_user: &UnixUser, unix_user: &UnixUser,
db_pool: MySqlPool, db_pool: Arc<RwLock<MySqlPool>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut message_stream = create_server_to_client_message_stream(socket); let mut message_stream = create_server_to_client_message_stream(socket);
log::debug!("Requesting database connection from pool"); log::debug!("Requesting database connection from pool");
let mut db_connection = match db_pool.acquire().await { let mut db_connection = match db_pool.read().await.acquire().await {
Ok(connection) => connection, Ok(connection) => connection,
Err(err) => { Err(err) => {
message_stream message_stream

View File

@@ -8,58 +8,49 @@ use std::{
use anyhow::{Context, anyhow}; use anyhow::{Context, anyhow};
use sqlx::MySqlPool; use sqlx::MySqlPool;
use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval}; use tokio::{
use tokio_util::task::TaskTracker; net::UnixListener as TokioUnixListener,
// use tokio_util::sync::CancellationToken; select,
sync::{Mutex, RwLock, broadcast},
task::JoinHandle,
time::interval,
};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use crate::server::{ use crate::server::{
config::{MysqlConfig, ServerConfig}, config::{MysqlConfig, ServerConfig},
session_handler::session_handler, session_handler::session_handler,
}; };
// TODO: implement graceful shutdown and graceful reloads #[derive(Clone, Debug)]
pub enum SupervisorMessage {
StopAcceptingNewConnections,
ResumeAcceptingNewConnections,
Shutdown,
}
// Graceful shutdown process: #[derive(Clone, Debug)]
// 1. Notify systemd that shutdown is starting. pub struct ReloadEvent;
// 2. Stop accepting new connections.
// 3. Wait for existing connections to:
// - Finish all requests
// - Forcefully terminate after a timeout
// 3.5: Log everytime a connection is terminated, and warn if it was forcefully terminated.
// 4. Shutdown the database connection pool.
// 5. Cleanup resources and exit.
// Graceful reload process:
// 1. Notify systemd that reload is starting.
// 2. Get ahold of the configuration mutex (and hence stop accepting new connections)
// 3. Reload configuration from file.
// 4. If the configuration is invalid, log an error and abort the reload (drop mutex, resume as if reload was performed).
// 5. Set mutex contents to new configuration.
// 6. If database configuration has changed:
// - Wait for existing connections to finish (as in shutdown step 3).
// - Shutdown old database connection pool.
// - Create new database connection pool.
// 7. Drop config mutex (and hence resume accepting new connections).
// 8. Notify systemd that reload is complete.
#[allow(dead_code)] #[allow(dead_code)]
pub struct Supervisor { pub struct Supervisor {
config_path: PathBuf, config_path: PathBuf,
config: ServerConfig, config: Arc<Mutex<ServerConfig>>,
systemd_mode: bool, systemd_mode: bool,
// sighup_cancel_token: CancellationToken, shutdown_cancel_token: CancellationToken,
// sigterm_cancel_token: CancellationToken, reload_message_receiver: broadcast::Receiver<ReloadEvent>,
// signal_handler_task: JoinHandle<()>, signal_handler_task: JoinHandle<()>,
db_connection_pool: MySqlPool,
// listener: TokioUnixListener, db_connection_pool: Arc<RwLock<MySqlPool>>,
listener: Arc<RwLock<TokioUnixListener>>,
listener_task: JoinHandle<anyhow::Result<()>>, listener_task: JoinHandle<anyhow::Result<()>>,
handler_task_tracker: TaskTracker, handler_task_tracker: TaskTracker,
supervisor_message_sender: broadcast::Sender<SupervisorMessage>,
watchdog_timeout: Option<Duration>, watchdog_timeout: Option<Duration>,
systemd_watchdog_task: Option<JoinHandle<()>>, systemd_watchdog_task: Option<JoinHandle<()>>,
connection_counter: std::sync::Arc<()>,
status_notifier_task: Option<JoinHandle<()>>, status_notifier_task: Option<JoinHandle<()>>,
} }
@@ -89,53 +80,212 @@ impl Supervisor {
None None
}; };
let db_connection_pool = create_db_connection_pool(&config.mysql).await?; let db_connection_pool =
Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
let task_tracker = TaskTracker::new();
let connection_counter = Arc::new(());
let status_notifier_task = if systemd_mode { let status_notifier_task = if systemd_mode {
Some(spawn_status_notifier_task(connection_counter.clone())) Some(spawn_status_notifier_task(task_tracker.clone()))
} else { } else {
None None
}; };
let (tx, rx) = broadcast::channel(1);
// TODO: try to detech systemd socket before using the provided socket path // TODO: try to detech systemd socket before using the provided socket path
let listener = match config.socket_path { let listener = Arc::new(RwLock::new(match config.socket_path {
Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?, Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
None => create_unix_listener_with_systemd_socket().await?, None => create_unix_listener_with_systemd_socket().await?,
}; }));
let (reload_tx, reload_rx) = broadcast::channel(1);
let shutdown_cancel_token = CancellationToken::new();
let signal_handler_task =
spawn_signal_handler_task(reload_tx, shutdown_cancel_token.clone());
let listener_clone = listener.clone();
let task_tracker_clone = task_tracker.clone();
let listener_task = { let listener_task = {
let connection_counter = connection_counter.clone(); tokio::spawn(listener_task(
tokio::spawn(spawn_listener_task( listener_clone,
listener, task_tracker_clone,
connection_counter,
db_connection_pool.clone(), db_connection_pool.clone(),
rx,
)) ))
}; };
// let sighup_cancel_token = CancellationToken::new();
// let sigterm_cancel_token = CancellationToken::new();
Ok(Self { Ok(Self {
config_path, config_path,
config, config: Arc::new(Mutex::new(config)),
systemd_mode, systemd_mode,
// sighup_cancel_token, reload_message_receiver: reload_rx,
// sigterm_cancel_token, shutdown_cancel_token,
// signal_handler_task, signal_handler_task,
db_connection_pool, db_connection_pool,
// listener, listener,
listener_task, listener_task,
handler_task_tracker: TaskTracker::new(), handler_task_tracker: task_tracker,
supervisor_message_sender: tx,
watchdog_timeout: watchdog_duration, watchdog_timeout: watchdog_duration,
systemd_watchdog_task: watchdog_task, systemd_watchdog_task: watchdog_task,
connection_counter,
status_notifier_task, status_notifier_task,
}) })
} }
pub async fn run(self) -> anyhow::Result<()> { async fn stop_receiving_new_connections(&self) -> anyhow::Result<()> {
self.listener_task.await? self.handler_task_tracker.close();
self.supervisor_message_sender
.send(SupervisorMessage::StopAcceptingNewConnections)
.context("Failed to send stop accepting new connections message to listener task")?;
Ok(())
}
async fn resume_receiving_new_connections(&self) -> anyhow::Result<()> {
self.handler_task_tracker.reopen();
self.supervisor_message_sender
.send(SupervisorMessage::ResumeAcceptingNewConnections)
.context("Failed to send resume accepting new connections message to listener task")?;
Ok(())
}
async fn wait_for_existing_connections_to_finish(&self) -> anyhow::Result<()> {
self.handler_task_tracker.wait().await;
Ok(())
}
async fn reload_config(&self) -> anyhow::Result<()> {
let new_config = ServerConfig::read_config_from_path(&self.config_path)
.context("Failed to read server configuration")?;
let mut config = self.config.clone().lock_owned().await;
*config = new_config;
Ok(())
}
async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
let config = self.config.lock().await;
let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
let new_db_pool = create_db_connection_pool(&config.mysql).await?;
*connection_pool = new_db_pool;
Ok(())
}
// NOTE: the listener task will block the write lock unless the task is cancelled
// first. Make sure to handle that appropriately to avoid a deadlock.
async fn reload_listener(&self) -> anyhow::Result<()> {
let config = self.config.lock().await;
let new_listener = match config.socket_path {
Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
None => create_unix_listener_with_systemd_socket().await?,
};
let mut listener = self.listener.write().await;
*listener = new_listener;
Ok(())
}
pub async fn reload(&self) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Reloading])?;
let previous_config = self.config.lock().await.clone();
self.reload_config().await?;
let mut listener_task_was_stopped = false;
// NOTE: despite closing the existing db pool, any already acquired connections will remain valid until dropped,
// so we don't need to close existing connections here.
if self.config.lock().await.mysql != previous_config.mysql {
log::debug!("MySQL configuration has changed");
log::debug!("Restarting database connection pool with new configuration");
self.restart_db_connection_pool().await?;
}
if self.config.lock().await.socket_path != previous_config.socket_path {
log::debug!("Socket path configuration has changed, reloading listener");
if !listener_task_was_stopped {
listener_task_was_stopped = true;
log::debug!("Stop accepting new connections");
self.stop_receiving_new_connections().await?;
log::debug!("Waiting for existing connections to finish");
self.wait_for_existing_connections_to_finish().await?;
}
log::debug!("Reloading listener with new socket path");
self.reload_listener().await?;
}
if listener_task_was_stopped {
log::debug!("Resuming listener task");
self.resume_receiving_new_connections().await?;
}
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
Ok(())
}
pub async fn shutdown(&self) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?;
log::debug!("Stop accepting new connections");
self.stop_receiving_new_connections().await?;
let connection_count = self.handler_task_tracker.len();
log::debug!(
"Waiting for {} existing connections to finish",
connection_count
);
self.wait_for_existing_connections_to_finish().await?;
log::debug!("Shutting down listener task");
self.supervisor_message_sender
.send(SupervisorMessage::Shutdown)
.unwrap_or_else(|e| {
log::warn!(
"Failed to send shutdown message to listener task: {}",
e
);
0
});
log::debug!("Shutting down database connection pool");
self.db_connection_pool.read().await.close().await;
log::debug!("Server shutdown complete");
std::process::exit(0);
}
pub async fn run(&self) -> anyhow::Result<()> {
loop {
select! {
biased;
_ = async {
let mut rx = self.reload_message_receiver.resubscribe();
rx.recv().await
} => {
log::info!("Reloading configuration");
match self.reload().await {
Ok(()) => {
log::info!("Configuration reloaded successfully");
}
Err(e) => {
log::error!("Failed to reload configuration: {}", e);
}
}
}
_ = self.shutdown_cancel_token.cancelled() => {
log::info!("Shutting down server");
self.shutdown().await?;
break;
}
}
}
Ok(())
} }
} }
@@ -155,23 +305,14 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
}) })
} }
fn spawn_status_notifier_task(connection_counter: std::sync::Arc<()>) -> JoinHandle<()> { fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
const NON_CONNECTION_ARC_COUNT: usize = 3;
const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1); const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1);
tokio::spawn(async move { tokio::spawn(async move {
let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS); let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS);
loop { loop {
interval.tick().await; interval.tick().await;
let count = match Arc::strong_count(&connection_counter) let count = task_tracker.len();
.checked_sub(NON_CONNECTION_ARC_COUNT)
{
Some(c) => c,
None => {
debug_assert!(false, "Connection counter calculation underflowed");
0
}
};
let message = if count > 0 { let message = if count > 0 {
format!("Handling {} connections", count) format!("Handling {} connections", count)
@@ -258,46 +399,75 @@ async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySql
Ok(pool) Ok(pool)
} }
// fn spawn_signal_handler_task( fn spawn_signal_handler_task(
// sighup_token: CancellationToken, reload_sender: broadcast::Sender<ReloadEvent>,
// sigterm_token: CancellationToken, shutdown_token: CancellationToken,
// ) -> JoinHandle<()> { ) -> JoinHandle<()> {
// tokio::spawn(async move { tokio::spawn(async move {
// let mut sighup_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup()) let mut sighup_stream =
// .expect("Failed to set up SIGHUP handler"); tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
// let mut sigterm_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("Failed to set up SIGHUP handler");
// .expect("Failed to set up SIGTERM handler"); let mut sigterm_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to set up SIGTERM handler");
// loop { loop {
// tokio::select! { tokio::select! {
// _ = sighup_stream.recv() => { _ = sighup_stream.recv() => {
// log::info!("Received SIGHUP signal"); log::info!("Received SIGHUP signal");
// sighup_token.cancel(); reload_sender.send(ReloadEvent).ok();
// } }
// _ = sigterm_stream.recv() => { _ = sigterm_stream.recv() => {
// log::info!("Received SIGTERM signal"); log::info!("Received SIGTERM signal");
// sigterm_token.cancel(); shutdown_token.cancel();
// break; break;
// } }
// } }
// } }
// }) })
// } }
async fn spawn_listener_task( async fn listener_task(
listener: TokioUnixListener, listener: Arc<RwLock<TokioUnixListener>>,
connection_counter: Arc<()>, task_tracker: TaskTracker,
db_pool: MySqlPool, db_pool: Arc<RwLock<MySqlPool>>,
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
while let Ok((conn, _addr)) = listener.accept().await { loop {
tokio::select! {
biased;
Ok(message) = supervisor_message_receiver.recv() => {
match message {
SupervisorMessage::StopAcceptingNewConnections => {
log::info!("Listener task received stop accepting new connections message, stopping listener");
while let Ok(msg) = supervisor_message_receiver.try_recv() {
if let SupervisorMessage::ResumeAcceptingNewConnections = msg {
log::info!("Listener task received resume accepting new connections message, resuming listener");
break;
}
}
}
SupervisorMessage::Shutdown => {
log::info!("Listener task received shutdown message, exiting listener task");
break;
}
_ => {}
}
}
accept_result = async {
let listener = listener.read().await;
listener.accept().await
} => {
match accept_result {
Ok((conn, _addr)) => {
log::debug!("Got new connection"); log::debug!("Got new connection");
let db_pool_clone = db_pool.clone(); let db_pool_clone = db_pool.clone();
let _connection_counter_guard = Arc::clone(&connection_counter); task_tracker.spawn(async {
tokio::spawn(async {
let _guard = _connection_counter_guard;
match session_handler(conn, db_pool_clone).await { match session_handler(conn, db_pool_clone).await {
Ok(()) => {} Ok(()) => {}
Err(e) => { Err(e) => {
@@ -306,6 +476,13 @@ async fn spawn_listener_task(
} }
}); });
} }
Err(e) => {
log::error!("Failed to accept new connection: {}", e);
}
}
}
}
}
Ok(()) Ok(())
} }