diff --git a/src/server/varlink_api.rs b/src/server/varlink_api.rs index bea243d..586d2a0 100644 --- a/src/server/varlink_api.rs +++ b/src/server/varlink_api.rs @@ -1,8 +1,9 @@ -use std::os::fd::OwnedFd; +use std::{os::fd::OwnedFd, time::Duration}; use anyhow::Context; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use tokio::time::timeout; use zlink::{ReplyError, service::MethodReply}; use crate::{ @@ -54,6 +55,7 @@ pub type VarlinkRuptimeResponse = Vec; #[zlink(interface = "no.ntnu.pvv.roowho2.rwhod")] pub enum VarlinkRwhodClientError { InvalidRequest, + TimedOut, } // Types for 'no.ntnu.pvv.roowho2.finger' @@ -95,6 +97,7 @@ pub type VarlinkFingerResponse = Vec; #[zlink(interface = "no.ntnu.pvv.roowho2.finger")] pub enum VarlinkFingerClientError { InvalidRequest, + TimedOut, } // -------------------- @@ -231,39 +234,88 @@ impl zlink::Service for VarlinkRoowhoo2ClientServer { Self::ReplyError<'service>, > { match call.method() { - VarlinkMethod::Rwhod(VarlinkRwhodClientRequest::Rwho { all }) => ( - MethodReply::Single(Some(VarlinkReply::Rwhod(VarlinkRwhodClientResponse::Rwho( - self.handle_rwho_request(*all).await, - )))), - Default::default(), - ), - VarlinkMethod::Rwhod(VarlinkRwhodClientRequest::Ruptime) => ( - MethodReply::Single(Some(VarlinkReply::Rwhod( - VarlinkRwhodClientResponse::Ruptime(self.handle_ruptime_request().await), - ))), - Default::default(), - ), + VarlinkMethod::Rwhod(VarlinkRwhodClientRequest::Rwho { all }) => { + let result = + match timeout(Duration::from_secs(2), self.handle_rwho_request(*all)).await { + Ok(response) => response, + Err(_) => { + tracing::error!("Rwho request timed out after 2 seconds"); + return ( + MethodReply::Error(VarlinkReplyError::Rwhod( + VarlinkRwhodClientError::TimedOut, + )), + Default::default(), + ); + } + }; + + ( + MethodReply::Single(Some(VarlinkReply::Rwhod( + VarlinkRwhodClientResponse::Rwho(result), + ))), + Default::default(), + ) + } + VarlinkMethod::Rwhod(VarlinkRwhodClientRequest::Ruptime) => { + let result = + match timeout(Duration::from_secs(2), self.handle_ruptime_request()).await { + Ok(response) => response, + Err(_) => { + tracing::error!("Ruptime request timed out after 2 seconds"); + return ( + MethodReply::Error(VarlinkReplyError::Rwhod( + VarlinkRwhodClientError::TimedOut, + )), + Default::default(), + ); + } + }; + + ( + MethodReply::Single(Some(VarlinkReply::Rwhod( + VarlinkRwhodClientResponse::Ruptime(result), + ))), + Default::default(), + ) + } VarlinkMethod::Finger(VarlinkFingerClientRequest::Finger { user_queries, match_fullnames, request_info, request_networking, disable_user_account_db, - }) => ( - MethodReply::Single(Some(VarlinkReply::Finger( - VarlinkFingerClientResponse::Finger( - self.handle_finger_request( - user_queries.clone(), - *match_fullnames, - request_info.clone(), - request_networking.clone(), - *disable_user_account_db, - ) - .await, + }) => { + let result = match timeout( + Duration::from_secs(2), + self.handle_finger_request( + user_queries.clone(), + *match_fullnames, + request_info.clone(), + request_networking.clone(), + *disable_user_account_db, ), - ))), - Default::default(), - ), + ) + .await + { + Ok(response) => response, + Err(_) => { + tracing::error!("Finger request timed out after 2 seconds"); + return ( + MethodReply::Error(VarlinkReplyError::Finger( + VarlinkFingerClientError::TimedOut, + )), + Default::default(), + ); + } + }; + + ( + MethodReply::Single(Some(VarlinkReply::Finger( + VarlinkFingerClientResponse::Finger(result), + ))), + Default::default(), + ) + } } } }