diff --git a/src/api/websocket_v1.rs b/src/api/websocket_v1.rs index 3ef1f32..8436e4f 100644 --- a/src/api/websocket_v1.rs +++ b/src/api/websocket_v1.rs @@ -1,4 +1,7 @@ -use std::net::SocketAddr; +use std::{ + net::SocketAddr, + sync::{Arc, Mutex}, +}; use anyhow::Context; use futures::{stream::FuturesUnordered, StreamExt}; @@ -18,29 +21,45 @@ use mpvipc_async::{ Switch, }; use serde_json::{json, Value}; -use tokio::select; +use tokio::{select, sync::watch}; -pub fn websocket_api(mpv: Mpv) -> Router { +use crate::util::IdPool; + +#[derive(Debug, Clone)] +struct WebsocketState { + mpv: Mpv, + id_pool: Arc>, +} + +pub fn websocket_api(mpv: Mpv, id_pool: Arc>) -> Router { + let state = WebsocketState { mpv, id_pool }; Router::new() .route("/", any(websocket_handler)) - .with_state(mpv) + .with_state(state) } async fn websocket_handler( ws: WebSocketUpgrade, ConnectInfo(addr): ConnectInfo, - State(mpv): State, + State(WebsocketState { mpv, id_pool }): State, ) -> impl IntoResponse { let mpv = mpv.clone(); + let id = match id_pool.lock().unwrap().request_id() { + Ok(id) => id, + Err(e) => { + log::error!("Failed to get id from id pool: {:?}", e); + return axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + }; - // TODO: get an id provisioned by the id pool - ws.on_upgrade(move |socket| handle_connection(socket, addr, mpv, 1)) + ws.on_upgrade(move |socket| handle_connection(socket, addr, mpv, id, id_pool)) } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct InitialState { pub cached_timestamp: Option, pub chapters: Vec, + pub connections: u64, pub current_percent_pos: Option, pub current_track: String, pub duration: f64, @@ -53,7 +72,7 @@ pub struct InitialState { pub volume: f64, } -async fn get_initial_state(mpv: &Mpv) -> InitialState { +async fn get_initial_state(mpv: &Mpv, id_pool: Arc>) -> InitialState { let cached_timestamp = mpv .get_property_value("demuxer-cache-state") .await @@ -69,6 +88,7 @@ async fn get_initial_state(mpv: &Mpv) -> InitialState { Ok(Some(Value::Array(chapters))) => chapters, _ => vec![], }; + let connections = id_pool.lock().unwrap().id_count(); let current_percent_pos = mpv.get_property("percent-pos").await.unwrap_or(None); let current_track = mpv.get_file_path().await.unwrap_or("".to_string()); let duration = mpv.get_duration().await.unwrap_or(0.0); @@ -104,6 +124,7 @@ async fn get_initial_state(mpv: &Mpv) -> InitialState { InitialState { cached_timestamp, chapters, + connections, current_percent_pos, current_track, duration, @@ -147,12 +168,18 @@ async fn setup_default_subscribes(mpv: &Mpv) -> anyhow::Result<()> { Ok(()) } -async fn handle_connection(mut socket: WebSocket, addr: SocketAddr, mpv: Mpv, channel_id: u64) { +async fn handle_connection( + mut socket: WebSocket, + addr: SocketAddr, + mpv: Mpv, + channel_id: u64, + id_pool: Arc>, +) { // TODO: There is an asynchronous gap between gathering the initial state and subscribing to the properties // This could lead to missing events if they happen in that gap. Send initial state, but also ensure // that there is an additional "initial state" sent upon subscription to all properties to ensure that // the state is correct. - let initial_state = get_initial_state(&mpv).await; + let initial_state = get_initial_state(&mpv, id_pool.clone()).await; let message = Message::Text( json!({ @@ -166,89 +193,17 @@ async fn handle_connection(mut socket: WebSocket, addr: SocketAddr, mpv: Mpv, ch setup_default_subscribes(&mpv).await.unwrap(); - let connection_loop_mpv = mpv.clone(); - let connection_loop = tokio::spawn(async move { - let mut event_stream = connection_loop_mpv.get_event_stream().await; - loop { - select! { - message = socket.recv() => { - log::trace!("Received command from {:?}: {:?}", addr, message); + let id_count_watch_receiver = id_pool.lock().unwrap().get_id_count_watch_receiver(); - let ws_message_content = message - .ok_or(anyhow::anyhow!("Event stream ended for {:?}", addr)) - .and_then(|message| { - match message { - Ok(message) => Ok(message), - err => Err(anyhow::anyhow!("Error reading message for {:?}: {:?}", addr, err)), - } - })?; + let connection_loop_result = tokio::spawn(connection_loop( + socket, + addr, + mpv.clone(), + channel_id, + id_count_watch_receiver, + )); - if let Message::Close(_) = ws_message_content { - log::trace!("Closing connection for {:?}", addr); - return Ok(()); - } - - if let Message::Ping(xs) = ws_message_content { - log::trace!("Ponging {:?} with {:?}", addr, xs); - socket.send(Message::Pong(xs)).await?; - continue; - } - - let message_content = match ws_message_content { - Message::Text(text) => text, - m => anyhow::bail!("Unexpected message type: {:?}", m), - }; - - let message_json = match serde_json::from_str::(&message_content) { - Ok(json) => json, - Err(e) => anyhow::bail!("Error parsing message from {:?}: {:?}", addr, e), - }; - - log::trace!("Handling command from {:?}: {:?}", addr, message_json); - - // TODO: handle errors - match handle_message(message_json, connection_loop_mpv.clone(), channel_id).await { - Ok(Some(response)) => { - log::trace!("Handled command from {:?} successfully, sending response", addr); - let message = Message::Text(json!({ - "type": "response", - "value": response, - }).to_string()); - socket.send(message).await?; - } - Ok(None) => { - log::trace!("Handled command from {:?} successfully", addr); - } - Err(e) => { - log::error!("Error handling message from {:?}: {:?}", addr, e); - } - } - } - event = event_stream.next() => { - match event { - Some(Ok(event)) => { - log::trace!("Sending event to {:?}: {:?}", addr, event); - let message = Message::Text(json!({ - "type": "event", - "value": event, - }).to_string()); - socket.send(message).await?; - } - Some(Err(e)) => { - log::error!("Error reading event stream for {:?}: {:?}", addr, e); - anyhow::bail!("Error reading event stream for {:?}: {:?}", addr, e); - } - None => { - log::trace!("Event stream ended for {:?}", addr); - return Ok(()); - } - } - } - } - } - }); - - match connection_loop.await { + match connection_loop_result.await { Ok(Ok(())) => { log::trace!("Connection loop ended for {:?}", addr); } @@ -272,6 +227,114 @@ async fn handle_connection(mut socket: WebSocket, addr: SocketAddr, mpv: Mpv, ch ); } } + + match id_pool.lock().unwrap().release_id(channel_id) { + Ok(()) => { + log::trace!("Released id {} for {:?}", channel_id, addr); + } + Err(e) => { + log::error!("Error releasing id {} for {:?}: {:?}", channel_id, addr, e); + } + } +} + +async fn connection_loop( + mut socket: WebSocket, + addr: SocketAddr, + mpv: Mpv, + channel_id: u64, + mut id_count_watch_receiver: watch::Receiver, +) -> Result<(), anyhow::Error> { + let mut event_stream = mpv.get_event_stream().await; + loop { + select! { + id_count = id_count_watch_receiver.changed() => { + if let Err(e) = id_count { + anyhow::bail!("Error reading id count watch receiver for {:?}: {:?}", addr, e); + } + + let message = Message::Text(json!({ + "type": "connection_count", + "value": id_count_watch_receiver.borrow().clone(), + }).to_string()); + + socket.send(message).await?; + } + message = socket.recv() => { + log::trace!("Received command from {:?}: {:?}", addr, message); + + let ws_message_content = message + .ok_or(anyhow::anyhow!("Event stream ended for {:?}", addr)) + .and_then(|message| { + match message { + Ok(message) => Ok(message), + err => Err(anyhow::anyhow!("Error reading message for {:?}: {:?}", addr, err)), + } + })?; + + if let Message::Close(_) = ws_message_content { + log::trace!("Closing connection for {:?}", addr); + return Ok(()); + } + + if let Message::Ping(xs) = ws_message_content { + log::trace!("Ponging {:?} with {:?}", addr, xs); + socket.send(Message::Pong(xs)).await?; + continue; + } + + let message_content = match ws_message_content { + Message::Text(text) => text, + m => anyhow::bail!("Unexpected message type: {:?}", m), + }; + + let message_json = match serde_json::from_str::(&message_content) { + Ok(json) => json, + Err(e) => anyhow::bail!("Error parsing message from {:?}: {:?}", addr, e), + }; + + log::trace!("Handling command from {:?}: {:?}", addr, message_json); + + // TODO: handle errors + match handle_message(message_json, mpv.clone(), channel_id).await { + Ok(Some(response)) => { + log::trace!("Handled command from {:?} successfully, sending response", addr); + let message = Message::Text(json!({ + "type": "response", + "value": response, + }).to_string()); + socket.send(message).await?; + } + Ok(None) => { + log::trace!("Handled command from {:?} successfully", addr); + } + Err(e) => { + log::error!("Error handling message from {:?}: {:?}", addr, e); + } + } + } + event = event_stream.next() => { + match event { + Some(Ok(event)) => { + log::trace!("Sending event to {:?}: {:?}", addr, event); + let message = Message::Text(json!({ + "type": "event", + "value": event, + }).to_string()); + socket.send(message).await?; + } + Some(Err(e)) => { + log::error!("Error reading event stream for {:?}: {:?}", addr, e); + anyhow::bail!("Error reading event stream for {:?}: {:?}", addr, e); + } + None => { + log::trace!("Event stream ended for {:?}", addr); + return Ok(()); + } + } + } + } + } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -279,7 +342,6 @@ async fn handle_connection(mut socket: WebSocket, addr: SocketAddr, mpv: Mpv, ch pub enum WSCommand { // Subscribe { property: String }, // UnsubscribeAll, - Load { urls: Vec }, TogglePlayback, Volume { volume: f64 }, diff --git a/src/main.rs b/src/main.rs index 0028575..bd9d29a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,13 +5,18 @@ use clap_verbosity_flag::Verbosity; use futures::StreamExt; use mpv_setup::{connect_to_mpv, create_mpv_config_file, show_grzegorz_image}; use mpvipc_async::{Event, Mpv, MpvDataType, MpvExt}; -use std::net::{IpAddr, SocketAddr}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex}, +}; use systemd_journal_logger::JournalLog; use tempfile::NamedTempFile; use tokio::task::JoinHandle; +use util::IdPool; mod api; mod mpv_setup; +mod util; #[derive(Parser)] struct Args { @@ -119,29 +124,26 @@ async fn setup_systemd_notifier(mpv: Mpv) -> anyhow::Result> { systemd_update_play_status(playing, ¤t_song); loop { - match event_stream.next().await { - Some(Ok(Event::PropertyChange { name, data, .. })) => { - match (name.as_str(), data) { - ("media-title", Some(MpvDataType::String(s))) => { - current_song = Some(s); - } - ("media-title", None) => { - current_song = None; - } - ("pause", Some(MpvDataType::Bool(b))) => { - playing = !b; - } - (event_name, _) => { - log::trace!( - "Received unexpected property change on systemd notifier thread: {}", - event_name - ); - } + if let Some(Ok(Event::PropertyChange { name, data, .. })) = event_stream.next().await { + match (name.as_str(), data) { + ("media-title", Some(MpvDataType::String(s))) => { + current_song = Some(s); + } + ("media-title", None) => { + current_song = None; + } + ("pause", Some(MpvDataType::Bool(b))) => { + playing = !b; + } + (event_name, _) => { + log::trace!( + "Received unexpected property change on systemd notifier thread: {}", + event_name + ); } - - systemd_update_play_status(playing, ¤t_song) } - _ => {} + + systemd_update_play_status(playing, ¤t_song) } } }); @@ -226,9 +228,11 @@ async fn main() -> anyhow::Result<()> { let socket_addr = SocketAddr::new(addr, args.port); log::info!("Starting API on {}", socket_addr); + let id_pool = Arc::new(Mutex::new(IdPool::new_with_max_limit(1024))); + let app = Router::new() .nest("/api", api::rest_api_routes(mpv.clone())) - .nest("/ws", api::websocket_api(mpv.clone())) + .nest("/ws", api::websocket_api(mpv.clone(), id_pool.clone())) .merge(api::rest_api_docs(mpv.clone())) .into_make_service_with_connect_info::(); diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..451206e --- /dev/null +++ b/src/util.rs @@ -0,0 +1,3 @@ +mod id_pool; + +pub use id_pool::IdPool; diff --git a/src/util/id_pool.rs b/src/util/id_pool.rs new file mode 100644 index 0000000..747ae07 --- /dev/null +++ b/src/util/id_pool.rs @@ -0,0 +1,145 @@ +use std::{collections::BTreeSet, fmt::Debug}; + +use tokio::sync::watch; + +/// A relatively naive ID pool implementation. +pub struct IdPool { + max_id: u64, + free_ids: BTreeSet, + id_count: u64, + id_count_watch_sender: watch::Sender, + id_count_watch_receiver: watch::Receiver, +} + +impl Debug for IdPool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("IdPool") + .field("max_id", &self.max_id) + .field("free_ids", &self.free_ids) + .field("id_count", &self.id_count) + .finish() + } +} + +impl Default for IdPool { + fn default() -> Self { + let (id_count_watch_sender, id_count_watch_receiver) = watch::channel(0); + Self { + max_id: u64::MAX, + free_ids: BTreeSet::new(), + id_count: 0, + id_count_watch_sender, + id_count_watch_receiver, + } + } +} + +//TODO: thiserror + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum IdPoolError { + NoFreeIds, + IdNotInUse(u64), + IdOutOfBound(u64), +} + +impl IdPool { + pub fn new_with_max_limit(max_id: u64) -> Self { + let (id_count_watch_sender, id_count_watch_receiver) = watch::channel(0); + Self { + max_id, + free_ids: BTreeSet::new(), + id_count: 0, + id_count_watch_sender, + id_count_watch_receiver, + } + } + + pub fn id_count(&self) -> u64 { + self.id_count - self.free_ids.len() as u64 + } + + pub fn id_is_used(&self, id: u64) -> Result { + if id > self.max_id { + Err(IdPoolError::IdOutOfBound(id)) + } else if self.free_ids.contains(&id) { + return Ok(false); + } else { + return Ok(id <= self.id_count); + } + } + + pub fn request_id(&mut self) -> Result { + if !self.free_ids.is_empty() { + let id = self.free_ids.pop_first().unwrap(); + self.update_watch(); + Ok(id) + } else if self.id_count < self.max_id { + self.id_count += 1; + self.update_watch(); + Ok(self.id_count) + } else { + Err(IdPoolError::NoFreeIds) + } + } + + pub fn release_id(&mut self, id: u64) -> Result<(), IdPoolError> { + if !self.id_is_used(id)? { + Err(IdPoolError::IdNotInUse(id)) + } else { + self.free_ids.insert(id); + self.update_watch(); + Ok(()) + } + } + + fn update_watch(&self) { + self.id_count_watch_sender.send(self.id_count()).unwrap(); + } + + pub fn get_id_count_watch_receiver(&self) -> watch::Receiver { + self.id_count_watch_receiver.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_id_pool() { + let mut pool = IdPool::new_with_max_limit(10); + assert_eq!(pool.request_id(), Ok(1)); + assert_eq!(pool.request_id(), Ok(2)); + assert_eq!(pool.request_id(), Ok(3)); + assert_eq!(pool.request_id(), Ok(4)); + assert_eq!(pool.id_count(), 4); + assert_eq!(pool.request_id(), Ok(5)); + assert_eq!(pool.request_id(), Ok(6)); + assert_eq!(pool.request_id(), Ok(7)); + assert_eq!(pool.request_id(), Ok(8)); + assert_eq!(pool.request_id(), Ok(9)); + assert_eq!(pool.request_id(), Ok(10)); + assert_eq!(pool.id_count(), 10); + assert_eq!(pool.request_id(), Err(IdPoolError::NoFreeIds)); + assert_eq!(pool.release_id(5), Ok(())); + assert_eq!(pool.release_id(5), Err(IdPoolError::IdNotInUse(5))); + assert_eq!(pool.id_count(), 9); + assert_eq!(pool.request_id(), Ok(5)); + assert_eq!(pool.release_id(11), Err(IdPoolError::IdOutOfBound(11))); + } + + #[test] + fn test_id_pool_watch() { + let mut pool = IdPool::new_with_max_limit(10); + let receiver = pool.get_id_count_watch_receiver(); + + assert_eq!(receiver.borrow().clone(), 0); + pool.request_id().unwrap(); + assert_eq!(receiver.borrow().clone(), 1); + pool.request_id().unwrap(); + assert_eq!(receiver.borrow().clone(), 2); + pool.release_id(1).unwrap(); + assert_eq!(receiver.borrow().clone(), 1); + } +}