diff --git a/examples/mpd-client/main.rs b/examples/mpd-client/main.rs index 09e8f988..b384518d 100644 --- a/examples/mpd-client/main.rs +++ b/examples/mpd-client/main.rs @@ -4,9 +4,8 @@ use empidee::MpdClient; async fn main() -> anyhow::Result<()> { let socket = tokio::net::TcpSocket::new_v4()?; let mut stream = socket.connect("127.0.0.1:6600".parse()?).await?; - - let mut client = MpdClient::new(&mut stream); - println!("{}", client.read_initial_mpd_version().await?); + let mut client = MpdClient::new(&mut stream).await?; + println!("Connected to MPD server: {}", client.get_mpd_version().unwrap_or("unknown")); client.play(None).await?; diff --git a/src/client.rs b/src/client.rs index c73e7ee5..b8f01ed8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,18 +9,20 @@ use crate::{Request, commands::*, types::SongPosition}; #[cfg(feature = "futures")] use futures_util::{ AsyncBufReadExt, - io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, + io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, }; -use thiserror::Error; #[cfg(feature = "tokio")] -use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}; + +use thiserror::Error; pub struct MpdClient<'a, T> where T: AsyncWrite + AsyncRead + Unpin, { - connection: &'a mut T, + stream: BufStream<&'a mut T>, + mpd_version: Option, } #[derive(Error, Debug)] @@ -39,30 +41,53 @@ impl<'a, T> MpdClient<'a, T> where T: AsyncWrite + AsyncRead + Unpin, { - pub fn new(connection: &'a mut T) -> Self { - MpdClient { connection } + pub async fn new(connection: &'a mut T) -> Result { + let mut client = MpdClient { + stream: BufStream::new(connection), + mpd_version: None, + }; + + client.read_initial_mpd_version().await?; + + Ok(client) } - pub async fn read_initial_mpd_version(&mut self) -> Result { - let mut reader = BufReader::new(&mut self.connection); + pub async fn wrap_existing(connection: &'a mut T, mpd_version: Option) -> Self { + MpdClient { + stream: BufStream::new(connection), + mpd_version, + } + } + + pub fn into_connection(self) -> &'a mut T { + self.stream.into_inner() + } + + pub fn get_mpd_version(&self) -> Option<&str> { + self.mpd_version.as_deref() + } + + async fn read_initial_mpd_version(&mut self) -> Result<(), MpdClientError> { let mut version_line = String::new(); - reader + self.stream .read_line(&mut version_line) .await .map_err(MpdClientError::ConnectionError)?; - Ok(version_line.trim().to_string()) + self.mpd_version = Some(version_line.trim().to_string()); + + Ok(()) } async fn read_response(&mut self) -> Result, MpdClientError> { let mut response = Vec::new(); - let mut reader = BufReader::new(&mut self.connection); loop { let mut line = Vec::new(); - let bytes_read = reader + let bytes_read = self + .stream .read_until(b'\n', &mut line) .await .map_err(MpdClientError::ConnectionError)?; @@ -88,12 +113,12 @@ where let message = Request::Play(position); let payload = message.serialize(); - self.connection + self.stream .write_all(payload.as_bytes()) .await .map_err(MpdClientError::ConnectionError)?; - self.connection + self.stream .flush() .await .map_err(MpdClientError::ConnectionError)?;