From d4771189bcd19beab49d97f9170d2129698c35b7 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Mon, 5 Jan 2026 19:48:23 +0900 Subject: [PATCH] server: properly support socket activation --- Cargo.lock | 10 +++++ Cargo.toml | 2 +- src/bin/roowhod.rs | 106 ++++++++++++++++++++++++++++++++++----------- 3 files changed, 91 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d70009c..2d93764 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -817,6 +817,7 @@ dependencies = [ "clap", "futures-util", "nix", + "sd-notify", "serde", "tokio", "toml", @@ -857,6 +858,15 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" +[[package]] +name = "sd-notify" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b943eadf71d8b69e661330cb0e2656e31040acf21ee7708e2c238a0ec6af2bf4" +dependencies = [ + "libc", +] + [[package]] name = "self_cell" version = "1.2.2" diff --git a/Cargo.toml b/Cargo.toml index 42457c2..e734881 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ toml = "0.9.10" tracing = "0.1.44" tracing-subscriber = { version = "0.3.22", features = ["env-filter"] } # onc-rpc = "0.3.2" -# sd-notify = "0.4.5" +sd-notify = "0.4.5" # serde_json = "1.0.148" uucore = { version = "0.5.0", features = ["utmpx"] } zlink = { version = "0.2.0", features = ["introspection"] } diff --git a/src/bin/roowhod.rs b/src/bin/roowhod.rs index 05a625f..2183641 100644 --- a/src/bin/roowhod.rs +++ b/src/bin/roowhod.rs @@ -1,39 +1,90 @@ use std::{ collections::HashMap, - net::{Ipv4Addr, SocketAddrV4}, + net::IpAddr, + os::fd::{AsRawFd, FromRawFd, OwnedFd}, + path::PathBuf, sync::Arc, }; use anyhow::Context; -use roowho2_lib::server::rwhod::{ - RWHOD_BROADCAST_PORT, rwhod_client_server_task, rwhod_packet_receiver_task, - rwhod_packet_sender_task, +use clap::Parser; +use roowho2_lib::{ + proto::WhodStatusUpdate, + server::rwhod::{ + rwhod_client_server_task, rwhod_packet_receiver_task, rwhod_packet_sender_task, + }, }; -use tokio::sync::RwLock; +use tokio::{net::UdpSocket, sync::RwLock}; use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt}; +#[derive(Parser)] +struct Args { + /// Path to configuration file + #[arg( + short = 'c', + long = "config", + default_value = "/etc/roowho2/roowho2.toml", + value_name = "PATH" + )] + config_path: PathBuf, +} + #[tokio::main] async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + tracing_subscriber::registry() .with(fmt::layer()) .with(EnvFilter::from_default_env()) .init(); let config = toml::from_str::( - &std::fs::read_to_string("/etc/roowho2/roowho2.toml") + &std::fs::read_to_string(args.config_path) .context("Failed to read configuration file /etc/roowho2/roowho2.toml")?, )?; + let fd_map: HashMap = HashMap::from_iter( + sd_notify::listen_fds_with_names(false)?.map(|(fd_num, name)| { + ( + name.clone(), + // SAFETY: please don't mess around with file descriptors in random places + // around the codebase lol + unsafe { std::os::fd::OwnedFd::from_raw_fd(fd_num) }, + ) + }), + ); + let mut join_set = tokio::task::JoinSet::new(); + let whod_status_store = Arc::new(RwLock::new(HashMap::new())); + if config.rwhod.enable { tracing::info!("Starting RWHOD server"); - join_set.spawn(rwhod_server()); + let socket = fd_map + .get("rwhod_socket") + .map(|fd| { + // SAFETY: see above + let std_socket = unsafe { std::net::UdpSocket::from_raw_fd(fd.as_raw_fd()) }; + std_socket.set_nonblocking(true)?; + UdpSocket::from_std(std_socket) + }) + .context("RWHOD server is enabled, but socket fd not provided by systemd")??; + + join_set.spawn(rwhod_server(socket, whod_status_store.clone())); } else { tracing::debug!("RWHOD server is disabled in configuration"); } + join_set.spawn(client_server( + fd_map + .get("client_socket") + .context("RWHOD client-server socket fd not provided by systemd")? + .try_clone() + .context("Failed to clone RWHOD client-server socket fd")?, + whod_status_store.clone(), + )); + join_set.spawn(ctrl_c_handler()); join_set.join_next().await.unwrap()??; @@ -47,34 +98,37 @@ async fn ctrl_c_handler() -> anyhow::Result<()> { .map_err(|e| anyhow::anyhow!("Failed to listen for Ctrl-C: {}", e)) } -async fn rwhod_server() -> anyhow::Result<()> { - let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, RWHOD_BROADCAST_PORT); - tracing::debug!("Binding RWHOD socket to {}", addr); - let socket = tokio::net::UdpSocket::bind(addr) - .await - .context("Failed to bind RWHOD UDP socket") - .and_then(|socket| { - socket.set_broadcast(true)?; - Ok(socket) - }) - .context("Failed to enable broadcast on RWHOD UDP socket") - .map(Arc::new)?; +async fn rwhod_server( + socket: UdpSocket, + whod_status_store: Arc>>, +) -> anyhow::Result<()> { + let socket = Arc::new(socket); let interfaces = roowho2_lib::server::rwhod::determine_relevant_interfaces()?; let sender_task = rwhod_packet_sender_task(socket.clone(), interfaces); - let status_store = Arc::new(RwLock::new(HashMap::new())); - let receiver_task = rwhod_packet_receiver_task(socket.clone(), status_store.clone()); - - tracing::debug!("Binding RWHOD client-server socket at /run/roowho2/rwhod.socket"); - let client_server_socket = zlink::unix::bind("/run/roowho2/rwhod.varlink")?; - let client_server_task = rwhod_client_server_task(client_server_socket, status_store.clone()); + let receiver_task = rwhod_packet_receiver_task(socket.clone(), whod_status_store); tokio::select! { res = sender_task => res?, res = receiver_task => res?, - res = client_server_task => res?, } Ok(()) } + +async fn client_server( + socket_fd: OwnedFd, + whod_status_store: Arc>>, +) -> anyhow::Result<()> { + // SAFETY: see above + let std_socket = + unsafe { std::os::unix::net::UnixListener::from_raw_fd(socket_fd.as_raw_fd()) }; + std_socket.set_nonblocking(true)?; + let zlink_listener = zlink::unix::Listener::try_from(OwnedFd::from(std_socket))?; + let client_server_task = rwhod_client_server_task(zlink_listener, whod_status_store); + + client_server_task.await?; + + Ok(()) +}