diff --git a/Cargo.toml b/Cargo.toml index 10f5267..5993462 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,8 @@ tokio-stream = { version = "0.1.15", features = ["sync"] } [dev-dependencies] env_logger = "0.10.0" test-log = "0.2.15" -tokio = { version = "1.37.0", features = ["rt-multi-thread", "time"] } +tokio = { version = "1.37.0", features = ["rt-multi-thread", "time", "process"] } +uuid = { version = "1.8.0", features = ["v4"] } [lib] doctest = false diff --git a/src/core_api.rs b/src/core_api.rs index e816f8c..9371ed8 100644 --- a/src/core_api.rs +++ b/src/core_api.rs @@ -68,7 +68,7 @@ pub(crate) trait IntoRawCommandPart { } /// Generic data type representing all possible data types that mpv can return. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum MpvDataType { Array(Vec), Bool(bool), diff --git a/src/event_parser.rs b/src/event_parser.rs index e751c0b..d22e7ab 100644 --- a/src/event_parser.rs +++ b/src/event_parser.rs @@ -16,7 +16,7 @@ use crate::{ipc::MpvIpcEvent, Error, ErrorCode, MpvDataType}; /// /// See for /// the upstream list of properties. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Property { Path(Option), Pause(bool), @@ -35,7 +35,7 @@ pub enum Property { /// /// See for /// the upstream list of events. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Event { Shutdown, StartFile, diff --git a/src/ipc.rs b/src/ipc.rs index 3159be9..47c0b7b 100644 --- a/src/ipc.rs +++ b/src/ipc.rs @@ -2,12 +2,11 @@ use futures::{SinkExt, StreamExt}; use serde_json::{json, Value}; -use std::mem; use tokio::{ net::UnixStream, - sync::{broadcast, mpsc, oneshot, Mutex}, + sync::{broadcast, mpsc, oneshot}, }; -use tokio_util::codec::{Framed, LinesCodec, LinesCodecError}; +use tokio_util::codec::{Framed, LinesCodec}; use crate::{Error, ErrorCode}; @@ -15,9 +14,6 @@ use crate::{Error, ErrorCode}; /// and message passing with [`Mpv`](crate::Mpv) controllers. pub(crate) struct MpvIpc { socket: Framed, - // I had trouble with reading and writing to the socket when it was wrapped - // in a MutexGuard, so I'm using a separate Mutex to lock the socket when needed. - socket_lock: Mutex<()>, command_channel: mpsc::Receiver<(MpvIpcCommand, oneshot::Sender)>, event_channel: broadcast::Sender, } @@ -50,14 +46,14 @@ impl MpvIpc { MpvIpc { socket: Framed::new(socket, LinesCodec::new()), command_channel, - socket_lock: Mutex::new(()), event_channel, } } pub(crate) async fn send_command(&mut self, command: &[Value]) -> Result, Error> { - let lock = self.socket_lock.lock().await; + // let lock = self.socket_lock.lock().await; // START CRITICAL SECTION + let ipc_command = json!({ "command": command }); let ipc_command_str = serde_json::to_string(&ipc_command) .map_err(|why| Error(ErrorCode::JsonParseError(why.to_string())))?; @@ -69,21 +65,34 @@ impl MpvIpc { .await .map_err(|why| Error(ErrorCode::ConnectError(why.to_string())))?; - let response = self - .socket - .next() - .await - .ok_or(Error(ErrorCode::MissingValue))? - .map_err(|why| Error(ErrorCode::ConnectError(why.to_string())))?; + let response = loop { + let response = self + .socket + .next() + .await + .ok_or(Error(ErrorCode::MissingValue))? + .map_err(|why| Error(ErrorCode::ConnectError(why.to_string())))?; + let parsed_response = serde_json::from_str::(&response) + .map_err(|why| Error(ErrorCode::JsonParseError(why.to_string()))); + + if parsed_response + .as_ref() + .ok() + .and_then(|v| v.as_object().map(|o| o.contains_key("event"))) + .unwrap_or(false) + { + self.handle_event(parsed_response).await; + } else { + break parsed_response; + } + }; // END CRITICAL SECTION - mem::drop(lock); + // mem::drop(lock); - log::trace!("Received response: {}", response); + log::trace!("Received response: {:?}", response); - serde_json::from_str::(&response) - .map_err(|why| Error(ErrorCode::JsonParseError(why.to_string()))) - .and_then(parse_mpv_response_data) + parse_mpv_response_data(response?) } pub(crate) async fn get_mpv_property( @@ -117,16 +126,8 @@ impl MpvIpc { .await } - async fn handle_event(&mut self, event: Result) { - let parsed_event = event - .as_ref() - .map_err(|why| Error(ErrorCode::ConnectError(why.to_string()))) - .and_then(|event| { - serde_json::from_str::(event) - .map_err(|why| Error(ErrorCode::JsonParseError(why.to_string()))) - }); - - match parsed_event { + async fn handle_event(&mut self, event: Result) { + match &event { Ok(event) => { log::trace!("Parsed event: {:?}", event); if let Err(broadcast::error::SendError(_)) = @@ -136,7 +137,7 @@ impl MpvIpc { } } Err(e) => { - log::trace!("Error parsing event, ignoring:\n {:?}\n {:?}", event, e); + log::trace!("Error parsing event, ignoring:\n {:?}\n {:?}", &event, e); } } } @@ -146,8 +147,14 @@ impl MpvIpc { tokio::select! { Some(event) = self.socket.next() => { log::trace!("Got event: {:?}", event); - // TODO: error handling - self.handle_event(event).await; + + let parsed_event = event + .map_err(|why| Error(ErrorCode::ConnectError(why.to_string()))) + .and_then(|event| + serde_json::from_str::(&event) + .map_err(|why| Error(ErrorCode::JsonParseError(why.to_string())))); + + self.handle_event(parsed_event).await; } Some((cmd, tx)) = self.command_channel.recv() => { log::trace!("Handling command: {:?}", cmd); diff --git a/src/message_parser.rs b/src/message_parser.rs index 7c038ec..f1e6e80 100644 --- a/src/message_parser.rs +++ b/src/message_parser.rs @@ -136,61 +136,27 @@ pub(crate) fn json_map_to_hashmap( } pub(crate) fn json_array_to_vec(array: &[Value]) -> Vec { - let mut output: Vec = Vec::new(); - if !array.is_empty() { - match array[0] { - Value::Array(_) => { - for entry in array { - if let Value::Array(ref a) = *entry { - output.push(MpvDataType::Array(json_array_to_vec(a))); - } + array + .iter() + .map(|entry| match entry { + Value::Array(a) => MpvDataType::Array(json_array_to_vec(a)), + Value::Bool(b) => MpvDataType::Bool(*b), + Value::Number(n) => { + if n.is_u64() { + MpvDataType::Usize(n.as_u64().unwrap() as usize) + } else if n.is_f64() { + MpvDataType::Double(n.as_f64().unwrap()) + } else { + panic!("unimplemented number"); } } - - Value::Bool(_) => { - for entry in array { - if let Value::Bool(ref b) = *entry { - output.push(MpvDataType::Bool(*b)); - } - } - } - - Value::Number(_) => { - for entry in array { - if let Value::Number(ref n) = *entry { - if n.is_u64() { - output.push(MpvDataType::Usize(n.as_u64().unwrap() as usize)); - } else if n.is_f64() { - output.push(MpvDataType::Double(n.as_f64().unwrap())); - } else { - panic!("unimplemented number"); - } - } - } - } - - Value::Object(_) => { - for entry in array { - if let Value::Object(ref map) = *entry { - output.push(MpvDataType::HashMap(json_map_to_hashmap(map))); - } - } - } - - Value::String(_) => { - for entry in array { - if let Value::String(ref s) = *entry { - output.push(MpvDataType::String(s.to_string())); - } - } - } - + Value::Object(ref o) => MpvDataType::HashMap(json_map_to_hashmap(o)), + Value::String(s) => MpvDataType::String(s.to_owned()), Value::Null => { unimplemented!(); } - } - } - output + }) + .collect() } pub(crate) fn json_array_to_playlist(array: &[Value]) -> Vec { @@ -217,3 +183,137 @@ pub(crate) fn json_array_to_playlist(array: &[Value]) -> Vec { } output } + +#[cfg(test)] +mod test { + use super::*; + use crate::MpvDataType; + use serde_json::json; + use std::collections::HashMap; + + #[test] + fn test_json_map_to_hashmap() { + let json = json!({ + "array": [1, 2, 3], + "bool": true, + "double": 1.0, + "usize": 1, + "string": "string", + "object": { + "key": "value" + } + }); + + let mut expected = HashMap::new(); + expected.insert( + "array".to_string(), + MpvDataType::Array(vec![ + MpvDataType::Usize(1), + MpvDataType::Usize(2), + MpvDataType::Usize(3), + ]), + ); + expected.insert("bool".to_string(), MpvDataType::Bool(true)); + expected.insert("double".to_string(), MpvDataType::Double(1.0)); + expected.insert("usize".to_string(), MpvDataType::Usize(1)); + expected.insert( + "string".to_string(), + MpvDataType::String("string".to_string()), + ); + expected.insert( + "object".to_string(), + MpvDataType::HashMap(HashMap::from([( + "key".to_string(), + MpvDataType::String("value".to_string()), + )])), + ); + + assert_eq!(json_map_to_hashmap(json.as_object().unwrap()), expected); + } + + #[test] + #[should_panic] + fn test_json_map_to_hashmap_fail_on_null() { + json_map_to_hashmap( + json!({ + "null": null + }) + .as_object() + .unwrap(), + ); + } + + #[test] + fn test_json_array_to_vec() { + let json = json!([ + [1, 2, 3], + true, + 1.0, + 1, + "string", + { + "key": "value" + } + ]); + + println!("{:?}", json.as_array().unwrap()); + println!("{:?}", json_array_to_vec(json.as_array().unwrap())); + + let expected = vec![ + MpvDataType::Array(vec![ + MpvDataType::Usize(1), + MpvDataType::Usize(2), + MpvDataType::Usize(3), + ]), + MpvDataType::Bool(true), + MpvDataType::Double(1.0), + MpvDataType::Usize(1), + MpvDataType::String("string".to_string()), + MpvDataType::HashMap(HashMap::from([( + "key".to_string(), + MpvDataType::String("value".to_string()), + )])), + ]; + + assert_eq!(json_array_to_vec(json.as_array().unwrap()), expected); + } + + #[test] + #[should_panic] + fn test_json_array_to_vec_fail_on_null() { + json_array_to_vec(json!([null]).as_array().unwrap().as_slice()); + } + + #[test] + fn test_json_array_to_playlist() { + let json = json!([ + { + "filename": "file1", + "title": "title1", + "current": true + }, + { + "filename": "file2", + "title": "title2", + "current": false + } + ]); + + let expected = vec![ + PlaylistEntry { + id: 0, + filename: "file1".to_string(), + title: "title1".to_string(), + current: true, + }, + PlaylistEntry { + id: 1, + filename: "file2".to_string(), + title: "title2".to_string(), + current: false, + }, + ]; + + assert_eq!(json_array_to_playlist(json.as_array().unwrap()), expected); + } +} diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..f967885 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,100 @@ +use mpvipc::{Error, Mpv, MpvExt}; +use std::path::Path; +use tokio::{ + process::{Child, Command}, + time::{sleep, timeout, Duration}, +}; + +#[cfg(target_family = "unix")] +async fn spawn_headless_mpv() -> Result<(Child, Mpv), Error> { + let socket_path_str = format!("/tmp/mpv-ipc-{}", uuid::Uuid::new_v4()); + let socket_path = Path::new(&socket_path_str); + + let process_handle = Command::new("mpv") + .arg("--no-config") + .arg("--idle") + .arg("--no-video") + .arg("--no-audio") + .arg(format!( + "--input-ipc-server={}", + &socket_path.to_str().unwrap() + )) + .spawn() + .expect("Failed to start mpv"); + + if timeout(Duration::from_millis(500), async { + while !&socket_path.exists() { + sleep(Duration::from_millis(10)).await; + } + }) + .await + .is_err() + { + panic!("Failed to create mpv socket at {:?}", &socket_path); + } + + let mpv = Mpv::connect(socket_path.to_str().unwrap()).await.unwrap(); + Ok((process_handle, mpv)) +} + +#[tokio::test] +#[cfg(target_family = "unix")] +async fn test_get_mpv_version() { + let (mut proc, mpv) = spawn_headless_mpv().await.unwrap(); + let version: String = mpv.get_property("mpv-version").await.unwrap(); + assert!(version.starts_with("mpv")); + + mpv.kill().await.unwrap(); + proc.kill().await.unwrap(); +} + +#[tokio::test] +#[cfg(target_family = "unix")] +async fn test_set_property() { + let (mut proc, mpv) = spawn_headless_mpv().await.unwrap(); + mpv.set_property("pause", true).await.unwrap(); + let paused: bool = mpv.get_property("pause").await.unwrap(); + assert!(paused); + + mpv.kill().await.unwrap(); + proc.kill().await.unwrap(); +} + +#[tokio::test] +#[cfg(target_family = "unix")] +async fn test_events() { + use futures::stream::StreamExt; + + let (mut proc, mpv) = spawn_headless_mpv().await.unwrap(); + + mpv.observe_property(1337, "pause").await.unwrap(); + + let mut events = mpv.get_event_stream().await; + let event_checking_thread = tokio::spawn(async move { + loop { + let event = events.next().await.unwrap().unwrap(); + if let mpvipc::Event::PropertyChange { id, property } = event { + if id == 1337 { + assert_eq!(property, mpvipc::Property::Pause(true)); + break; + } + } + } + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + + mpv.set_property("pause", true).await.unwrap(); + + if let Err(_) = tokio::time::timeout( + tokio::time::Duration::from_millis(500), + event_checking_thread, + ) + .await + { + panic!("Event checking thread timed out"); + } + + mpv.kill().await.unwrap(); + proc.kill().await.unwrap(); +}