diff --git a/Cargo.lock b/Cargo.lock index 17395b3..673b481 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1690,6 +1690,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -2094,6 +2103,7 @@ dependencies = [ "libc", "mio", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.61.2", diff --git a/Cargo.toml b/Cargo.toml index 099c78d..0e9b4d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ serde = "1.0.228" serde_json = { version = "1.0.145", features = ["preserve_order"] } sqlx = { version = "0.8.6", features = ["runtime-tokio", "mysql", "tls-rustls"] } systemd-journal-logger = "2.2.2" -tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] } +tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros", "signal"] } tokio-serde = { version = "0.9.0", features = ["bincode"] } tokio-stream = "0.1.17" tokio-util = { version = "0.7.17", features = ["codec", "rt"] } diff --git a/assets/systemd/muscl.service b/assets/systemd/muscl.service index ce22fe6..44172bb 100644 --- a/assets/systemd/muscl.service +++ b/assets/systemd/muscl.service @@ -5,6 +5,7 @@ Requires=muscl.socket [Service] Type=notify ExecStart=/usr/bin/muscl server --systemd socket-activate +ExecReload=/usr/bin/kill -HUP $MAINPID WatchdogSec=15 @@ -15,7 +16,7 @@ Group=muscl DynamicUser=yes ConfigurationDirectory=muscl -RuntimeDirectory=muscl +# RuntimeDirectory=muscl # This is required to read unix user/group details. PrivateUsers=false diff --git a/nix/module.nix b/nix/module.nix index ae37411..90cc0df 100644 --- a/nix/module.nix +++ b/nix/module.nix @@ -101,13 +101,18 @@ in systemd.sockets."muscl".wantedBy = [ "sockets.target" ]; systemd.services."muscl" = { - restartTriggers = [ config.environment.etc."muscl/config.toml".source ]; + reloadTriggers = [ config.environment.etc."muscl/config.toml".source ]; serviceConfig = { ExecStart = [ "" "${lib.getExe cfg.package} ${cfg.logLevel} server --systemd socket-activate" ]; + ExecReload = [ + "" + "${lib.getExe' pkgs.coreutils "kill"} -HUP $MAINPID" + ]; + IPAddressDeny = "any"; IPAddressAllow = [ "127.0.0.0/8" diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index 2e58233..cf9c8a2 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -1,11 +1,11 @@ -use std::{fs, path::PathBuf, time::Duration}; +use std::{fs, path::PathBuf, sync::Arc, time::Duration}; use anyhow::{Context, anyhow}; use clap_verbosity_flag::Verbosity; use nix::libc::{EXIT_SUCCESS, exit}; use sqlx::mysql::MySqlPoolOptions; use std::os::unix::net::UnixStream as StdUnixStream; -use tokio::net::UnixStream as TokioUnixStream; +use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock}; use crate::{ core::common::{ @@ -254,6 +254,7 @@ fn run_forked_server( .block_on(async { let socket = TokioUnixStream::from_std(server_socket)?; let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?; + let db_pool = Arc::new(RwLock::new(db_pool)); session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?; Ok(()) }); diff --git a/src/server/config.rs b/src/server/config.rs index 82cb9de..73f44c8 100644 --- a/src/server/config.rs +++ b/src/server/config.rs @@ -17,7 +17,7 @@ fn default_mysql_timeout() -> u64 { DEFAULT_TIMEOUT } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename = "mysql")] pub struct MysqlConfig { pub socket_path: Option, @@ -70,7 +70,7 @@ impl MysqlConfig { } } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct ServerConfig { pub socket_path: Option, pub mysql: MysqlConfig, diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index 370c5e1..43edd8b 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -1,9 +1,9 @@ -use std::collections::BTreeSet; +use std::{collections::BTreeSet, sync::Arc}; use futures_util::{SinkExt, StreamExt}; use indoc::concatdoc; use sqlx::{MySqlConnection, MySqlPool}; -use tokio::net::UnixStream; +use tokio::{net::UnixStream, sync::RwLock}; use crate::{ core::{ @@ -33,7 +33,10 @@ use crate::{ // TODO: don't use database connection unless necessary. -pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow::Result<()> { +pub async fn session_handler( + socket: UnixStream, + db_pool: Arc>, +) -> anyhow::Result<()> { let uid = match socket.peer_cred() { Ok(cred) => cred.uid(), Err(e) => { @@ -80,12 +83,12 @@ pub async fn session_handler(socket: UnixStream, db_pool: MySqlPool) -> anyhow:: pub async fn session_handler_with_unix_user( socket: UnixStream, unix_user: &UnixUser, - db_pool: MySqlPool, + db_pool: Arc>, ) -> anyhow::Result<()> { let mut message_stream = create_server_to_client_message_stream(socket); log::debug!("Requesting database connection from pool"); - let mut db_connection = match db_pool.acquire().await { + let mut db_connection = match db_pool.read().await.acquire().await { Ok(connection) => connection, Err(err) => { message_stream diff --git a/src/server/supervisor.rs b/src/server/supervisor.rs index 50124cb..9701078 100644 --- a/src/server/supervisor.rs +++ b/src/server/supervisor.rs @@ -8,58 +8,49 @@ use std::{ use anyhow::{Context, anyhow}; use sqlx::MySqlPool; -use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval}; -use tokio_util::task::TaskTracker; -// use tokio_util::sync::CancellationToken; +use tokio::{ + net::UnixListener as TokioUnixListener, + select, + sync::{Mutex, RwLock, broadcast}, + task::JoinHandle, + time::interval, +}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::server::{ config::{MysqlConfig, ServerConfig}, session_handler::session_handler, }; -// TODO: implement graceful shutdown and graceful reloads +#[derive(Clone, Debug)] +pub enum SupervisorMessage { + StopAcceptingNewConnections, + ResumeAcceptingNewConnections, + Shutdown, +} -// Graceful shutdown process: -// 1. Notify systemd that shutdown is starting. -// 2. Stop accepting new connections. -// 3. Wait for existing connections to: -// - Finish all requests -// - Forcefully terminate after a timeout -// 3.5: Log everytime a connection is terminated, and warn if it was forcefully terminated. -// 4. Shutdown the database connection pool. -// 5. Cleanup resources and exit. - -// Graceful reload process: -// 1. Notify systemd that reload is starting. -// 2. Get ahold of the configuration mutex (and hence stop accepting new connections) -// 3. Reload configuration from file. -// 4. If the configuration is invalid, log an error and abort the reload (drop mutex, resume as if reload was performed). -// 5. Set mutex contents to new configuration. -// 6. If database configuration has changed: -// - Wait for existing connections to finish (as in shutdown step 3). -// - Shutdown old database connection pool. -// - Create new database connection pool. -// 7. Drop config mutex (and hence resume accepting new connections). -// 8. Notify systemd that reload is complete. +#[derive(Clone, Debug)] +pub struct ReloadEvent; #[allow(dead_code)] pub struct Supervisor { config_path: PathBuf, - config: ServerConfig, + config: Arc>, systemd_mode: bool, - // sighup_cancel_token: CancellationToken, - // sigterm_cancel_token: CancellationToken, - // signal_handler_task: JoinHandle<()>, - db_connection_pool: MySqlPool, - // listener: TokioUnixListener, + shutdown_cancel_token: CancellationToken, + reload_message_receiver: broadcast::Receiver, + signal_handler_task: JoinHandle<()>, + + db_connection_pool: Arc>, + listener: Arc>, listener_task: JoinHandle>, handler_task_tracker: TaskTracker, + supervisor_message_sender: broadcast::Sender, watchdog_timeout: Option, systemd_watchdog_task: Option>, - connection_counter: std::sync::Arc<()>, status_notifier_task: Option>, } @@ -89,53 +80,212 @@ impl Supervisor { None }; - let db_connection_pool = create_db_connection_pool(&config.mysql).await?; + let db_connection_pool = + Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?)); + + let task_tracker = TaskTracker::new(); - let connection_counter = Arc::new(()); let status_notifier_task = if systemd_mode { - Some(spawn_status_notifier_task(connection_counter.clone())) + Some(spawn_status_notifier_task(task_tracker.clone())) } else { None }; + let (tx, rx) = broadcast::channel(1); + // TODO: try to detech systemd socket before using the provided socket path - let listener = match config.socket_path { + let listener = Arc::new(RwLock::new(match config.socket_path { Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?, None => create_unix_listener_with_systemd_socket().await?, - }; + })); + let (reload_tx, reload_rx) = broadcast::channel(1); + let shutdown_cancel_token = CancellationToken::new(); + let signal_handler_task = + spawn_signal_handler_task(reload_tx, shutdown_cancel_token.clone()); + + let listener_clone = listener.clone(); + let task_tracker_clone = task_tracker.clone(); let listener_task = { - let connection_counter = connection_counter.clone(); - tokio::spawn(spawn_listener_task( - listener, - connection_counter, + tokio::spawn(listener_task( + listener_clone, + task_tracker_clone, db_connection_pool.clone(), + rx, )) }; - // let sighup_cancel_token = CancellationToken::new(); - // let sigterm_cancel_token = CancellationToken::new(); - Ok(Self { config_path, - config, + config: Arc::new(Mutex::new(config)), systemd_mode, - // sighup_cancel_token, - // sigterm_cancel_token, - // signal_handler_task, + reload_message_receiver: reload_rx, + shutdown_cancel_token, + signal_handler_task, db_connection_pool, - // listener, + listener, listener_task, - handler_task_tracker: TaskTracker::new(), + handler_task_tracker: task_tracker, + supervisor_message_sender: tx, watchdog_timeout: watchdog_duration, systemd_watchdog_task: watchdog_task, - connection_counter, status_notifier_task, }) } - pub async fn run(self) -> anyhow::Result<()> { - self.listener_task.await? + async fn stop_receiving_new_connections(&self) -> anyhow::Result<()> { + self.handler_task_tracker.close(); + self.supervisor_message_sender + .send(SupervisorMessage::StopAcceptingNewConnections) + .context("Failed to send stop accepting new connections message to listener task")?; + Ok(()) + } + + async fn resume_receiving_new_connections(&self) -> anyhow::Result<()> { + self.handler_task_tracker.reopen(); + self.supervisor_message_sender + .send(SupervisorMessage::ResumeAcceptingNewConnections) + .context("Failed to send resume accepting new connections message to listener task")?; + Ok(()) + } + + async fn wait_for_existing_connections_to_finish(&self) -> anyhow::Result<()> { + self.handler_task_tracker.wait().await; + Ok(()) + } + + async fn reload_config(&self) -> anyhow::Result<()> { + let new_config = ServerConfig::read_config_from_path(&self.config_path) + .context("Failed to read server configuration")?; + let mut config = self.config.clone().lock_owned().await; + *config = new_config; + Ok(()) + } + + async fn restart_db_connection_pool(&self) -> anyhow::Result<()> { + let config = self.config.lock().await; + let mut connection_pool = self.db_connection_pool.clone().write_owned().await; + let new_db_pool = create_db_connection_pool(&config.mysql).await?; + *connection_pool = new_db_pool; + Ok(()) + } + + // NOTE: the listener task will block the write lock unless the task is cancelled + // first. Make sure to handle that appropriately to avoid a deadlock. + async fn reload_listener(&self) -> anyhow::Result<()> { + let config = self.config.lock().await; + let new_listener = match config.socket_path { + Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?, + None => create_unix_listener_with_systemd_socket().await?, + }; + let mut listener = self.listener.write().await; + *listener = new_listener; + Ok(()) + } + + pub async fn reload(&self) -> anyhow::Result<()> { + sd_notify::notify(false, &[sd_notify::NotifyState::Reloading])?; + + let previous_config = self.config.lock().await.clone(); + self.reload_config().await?; + + let mut listener_task_was_stopped = false; + + // NOTE: despite closing the existing db pool, any already acquired connections will remain valid until dropped, + // so we don't need to close existing connections here. + if self.config.lock().await.mysql != previous_config.mysql { + log::debug!("MySQL configuration has changed"); + + log::debug!("Restarting database connection pool with new configuration"); + self.restart_db_connection_pool().await?; + } + + if self.config.lock().await.socket_path != previous_config.socket_path { + log::debug!("Socket path configuration has changed, reloading listener"); + if !listener_task_was_stopped { + listener_task_was_stopped = true; + log::debug!("Stop accepting new connections"); + self.stop_receiving_new_connections().await?; + + log::debug!("Waiting for existing connections to finish"); + self.wait_for_existing_connections_to_finish().await?; + } + + log::debug!("Reloading listener with new socket path"); + self.reload_listener().await?; + } + + if listener_task_was_stopped { + log::debug!("Resuming listener task"); + self.resume_receiving_new_connections().await?; + } + + sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; + + Ok(()) + } + + pub async fn shutdown(&self) -> anyhow::Result<()> { + sd_notify::notify(false, &[sd_notify::NotifyState::Stopping])?; + + log::debug!("Stop accepting new connections"); + self.stop_receiving_new_connections().await?; + + let connection_count = self.handler_task_tracker.len(); + log::debug!( + "Waiting for {} existing connections to finish", + connection_count + ); + self.wait_for_existing_connections_to_finish().await?; + + log::debug!("Shutting down listener task"); + self.supervisor_message_sender + .send(SupervisorMessage::Shutdown) + .unwrap_or_else(|e| { + log::warn!( + "Failed to send shutdown message to listener task: {}", + e + ); + 0 + }); + + log::debug!("Shutting down database connection pool"); + self.db_connection_pool.read().await.close().await; + + log::debug!("Server shutdown complete"); + + std::process::exit(0); + } + + pub async fn run(&self) -> anyhow::Result<()> { + loop { + select! { + biased; + + _ = async { + let mut rx = self.reload_message_receiver.resubscribe(); + rx.recv().await + } => { + log::info!("Reloading configuration"); + match self.reload().await { + Ok(()) => { + log::info!("Configuration reloaded successfully"); + } + Err(e) => { + log::error!("Failed to reload configuration: {}", e); + } + } + } + + _ = self.shutdown_cancel_token.cancelled() => { + log::info!("Shutting down server"); + self.shutdown().await?; + break; + } + } + } + + Ok(()) } } @@ -155,23 +305,14 @@ fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> { }) } -fn spawn_status_notifier_task(connection_counter: std::sync::Arc<()>) -> JoinHandle<()> { - const NON_CONNECTION_ARC_COUNT: usize = 3; +fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> { const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1); tokio::spawn(async move { let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS); loop { interval.tick().await; - let count = match Arc::strong_count(&connection_counter) - .checked_sub(NON_CONNECTION_ARC_COUNT) - { - Some(c) => c, - None => { - debug_assert!(false, "Connection counter calculation underflowed"); - 0 - } - }; + let count = task_tracker.len(); let message = if count > 0 { format!("Handling {} connections", count) @@ -258,53 +399,89 @@ async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result JoinHandle<()> { -// tokio::spawn(async move { -// let mut sighup_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup()) -// .expect("Failed to set up SIGHUP handler"); -// let mut sigterm_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) -// .expect("Failed to set up SIGTERM handler"); +fn spawn_signal_handler_task( + reload_sender: broadcast::Sender, + shutdown_token: CancellationToken, +) -> JoinHandle<()> { + tokio::spawn(async move { + let mut sighup_stream = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup()) + .expect("Failed to set up SIGHUP handler"); + let mut sigterm_stream = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Failed to set up SIGTERM handler"); -// loop { -// tokio::select! { -// _ = sighup_stream.recv() => { -// log::info!("Received SIGHUP signal"); -// sighup_token.cancel(); -// } -// _ = sigterm_stream.recv() => { -// log::info!("Received SIGTERM signal"); -// sigterm_token.cancel(); -// break; -// } -// } -// } -// }) -// } + loop { + tokio::select! { + _ = sighup_stream.recv() => { + log::info!("Received SIGHUP signal"); + reload_sender.send(ReloadEvent).ok(); + } + _ = sigterm_stream.recv() => { + log::info!("Received SIGTERM signal"); + shutdown_token.cancel(); + break; + } + } + } + }) +} -async fn spawn_listener_task( - listener: TokioUnixListener, - connection_counter: Arc<()>, - db_pool: MySqlPool, +async fn listener_task( + listener: Arc>, + task_tracker: TaskTracker, + db_pool: Arc>, + mut supervisor_message_receiver: broadcast::Receiver, ) -> anyhow::Result<()> { sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; - while let Ok((conn, _addr)) = listener.accept().await { - log::debug!("Got new connection"); + loop { + tokio::select! { + biased; - let db_pool_clone = db_pool.clone(); - let _connection_counter_guard = Arc::clone(&connection_counter); - tokio::spawn(async { - let _guard = _connection_counter_guard; - match session_handler(conn, db_pool_clone).await { - Ok(()) => {} - Err(e) => { - log::error!("Failed to run server: {}", e); + Ok(message) = supervisor_message_receiver.recv() => { + match message { + SupervisorMessage::StopAcceptingNewConnections => { + log::info!("Listener task received stop accepting new connections message, stopping listener"); + while let Ok(msg) = supervisor_message_receiver.try_recv() { + if let SupervisorMessage::ResumeAcceptingNewConnections = msg { + log::info!("Listener task received resume accepting new connections message, resuming listener"); + break; + } + } + } + SupervisorMessage::Shutdown => { + log::info!("Listener task received shutdown message, exiting listener task"); + break; + } + _ => {} } } - }); + + accept_result = async { + let listener = listener.read().await; + listener.accept().await + } => { + match accept_result { + Ok((conn, _addr)) => { + log::debug!("Got new connection"); + + let db_pool_clone = db_pool.clone(); + task_tracker.spawn(async { + match session_handler(conn, db_pool_clone).await { + Ok(()) => {} + Err(e) => { + log::error!("Failed to run server: {}", e); + } + } + }); + } + Err(e) => { + log::error!("Failed to accept new connection: {}", e); + } + } + } + } } Ok(())