diff --git a/examples/make_realtime/main.rs b/examples/make_realtime/main.rs index 104b826..ab491ba 100644 --- a/examples/make_realtime/main.rs +++ b/examples/make_realtime/main.rs @@ -1,8 +1,7 @@ use nix::unistd::gettid; use std::thread; -// use rtkit_client_rs::::RTKitProxyBlocking; -use rtkit_client_rs::{make_current_thread_realtime, Error}; +use rtkit_client_rs::{set_current_thread_priority, Error}; fn main() { println!("Main tid: {}", gettid()); @@ -11,10 +10,21 @@ fn main() { fn important_thread() { println!("Important thread tid: {}", gettid()); + let requested_priority = 40; - let actual_priority = make_current_thread_realtime(Some(requested_priority)); + let actual_priority = set_current_thread_priority(Some(requested_priority.try_into().unwrap())); match actual_priority { - Ok(actual_priority) => println!("Requested priority: {requested_priority}, Actual priority: {actual_priority}"), + Ok(actual_priority) => println!("Requested priority: {requested_priority}, Actual priority: {}", actual_priority.value()), + Err(Error::PermissionDenied) => println!("Permission denied. Do you have polkit rules set up, or otherwise have the necessary permissions?"), + Err(e) => println!("Internal zbus error: {e}"), + } + + let requested_nice_level = -10; + let actual_nice_level = rtkit_client_rs::set_current_thread_nice_level(Some( + requested_nice_level.try_into().unwrap(), + )); + match actual_nice_level { + Ok(actual_nice_level) => println!("Requested nice level: {requested_nice_level}, Actual nice level: {}", actual_nice_level.value()), Err(Error::PermissionDenied) => println!("Permission denied. Do you have polkit rules set up, or otherwise have the necessary permissions?"), Err(e) => println!("Internal zbus error: {e}"), } diff --git a/src/lib.rs b/src/lib.rs index 99c4a74..89afd17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,25 @@ mod low_level_zbus_api; +mod types; + +pub use types::{NiceLevel, Priority}; #[derive(thiserror::Error, Debug)] pub enum Error { #[error("Permission denied")] PermissionDenied, + #[error("Invalid PID")] + InvalidPid, + + #[error("Invalid thread ID")] + InvalidThreadId, + + #[error("Got invalid max priority from RTKit")] + InvalidMaxPriority(i32), + + #[error("Got invalid min nice level from RTKit")] + InvalidMinNiceLevel(i32), + #[error(transparent)] ZbusError(#[from] zbus::Error), @@ -13,9 +28,9 @@ pub enum Error { } /// Helper for [make_thread_realtime], uses the current thread's tid. -pub fn make_current_thread_realtime(priority: Option) -> Result { +pub fn set_current_thread_priority(priority: Option) -> Result { let thread_id = nix::unistd::gettid().as_raw().try_into().unwrap(); - make_thread_realtime(thread_id, priority) + set_thread_priority(thread_id, priority) } /// Elevates the realtime priority of a thread as close as possible to the requested priority. @@ -23,26 +38,20 @@ pub fn make_current_thread_realtime(priority: Option) -> Result /// If no priority is specified, the maximum priority allowed by the system will be used. /// /// Returns the actual priority that was set. -pub fn make_thread_realtime(tid: u64, priority: Option) -> Result { - debug_assert!( - priority.is_some_and(|p| p <= 99), - "priority must be between 0 and 99" - ); - +pub fn set_thread_priority(tid: u64, priority: Option) -> Result { let connection = zbus::blocking::Connection::system()?; let proxy = low_level_zbus_api::RTKitProxyBlocking::builder(&connection) .cache_properties(zbus::proxy::CacheProperties::No) .build()?; - let max_realtime_priority = proxy.max_realtime_priority()?; - debug_assert!( - (0..=99).contains(&max_realtime_priority), - "max_realtime_priority must be between 0 and 99" - ); - let priority = priority.unwrap_or(max_realtime_priority as u32); + let max_priority = proxy.max_realtime_priority()?; + let max_priority = max_priority + .try_into() + .map_err(|_| Error::InvalidMaxPriority(max_priority))?; + let priority = priority.unwrap_or_default().min(max_priority); proxy - .make_thread_realtime(tid, priority) + .make_thread_realtime(tid, priority.value()) .map(|_| priority) .map_err(|e| match e { zbus::fdo::Error::AccessDenied(_) => Error::PermissionDenied, @@ -55,33 +64,91 @@ pub fn make_thread_realtime(tid: u64, priority: Option) -> Result, -) -> Result { - debug_assert!( - priority.is_some_and(|p| p <= 99), - "priority must be between 0 and 99" - ); - + priority: Option, +) -> Result { let connection = zbus::blocking::Connection::system()?; let proxy = low_level_zbus_api::RTKitProxyBlocking::builder(&connection) .cache_properties(zbus::proxy::CacheProperties::No) .build()?; - let max_realtime_priority = proxy.max_realtime_priority()?; - debug_assert!( - (0..=99).contains(&max_realtime_priority), - "max_realtime_priority must be between 0 and 99" - ); - let priority = priority.unwrap_or(max_realtime_priority as u32); + let max_priority = proxy.max_realtime_priority()?; + let max_priority = max_priority + .try_into() + .map_err(|_| Error::InvalidMaxPriority(max_priority))?; + let priority = priority.unwrap_or_default().min(max_priority); proxy - .make_thread_realtime_with_pid(pid, tid, priority) + .make_thread_realtime_with_pid(pid, tid, priority.value()) .map(|_| priority) .map_err(|e| match e { zbus::fdo::Error::AccessDenied(_) => Error::PermissionDenied, e => Error::ZbusMethodError(e), }) } + +/// Helper for [set_thread_nice_level], uses the current thread's tid. +pub fn set_current_thread_nice_level(nice_level: Option) -> Result { + let thread_id = nix::unistd::gettid().as_raw().try_into().unwrap(); + set_thread_nice_level(thread_id, nice_level) +} + +/// Elevates the nice level of a thread as close as possible to the requested nice level. +/// +/// If no nice level is specified, the maximum nice level allowed by the system will be used. +/// +/// Returns the actual nice level that was set. +pub fn set_thread_nice_level(tid: u64, nice_level: Option) -> Result { + let connection = zbus::blocking::Connection::system()?; + let proxy = low_level_zbus_api::RTKitProxyBlocking::builder(&connection) + .cache_properties(zbus::proxy::CacheProperties::No) + .build()?; + + let nice_level = nice_level.unwrap_or_default(); + + let max_nice_level = proxy.min_nice_level()?; + let max_nice_level = max_nice_level + .try_into() + .map_err(|_| Error::InvalidMinNiceLevel(max_nice_level))?; + let nice_level = nice_level.max(max_nice_level); + + proxy + .make_thread_high_priority(tid, nice_level.value()) + .map(|_| nice_level) + .map_err(|e| match e { + zbus::fdo::Error::AccessDenied(_) => Error::PermissionDenied, + e => Error::ZbusMethodError(e), + }) +} + +/// Elevates the nice level of a specified processes thread as close as possible to the requested nice level. +/// +/// If no nice level is specified, the maximum nice level allowed by the system will be used. +/// Returns the actual nice level that was set. +pub fn set_process_thread_nice_level( + pid: u64, + tid: u64, + nice_level: Option, +) -> Result { + let connection = zbus::blocking::Connection::system()?; + let proxy = low_level_zbus_api::RTKitProxyBlocking::builder(&connection) + .cache_properties(zbus::proxy::CacheProperties::No) + .build()?; + + let nice_level = nice_level.unwrap_or_default(); + let max_nice_level = proxy.min_nice_level()?; + let max_nice_level = max_nice_level + .try_into() + .map_err(|_| Error::InvalidMinNiceLevel(max_nice_level))?; + let nice_level = nice_level.max(max_nice_level); + + proxy + .make_thread_high_priority_with_pid(pid, tid, nice_level.value()) + .map(|_| nice_level) + .map_err(|e| match e { + zbus::fdo::Error::AccessDenied(_) => Error::PermissionDenied, + e => Error::ZbusMethodError(e), + }) +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..460fe8c --- /dev/null +++ b/src/types.rs @@ -0,0 +1,124 @@ +use std::ops::Deref; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct Priority(u32); + +impl Priority { + pub const MIN: u32 = 0; + pub const MAX: u32 = 99; + + /// Creates a new Priority, ensuring it is within valid bounds (0-99). + pub fn new(value: u32) -> Option { + if value <= Self::MAX { + Some(Priority(value)) + } else { + None + } + } + + /// Returns the underlying priority value. + pub fn value(&self) -> u32 { + self.0 + } +} + +impl From for u32 { + fn from(priority: Priority) -> Self { + priority.0 + } +} + +impl TryFrom for Priority { + type Error = &'static str; + + fn try_from(value: u32) -> Result { + Priority::new(value).ok_or("Priority must be between 0 and 99") + } +} + +impl TryFrom for Priority { + type Error = &'static str; + + fn try_from(value: i32) -> Result { + if value < 0 { + return Err("Priority must be between 0 and 99"); + } + Priority::new(value as u32).ok_or("Priority must be between 0 and 99") + } +} + +impl Default for Priority { + fn default() -> Self { + Priority(99) + } +} + +impl AsRef for Priority { + fn as_ref(&self) -> &u32 { + &self.0 + } +} + +impl Deref for Priority { + type Target = u32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct NiceLevel(i32); + +impl NiceLevel { + pub const MIN: i32 = -20; + pub const MAX: i32 = 19; + + /// Creates a new NiceLevel, ensuring it is within valid bounds (-20 to 19). + pub fn new(value: i32) -> Option { + if (Self::MIN..=Self::MAX).contains(&value) { + Some(NiceLevel(value)) + } else { + None + } + } + + /// Returns the underlying nice level value. + pub fn value(&self) -> i32 { + self.0 + } +} + +impl From for i32 { + fn from(nice_level: NiceLevel) -> Self { + nice_level.0 + } +} + +impl TryFrom for NiceLevel { + type Error = &'static str; + + fn try_from(value: i32) -> Result { + NiceLevel::new(value).ok_or("NiceLevel must be between -20 and 19") + } +} + +impl Default for NiceLevel { + fn default() -> Self { + NiceLevel(Self::MIN) + } +} + +impl AsRef for NiceLevel { + fn as_ref(&self) -> &i32 { + &self.0 + } +} + +impl Deref for NiceLevel { + type Target = i32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +}