diff --git a/src/core/types.rs b/src/core/types.rs index d9981c2..4a14d1b 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -105,12 +105,43 @@ impl From for OsString { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum DbOrUser { Database(MySQLDatabase), User(MySQLUser), } +impl Serialize for DbOrUser { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + DbOrUser::Database(db) => ("d:".to_string() + &db.to_string()).serialize(serializer), + DbOrUser::User(user) => ("u:".to_string() + &user.to_string()).serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for DbOrUser { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + if let Some(rest) = s.strip_prefix("d:") { + Ok(DbOrUser::Database(MySQLDatabase(rest.to_string()))) + } else if let Some(rest) = s.strip_prefix("u:") { + Ok(DbOrUser::User(MySQLUser(rest.to_string()))) + } else { + Err(serde::de::Error::custom(format!( + "Invalid DbOrUser format: {}", + s + ))) + } + } +} + impl DbOrUser { #[must_use] pub fn lowercased_noun(&self) -> &'static str {