core/common: make testable, fix some status messages

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-08 19:30:27 +02:00
parent 69870147f5
commit d8ca543087
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
6 changed files with 266 additions and 103 deletions

View File

@ -1,30 +1,7 @@
use crate::core::common::{ use crate::core::common::{
get_current_unix_user, validate_name_token, validate_ownership_by_user_prefix, get_current_unix_user, validate_name_or_error, validate_ownership_or_error, DbOrUser,
}; };
/// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
}
}
/// In contrast to the new implementation which reports errors on any invalid name /// In contrast to the new implementation which reports errors on any invalid name
/// for any reason, mysql-admutils would only log the error and skip that particular /// for any reason, mysql-admutils would only log the error and skip that particular
/// name. This function replicates that behavior. /// name. This function replicates that behavior.
@ -45,7 +22,7 @@ pub fn filter_db_or_user_names(
// here. // here.
.map(|name| name.chars().take(32).collect::<String>()) .map(|name| name.chars().take(32).collect::<String>())
.filter(|name| { .filter(|name| {
if let Err(_err) = validate_ownership_by_user_prefix(name, &unix_user) { if let Err(_err) = validate_ownership_or_error(name, &unix_user, db_or_user) {
println!( println!(
"You are not in charge of mysql-{}: '{}'. Skipping.", "You are not in charge of mysql-{}: '{}'. Skipping.",
db_or_user.lowercased(), db_or_user.lowercased(),
@ -60,7 +37,7 @@ pub fn filter_db_or_user_names(
// the name is already truncated to 32 characters. So // the name is already truncated to 32 characters. So
// if there is an error, it's guaranteed to be due to // if there is an error, it's guaranteed to be due to
// invalid characters. // invalid characters.
if let Err(_err) = validate_name_token(name) { if let Err(_err) = validate_name_or_error(name, db_or_user) {
println!( println!(
concat!( concat!(
"{}: {} name '{}' contains invalid characters.\n", "{}: {} name '{}' contains invalid characters.\n",

View File

@ -2,12 +2,9 @@ use clap::Parser;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use crate::{ use crate::{
cli::{ cli::{database_command, mysql_admutils_compatibility::common::filter_db_or_user_names},
database_command,
mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
},
core::{ core::{
common::yn, common::{yn, DbOrUser},
config::{get_config, mysql_connection_from_config, GlobalConfigArgs}, config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
database_operations::{create_database, drop_database, get_database_list}, database_operations::{create_database, drop_database, get_database_list},
database_privilege_operations, database_privilege_operations,

View File

@ -2,12 +2,9 @@ use clap::Parser;
use sqlx::MySqlConnection; use sqlx::MySqlConnection;
use crate::{ use crate::{
cli::{ cli::{mysql_admutils_compatibility::common::filter_db_or_user_names, user_command},
mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
user_command,
},
core::{ core::{
common::{close_database_connection, get_current_unix_user}, common::{close_database_connection, get_current_unix_user, DbOrUser},
config::{get_config, mysql_connection_from_config, GlobalConfigArgs}, config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
user_operations::*, user_operations::*,
}, },

View File

@ -87,84 +87,165 @@ pub fn create_user_group_matching_regex(user: &User) -> String {
} }
} }
pub fn validate_name_token(name: &str) -> anyhow::Result<()> { /// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum NameValidationResult {
Valid,
EmptyString,
InvalidCharacters,
TooLong,
}
pub fn validate_name(name: &str) -> NameValidationResult {
if name.is_empty() { if name.is_empty() {
anyhow::bail!("Database name cannot be empty."); NameValidationResult::EmptyString
} } else if name.len() > 64 {
NameValidationResult::TooLong
if name.len() > 64 { } else if !name
anyhow::bail!("Database name is too long. Maximum length is 64 characters.");
}
if !name
.chars() .chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{ {
anyhow::bail!( NameValidationResult::InvalidCharacters
} else {
NameValidationResult::Valid
}
}
pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
match validate_name(name) {
NameValidationResult::Valid => Ok(()),
NameValidationResult::EmptyString => {
anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
}
NameValidationResult::TooLong => anyhow::bail!(
"{} is too long. Maximum length is 64 characters.",
db_or_user.capitalized()
),
NameValidationResult::InvalidCharacters => anyhow::bail!(
indoc! {r#" indoc! {r#"
Invalid characters in name: '{}' Invalid characters in {} name: '{}'
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted. Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#}, "#},
db_or_user.lowercased(),
name name
); ),
} }
Ok(())
} }
pub fn validate_ownership_by_user_prefix<'a>( #[derive(Debug, PartialEq, Eq)]
pub enum OwnerValidationResult {
// The name is valid and matches one of the given prefixes
Match,
// The name is valid, but none of the given prefixes matched the name
NoMatch,
// The name is empty, which is invalid
StringEmpty,
// The name is in the format "_<postfix>", which is invalid
MissingPrefix,
// The name is in the format "<prefix>_", which is invalid
MissingPostfix,
}
/// Core logic for validating the ownership of a database name.
/// This function checks if the given name matches any of the given prefixes.
/// These prefixes will in most cases be the user's unix username and any
/// unix groups the user is a member of.
pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
if name.is_empty() {
return OwnerValidationResult::StringEmpty;
}
if name.starts_with('_') {
return OwnerValidationResult::MissingPrefix;
}
let (prefix, _) = match name.split_once('_') {
Some(pair) => pair,
None => return OwnerValidationResult::MissingPostfix,
};
if prefixes.iter().any(|g| g == prefix) {
OwnerValidationResult::Match
} else {
OwnerValidationResult::NoMatch
}
}
/// Validate the ownership of a database name or database user name.
/// This function takes the name of a database or user and a unix user,
/// for which it fetches the user's groups. It then checks if the name
/// is prefixed with the user's username or any of the user's groups.
pub fn validate_ownership_or_error<'a>(
name: &'a str, name: &'a str,
user: &User, user: &User,
db_or_user: DbOrUser,
) -> anyhow::Result<&'a str> { ) -> anyhow::Result<&'a str> {
let user_groups = get_unix_groups(user)?; let user_groups = get_unix_groups(user)?;
let prefixes = std::iter::once(user.name.clone())
.chain(user_groups.iter().map(|g| g.name.clone()))
.collect::<Vec<String>>();
let mut split_name = name.split('_'); match validate_ownership_by_prefixes(name, &prefixes) {
OwnerValidationResult::Match => Ok(name),
OwnerValidationResult::NoMatch => {
anyhow::bail!(
indoc! {r#"
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
let prefix = split_name Allowed prefixes:
.next() - {}
.ok_or(anyhow::anyhow!(indoc! {r#" {}
Failed to find prefix. "#},
"#},)) db_or_user.lowercased(),
.and_then(|prefix| { name,
if user.name == prefix || user_groups.iter().any(|g| g.name == prefix) { db_or_user.lowercased(),
Ok(prefix) user.name,
} else { user_groups
anyhow::bail!( .iter()
indoc! {r#" .filter(|g| g.name != user.name)
Invalid prefix: '{}' does not match your username or any of your groups. .map(|g| format!(" - {}", g.name))
Are you sure you are allowed to create databases or users with this prefix? .sorted()
.join("\n"),
Allowed prefixes: );
- {} }
{} _ => anyhow::bail!(
"#}, "'{}' is not a valid {} name.",
prefix, name,
user.name, db_or_user.lowercased()
user_groups ),
.iter()
.filter(|g| g.name != user.name)
.map(|g| format!(" - {}", g.name))
.sorted()
.join("\n"),
);
}
})?;
if !split_name.next().is_some_and(|s| !s.is_empty()) {
anyhow::bail!(
indoc! {r#"
Missing the rest of the name after the user/group prefix.
The name should be in the format: '{}_<name>'
"#},
prefix
);
} }
Ok(prefix)
} }
/// Gracefully close a MySQL connection.
pub async fn close_database_connection(connection: MySqlConnection) { pub async fn close_database_connection(connection: MySqlConnection) {
if let Err(e) = connection if let Err(e) = connection
.close() .close()
@ -203,3 +284,113 @@ pub(crate) fn rev_yn(s: &str) -> Option<bool> {
_ => None, _ => None,
} }
} }
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_yn() {
assert_eq!(yn(true), "Y");
assert_eq!(yn(false), "N");
}
#[test]
fn test_rev_yn() {
assert_eq!(rev_yn("Y"), Some(true));
assert_eq!(rev_yn("y"), Some(true));
assert_eq!(rev_yn("N"), Some(false));
assert_eq!(rev_yn("n"), Some(false));
assert_eq!(rev_yn("X"), None);
}
#[test]
fn test_quote_literal() {
let payload = "' OR 1=1 --";
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
}
#[test]
fn test_quote_identifier() {
let payload = "` OR 1=1 --";
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
}
#[test]
fn test_validate_name() {
assert_eq!(validate_name(""), NameValidationResult::EmptyString);
assert_eq!(
validate_name("abcdefghijklmnopqrstuvwxyz"),
NameValidationResult::Valid
);
assert_eq!(
validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
NameValidationResult::Valid
);
assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
assert_eq!(
validate_name(&c.to_string()),
NameValidationResult::InvalidCharacters
);
}
assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
assert_eq!(
validate_name(&"a".repeat(65)),
NameValidationResult::TooLong
);
}
#[test]
fn test_validate_owner_by_prefixes() {
let prefixes = vec!["user".to_string(), "group".to_string()];
assert_eq!(
validate_ownership_by_prefixes("", &prefixes),
OwnerValidationResult::StringEmpty
);
assert_eq!(
validate_ownership_by_prefixes("user", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("something", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("user-testdb", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("_testdb", &prefixes),
OwnerValidationResult::MissingPrefix
);
assert_eq!(
validate_ownership_by_prefixes("user_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test_db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test-db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
OwnerValidationResult::NoMatch
);
}
}

View File

@ -8,7 +8,7 @@ use sqlx::{prelude::*, MySqlConnection};
use crate::core::{ use crate::core::{
common::{ common::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier, create_user_group_matching_regex, get_current_unix_user, quote_identifier,
validate_name_token, validate_ownership_by_user_prefix, validate_ownership_or_error, validate_name_or_error, DbOrUser
}, },
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS, database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
}; };
@ -112,8 +112,10 @@ pub async fn get_databases_where_user_has_privileges(
/// the database name as a parameter to the query. This means that we have /// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection. /// to validate the database name ourselves to prevent SQL injection.
pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> { pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_token(name).context("Invalid database name")?; validate_name_or_error(name, DbOrUser::Database)
validate_ownership_by_user_prefix(name, user).context("Invalid database name")?; .context(format!("Invalid database name: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
Ok(()) Ok(())
} }

View File

@ -3,11 +3,9 @@ use nix::unistd::User;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{prelude::*, MySqlConnection}; use sqlx::{prelude::*, MySqlConnection};
use crate::core::common::quote_literal; use crate::core::common::{
create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
use super::common::{ validate_ownership_or_error, DbOrUser,
create_user_group_matching_regex, get_current_unix_user, validate_name_token,
validate_ownership_by_user_prefix,
}; };
pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> { pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
@ -242,8 +240,9 @@ pub async fn get_database_user_for_user(
/// the database name as a parameter to the query. This means that we have /// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection. /// to validate the database name ourselves to prevent SQL injection.
pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> { pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_token(name).context(format!("Invalid username: '{}'", name))?; validate_name_or_error(name, DbOrUser::User)
validate_ownership_by_user_prefix(name, user) .context(format!("Invalid username: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?; .context(format!("Invalid username: '{}'", name))?;
Ok(()) Ok(())