From 39a3f8ffd135553e6e1248552db6803d4b41567a Mon Sep 17 00:00:00 2001
From: h7x4 <h7x4@nani.wtf>
Date: Thu, 8 Aug 2024 19:30:27 +0200
Subject: [PATCH] core/common: make testable, fix some status messages

---
 .../mysql_admutils_compatibility/common.rs    |  29 +-
 .../mysql_dbadm.rs                            |   7 +-
 .../mysql_useradm.rs                          |   7 +-
 src/core/common.rs                            | 305 ++++++++++++++----
 src/core/database_operations.rs               |   8 +-
 src/core/user_operations.rs                   |  13 +-
 6 files changed, 266 insertions(+), 103 deletions(-)

diff --git a/src/cli/mysql_admutils_compatibility/common.rs b/src/cli/mysql_admutils_compatibility/common.rs
index 1a53f0f..506b9c0 100644
--- a/src/cli/mysql_admutils_compatibility/common.rs
+++ b/src/cli/mysql_admutils_compatibility/common.rs
@@ -1,30 +1,7 @@
 use crate::core::common::{
-    get_current_unix_user, validate_name_token, validate_ownership_by_user_prefix,
+    get_current_unix_user, validate_name_or_error, validate_ownership_or_error, DbOrUser,
 };
 
-/// This enum is used to differentiate between database and user operations.
-/// Their output are very similar, but there are slight differences in the words used.
-pub enum DbOrUser {
-    Database,
-    User,
-}
-
-impl DbOrUser {
-    pub fn lowercased(&self) -> String {
-        match self {
-            DbOrUser::Database => "database".to_string(),
-            DbOrUser::User => "user".to_string(),
-        }
-    }
-
-    pub fn capitalized(&self) -> String {
-        match self {
-            DbOrUser::Database => "Database".to_string(),
-            DbOrUser::User => "User".to_string(),
-        }
-    }
-}
-
 /// In contrast to the new implementation which reports errors on any invalid name
 /// for any reason, mysql-admutils would only log the error and skip that particular
 /// name. This function replicates that behavior.
@@ -45,7 +22,7 @@ pub fn filter_db_or_user_names(
         //       here.
         .map(|name| name.chars().take(32).collect::<String>())
         .filter(|name| {
-            if let Err(_err) = validate_ownership_by_user_prefix(name, &unix_user) {
+            if let Err(_err) = validate_ownership_or_error(name, &unix_user, db_or_user) {
                 println!(
                     "You are not in charge of mysql-{}: '{}'.  Skipping.",
                     db_or_user.lowercased(),
@@ -60,7 +37,7 @@ pub fn filter_db_or_user_names(
             //       the name is already truncated to 32 characters. So
             //       if there is an error, it's guaranteed to be due to
             //       invalid characters.
-            if let Err(_err) = validate_name_token(name) {
+            if let Err(_err) = validate_name_or_error(name, db_or_user) {
                 println!(
                     concat!(
                         "{}: {} name '{}' contains invalid characters.\n",
diff --git a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs
index 2c0ba21..34c6bc4 100644
--- a/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs
+++ b/src/cli/mysql_admutils_compatibility/mysql_dbadm.rs
@@ -2,12 +2,9 @@ use clap::Parser;
 use sqlx::MySqlConnection;
 
 use crate::{
-    cli::{
-        database_command,
-        mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
-    },
+    cli::{database_command, mysql_admutils_compatibility::common::filter_db_or_user_names},
     core::{
-        common::yn,
+        common::{yn, DbOrUser},
         config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
         database_operations::{create_database, drop_database, get_database_list},
         database_privilege_operations,
diff --git a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
index a44a00f..08614fc 100644
--- a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
+++ b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
@@ -2,12 +2,9 @@ use clap::Parser;
 use sqlx::MySqlConnection;
 
 use crate::{
-    cli::{
-        mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
-        user_command,
-    },
+    cli::{mysql_admutils_compatibility::common::filter_db_or_user_names, user_command},
     core::{
-        common::{close_database_connection, get_current_unix_user},
+        common::{close_database_connection, get_current_unix_user, DbOrUser},
         config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
         user_operations::*,
     },
diff --git a/src/core/common.rs b/src/core/common.rs
index 5c27e43..9beb916 100644
--- a/src/core/common.rs
+++ b/src/core/common.rs
@@ -87,84 +87,165 @@ pub fn create_user_group_matching_regex(user: &User) -> String {
     }
 }
 
-pub fn validate_name_token(name: &str) -> anyhow::Result<()> {
+/// This enum is used to differentiate between database and user operations.
+/// Their output are very similar, but there are slight differences in the words used.
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum DbOrUser {
+    Database,
+    User,
+}
+
+impl DbOrUser {
+    pub fn lowercased(&self) -> String {
+        match self {
+            DbOrUser::Database => "database".to_string(),
+            DbOrUser::User => "user".to_string(),
+        }
+    }
+
+    pub fn capitalized(&self) -> String {
+        match self {
+            DbOrUser::Database => "Database".to_string(),
+            DbOrUser::User => "User".to_string(),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq)]
+pub enum NameValidationResult {
+    Valid,
+    EmptyString,
+    InvalidCharacters,
+    TooLong,
+}
+
+pub fn validate_name(name: &str) -> NameValidationResult {
     if name.is_empty() {
-        anyhow::bail!("Database name cannot be empty.");
-    }
-
-    if name.len() > 64 {
-        anyhow::bail!("Database name is too long. Maximum length is 64 characters.");
-    }
-
-    if !name
+        NameValidationResult::EmptyString
+    } else if name.len() > 64 {
+        NameValidationResult::TooLong
+    } else if !name
         .chars()
         .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
     {
-        anyhow::bail!(
+        NameValidationResult::InvalidCharacters
+    } else {
+        NameValidationResult::Valid
+    }
+}
+
+pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
+    match validate_name(name) {
+        NameValidationResult::Valid => Ok(()),
+        NameValidationResult::EmptyString => {
+            anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
+        }
+        NameValidationResult::TooLong => anyhow::bail!(
+            "{} is too long. Maximum length is 64 characters.",
+            db_or_user.capitalized()
+        ),
+        NameValidationResult::InvalidCharacters => anyhow::bail!(
             indoc! {r#"
-              Invalid characters in name: '{}'
+              Invalid characters in {} name: '{}'
 
               Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
             "#},
+            db_or_user.lowercased(),
             name
-        );
+        ),
     }
-
-    Ok(())
 }
 
-pub fn validate_ownership_by_user_prefix<'a>(
+#[derive(Debug, PartialEq, Eq)]
+pub enum OwnerValidationResult {
+    // The name is valid and matches one of the given prefixes
+    Match,
+
+    // The name is valid, but none of the given prefixes matched the name
+    NoMatch,
+
+    // The name is empty, which is invalid
+    StringEmpty,
+
+    // The name is in the format "_<postfix>", which is invalid
+    MissingPrefix,
+
+    // The name is in the format "<prefix>_", which is invalid
+    MissingPostfix,
+}
+
+/// Core logic for validating the ownership of a database name.
+/// This function checks if the given name matches any of the given prefixes.
+/// These prefixes will in most cases be the user's unix username and any
+/// unix groups the user is a member of.
+pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
+    if name.is_empty() {
+        return OwnerValidationResult::StringEmpty;
+    }
+
+    if name.starts_with('_') {
+        return OwnerValidationResult::MissingPrefix;
+    }
+
+    let (prefix, _) = match name.split_once('_') {
+        Some(pair) => pair,
+        None => return OwnerValidationResult::MissingPostfix,
+    };
+
+    if prefixes.iter().any(|g| g == prefix) {
+        OwnerValidationResult::Match
+    } else {
+        OwnerValidationResult::NoMatch
+    }
+}
+
+/// Validate the ownership of a database name or database user name.
+/// This function takes the name of a database or user and a unix user,
+/// for which it fetches the user's groups. It then checks if the name
+/// is prefixed with the user's username or any of the user's groups.
+pub fn validate_ownership_or_error<'a>(
     name: &'a str,
     user: &User,
+    db_or_user: DbOrUser,
 ) -> anyhow::Result<&'a str> {
     let user_groups = get_unix_groups(user)?;
+    let prefixes = std::iter::once(user.name.clone())
+        .chain(user_groups.iter().map(|g| g.name.clone()))
+        .collect::<Vec<String>>();
 
-    let mut split_name = name.split('_');
+    match validate_ownership_by_prefixes(name, &prefixes) {
+        OwnerValidationResult::Match => Ok(name),
+        OwnerValidationResult::NoMatch => {
+            anyhow::bail!(
+                indoc! {r#"
+                  Invalid {} name prefix: '{}' does not match your username or any of your groups.
+                  Are you sure you are allowed to create {} names with this prefix?
 
-    let prefix = split_name
-        .next()
-        .ok_or(anyhow::anyhow!(indoc! {r#"
-              Failed to find prefix.
-            "#},))
-        .and_then(|prefix| {
-            if user.name == prefix || user_groups.iter().any(|g| g.name == prefix) {
-                Ok(prefix)
-            } else {
-                anyhow::bail!(
-                    indoc! {r#"
-                      Invalid prefix: '{}' does not match your username or any of your groups.
-                      Are you sure you are allowed to create databases or users with this prefix?
-
-                      Allowed prefixes:
-                        - {}
-                      {}
-                    "#},
-                    prefix,
-                    user.name,
-                    user_groups
-                        .iter()
-                        .filter(|g| g.name != user.name)
-                        .map(|g| format!("  - {}", g.name))
-                        .sorted()
-                        .join("\n"),
-                );
-            }
-        })?;
-
-    if !split_name.next().is_some_and(|s| !s.is_empty()) {
-        anyhow::bail!(
-            indoc! {r#"
-              Missing the rest of the name after the user/group prefix.
-
-              The name should be in the format: '{}_<name>'
-            "#},
-            prefix
-        );
+                  Allowed prefixes:
+                    - {}
+                  {}
+                "#},
+                db_or_user.lowercased(),
+                name,
+                db_or_user.lowercased(),
+                user.name,
+                user_groups
+                    .iter()
+                    .filter(|g| g.name != user.name)
+                    .map(|g| format!("  - {}", g.name))
+                    .sorted()
+                    .join("\n"),
+            );
+        }
+        _ => anyhow::bail!(
+            "'{}' is not a valid {} name.",
+            name,
+            db_or_user.lowercased()
+        ),
     }
-
-    Ok(prefix)
 }
 
+/// Gracefully close a MySQL connection.
 pub async fn close_database_connection(connection: MySqlConnection) {
     if let Err(e) = connection
         .close()
@@ -203,3 +284,113 @@ pub(crate) fn rev_yn(s: &str) -> Option<bool> {
         _ => None,
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_yn() {
+        assert_eq!(yn(true), "Y");
+        assert_eq!(yn(false), "N");
+    }
+
+    #[test]
+    fn test_rev_yn() {
+        assert_eq!(rev_yn("Y"), Some(true));
+        assert_eq!(rev_yn("y"), Some(true));
+        assert_eq!(rev_yn("N"), Some(false));
+        assert_eq!(rev_yn("n"), Some(false));
+        assert_eq!(rev_yn("X"), None);
+    }
+
+    #[test]
+    fn test_quote_literal() {
+        let payload = "' OR 1=1 --";
+        assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
+    }
+
+    #[test]
+    fn test_quote_identifier() {
+        let payload = "` OR 1=1 --";
+        assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
+    }
+
+    #[test]
+    fn test_validate_name() {
+        assert_eq!(validate_name(""), NameValidationResult::EmptyString);
+        assert_eq!(
+            validate_name("abcdefghijklmnopqrstuvwxyz"),
+            NameValidationResult::Valid
+        );
+        assert_eq!(
+            validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
+            NameValidationResult::Valid
+        );
+        assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
+
+        for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
+            assert_eq!(
+                validate_name(&c.to_string()),
+                NameValidationResult::InvalidCharacters
+            );
+        }
+
+        assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
+
+        assert_eq!(
+            validate_name(&"a".repeat(65)),
+            NameValidationResult::TooLong
+        );
+    }
+
+    #[test]
+    fn test_validate_owner_by_prefixes() {
+        let prefixes = vec!["user".to_string(), "group".to_string()];
+
+        assert_eq!(
+            validate_ownership_by_prefixes("", &prefixes),
+            OwnerValidationResult::StringEmpty
+        );
+
+        assert_eq!(
+            validate_ownership_by_prefixes("user", &prefixes),
+            OwnerValidationResult::MissingPostfix
+        );
+        assert_eq!(
+            validate_ownership_by_prefixes("something", &prefixes),
+            OwnerValidationResult::MissingPostfix
+        );
+        assert_eq!(
+            validate_ownership_by_prefixes("user-testdb", &prefixes),
+            OwnerValidationResult::MissingPostfix
+        );
+
+        assert_eq!(
+            validate_ownership_by_prefixes("_testdb", &prefixes),
+            OwnerValidationResult::MissingPrefix
+        );
+
+        assert_eq!(
+            validate_ownership_by_prefixes("user_testdb", &prefixes),
+            OwnerValidationResult::Match
+        );
+        assert_eq!(
+            validate_ownership_by_prefixes("group_testdb", &prefixes),
+            OwnerValidationResult::Match
+        );
+        assert_eq!(
+            validate_ownership_by_prefixes("group_test_db", &prefixes),
+            OwnerValidationResult::Match
+        );
+        assert_eq!(
+            validate_ownership_by_prefixes("group_test-db", &prefixes),
+            OwnerValidationResult::Match
+        );
+
+        assert_eq!(
+            validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
+            OwnerValidationResult::NoMatch
+        );
+    }
+}
diff --git a/src/core/database_operations.rs b/src/core/database_operations.rs
index 63687ad..cae9dae 100644
--- a/src/core/database_operations.rs
+++ b/src/core/database_operations.rs
@@ -8,7 +8,7 @@ use sqlx::{prelude::*, MySqlConnection};
 use crate::core::{
     common::{
         create_user_group_matching_regex, get_current_unix_user, quote_identifier,
-        validate_name_token, validate_ownership_by_user_prefix,
+        validate_name_or_error, validate_ownership_or_error, DbOrUser,
     },
     database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
 };
@@ -112,8 +112,10 @@ pub async fn get_databases_where_user_has_privileges(
 ///       the database name as a parameter to the query. This means that we have
 ///       to validate the database name ourselves to prevent SQL injection.
 pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
-    validate_name_token(name).context("Invalid database name")?;
-    validate_ownership_by_user_prefix(name, user).context("Invalid database name")?;
+    validate_name_or_error(name, DbOrUser::Database)
+        .context(format!("Invalid database name: '{}'", name))?;
+    validate_ownership_or_error(name, user, DbOrUser::Database)
+        .context(format!("Invalid database name: '{}'", name))?;
 
     Ok(())
 }
diff --git a/src/core/user_operations.rs b/src/core/user_operations.rs
index d567f66..5d3ac8a 100644
--- a/src/core/user_operations.rs
+++ b/src/core/user_operations.rs
@@ -3,11 +3,9 @@ use nix::unistd::User;
 use serde::{Deserialize, Serialize};
 use sqlx::{prelude::*, MySqlConnection};
 
-use crate::core::common::quote_literal;
-
-use super::common::{
-    create_user_group_matching_regex, get_current_unix_user, validate_name_token,
-    validate_ownership_by_user_prefix,
+use crate::core::common::{
+    create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
+    validate_ownership_or_error, DbOrUser,
 };
 
 pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
@@ -242,8 +240,9 @@ pub async fn get_database_user_for_user(
 ///       the database name as a parameter to the query. This means that we have
 ///       to validate the database name ourselves to prevent SQL injection.
 pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
-    validate_name_token(name).context(format!("Invalid username: '{}'", name))?;
-    validate_ownership_by_user_prefix(name, user)
+    validate_name_or_error(name, DbOrUser::User)
+        .context(format!("Invalid username: '{}'", name))?;
+    validate_ownership_or_error(name, user, DbOrUser::User)
         .context(format!("Invalid username: '{}'", name))?;
 
     Ok(())