server: refactor server logic into supervisor + session handler
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -2134,6 +2134,7 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
|
"futures-util",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ systemd-journal-logger = "2.2.2"
|
|||||||
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] }
|
tokio = { version = "1.48.0", features = ["rt-multi-thread", "macros"] }
|
||||||
tokio-serde = { version = "0.9.0", features = ["bincode"] }
|
tokio-serde = { version = "0.9.0", features = ["bincode"] }
|
||||||
tokio-stream = "0.1.17"
|
tokio-stream = "0.1.17"
|
||||||
tokio-util = { version = "0.7.17", features = ["codec"] }
|
tokio-util = { version = "0.7.17", features = ["codec", "rt"] }
|
||||||
toml = "0.9.8"
|
toml = "0.9.8"
|
||||||
uuid = { version = "1.18.1", features = ["v4"] }
|
uuid = { version = "1.18.1", features = ["v4"] }
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ use crate::{
|
|||||||
core::common::{
|
core::common::{
|
||||||
DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executable_is_suid_or_sgid,
|
DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executable_is_suid_or_sgid,
|
||||||
},
|
},
|
||||||
server::{config::read_config_from_path, server_loop::handle_requests_for_single_session},
|
server::{config::read_config_from_path, session_handler},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Determine whether we will make a connection to an external server
|
/// Determine whether we will make a connection to an external server
|
||||||
@@ -223,7 +223,7 @@ fn run_forked_server(
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
let socket = TokioUnixStream::from_std(server_socket)?;
|
let socket = TokioUnixStream::from_std(server_socket)?;
|
||||||
handle_requests_for_single_session(socket, &unix_user, &config).await?;
|
session_handler::session_handler(socket, &unix_user, &config).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -147,7 +147,6 @@ fn handle_server_command(args: &Args) -> anyhow::Result<Option<()>> {
|
|||||||
"The executable should not be SUID or SGID when running the server manually"
|
"The executable should not be SUID or SGID when running the server manually"
|
||||||
);
|
);
|
||||||
tokio_start_server(
|
tokio_start_server(
|
||||||
args.server_socket_path.to_owned(),
|
|
||||||
args.config.to_owned(),
|
args.config.to_owned(),
|
||||||
args.verbose.to_owned(),
|
args.verbose.to_owned(),
|
||||||
command.to_owned(),
|
command.to_owned(),
|
||||||
@@ -191,7 +190,6 @@ fn handle_generate_completions_command(args: &Args) -> anyhow::Result<Option<()>
|
|||||||
|
|
||||||
/// Start a long-lived server using Tokio.
|
/// Start a long-lived server using Tokio.
|
||||||
fn tokio_start_server(
|
fn tokio_start_server(
|
||||||
server_socket_path: Option<PathBuf>,
|
|
||||||
config_path: Option<PathBuf>,
|
config_path: Option<PathBuf>,
|
||||||
verbosity: Verbosity,
|
verbosity: Verbosity,
|
||||||
args: ServerArgs,
|
args: ServerArgs,
|
||||||
@@ -200,9 +198,7 @@ fn tokio_start_server(
|
|||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()
|
||||||
.context("Failed to start Tokio runtime")?
|
.context("Failed to start Tokio runtime")?
|
||||||
.block_on(async {
|
.block_on(async { server::command::handle_command(config_path, verbosity, args).await })
|
||||||
server::command::handle_command(server_socket_path, config_path, verbosity, args).await
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the given commmand (from the client side) using Tokio.
|
/// Run the given commmand (from the client side) using Tokio.
|
||||||
|
|||||||
@@ -2,5 +2,7 @@ pub mod command;
|
|||||||
mod common;
|
mod common;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod input_sanitization;
|
pub mod input_sanitization;
|
||||||
pub mod server_loop;
|
// pub mod server_loop;
|
||||||
|
pub mod session_handler;
|
||||||
pub mod sql;
|
pub mod sql;
|
||||||
|
pub mod supervisor;
|
||||||
|
|||||||
@@ -7,10 +7,11 @@ use systemd_journal_logger::JournalLog;
|
|||||||
|
|
||||||
use crate::server::{
|
use crate::server::{
|
||||||
config::{ServerConfigArgs, read_config_from_path_with_arg_overrides},
|
config::{ServerConfigArgs, read_config_from_path_with_arg_overrides},
|
||||||
server_loop::{
|
supervisor::Supervisor,
|
||||||
listen_for_incoming_connections_with_socket_path,
|
// server_loop::{
|
||||||
listen_for_incoming_connections_with_systemd_socket,
|
// listen_for_incoming_connections_with_socket_path,
|
||||||
},
|
// listen_for_incoming_connections_with_systemd_socket,
|
||||||
|
// },
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@@ -46,7 +47,6 @@ const LOG_LEVEL_WARNING: &str = r#"
|
|||||||
"#;
|
"#;
|
||||||
|
|
||||||
pub async fn handle_command(
|
pub async fn handle_command(
|
||||||
socket_path: Option<PathBuf>,
|
|
||||||
config_path: Option<PathBuf>,
|
config_path: Option<PathBuf>,
|
||||||
verbosity: Verbosity,
|
verbosity: Verbosity,
|
||||||
args: ServerArgs,
|
args: ServerArgs,
|
||||||
@@ -78,7 +78,7 @@ pub async fn handle_command(
|
|||||||
log::info!("Running in systemd mode");
|
log::info!("Running in systemd mode");
|
||||||
}
|
}
|
||||||
|
|
||||||
start_watchdog_thread_if_enabled();
|
// start_watchdog_thread_if_enabled();
|
||||||
} else {
|
} else {
|
||||||
env_logger::Builder::new()
|
env_logger::Builder::new()
|
||||||
.filter_level(verbosity.log_level_filter())
|
.filter_level(verbosity.log_level_filter())
|
||||||
@@ -90,9 +90,7 @@ pub async fn handle_command(
|
|||||||
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
|
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
|
||||||
|
|
||||||
match args.subcmd {
|
match args.subcmd {
|
||||||
ServerCommand::Listen => {
|
ServerCommand::Listen => Supervisor::new(config, systemd_mode).await?.run().await,
|
||||||
listen_for_incoming_connections_with_socket_path(socket_path, config).await
|
|
||||||
}
|
|
||||||
ServerCommand::SocketActivate => {
|
ServerCommand::SocketActivate => {
|
||||||
if !args.systemd {
|
if !args.systemd {
|
||||||
anyhow::bail!(concat!(
|
anyhow::bail!(concat!(
|
||||||
@@ -101,33 +99,33 @@ pub async fn handle_command(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
listen_for_incoming_connections_with_systemd_socket(config).await
|
Supervisor::new(config, systemd_mode).await?.run().await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_watchdog_thread_if_enabled() {
|
// fn start_watchdog_thread_if_enabled() {
|
||||||
let mut micro_seconds: u64 = 0;
|
// let mut micro_seconds: u64 = 0;
|
||||||
let watchdog_enabled = sd_notify::watchdog_enabled(false, &mut micro_seconds);
|
// let watchdog_enabled = sd_notify::watchdog_enabled(false, &mut micro_seconds);
|
||||||
|
|
||||||
if watchdog_enabled {
|
// if watchdog_enabled {
|
||||||
micro_seconds = micro_seconds.max(2_000_000).div_ceil(2);
|
// micro_seconds = micro_seconds.max(2_000_000).div_ceil(2);
|
||||||
|
|
||||||
tokio::spawn(async move {
|
// tokio::spawn(async move {
|
||||||
log::debug!(
|
// log::debug!(
|
||||||
"Starting systemd watchdog thread with {} millisecond interval",
|
// "Starting systemd watchdog thread with {} millisecond interval",
|
||||||
micro_seconds.div_ceil(1000)
|
// micro_seconds.div_ceil(1000)
|
||||||
);
|
// );
|
||||||
loop {
|
// loop {
|
||||||
tokio::time::sleep(tokio::time::Duration::from_micros(micro_seconds)).await;
|
// tokio::time::sleep(tokio::time::Duration::from_micros(micro_seconds)).await;
|
||||||
if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
|
// if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
|
||||||
log::warn!("Failed to notify systemd watchdog: {}", err);
|
// log::warn!("Failed to notify systemd watchdog: {}", err);
|
||||||
} else {
|
// } else {
|
||||||
log::trace!("Ping sent to systemd watchdog");
|
// log::trace!("Ping sent to systemd watchdog");
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
});
|
// });
|
||||||
} else {
|
// } else {
|
||||||
log::debug!("Systemd watchdog not enabled, skipping watchdog thread");
|
// log::debug!("Systemd watchdog not enabled, skipping watchdog thread");
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|||||||
@@ -8,13 +8,18 @@ use sqlx::{ConnectOptions, MySqlConnection, mysql::MySqlConnectOptions};
|
|||||||
use crate::core::common::DEFAULT_CONFIG_PATH;
|
use crate::core::common::DEFAULT_CONFIG_PATH;
|
||||||
|
|
||||||
pub const DEFAULT_PORT: u16 = 3306;
|
pub const DEFAULT_PORT: u16 = 3306;
|
||||||
pub const DEFAULT_TIMEOUT: u64 = 2;
|
fn default_mysql_port() -> u16 {
|
||||||
|
DEFAULT_PORT
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const DEFAULT_TIMEOUT: u64 = 2;
|
||||||
|
fn default_mysql_timeout() -> u64 {
|
||||||
|
DEFAULT_TIMEOUT
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: this might look empty now, and the extra wrapping for the mysql
|
|
||||||
// config seems unnecessary, but it will be useful later when we
|
|
||||||
// add more configuration options.
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
|
pub socket_path: Option<PathBuf>,
|
||||||
pub mysql: MysqlConfig,
|
pub mysql: MysqlConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -23,19 +28,25 @@ pub struct ServerConfig {
|
|||||||
pub struct MysqlConfig {
|
pub struct MysqlConfig {
|
||||||
pub socket_path: Option<PathBuf>,
|
pub socket_path: Option<PathBuf>,
|
||||||
pub host: Option<String>,
|
pub host: Option<String>,
|
||||||
pub port: Option<u16>,
|
#[serde(default = "default_mysql_port")]
|
||||||
|
pub port: u16,
|
||||||
pub username: Option<String>,
|
pub username: Option<String>,
|
||||||
pub password: Option<String>,
|
pub password: Option<String>,
|
||||||
pub password_file: Option<PathBuf>,
|
pub password_file: Option<PathBuf>,
|
||||||
pub timeout: Option<u64>,
|
#[serde(default = "default_mysql_timeout")]
|
||||||
|
pub timeout: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
pub struct ServerConfigArgs {
|
pub struct ServerConfigArgs {
|
||||||
/// Path to the socket of the MySQL server.
|
/// Path where the server socket should be created.
|
||||||
#[arg(long, value_name = "PATH", global = true)]
|
#[arg(long, value_name = "PATH", global = true)]
|
||||||
socket_path: Option<PathBuf>,
|
socket_path: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// Path to the socket of the MySQL server.
|
||||||
|
#[arg(long, value_name = "PATH", global = true)]
|
||||||
|
mysql_socket_path: Option<PathBuf>,
|
||||||
|
|
||||||
/// Hostname of the MySQL server.
|
/// Hostname of the MySQL server.
|
||||||
#[arg(
|
#[arg(
|
||||||
long,
|
long,
|
||||||
@@ -94,14 +105,15 @@ pub fn read_config_from_path_with_arg_overrides(
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok(ServerConfig {
|
Ok(ServerConfig {
|
||||||
|
socket_path: args.socket_path.or(config.socket_path),
|
||||||
mysql: MysqlConfig {
|
mysql: MysqlConfig {
|
||||||
socket_path: args.socket_path.or(mysql.socket_path),
|
socket_path: args.mysql_socket_path.or(mysql.socket_path),
|
||||||
host: args.mysql_host.or(mysql.host),
|
host: args.mysql_host.or(mysql.host),
|
||||||
port: args.mysql_port.or(mysql.port),
|
port: args.mysql_port.unwrap_or(mysql.port),
|
||||||
username: args.mysql_user.or(mysql.username.to_owned()),
|
username: args.mysql_user.or(mysql.username.to_owned()),
|
||||||
password,
|
password,
|
||||||
password_file: args.mysql_password_file.or(mysql.password_file),
|
password_file: args.mysql_password_file.or(mysql.password_file),
|
||||||
timeout: args.mysql_connect_timeout.or(mysql.timeout),
|
timeout: args.mysql_connect_timeout.unwrap_or(mysql.timeout),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -151,17 +163,12 @@ pub async fn create_mysql_connection_from_config(
|
|||||||
mysql_options = mysql_options.socket(socket_path);
|
mysql_options = mysql_options.socket(socket_path);
|
||||||
} else if let Some(host) = &config.host {
|
} else if let Some(host) = &config.host {
|
||||||
mysql_options = mysql_options.host(host);
|
mysql_options = mysql_options.host(host);
|
||||||
mysql_options = mysql_options.port(config.port.unwrap_or(DEFAULT_PORT));
|
mysql_options = mysql_options.port(config.port);
|
||||||
} else {
|
} else {
|
||||||
anyhow::bail!("No MySQL host or socket path provided");
|
anyhow::bail!("No MySQL host or socket path provided");
|
||||||
}
|
}
|
||||||
|
|
||||||
match tokio::time::timeout(
|
match tokio::time::timeout(Duration::from_secs(config.timeout), mysql_options.connect()).await {
|
||||||
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
|
|
||||||
mysql_options.connect(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(connection) => connection.context("Failed to connect to the database"),
|
Ok(connection) => connection.context("Failed to connect to the database"),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to the database")
|
Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to the database")
|
||||||
|
|||||||
@@ -1,19 +1,8 @@
|
|||||||
use std::{
|
use std::collections::BTreeSet;
|
||||||
collections::BTreeSet,
|
|
||||||
fs,
|
|
||||||
os::unix::{io::FromRawFd, net::UnixListener as StdUnixListener},
|
|
||||||
path::PathBuf,
|
|
||||||
sync::Arc,
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::Context;
|
|
||||||
use futures_util::{SinkExt, StreamExt};
|
use futures_util::{SinkExt, StreamExt};
|
||||||
use indoc::concatdoc;
|
use indoc::concatdoc;
|
||||||
use tokio::{
|
use tokio::net::UnixStream;
|
||||||
net::{UnixListener as TokioUnixListener, UnixStream as TokioUnixStream},
|
|
||||||
time::interval,
|
|
||||||
};
|
|
||||||
|
|
||||||
use sqlx::MySqlConnection;
|
use sqlx::MySqlConnection;
|
||||||
use sqlx::prelude::*;
|
use sqlx::prelude::*;
|
||||||
@@ -22,7 +11,7 @@ use crate::core::protocol::SetPasswordError;
|
|||||||
use crate::server::sql::database_operations::list_databases;
|
use crate::server::sql::database_operations::list_databases;
|
||||||
use crate::{
|
use crate::{
|
||||||
core::{
|
core::{
|
||||||
common::{DEFAULT_SOCKET_PATH, UnixUser},
|
common::UnixUser,
|
||||||
protocol::{
|
protocol::{
|
||||||
Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
|
Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
|
||||||
},
|
},
|
||||||
@@ -43,141 +32,10 @@ use crate::{
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: consider using a connection pool
|
// TODO: don't use database connection unless necessary.
|
||||||
|
|
||||||
pub async fn listen_for_incoming_connections_with_socket_path(
|
pub async fn session_handler(
|
||||||
socket_path: Option<PathBuf>,
|
socket: UnixStream,
|
||||||
config: ServerConfig,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
|
|
||||||
|
|
||||||
let parent_directory = socket_path.parent().unwrap();
|
|
||||||
if !parent_directory.exists() {
|
|
||||||
log::debug!("Creating directory {:?}", parent_directory);
|
|
||||||
fs::create_dir_all(parent_directory)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
log::info!("Listening on socket {:?}", socket_path);
|
|
||||||
|
|
||||||
match fs::remove_file(socket_path.as_path()) {
|
|
||||||
Ok(_) => {}
|
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
|
||||||
Err(e) => return Err(e.into()),
|
|
||||||
}
|
|
||||||
|
|
||||||
let listener = TokioUnixListener::bind(socket_path)?;
|
|
||||||
|
|
||||||
listen_for_incoming_connections_with_listener(listener, config).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn listen_for_incoming_connections_with_systemd_socket(
|
|
||||||
config: ServerConfig,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let fd = sd_notify::listen_fds()
|
|
||||||
.context("Failed to get file descriptors from systemd")?
|
|
||||||
.next()
|
|
||||||
.context("No file descriptors received from systemd")?;
|
|
||||||
|
|
||||||
debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
|
|
||||||
|
|
||||||
log::debug!(
|
|
||||||
"Received file descriptor from systemd with id: '{}', assuming socket",
|
|
||||||
fd
|
|
||||||
);
|
|
||||||
|
|
||||||
let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
|
|
||||||
let listener = TokioUnixListener::from_std(std_unix_listener)?;
|
|
||||||
listen_for_incoming_connections_with_listener(listener, config).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn listen_for_incoming_connections_with_listener(
|
|
||||||
listener: TokioUnixListener,
|
|
||||||
config: ServerConfig,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let connection_counter = Arc::new(());
|
|
||||||
let connection_counter_for_log = Arc::clone(&connection_counter);
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut interval = interval(Duration::from_secs(1));
|
|
||||||
loop {
|
|
||||||
interval.tick().await;
|
|
||||||
let count = Arc::strong_count(&connection_counter_for_log) - 2;
|
|
||||||
let message = if count > 0 {
|
|
||||||
format!("Handling {} connections", count)
|
|
||||||
} else {
|
|
||||||
"Waiting for connections".to_string()
|
|
||||||
};
|
|
||||||
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())]).ok();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready]).ok();
|
|
||||||
|
|
||||||
while let Ok((conn, _addr)) = listener.accept().await {
|
|
||||||
let uid = match conn.peer_cred() {
|
|
||||||
Ok(cred) => cred.uid(),
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Failed to get peer credentials from socket: {}", e);
|
|
||||||
let mut message_stream = create_server_to_client_message_stream(conn);
|
|
||||||
message_stream
|
|
||||||
.send(Response::Error(
|
|
||||||
(concatdoc! {
|
|
||||||
"Server failed to get peer credentials from socket\n",
|
|
||||||
"Please check the server logs or contact the system administrators"
|
|
||||||
})
|
|
||||||
.to_string(),
|
|
||||||
))
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let _connection_counter_guard = Arc::clone(&connection_counter);
|
|
||||||
|
|
||||||
log::debug!("Accepted connection from uid {}", uid);
|
|
||||||
|
|
||||||
let unix_user = match UnixUser::from_uid(uid) {
|
|
||||||
Ok(user) => user,
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Failed to get username from uid: {}", e);
|
|
||||||
let mut message_stream = create_server_to_client_message_stream(conn);
|
|
||||||
message_stream
|
|
||||||
.send(Response::Error(
|
|
||||||
(concatdoc! {
|
|
||||||
"Server failed to get user data from the system\n",
|
|
||||||
"Please check the server logs or contact the system administrators"
|
|
||||||
})
|
|
||||||
.to_string(),
|
|
||||||
))
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
log::info!("Accepted connection from {}", unix_user.username);
|
|
||||||
|
|
||||||
match handle_requests_for_single_session(conn, &unix_user, &config).await {
|
|
||||||
Ok(()) => {}
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Failed to run server: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
|
|
||||||
if let Err(e) = db_connection.close().await {
|
|
||||||
log::error!("Failed to close database connection: {}", e);
|
|
||||||
log::error!("{}", e);
|
|
||||||
log::error!("Ignoring...");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn handle_requests_for_single_session(
|
|
||||||
socket: TokioUnixStream,
|
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
config: &ServerConfig,
|
config: &ServerConfig,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
@@ -222,12 +80,8 @@ pub async fn handle_requests_for_single_session(
|
|||||||
|
|
||||||
log::debug!("Successfully connected to database");
|
log::debug!("Successfully connected to database");
|
||||||
|
|
||||||
let result = handle_requests_for_single_session_with_db_connection(
|
let result =
|
||||||
message_stream,
|
session_handler_with_db_connection(message_stream, unix_user, &mut db_connection).await;
|
||||||
unix_user,
|
|
||||||
&mut db_connection,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
close_or_ignore_db_connection(db_connection).await;
|
close_or_ignore_db_connection(db_connection).await;
|
||||||
|
|
||||||
@@ -237,7 +91,7 @@ pub async fn handle_requests_for_single_session(
|
|||||||
// TODO: ensure proper db_connection hygiene for functions that invoke
|
// TODO: ensure proper db_connection hygiene for functions that invoke
|
||||||
// this function
|
// this function
|
||||||
|
|
||||||
async fn handle_requests_for_single_session_with_db_connection(
|
async fn session_handler_with_db_connection(
|
||||||
mut stream: ServerToClientMessageStream,
|
mut stream: ServerToClientMessageStream,
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
db_connection: &mut MySqlConnection,
|
db_connection: &mut MySqlConnection,
|
||||||
@@ -360,3 +214,11 @@ async fn handle_requests_for_single_session_with_db_connection(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
|
||||||
|
if let Err(e) = db_connection.close().await {
|
||||||
|
log::error!("Failed to close database connection: {}", e);
|
||||||
|
log::error!("{}", e);
|
||||||
|
log::error!("Ignoring...");
|
||||||
|
}
|
||||||
|
}
|
||||||
317
src/server/supervisor.rs
Normal file
317
src/server/supervisor.rs
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
use std::{
|
||||||
|
fs,
|
||||||
|
os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
|
||||||
|
path::PathBuf,
|
||||||
|
sync::Arc,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::{Context, anyhow};
|
||||||
|
use futures_util::SinkExt;
|
||||||
|
use indoc::concatdoc;
|
||||||
|
use sqlx::{MySqlPool, mysql::MySqlConnectOptions, prelude::*};
|
||||||
|
use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval};
|
||||||
|
use tokio_util::task::TaskTracker;
|
||||||
|
// use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
core::{
|
||||||
|
common::UnixUser,
|
||||||
|
protocol::{Response, create_server_to_client_message_stream},
|
||||||
|
},
|
||||||
|
server::{
|
||||||
|
config::{MysqlConfig, ServerConfig},
|
||||||
|
session_handler::session_handler,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: implement graceful shutdown and graceful restarts
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub struct Supervisor {
|
||||||
|
config: ServerConfig,
|
||||||
|
systemd_mode: bool,
|
||||||
|
|
||||||
|
// sighup_cancel_token: CancellationToken,
|
||||||
|
// sigterm_cancel_token: CancellationToken,
|
||||||
|
// signal_handler_task: JoinHandle<()>,
|
||||||
|
db_connection_pool: MySqlPool,
|
||||||
|
// listener: TokioUnixListener,
|
||||||
|
listener_task: JoinHandle<anyhow::Result<()>>,
|
||||||
|
handler_task_tracker: TaskTracker,
|
||||||
|
|
||||||
|
watchdog_timeout: Option<Duration>,
|
||||||
|
systemd_watchdog_task: Option<JoinHandle<()>>,
|
||||||
|
|
||||||
|
connection_counter: std::sync::Arc<()>,
|
||||||
|
status_notifier_task: Option<JoinHandle<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Supervisor {
|
||||||
|
pub async fn new(config: ServerConfig, systemd_mode: bool) -> anyhow::Result<Self> {
|
||||||
|
let mut watchdog_duration = None;
|
||||||
|
let mut watchdog_micro_seconds = 0;
|
||||||
|
let watchdog_task =
|
||||||
|
if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
|
||||||
|
watchdog_duration = Some(Duration::from_micros(watchdog_micro_seconds));
|
||||||
|
log::debug!(
|
||||||
|
"Systemd watchdog enabled with {} millisecond interval",
|
||||||
|
watchdog_micro_seconds.div_ceil(1000),
|
||||||
|
);
|
||||||
|
Some(spawn_watchdog_task(watchdog_duration.unwrap()))
|
||||||
|
} else {
|
||||||
|
log::debug!("Systemd watchdog not enabled, skipping watchdog thread");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let db_connection_pool = create_db_connection_pool(&config.mysql).await?;
|
||||||
|
|
||||||
|
let connection_counter = Arc::new(());
|
||||||
|
let status_notifier_task = if systemd_mode {
|
||||||
|
Some(spawn_status_notifier_task(connection_counter.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: try to detech systemd socket before using the provided socket path
|
||||||
|
let listener = match config.socket_path {
|
||||||
|
Some(ref path) => create_unix_listener_with_socket_path(path.clone()).await?,
|
||||||
|
None => create_unix_listener_with_systemd_socket().await?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let listener_task = {
|
||||||
|
let connection_counter = connection_counter.clone();
|
||||||
|
let config_clone = config.clone();
|
||||||
|
tokio::spawn(spawn_listener_task(
|
||||||
|
listener,
|
||||||
|
config_clone,
|
||||||
|
connection_counter,
|
||||||
|
))
|
||||||
|
};
|
||||||
|
|
||||||
|
// let sighup_cancel_token = CancellationToken::new();
|
||||||
|
// let sigterm_cancel_token = CancellationToken::new();
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
config,
|
||||||
|
systemd_mode,
|
||||||
|
// sighup_cancel_token,
|
||||||
|
// sigterm_cancel_token,
|
||||||
|
// signal_handler_task,
|
||||||
|
db_connection_pool,
|
||||||
|
// listener,
|
||||||
|
listener_task,
|
||||||
|
handler_task_tracker: TaskTracker::new(),
|
||||||
|
watchdog_timeout: watchdog_duration,
|
||||||
|
systemd_watchdog_task: watchdog_task,
|
||||||
|
connection_counter,
|
||||||
|
status_notifier_task,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(self) -> anyhow::Result<()> {
|
||||||
|
self.listener_task.await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_watchdog_task(duration: Duration) -> JoinHandle<()> {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = interval(duration.div_f32(2.0));
|
||||||
|
log::debug!(
|
||||||
|
"Starting systemd watchdog task, pinging every {} milliseconds",
|
||||||
|
duration.div_f32(2.0).as_millis()
|
||||||
|
);
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
if let Err(err) = sd_notify::notify(false, &[sd_notify::NotifyState::Watchdog]) {
|
||||||
|
log::warn!("Failed to notify systemd watchdog: {}", err);
|
||||||
|
} else {
|
||||||
|
log::trace!("Ping sent to systemd watchdog");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_status_notifier_task(connection_counter: std::sync::Arc<()>) -> JoinHandle<()> {
|
||||||
|
const NON_CONNECTION_ARC_COUNT: usize = 4;
|
||||||
|
const STATUS_UPDATE_INTERVAL_SECS: Duration = Duration::from_secs(1);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = interval(STATUS_UPDATE_INTERVAL_SECS);
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
log::trace!("Updating systemd status notification");
|
||||||
|
let count = Arc::strong_count(&connection_counter) - NON_CONNECTION_ARC_COUNT;
|
||||||
|
let message = if count > 0 {
|
||||||
|
format!("Handling {} connections", count)
|
||||||
|
} else {
|
||||||
|
"Waiting for connections".to_string()
|
||||||
|
};
|
||||||
|
sd_notify::notify(false, &[sd_notify::NotifyState::Status(message.as_str())]).ok();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_unix_listener_with_socket_path(
|
||||||
|
socket_path: PathBuf,
|
||||||
|
) -> anyhow::Result<TokioUnixListener> {
|
||||||
|
let parent_directory = socket_path.parent().unwrap();
|
||||||
|
if !parent_directory.exists() {
|
||||||
|
log::debug!("Creating directory {:?}", parent_directory);
|
||||||
|
fs::create_dir_all(parent_directory)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Listening on socket {:?}", socket_path);
|
||||||
|
|
||||||
|
match fs::remove_file(socket_path.as_path()) {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
}
|
||||||
|
|
||||||
|
let listener = TokioUnixListener::bind(socket_path)?;
|
||||||
|
|
||||||
|
Ok(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_unix_listener_with_systemd_socket() -> anyhow::Result<TokioUnixListener> {
|
||||||
|
let fd = sd_notify::listen_fds()
|
||||||
|
.context("Failed to get file descriptors from systemd")?
|
||||||
|
.next()
|
||||||
|
.context("No file descriptors received from systemd")?;
|
||||||
|
|
||||||
|
debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
|
||||||
|
|
||||||
|
log::debug!(
|
||||||
|
"Received file descriptor from systemd with id: '{}', assuming socket",
|
||||||
|
fd
|
||||||
|
);
|
||||||
|
|
||||||
|
let std_unix_listener = unsafe { StdUnixListener::from_raw_fd(fd) };
|
||||||
|
let listener = TokioUnixListener::from_std(std_unix_listener)?;
|
||||||
|
Ok(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySqlPool> {
|
||||||
|
let mut mysql_options = MySqlConnectOptions::new()
|
||||||
|
.database("mysql")
|
||||||
|
.log_statements(log::LevelFilter::Trace);
|
||||||
|
|
||||||
|
if let Some(username) = config.username.as_ref() {
|
||||||
|
mysql_options = mysql_options.username(username);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(password) = config.password.as_ref() {
|
||||||
|
mysql_options = mysql_options.password(password);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(socket_path) = config.socket_path.as_ref() {
|
||||||
|
mysql_options = mysql_options.socket(socket_path);
|
||||||
|
} else if let Some(host) = config.host.as_ref() {
|
||||||
|
mysql_options = mysql_options.host(host);
|
||||||
|
mysql_options = mysql_options.port(config.port);
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("No MySQL host or socket path provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
match tokio::time::timeout(
|
||||||
|
Duration::from_secs(config.timeout),
|
||||||
|
MySqlPool::connect_with(mysql_options),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(connection) => connection.context("Failed to connect to the database"),
|
||||||
|
Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
|
||||||
|
.context("Failed to connect to the database"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fn spawn_signal_handler_task(
|
||||||
|
// sighup_token: CancellationToken,
|
||||||
|
// sigterm_token: CancellationToken,
|
||||||
|
// ) -> JoinHandle<()> {
|
||||||
|
// tokio::spawn(async move {
|
||||||
|
// let mut sighup_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
|
||||||
|
// .expect("Failed to set up SIGHUP handler");
|
||||||
|
// let mut sigterm_stream = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
||||||
|
// .expect("Failed to set up SIGTERM handler");
|
||||||
|
|
||||||
|
// loop {
|
||||||
|
// tokio::select! {
|
||||||
|
// _ = sighup_stream.recv() => {
|
||||||
|
// log::info!("Received SIGHUP signal");
|
||||||
|
// sighup_token.cancel();
|
||||||
|
// }
|
||||||
|
// _ = sigterm_stream.recv() => {
|
||||||
|
// log::info!("Received SIGTERM signal");
|
||||||
|
// sigterm_token.cancel();
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// }
|
||||||
|
|
||||||
|
async fn spawn_listener_task(
|
||||||
|
listener: TokioUnixListener,
|
||||||
|
config: ServerConfig,
|
||||||
|
connection_counter: Arc<()>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
|
||||||
|
|
||||||
|
while let Ok((conn, _addr)) = listener.accept().await {
|
||||||
|
log::debug!("Got new connection");
|
||||||
|
|
||||||
|
let uid = match conn.peer_cred() {
|
||||||
|
Ok(cred) => cred.uid(),
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to get peer credentials from socket: {}", e);
|
||||||
|
let mut message_stream = create_server_to_client_message_stream(conn);
|
||||||
|
message_stream
|
||||||
|
.send(Response::Error(
|
||||||
|
(concatdoc! {
|
||||||
|
"Server failed to get peer credentials from socket\n",
|
||||||
|
"Please check the server logs or contact the system administrators"
|
||||||
|
})
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
log::debug!("Validated peer UID: {}", uid);
|
||||||
|
|
||||||
|
let _connection_counter_guard = Arc::clone(&connection_counter);
|
||||||
|
|
||||||
|
let unix_user = match UnixUser::from_uid(uid) {
|
||||||
|
Ok(user) => user,
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to get username from uid: {}", e);
|
||||||
|
let mut message_stream = create_server_to_client_message_stream(conn);
|
||||||
|
message_stream
|
||||||
|
.send(Response::Error(
|
||||||
|
(concatdoc! {
|
||||||
|
"Server failed to get user data from the system\n",
|
||||||
|
"Please check the server logs or contact the system administrators"
|
||||||
|
})
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
log::info!("Accepted connection from UNIX user: {}", unix_user.username);
|
||||||
|
|
||||||
|
match session_handler(conn, &unix_user, &config).await {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to run server: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user