diff --git a/src/commands/client_to_client/channels.rs b/src/commands/client_to_client/channels.rs index 41a6e0b..f703095 100644 --- a/src/commands/client_to_client/channels.rs +++ b/src/commands/client_to_client/channels.rs @@ -37,7 +37,10 @@ impl Command for Channels { for (key, value) in parts { debug_assert!(key == "channels"); let channel_name = expect_property_type!(Some(value), "channels", Text); - channel_names.push(channel_name.to_string()); + let channel_name = channel_name + .parse() + .map_err(|_| ResponseParserError::SyntaxError(0, channel_name))?; + channel_names.push(channel_name); } Ok(ChannelsResponse { @@ -64,7 +67,11 @@ mod tests { assert_eq!( response, ChannelsResponse { - channels: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()] + channels: vec![ + "foo".parse().unwrap(), + "bar".parse().unwrap(), + "baz".parse().unwrap(), + ] } ); } diff --git a/src/commands/client_to_client/readmessages.rs b/src/commands/client_to_client/readmessages.rs index c67fde5..d7b2d22 100644 --- a/src/commands/client_to_client/readmessages.rs +++ b/src/commands/client_to_client/readmessages.rs @@ -47,7 +47,11 @@ impl Command for ReadMessages { debug_assert!(ckey == "channel"); debug_assert!(mkey == "message"); - let channel = expect_property_type!(Some(cvalue), "channel", Text).to_string(); + let channel = expect_property_type!(Some(cvalue), "channel", Text); + let channel = channel + .parse() + .map_err(|_| ResponseParserError::SyntaxError(0, channel))?; + let message = expect_property_type!(Some(mvalue), "message", Text).to_string(); messages.push(ReadMessagesResponseEntry { channel, message }); @@ -77,11 +81,11 @@ mod tests { result, Ok(vec![ ReadMessagesResponseEntry { - channel: "channel1".to_string(), + channel: "channel1".parse().unwrap(), message: "message1".to_string(), }, ReadMessagesResponseEntry { - channel: "channel2".to_string(), + channel: "channel2".parse().unwrap(), message: "message2".to_string(), }, ]) diff --git a/src/commands/client_to_client/sendmessage.rs b/src/commands/client_to_client/sendmessage.rs index dc9b1eb..2f6b974 100644 --- a/src/commands/client_to_client/sendmessage.rs +++ b/src/commands/client_to_client/sendmessage.rs @@ -26,13 +26,16 @@ impl Command for SendMessage { fn parse_request(mut parts: RequestTokenizer<'_>) -> RequestParserResult<'_> { let channel = parts.next().ok_or(RequestParserError::UnexpectedEOF)?; + let channel = channel + .parse() + .map_err(|_| RequestParserError::SyntaxError(0, channel.to_owned()))?; // TODO: SplitWhitespace::remainder() is unstable, use when stable let message = parts.collect::>().join(" "); debug_assert!(!message.is_empty()); - Ok((Request::SendMessage(channel.to_string(), message), "")) + Ok((Request::SendMessage(channel, message), "")) } fn parse_response( diff --git a/src/commands/client_to_client/subscribe.rs b/src/commands/client_to_client/subscribe.rs index c054be5..1a8722a 100644 --- a/src/commands/client_to_client/subscribe.rs +++ b/src/commands/client_to_client/subscribe.rs @@ -18,10 +18,13 @@ impl Command for Subscribe { fn parse_request(mut parts: RequestTokenizer<'_>) -> RequestParserResult<'_> { let channel_name = parts.next().ok_or(RequestParserError::UnexpectedEOF)?; + let channel_name = channel_name + .parse() + .map_err(|_| RequestParserError::SyntaxError(0, channel_name.to_owned()))?; debug_assert!(parts.next().is_none()); - Ok((Request::Subscribe(channel_name.to_string()), "")) + Ok((Request::Subscribe(channel_name), "")) } fn parse_response( diff --git a/src/commands/client_to_client/unsubscribe.rs b/src/commands/client_to_client/unsubscribe.rs index bb7b475..200f2ca 100644 --- a/src/commands/client_to_client/unsubscribe.rs +++ b/src/commands/client_to_client/unsubscribe.rs @@ -18,10 +18,13 @@ impl Command for Unsubscribe { fn parse_request(mut parts: RequestTokenizer<'_>) -> RequestParserResult<'_> { let channel_name = parts.next().ok_or(RequestParserError::UnexpectedEOF)?; + let channel_name = channel_name + .parse() + .map_err(|_| RequestParserError::SyntaxError(0, channel_name.to_owned()))?; debug_assert!(parts.next().is_none()); - Ok((Request::Unsubscribe(channel_name.to_string()), "")) + Ok((Request::Unsubscribe(channel_name), "")) } fn parse_response( diff --git a/src/common/types.rs b/src/common/types.rs index a80b586..7a36b7a 100644 --- a/src/common/types.rs +++ b/src/common/types.rs @@ -1,6 +1,7 @@ mod absolute_relative_song_position; mod audio; mod bool_or_oneshot; +mod channel_name; mod group_type; mod one_or_range; mod replay_gain_mode_mode; @@ -16,6 +17,7 @@ mod window_range; pub use absolute_relative_song_position::AbsouluteRelativeSongPosition; pub use audio::Audio; pub use bool_or_oneshot::BoolOrOneshot; +pub use channel_name::ChannelName; pub use group_type::GroupType; pub use one_or_range::OneOrRange; pub use replay_gain_mode_mode::ReplayGainModeMode; @@ -39,7 +41,6 @@ pub type TimeWithFractions = f64; // TODO: use a proper types pub type AudioOutputId = String; -pub type ChannelName = String; pub type Feature = String; pub type PartitionName = String; pub type Path = String; diff --git a/src/common/types/channel_name.rs b/src/common/types/channel_name.rs new file mode 100644 index 0000000..44b6b19 --- /dev/null +++ b/src/common/types/channel_name.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ChannelName(String); + +impl ChannelName { + fn new(name: String) -> Self { + ChannelName(name) + } + + pub fn as_str(&self) -> &str { + &self.0 + } + + pub fn into_inner(self) -> String { + self.0 + } +} + +impl std::fmt::Display for ChannelName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::str::FromStr for ChannelName { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if !s + .chars() + .all(|c| c.is_ascii_alphanumeric() || "-_.:".contains(c)) + { + Err("Invalid channel name") + } else { + Ok(ChannelName::new(s.to_string())) + } + } +}