Files
muscl/src/core/bootstrap.rs
h7x4 4c82da390f
All checks were successful
Build and test / check (push) Successful in 1m54s
Build and test / build (push) Successful in 3m10s
Build and test / test (push) Successful in 3m30s
Build and test / check-license (push) Successful in 7m25s
Build and test / docs (push) Successful in 5m26s
server: determine sql server variant, fix lock-user,unlock-user
2025-12-14 03:30:40 +09:00

306 lines
11 KiB
Rust

use std::{fs, path::PathBuf, sync::Arc, time::Duration};
use anyhow::{Context, anyhow};
use clap_verbosity_flag::{InfoLevel, Verbosity};
use nix::libc::{EXIT_SUCCESS, exit};
use sqlx::mysql::MySqlPoolOptions;
use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::{net::UnixStream as TokioUnixStream, sync::RwLock};
use tracing_subscriber::prelude::*;
use crate::{
core::common::{
DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executing_in_suid_sgid_mode,
},
server::{
config::{MysqlConfig, ServerConfig},
landlock::landlock_restrict_server,
session_handler,
},
};
/// Determine whether we will make a connection to an external server
/// or start an internal server with elevated privileges.
///
/// If neither is feasible, an error is returned.
fn will_connect_to_external_server(
server_socket_path: Option<&PathBuf>,
// This parameter is only used in suid-sgid-mode
#[allow(unused_variables)] config_path: Option<&PathBuf>,
) -> anyhow::Result<bool> {
if server_socket_path.is_some() {
return Ok(true);
}
#[cfg(feature = "suid-sgid-mode")]
if config_path.is_some() {
return Ok(false);
}
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
return Ok(true);
}
#[cfg(feature = "suid-sgid-mode")]
if fs::metadata(DEFAULT_CONFIG_PATH).is_ok() {
return Ok(false);
}
#[cfg(feature = "suid-sgid-mode")]
anyhow::bail!("No socket path or config path provided, and no default socket or config found");
#[cfg(not(feature = "suid-sgid-mode"))]
anyhow::bail!("No socket path provided, and no default socket found");
}
/// This function is used to bootstrap the connection to the server.
/// This can happen in two ways:
///
/// 1. If a socket path is provided, or exists in the default location,
/// the function will connect to the socket and authenticate with the
/// server to ensure that the server knows the uid of the client.
///
/// 2. If a config path is provided, or exists in the default location,
/// and the config is readable, the function will assume it is either
/// setuid or setgid, and will fork a child process to run the server
/// with the provided config. The server will exit silently by itself
/// when it is done, and this function will only return for the client
/// with the socket for the server.
///
/// If neither of these options are available, the function will fail.
///
/// Note that this function is also responsible for setting up logging,
/// because in the case of an internal server, we need to drop privileges
/// before we can initialize logging.
///
/// **WARNING:** This function may be run with elevated privileges.
pub fn bootstrap_server_connection_and_drop_privileges(
server_socket_path: Option<PathBuf>,
config: Option<PathBuf>,
verbose: Verbosity<InfoLevel>,
) -> anyhow::Result<StdUnixStream> {
if will_connect_to_external_server(server_socket_path.as_ref(), config.as_ref())? {
assert!(
!executing_in_suid_sgid_mode()?,
"The executable should not be SUID or SGID when connecting to an external server"
);
let subscriber = tracing_subscriber::Registry::default()
.with(verbose.tracing_level_filter())
.with(
tracing_subscriber::fmt::layer()
.with_line_number(cfg!(debug_assertions))
.with_target(cfg!(debug_assertions))
.with_thread_ids(false)
.with_thread_names(false),
);
tracing::subscriber::set_global_default(subscriber)
.context("Failed to set global default tracing subscriber")?;
connect_to_external_server(server_socket_path)
} else if cfg!(feature = "suid-sgid-mode") {
// NOTE: We need to be really careful with the code up until this point,
// as we might be running with elevated privileges.
let server_connection = bootstrap_internal_server_and_drop_privs(config)?;
let subscriber = tracing_subscriber::Registry::default()
.with(verbose.tracing_level_filter())
.with(
tracing_subscriber::fmt::layer()
.with_line_number(cfg!(debug_assertions))
.with_target(cfg!(debug_assertions))
.with_thread_ids(false)
.with_thread_names(false),
);
tracing::subscriber::set_global_default(subscriber)
.context("Failed to set global default tracing subscriber")?;
Ok(server_connection)
} else {
anyhow::bail!("SUID/SGID support is not enabled, cannot start internal server");
}
}
fn connect_to_external_server(
server_socket_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> {
// TODO: ensure this is both readable and writable
if let Some(socket_path) = server_socket_path {
tracing::debug!("Connecting to socket at {:?}", socket_path);
return match StdUnixStream::connect(socket_path) {
Ok(socket) => Ok(socket),
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
},
};
}
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
tracing::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
Ok(socket) => Ok(socket),
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
},
};
}
anyhow::bail!("No socket path provided, and no default socket found");
}
// TODO: this function is security critical, it should be integration tested
// in isolation.
/// Drop privileges to the real user and group of the process.
/// If the process is not running with elevated privileges, this function
/// is a no-op.
pub fn drop_privs() -> anyhow::Result<()> {
tracing::debug!("Dropping privileges");
let real_uid = nix::unistd::getuid();
let real_gid = nix::unistd::getgid();
nix::unistd::setuid(real_uid).context("Failed to drop privileges")?;
nix::unistd::setgid(real_gid).context("Failed to drop privileges")?;
debug_assert_eq!(nix::unistd::getuid(), real_uid);
debug_assert_eq!(nix::unistd::getgid(), real_gid);
tracing::debug!("Privileges dropped successfully");
Ok(())
}
fn bootstrap_internal_server_and_drop_privs(
config_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> {
if let Some(config_path) = config_path {
if !executing_in_suid_sgid_mode()? {
anyhow::bail!("Executable is not SUID/SGID - refusing to start internal sever");
}
// ensure config exists and is readable
if fs::metadata(&config_path).is_err() {
return Err(anyhow::anyhow!("Config file not found or not readable"));
}
tracing::debug!("Starting server with config at {:?}", config_path);
let socket = invoke_server_with_config(config_path)?;
drop_privs()?;
return Ok(socket);
};
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
if fs::metadata(&config_path).is_ok() {
if !executing_in_suid_sgid_mode()? {
anyhow::bail!("Executable is not SUID/SGID - refusing to start internal sever");
}
tracing::debug!("Starting server with default config at {:?}", config_path);
let socket = invoke_server_with_config(config_path)?;
drop_privs()?;
return Ok(socket);
};
anyhow::bail!("No config path provided, and no default config found");
}
// TODO: we should somehow ensure that the forked process is killed on completion,
// just in case the client does not behave properly.
/// Fork a child process to run the server with the provided config.
/// The server will exit silently by itself when it is done, and this function
/// will only return for the client with the socket for the server.
fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStream> {
let (server_socket, client_socket) = StdUnixStream::pair()?;
let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?;
match (unsafe { nix::unistd::fork() }).context("Failed to fork")? {
nix::unistd::ForkResult::Parent { child } => {
tracing::debug!("Forked child process with PID {}", child);
Ok(client_socket)
}
nix::unistd::ForkResult::Child => {
tracing::debug!("Running server in child process");
landlock_restrict_server(Some(config_path.as_path()))
.context("Failed to apply Landlock restrictions to the server process")?;
match run_forked_server(config_path, server_socket, unix_user) {
Err(e) => Err(e),
Ok(_) => unreachable!(),
}
}
}
}
async fn construct_single_connection_mysql_pool(
config: &MysqlConfig,
) -> anyhow::Result<sqlx::MySqlPool> {
let mysql_config = config.as_mysql_connect_options()?;
let pool_opts = MySqlPoolOptions::new()
.max_connections(1)
.min_connections(1);
config.log_connection_notice();
let pool = match tokio::time::timeout(
Duration::from_secs(config.timeout),
pool_opts.connect_with(mysql_config),
)
.await
{
Ok(connection) => connection.context("Failed to connect to the database"),
Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
.context("Failed to connect to the database"),
}?;
Ok(pool)
}
/// Run the server in the forked child process.
/// This function will not return, but will exit the process with a success code.
fn run_forked_server(
config_path: PathBuf,
server_socket: StdUnixStream,
unix_user: UnixUser,
) -> anyhow::Result<()> {
let config = ServerConfig::read_config_from_path(&config_path)
.context("Failed to read server config in forked process")?;
let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let socket = TokioUnixStream::from_std(server_socket)?;
let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?;
let db_is_mariadb = {
let mut conn = db_pool.acquire().await?;
let version_row: String = sqlx::query_scalar("SELECT VERSION()")
.fetch_one(&mut *conn)
.await
.context("Failed to query MySQL version")?;
version_row.to_lowercase().contains("mariadb")
};
let db_pool = Arc::new(RwLock::new(db_pool));
session_handler::session_handler_with_unix_user(
socket,
&unix_user,
db_pool,
db_is_mariadb,
)
.await?;
Ok(())
});
result?;
unsafe {
exit(EXIT_SUCCESS);
}
}