server: implement graceful shutdown and reloads

This commit is contained in:
2025-11-29 21:57:25 +09:00
parent 4a6e49110a
commit 1fe08b59a3
8 changed files with 313 additions and 116 deletions

10
Cargo.lock generated
View File

@@ -1690,6 +1690,15 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook-registry"
version = "1.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad"
dependencies = [
"libc",
]
[[package]]
name = "signature"
version = "2.2.0"
@@ -2094,6 +2103,7 @@ dependencies = [
"libc",
"mio",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys 0.61.2",

View File

@@ -39,7 +39,7 @@ serde = "1.0.228"
serde_json = { version = "1.0.145", features = ["preserve_order"] }
sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] }
systemd-journal-logger = "2.2.2"
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] }
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros", "signal"] }
tokio-serde = { version = "0.9.0", features = ["bincode"] }
tokio-stream = "0.1.17"
tokio-util = { version = "0.7.17", features = ["codec", "rt"] }

View File

@@ -5,6 +5,7 @@ Requires=muscl.socket
[Service]
Type=notify
ExecStart=/usr/bin/muscl server --systemd socket-activate
ExecReload=/usr/bin/kill -HUP $MAINPID
WatchdogSec=15
@@ -15,7 +16,7 @@ Group=muscl
DynamicUser=yes
ConfigurationDirectory=muscl
RuntimeDirectory=muscl
# RuntimeDirectory=muscl
# This is required to read unix user/group details.
PrivateUsers=false

View File

