diff --git a/src/proto/rwhod_protocol.rs b/src/proto/rwhod_protocol.rs index 809abfb..9f0f28c 100644 --- a/src/proto/rwhod_protocol.rs +++ b/src/proto/rwhod_protocol.rs @@ -172,10 +172,17 @@ impl Whod { let mut bytes = bytes::Bytes::copy_from_slice(input); let wd_vers = bytes.get_u8(); - debug_assert!(wd_vers == Self::WHODVERSION); + if wd_vers != Self::WHODVERSION { + return Err(anyhow::anyhow!( + "Unsupported whod protocol version: {}", + wd_vers + )); + } let wd_type = bytes.get_u8(); - debug_assert!(wd_type == Self::WHODTYPE_STATUS); + if wd_type != Self::WHODTYPE_STATUS { + return Err(anyhow::anyhow!("Unsupported whod packet type: {}", wd_type)); + } bytes.advance(2); // skip wd_pad @@ -521,4 +528,30 @@ mod tests { assert_eq!(original_status, final_status); } + + #[test] + fn test_parser_invalid_bytes() { + // Too short + let short_bytes = vec![0u8; Whod::HEADER_SIZE - 1]; + assert!(Whod::from_bytes(&short_bytes).is_err()); + + // Too long + let long_bytes = vec![0u8; Whod::MAX_SIZE + 1]; + assert!(Whod::from_bytes(&long_bytes).is_err()); + + // Misaligned length + let misaligned_bytes = vec![0u8; Whod::HEADER_SIZE + 1]; + assert!(Whod::from_bytes(&misaligned_bytes).is_err()); + + // Invalid version + let mut invalid_version_bytes = vec![0u8; Whod::HEADER_SIZE]; + invalid_version_bytes[0] = 99; // invalid version + assert!(Whod::from_bytes(&invalid_version_bytes).is_err()); + + // Invalid packet type + let mut invalid_type_bytes = vec![0u8; Whod::HEADER_SIZE]; + invalid_type_bytes[0] = Whod::WHODVERSION; + invalid_type_bytes[1] = 99; // invalid type + assert!(Whod::from_bytes(&invalid_type_bytes).is_err()); + } }