diff --git a/Cargo.lock b/Cargo.lock index aed6637..9059870 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -624,6 +624,7 @@ dependencies = [ "tokio", "tower", "tower-http", + "tungstenite", "utoipa", "utoipa-axum", "utoipa-swagger-ui", diff --git a/Cargo.toml b/Cargo.toml index 7c4a74d..abfdb3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ tempfile = "3.11.0" tokio = { version = "1.32.0", features = ["full"] } tower = { version = "0.5.2", features = ["full"] } tower-http = { version = "0.6.2", features = ["full"] } +tungstenite = "0.26.2" utoipa = { version = "5.1.3", features = ["axum_extras"] } utoipa-axum = "0.2.0" utoipa-swagger-ui = { version = "9.0.1", features = ["axum", "vendored"] } diff --git a/src/api/websocket_v1.rs b/src/api/websocket_v1.rs index 503ed64..d97b4eb 100644 --- a/src/api/websocket_v1.rs +++ b/src/api/websocket_v1.rs @@ -289,91 +289,108 @@ async fn connection_loop( let mut event_stream = mpv.get_event_stream().await; loop { select! { - id_count = id_count_watch_receiver.changed() => { - if let Err(e) = id_count { - anyhow::bail!("Error reading id count watch receiver for {:?}: {:?}", addr, e); - } - - let message = Message::Text(json!({ - "type": "connection_count", - "value": id_count_watch_receiver.borrow().clone(), - }).to_string().into(),); - - socket.send(message).await?; - } - message = socket.recv() => { - log::trace!("Received command from {:?}: {:?}", addr, message); - - let ws_message_content = message - .ok_or(anyhow::anyhow!("Event stream ended for {:?}", addr)) - .and_then(|message| { - match message { - Ok(message) => Ok(message), - err => Err(anyhow::anyhow!("Error reading message for {:?}: {:?}", addr, err)), + id_count = id_count_watch_receiver.changed() => { + if let Err(e) = id_count { + anyhow::bail!("Error reading id count watch receiver for {:?}: {:?}", addr, e); } - })?; - if let Message::Close(_) = ws_message_content { - log::trace!("Closing connection for {:?}", addr); - return Ok(()); - } - - if let Message::Ping(xs) = ws_message_content { - log::trace!("Ponging {:?} with {:?}", addr, xs); - socket.send(Message::Pong(xs)).await?; - continue; - } - - let message_content = match ws_message_content { - Message::Text(text) => text, - m => anyhow::bail!("Unexpected message type: {:?}", m), - }; - - let message_json = match serde_json::from_str::(&message_content) { - Ok(json) => json, - Err(e) => anyhow::bail!("Error parsing message from {:?}: {:?}", addr, e), - }; - - log::trace!("Handling command from {:?}: {:?}", addr, message_json); - - // TODO: handle errors - match handle_message(message_json, mpv.clone(), channel_id).await { - Ok(Some(response)) => { - log::trace!("Handled command from {:?} successfully, sending response", addr); - let message = Message::Text(json!({ - "type": "response", - "value": response, - }).to_string().into(),); - socket.send(message).await?; - } - Ok(None) => { - log::trace!("Handled command from {:?} successfully", addr); - } - Err(e) => { - log::error!("Error handling message from {:?}: {:?}", addr, e); - } - } - } - event = event_stream.next() => { - match event { - Some(Ok(event)) => { - log::trace!("Sending event to {:?}: {:?}", addr, event); let message = Message::Text(json!({ - "type": "event", - "value": event, + "type": "connection_count", + "value": id_count_watch_receiver.borrow().clone(), }).to_string().into(),); + socket.send(message).await?; - } - Some(Err(e)) => { - log::error!("Error reading event stream for {:?}: {:?}", addr, e); - anyhow::bail!("Error reading event stream for {:?}: {:?}", addr, e); - } - None => { - log::trace!("Event stream ended for {:?}", addr); - return Ok(()); - } } - } + + message = socket.recv() => { + log::trace!("Received command from {:?}: {:?}", addr, message); + + let ws_message_content = match message { + Some(Ok(message)) => message, + + None => { + log::debug!("Connection closed for {:?}", addr); + return Ok(()); + }, + + Some(Err(e)) => { + let inner_error = e.into_inner(); + if inner_error + .downcast_ref::() + .is_some_and(|e| match *e { + tungstenite::error::Error::Protocol(tungstenite::error::ProtocolError::ResetWithoutClosingHandshake) => true, + _ => false, + }) { + log::warn!("Connection reset without closing handshake for {:?}", addr); + return Ok(()); + } else { + log::error!("Error reading message for {:?}: {:?}", addr, inner_error); + anyhow::bail!("Error reading message for {:?}: {:?}", addr, inner_error); + } + }, + }; + + if let Message::Close(_) = ws_message_content { + log::trace!("Closing connection for {:?}", addr); + return Ok(()); + } + + if let Message::Ping(xs) = ws_message_content { + log::trace!("Ponging {:?} with {:?}", addr, xs); + socket.send(Message::Pong(xs)).await?; + continue; + } + + let message_content = match ws_message_content { + Message::Text(text) => text, + m => anyhow::bail!("Unexpected message type: {:?}", m), + }; + + let message_json = match serde_json::from_str::(&message_content) { + Ok(json) => json, + Err(e) => anyhow::bail!("Error parsing message from {:?}: {:?}", addr, e), + }; + + log::trace!("Handling command from {:?}: {:?}", addr, message_json); + + // TODO: handle errors + match handle_message(message_json, mpv.clone(), channel_id).await { + Ok(Some(response)) => { + log::trace!("Handled command from {:?} successfully, sending response", addr); + let message = Message::Text(json!({ + "type": "response", + "value": response, + }).to_string().into(),); + socket.send(message).await?; + } + Ok(None) => { + log::trace!("Handled command from {:?} successfully", addr); + } + Err(e) => { + log::error!("Error handling message from {:?}: {:?}", addr, e); + } + } + } + event = event_stream.next() => { + match event { + Some(Ok(event)) => { + log::trace!("Sending event to {:?}: {:?}", addr, event); + let message = Message::Text(json!({ + "type": "event", + "value": event, + }).to_string().into(),); + socket.send(message).await?; + } + Some(Err(e)) => { + log::error!("Error reading event stream for {:?}: {:?}", addr, e); + anyhow::bail!("Error reading event stream for {:?}: {:?}", addr, e); + } + None => { + log::trace!("Event stream ended for {:?}", addr); + return Ok(()); + } + } + } } } }