core/common: make testable, fix some status messages

This commit is contained in:
Oystein Kristoffer Tveit 2024-08-08 19:30:27 +02:00
parent 69870147f5
commit 39a3f8ffd1
Signed by: oysteikt
GPG Key ID: 9F2F7D8250F35146
6 changed files with 266 additions and 103 deletions

View File

@ -1,30 +1,7 @@
use crate::core::common::{
get_current_unix_user, validate_name_token, validate_ownership_by_user_prefix,
get_current_unix_user, validate_name_or_error, validate_ownership_or_error, DbOrUser,
};
/// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
}
}
/// In contrast to the new implementation which reports errors on any invalid name
/// for any reason, mysql-admutils would only log the error and skip that particular
/// name. This function replicates that behavior.
@ -45,7 +22,7 @@ pub fn filter_db_or_user_names(
// here.
.map(|name| name.chars().take(32).collect::<String>())
.filter(|name| {
if let Err(_err) = validate_ownership_by_user_prefix(name, &unix_user) {
if let Err(_err) = validate_ownership_or_error(name, &unix_user, db_or_user) {
println!(
"You are not in charge of mysql-{}: '{}'. Skipping.",
db_or_user.lowercased(),
@ -60,7 +37,7 @@ pub fn filter_db_or_user_names(
// the name is already truncated to 32 characters. So
// if there is an error, it's guaranteed to be due to
// invalid characters.
if let Err(_err) = validate_name_token(name) {
if let Err(_err) = validate_name_or_error(name, db_or_user) {
println!(
concat!(
"{}: {} name '{}' contains invalid characters.\n",

View File

@ -2,12 +2,9 @@ use clap::Parser;
use sqlx::MySqlConnection;
use crate::{
cli::{
database_command,
mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
},
cli::{database_command, mysql_admutils_compatibility::common::filter_db_or_user_names},
core::{
common::yn,
common::{yn, DbOrUser},
config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
database_operations::{create_database, drop_database, get_database_list},
database_privilege_operations,

View File

@ -2,12 +2,9 @@ use clap::Parser;
use sqlx::MySqlConnection;
use crate::{
cli::{
mysql_admutils_compatibility::common::{filter_db_or_user_names, DbOrUser},
user_command,
},
cli::{mysql_admutils_compatibility::common::filter_db_or_user_names, user_command},
core::{
common::{close_database_connection, get_current_unix_user},
common::{close_database_connection, get_current_unix_user, DbOrUser},
config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
user_operations::*,
},

View File

@ -87,84 +87,165 @@ pub fn create_user_group_matching_regex(user: &User) -> String {
}
}
pub fn validate_name_token(name: &str) -> anyhow::Result<()> {
/// This enum is used to differentiate between database and user operations.
/// Their output are very similar, but there are slight differences in the words used.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DbOrUser {
Database,
User,
}
impl DbOrUser {
pub fn lowercased(&self) -> String {
match self {
DbOrUser::Database => "database".to_string(),
DbOrUser::User => "user".to_string(),
}
}
pub fn capitalized(&self) -> String {
match self {
DbOrUser::Database => "Database".to_string(),
DbOrUser::User => "User".to_string(),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum NameValidationResult {
Valid,
EmptyString,
InvalidCharacters,
TooLong,
}
pub fn validate_name(name: &str) -> NameValidationResult {
if name.is_empty() {
anyhow::bail!("Database name cannot be empty.");
}
if name.len() > 64 {
anyhow::bail!("Database name is too long. Maximum length is 64 characters.");
}
if !name
NameValidationResult::EmptyString
} else if name.len() > 64 {
NameValidationResult::TooLong
} else if !name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
anyhow::bail!(
NameValidationResult::InvalidCharacters
} else {
NameValidationResult::Valid
}
}
pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
match validate_name(name) {
NameValidationResult::Valid => Ok(()),
NameValidationResult::EmptyString => {
anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
}
NameValidationResult::TooLong => anyhow::bail!(
"{} is too long. Maximum length is 64 characters.",
db_or_user.capitalized()
),
NameValidationResult::InvalidCharacters => anyhow::bail!(
indoc! {r#"
Invalid characters in name: '{}'
Invalid characters in {} name: '{}'
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#},
db_or_user.lowercased(),
name
);
),
}
Ok(())
}
pub fn validate_ownership_by_user_prefix<'a>(
#[derive(Debug, PartialEq, Eq)]
pub enum OwnerValidationResult {
// The name is valid and matches one of the given prefixes
Match,
// The name is valid, but none of the given prefixes matched the name
NoMatch,
// The name is empty, which is invalid
StringEmpty,
// The name is in the format "_<postfix>", which is invalid
MissingPrefix,
// The name is in the format "<prefix>_", which is invalid
MissingPostfix,
}
/// Core logic for validating the ownership of a database name.
/// This function checks if the given name matches any of the given prefixes.
/// These prefixes will in most cases be the user's unix username and any
/// unix groups the user is a member of.
pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
if name.is_empty() {
return OwnerValidationResult::StringEmpty;
}
if name.starts_with('_') {
return OwnerValidationResult::MissingPrefix;
}
let (prefix, _) = match name.split_once('_') {
Some(pair) => pair,
None => return OwnerValidationResult::MissingPostfix,
};
if prefixes.iter().any(|g| g == prefix) {
OwnerValidationResult::Match
} else {
OwnerValidationResult::NoMatch
}
}
/// Validate the ownership of a database name or database user name.
/// This function takes the name of a database or user and a unix user,
/// for which it fetches the user's groups. It then checks if the name
/// is prefixed with the user's username or any of the user's groups.
pub fn validate_ownership_or_error<'a>(
name: &'a str,
user: &User,
db_or_user: DbOrUser,
) -> anyhow::Result<&'a str> {
let user_groups = get_unix_groups(user)?;
let prefixes = std::iter::once(user.name.clone())
.chain(user_groups.iter().map(|g| g.name.clone()))
.collect::<Vec<String>>();
let mut split_name = name.split('_');
match validate_ownership_by_prefixes(name, &prefixes) {
OwnerValidationResult::Match => Ok(name),
OwnerValidationResult::NoMatch => {
anyhow::bail!(
indoc! {r#"
Invalid {} name prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create {} names with this prefix?
let prefix = split_name
.next()
.ok_or(anyhow::anyhow!(indoc! {r#"
Failed to find prefix.
"#},))
.and_then(|prefix| {
if user.name == prefix || user_groups.iter().any(|g| g.name == prefix) {
Ok(prefix)
} else {
anyhow::bail!(
indoc! {r#"
Invalid prefix: '{}' does not match your username or any of your groups.
Are you sure you are allowed to create databases or users with this prefix?
Allowed prefixes:
- {}
{}
"#},
prefix,
user.name,
user_groups
.iter()
.filter(|g| g.name != user.name)
.map(|g| format!(" - {}", g.name))
.sorted()
.join("\n"),
);
}
})?;
if !split_name.next().is_some_and(|s| !s.is_empty()) {
anyhow::bail!(
indoc! {r#"
Missing the rest of the name after the user/group prefix.
The name should be in the format: '{}_<name>'
"#},
prefix
);
Allowed prefixes:
- {}
{}
"#},
db_or_user.lowercased(),
name,
db_or_user.lowercased(),
user.name,
user_groups
.iter()
.filter(|g| g.name != user.name)
.map(|g| format!(" - {}", g.name))
.sorted()
.join("\n"),
);
}
_ => anyhow::bail!(
"'{}' is not a valid {} name.",
name,
db_or_user.lowercased()
),
}
Ok(prefix)
}
/// Gracefully close a MySQL connection.
pub async fn close_database_connection(connection: MySqlConnection) {
if let Err(e) = connection
.close()
@ -203,3 +284,113 @@ pub(crate) fn rev_yn(s: &str) -> Option<bool> {
_ => None,
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_yn() {
assert_eq!(yn(true), "Y");
assert_eq!(yn(false), "N");
}
#[test]
fn test_rev_yn() {
assert_eq!(rev_yn("Y"), Some(true));
assert_eq!(rev_yn("y"), Some(true));
assert_eq!(rev_yn("N"), Some(false));
assert_eq!(rev_yn("n"), Some(false));
assert_eq!(rev_yn("X"), None);
}
#[test]
fn test_quote_literal() {
let payload = "' OR 1=1 --";
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
}
#[test]
fn test_quote_identifier() {
let payload = "` OR 1=1 --";
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
}
#[test]
fn test_validate_name() {
assert_eq!(validate_name(""), NameValidationResult::EmptyString);
assert_eq!(
validate_name("abcdefghijklmnopqrstuvwxyz"),
NameValidationResult::Valid
);
assert_eq!(
validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
NameValidationResult::Valid
);
assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
assert_eq!(
validate_name(&c.to_string()),
NameValidationResult::InvalidCharacters
);
}
assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
assert_eq!(
validate_name(&"a".repeat(65)),
NameValidationResult::TooLong
);
}
#[test]
fn test_validate_owner_by_prefixes() {
let prefixes = vec!["user".to_string(), "group".to_string()];
assert_eq!(
validate_ownership_by_prefixes("", &prefixes),
OwnerValidationResult::StringEmpty
);
assert_eq!(
validate_ownership_by_prefixes("user", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("something", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("user-testdb", &prefixes),
OwnerValidationResult::MissingPostfix
);
assert_eq!(
validate_ownership_by_prefixes("_testdb", &prefixes),
OwnerValidationResult::MissingPrefix
);
assert_eq!(
validate_ownership_by_prefixes("user_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_testdb", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test_db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("group_test-db", &prefixes),
OwnerValidationResult::Match
);
assert_eq!(
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
OwnerValidationResult::NoMatch
);
}
}

View File

@ -8,7 +8,7 @@ use sqlx::{prelude::*, MySqlConnection};
use crate::core::{
common::{
create_user_group_matching_regex, get_current_unix_user, quote_identifier,
validate_name_token, validate_ownership_by_user_prefix,
validate_name_or_error, validate_ownership_or_error, DbOrUser,
},
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
};
@ -112,8 +112,10 @@ pub async fn get_databases_where_user_has_privileges(
/// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection.
pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_token(name).context("Invalid database name")?;
validate_ownership_by_user_prefix(name, user).context("Invalid database name")?;
validate_name_or_error(name, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::Database)
.context(format!("Invalid database name: '{}'", name))?;
Ok(())
}

View File

@ -3,11 +3,9 @@ use nix::unistd::User;
use serde::{Deserialize, Serialize};
use sqlx::{prelude::*, MySqlConnection};
use crate::core::common::quote_literal;
use super::common::{
create_user_group_matching_regex, get_current_unix_user, validate_name_token,
validate_ownership_by_user_prefix,
use crate::core::common::{
create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
validate_ownership_or_error, DbOrUser,
};
pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
@ -242,8 +240,9 @@ pub async fn get_database_user_for_user(
/// the database name as a parameter to the query. This means that we have
/// to validate the database name ourselves to prevent SQL injection.
pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
validate_name_token(name).context(format!("Invalid username: '{}'", name))?;
validate_ownership_by_user_prefix(name, user)
validate_name_or_error(name, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?;
validate_ownership_or_error(name, user, DbOrUser::User)
.context(format!("Invalid username: '{}'", name))?;
Ok(())