@@ -101,13 +101,18 @@ in
systemd.sockets."muscl".wantedBy = [ "sockets.target" ];
systemd.services."muscl" = {
restartTriggers = [ config.environment.etc."muscl/config.toml".source ];
reloadTriggers = [ config.environment.etc."muscl/config.toml".source ];
serviceConfig = {
ExecStart = [
""
"${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate"
];
ExecReload = [
""
"${lib.getExe' pkgs.coreutils "kill"} -HUP $MAINPID"
];
IPAddressDeny = "any";
IPAddressAllow = [
"127.0.0.0/8"

View File

@@ -1,11 +1,11 @@
use std::{fs, path::PathBuf, time::Duration};
use std::{fs, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{Context, anyhow};
use clap_verbosity_flag::Verbosity;
use nix::libc::{EXIT_SUCCESS, exit};
use sqlx::mysql::MySqlPoolOptions;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream;
use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
use crate::{
core::common::{
@@ -254,6 +254,7 @@ fn run_forked_server(
.block_on(async {
let socket = TokioUnixStream::from_std(server_socket)?;
let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?;
let db_pool = Arc::new(RwLock::new(db_pool));
session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?;
Ok(())
});

View File

@@ -17,7 +17,7 @@ fn default_mysql_timeout() -> u64 {
DEFAULT_TIMEOUT
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename = "mysql")]
pub struct MysqlConfig {
pub socket_path: Option<PathBuf>,
@@ -70,7 +70,7 @@ impl MysqlConfig {
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct ServerConfig {
pub socket_path: Option<PathBuf>,
pub mysql: MysqlConfig,

View File

@@ -1,9 +1,9 @@
use std::collections::BTreeSet;
use std::{collections::BTreeSet, sync::Arc};
use futures_util::{SinkExt, StreamExt};
use indoc::concatdoc;
use sqlx::{MySqlConnection, MySqlPool};
use tokio::net::UnixStream;
use tokio::{net::UnixStream, sync::RwLock};
use crate::{
core::{
@@ -33,7 +33,10 @@ use crate::{
// TODO: don't use database connection unless necessary.
pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow::Result<()> {
pub async fn session_handler(
socket: UnixStream,
db_pool: Arc<RwLock<MySqlPool>>,
) -> anyhow::Result<()> {
let uid = match socket.peer_cred() {
Ok(cred) => cred.uid(),
Err(e) => {
@@ -80,12 +83,12 @@ pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow::
pub async fn session_handler_with_unix_user(
socket: UnixStream,
unix_user: &UnixUser,
db_pool: MySqlPool,
db_pool: Arc<RwLock<MySqlPool>>,
) -> anyhow::Result<()> {
let mut message_stream = create_server_to_client_message_stream(socket);
log::debug!("Requesting database connection from pool");
let mut db_connection = match db_pool.acquire().await {
let mut db_connection = match db_pool.read().await.acquire().await {
Ok(connection) => connection,
Err(err) => {
message_stream

View File

@@ -8,58 +8,49 @@ use std::{
use anyhow::{Context, anyhow};
use sqlx::MySqlPool;
use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval};
use tokio_util::task::TaskTracker;
// use tokio_util::sync::CancellationToken;
use tokio::{
net::UnixListener as TokioUnixListener,
select,
sync::{Mutex, RwLock, broadcast},
task::JoinHandle,
time::interval,
};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use crate::server::{
config::{MysqlConfig, ServerConfig},
session_handler::session_handler,
};
// TODO: implement graceful shutdown and graceful reloads
#[derive(Clone, Debug)]
pub enum SupervisorMessage {
StopAcceptingNewConnections,
ResumeAcceptingNewConnections,
Shutdown,
}
// Graceful shutdown process:
// 1. Notify systemd that shutdown is starting.
// 2. Stop accepting new connections.
// 3. Wait for existing connections to:
// - Finish all requests
// - Forcefully terminate after a timeout
// 3.5: Log everytime a connection is terminated, and warn if it was forcefully terminated.
// 4. Shutdown the database connection pool.
// 5. Cleanup resources and exit.
// Graceful reload process:
// 1. Notify systemd that reload is starting.
// 2. Get ahold of the configuration mutex (and hence stop accepting new connections)
// 3. Reload configuration from file.
// 4. If the configuration is invalid, log an error and abort the reload (drop mutex, resume as if reload was performed).
// 5. Set mutex contents to new configuration.
// 6. If database configuration has changed:
// - Wait for existing connections to finish (as in shutdown step 3).
// - Shutdown old database connection pool.
// - Create new database connection pool.
// 7. Drop config mutex (and hence resume accepting new connections).
// 8. Notify systemd that reload is complete.
#[derive(Clone, Debug)]
pub struct ReloadEvent;
#[allow(dead_code)]
pub struct Supervisor {
config_path: PathBuf,
config: ServerConfig,
config: Arc<Mutex<ServerConfig>>,
systemd_mode: bool,
// sighup_cancel_token: CancellationToken,
// sigterm_cancel_token: CancellationToken,
// signal_handler_task: JoinHandle<()>,
db_connection_pool: MySqlPool,
// listener: TokioUnixListener,
shutdown_cancel_token: CancellationToken,
reload_message_receiver: broadcast::Receiver<ReloadEvent>,
signal_handler_task: JoinHandle<()>,
db_connection_pool: Arc<RwLock<MySqlPool>>,
listener: Arc<RwLock<TokioUnixListener>>,
listener_task: JoinHandle<anyhow::Result<()>>,
handler_task_tracker: TaskTracker,
supervisor_message_sender: broadcast::Sender<SupervisorMessage>,
watchdog_timeout: Option<Duration>,
systemd_watchdog_task: Option<JoinHandle<()>>,
connection_counter: std::sync::Arc<()>,
status_notifier_task: Option<JoinHandle<()>>,
}
@@ -89,53 +80,212 @@ impl Supervisor {
None
};
let db_connection_pool = create_db_connection_pool(&config.mysql).await?;
let db_connection_pool =
Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?));
let task_tracker = TaskTracker::new();
let connection_counter = Arc::new(());
let status_notifier_task = if systemd_mode {
Some(spawn_status_notifier_task(connection_counter.clone()))
Some(spawn_status_notifier_task(task_tracker.clone()))
} else {
None
};
let (tx, rx) = broadcast::channel(1);
// TODO: try to detech systemd socket before using the provided socket path
let listener = match config.socket_path {
let listener = Arc::new(RwLock::new(match config.socket_path {
Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
None => create_unix_listener_with_systemd_socket().await?,
};
}));
let (reload_tx, reload_rx) = broadcast::channel(1);
let shutdown_cancel_token = CancellationToken::new();
let signal_handler_task =
spawn_signal_handler_task(reload_tx, shutdown_cancel_token.clone());
let listener_clone = listener.clone();
let task_tracker_clone = task_tracker.clone();
let listener_task = {
let connection_counter = connection_counter.clone();
tokio::spawn(spawn_listener_task(
listener,
connection_counter,
tokio::spawn(listener_task(
listener_clone,
task_tracker_clone,
db_connection_pool.clone(),
rx,
))
};
// let sighup_cancel_token = CancellationToken::new();
// let sigterm_cancel_token = CancellationToken::new();
Ok(Self {
config_path,
config,
config: Arc::new(Mutex::new(config)),
systemd_mode,
// sighup_cancel_token,
// sigterm_cancel_token,
// signal_handler_task,
reload_message_receiver: reload_rx,
shutdown_cancel_token,
signal_handler_task,
db_connection_pool,
// listener,
listener,
listener_task,
handler_task_tracker: TaskTracker::new(),
handler_task_tracker: task_tracker,
supervisor_message_sender: tx,
watchdog_timeout: watchdog_duration,
systemd_watchdog_task: watchdog_task,
connection_counter,
status_notifier_task,
})
}
pub async fn run(self) -> anyhow::Result<()> {
self.listener_task.await?
async fn stop_receiving_new_connections(&self) -> anyhow::Result<()> {
self.handler_task_tracker.close();
self.supervisor_message_sender
.send(SupervisorMessage::StopAcceptingNewConnections)
.context("Failed to send stop accepting new connections message to listener task")?;
Ok(())
}
async fn resume_receiving_new_connections(&self) -> anyhow::Result<()> {
self.handler_task_tracker.reopen();
self.supervisor_message_sender
.send(SupervisorMessage::ResumeAcceptingNewConnections)
.context("Failed to send resume accepting new connections message to listener task")?;
Ok(())
}
async fn wait_for_existing_connections_to_finish(&self) -> anyhow::Result<()> {
self.handler_task_tracker.wait().await;
Ok(())
}
async fn reload_config(&self) -> anyhow::Result<()> {
let new_config = ServerConfig::read_config_from_path(&self.config_path)
.context("Failed to read server configuration")?;
let mut config = self.config.clone().lock_owned().await;
*config = new_config;
Ok(())
}
async fn restart_db_connection_pool(&self) -> anyhow::Result<()> {
let config = self.config.lock().await;
let mut connection_pool = self.db_connection_pool.clone().write_owned().await;
let new_db_pool = create_db_connection_pool(&config.mysql).await?;
*connection_pool = new_db_pool;
Ok(())
}
// NOTE: the listener task will block the write lock unless the task is cancelled
// first. Make sure to handle that appropriately to avoid a deadlock.
async fn reload_listener(&self) -> anyhow::Result<()> {
let config = self.config.lock().await;
let new_listener = match config.socket_path {
Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
None => create_unix_listener_with_systemd_socket().await?,
};
let mut listener = self.listener.write().await;
*listener = new_listener;
Ok(())
}
pub async fn reload(&self) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Reloading])?;
let previous_config = self.config.lock().await.clone();
self.reload_config().await?;
let mut listener_task_was_stopped = false;
// NOTE: despite closing the existing db pool, any already acquired connections will remain valid until dropped,
// so we don't need to close existing connections here.
if self.config.lock().await.mysql != previous_config.mysql {
log::debug!("MySQL configuration has changed");
log::debug!("Restarting database connection pool with new configuration");
self.restart_db_connection_pool().await?;
}
if self.config.lock().await.socket_path != previous_config.socket_path {
log::debug!("Socket path configuration has changed, reloading listener");
if !listener_task_was_stopped {
listener_task_was_stopped = true;
log::debug!("Stop accepting new connections");
self.stop_receiving_new_connections().await?;
log::debug!("Waiting for existing connections to finish");
self.wait_for_existing_connections_to_finish().await?;
}
log::debug!("Reloading listener with new socket path");
self.reload_listener().await?;
}
if listener_task_was_stopped {
log::debug!("Resuming listener task");
self.resume_receiving_new_connections().await?;
}
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
Ok(())
}
pub async fn shutdown(&self) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?;
log::debug!("Stop accepting new connections");
self.stop_receiving_new_connections().await?;
let connection_count = self.handler_task_tracker.len();
log::debug!(
"Waiting for {} existing connections to finish",
connection_count
);
self.wait_for_existing_connections_to_finish().await?;
log::debug!("Shutting down listener task");
self.supervisor_message_sender
.send(SupervisorMessage::Shutdown)
.unwrap_or_else(|e| {
log::warn!(
"Failed to send shutdown message to listener task: {}",
e
);
0
});
log::debug!("Shutting down database connection pool");
self.db_connection_pool.read().await.close().await;
log::debug!("Server shutdown complete");
std::process::exit(0);
}
pub async fn run(&self) -> anyhow::Result<()> {
loop {
select! {
biased;
_ = async {
let mut rx = self.reload_message_receiver.resubscribe();
rx.recv().await
} => {
log::info!("Reloading configuration");
match self.reload().await {
Ok(()) => {
log::info!("Configuration reloaded successfully");
}
Err(e) => {
log::error!("Failed to reload configuration: {}", e);
}
}
}
_ = self.shutdown_cancel_token.cancelled() => {
log::info!("Shutting down server");
self.shutdown().await?;
break;
}
}
}
Ok(())
}
}
@@ -155,23 +305,14 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
})
}
fn spawn_status_notifier_task(connection_counter: std::sync::Arc<()>) -> JoinHandle<()> {
const NON_CONNECTION_ARC_COUNT: usize = 3;
fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1);
tokio::spawn(async move {
let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS);
loop {
interval.tick().await;
let count = match Arc::strong_count(&connection_counter)
.checked_sub(NON_CONNECTION_ARC_COUNT)
{
Some(c) => c,
None => {
debug_assert!(false, "Connection counter calculation underflowed");
0
}
};
let count = task_tracker.len();
let message = if count > 0 {
format!("Handling {} connections", count)
@@ -258,53 +399,89 @@ async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySql
Ok(pool)
}
// fn spawn_signal_handler_task(
// sighup_token: CancellationToken,
// sigterm_token: CancellationToken,
// ) -> JoinHandle<()> {
// tokio::spawn(async move {
// let mut sighup_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
// .expect("Failed to set up SIGHUP handler");
// let mut sigterm_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
// .expect("Failed to set up SIGTERM handler");
fn spawn_signal_handler_task(
reload_sender: broadcast::Sender<ReloadEvent>,
shutdown_token: CancellationToken,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut sighup_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
.expect("Failed to set up SIGHUP handler");
let mut sigterm_stream =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to set up SIGTERM handler");
// loop {
// tokio::select! {
// _ = sighup_stream.recv() => {
// log::info!("Received SIGHUP signal");
// sighup_token.cancel();
// }
// _ = sigterm_stream.recv() => {
// log::info!("Received SIGTERM signal");
// sigterm_token.cancel();
// break;
// }
// }
// }
// })
// }
loop {
tokio::select! {
_ = sighup_stream.recv() => {
log::info!("Received SIGHUP signal");
reload_sender.send(ReloadEvent).ok();
}
_ = sigterm_stream.recv() => {
log::info!("Received SIGTERM signal");
shutdown_token.cancel();
break;
}
}
}
})
}
async fn spawn_listener_task(
listener: TokioUnixListener,
connection_counter: Arc<()>,
db_pool: MySqlPool,
async fn listener_task(
listener: Arc<RwLock<TokioUnixListener>>,
task_tracker: TaskTracker,
db_pool: Arc<RwLock<MySqlPool>>,
mut supervisor_message_receiver: broadcast::Receiver<SupervisorMessage>,
) -> anyhow::Result<()> {
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
while let Ok((conn, _addr)) = listener.accept().await {
log::debug!("Got new connection");
loop {
tokio::select! {
biased;
let db_pool_clone = db_pool.clone();
let _connection_counter_guard = Arc::clone(&connection_counter);
tokio::spawn(async {
let _guard = _connection_counter_guard;
match session_handler(conn, db_pool_clone).await {
Ok(()) => {}
Err(e) => {
log::error!("Failed to run server: {}", e);
Ok(message) = supervisor_message_receiver.recv() => {
match message {
SupervisorMessage::StopAcceptingNewConnections => {
log::info!("Listener task received stop accepting new connections message, stopping listener");
while let Ok(msg) = supervisor_message_receiver.try_recv() {
if let SupervisorMessage::ResumeAcceptingNewConnections = msg {
log::info!("Listener task received resume accepting new connections message, resuming listener");
break;
}
}
}
SupervisorMessage::Shutdown => {
log::info!("Listener task received shutdown message, exiting listener task");
break;
}
_ => {}
}
}
});
accept_result = async {
let listener = listener.read().await;
listener.accept().await
} => {
match accept_result {
Ok((conn, _addr)) => {
log::debug!("Got new connection");
let db_pool_clone = db_pool.clone();
task_tracker.spawn(async {
match session_handler(conn, db_pool_clone).await {
Ok(()) => {}
Err(e) => {
log::error!("Failed to run server: {}", e);
}
}
});
}
Err(e) => {
log::error!("Failed to accept new connection: {}", e);
}
}
}
}
}
Ok(())