use std::collections::BTreeMap; use sqlx::MySqlConnection; use sqlx::prelude::*; use serde::{Deserialize, Serialize}; use crate::core::protocol::CompleteDatabaseNameResponse; use crate::core::protocol::request_validation::GroupDenylist; use crate::core::protocol::request_validation::validate_db_or_user_request; use crate::core::types::DbOrUser; use crate::core::types::MySQLDatabase; use crate::core::types::MySQLUser; use crate::{ core::{ common::UnixUser, protocol::{ CreateDatabaseError, CreateDatabasesResponse, DropDatabaseError, DropDatabasesResponse, ListAllDatabasesError, ListAllDatabasesResponse, ListDatabasesError, ListDatabasesResponse, }, }, server::{common::create_user_group_matching_regex, sql::quote_identifier}, }; // NOTE: this function is unsafe because it does no input validation. pub(super) async fn unsafe_database_exists( database_name: &str, connection: &mut MySqlConnection, ) -> Result { let result = sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") .bind(database_name) .fetch_optional(connection) .await; if let Err(err) = &result { tracing::error!( "Failed to check if database '{}' exists: {:?}", &database_name, err ); } Ok(result?.is_some()) } pub async fn complete_database_name( database_prefix: &str, unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> CompleteDatabaseNameResponse { let result = sqlx::query( r" SELECT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database` FROM `information_schema`.`SCHEMATA` WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') AND `SCHEMA_NAME` REGEXP ? AND `SCHEMA_NAME` LIKE ? ", ) .bind(create_user_group_matching_regex(unix_user, group_denylist)) .bind(format!("{database_prefix}%")) .fetch_all(connection) .await; match result { Ok(rows) => rows .into_iter() .filter_map(|row| { let database: String = row.try_get("database").ok()?; Some(database.into()) }) .collect(), Err(err) => { tracing::error!( "Failed to complete database name for prefix '{}' and user '{}': {:?}", database_prefix, unix_user.username, err ); vec![] } } } pub async fn create_databases( database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> CreateDatabasesResponse { let mut results = BTreeMap::new(); for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(database_name.clone()), unix_user, group_denylist, ) .map_err(CreateDatabaseError::ValidationError) { results.insert(database_name.clone(), Err(err)); continue; } match unsafe_database_exists(&database_name, &mut *connection).await { Ok(true) => { results.insert( database_name.clone(), Err(CreateDatabaseError::DatabaseAlreadyExists), ); continue; } Err(err) => { results.insert( database_name.clone(), Err(CreateDatabaseError::MySqlError(err.to_string())), ); continue; } _ => {} } let result = sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str()) .execute(&mut *connection) .await .map(|_| ()) .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to create database '{}': {:?}", &database_name, err); } results.insert(database_name, result); } results } pub async fn drop_databases( database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> DropDatabasesResponse { let mut results = BTreeMap::new(); for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(database_name.clone()), unix_user, group_denylist, ) .map_err(DropDatabaseError::ValidationError) { results.insert(database_name.clone(), Err(err)); continue; } match unsafe_database_exists(&database_name, &mut *connection).await { Ok(false) => { results.insert( database_name.clone(), Err(DropDatabaseError::DatabaseDoesNotExist), ); continue; } Err(err) => { results.insert( database_name.clone(), Err(DropDatabaseError::MySqlError(err.to_string())), ); continue; } _ => {} } let result = sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str()) .execute(&mut *connection) .await .map(|_| ()) .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to drop database '{}': {:?}", &database_name, err); } results.insert(database_name, result); } results } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct DatabaseRow { pub database: MySQLDatabase, pub tables: Vec, pub users: Vec, pub collation: Option, pub character_set: Option, pub size_bytes: u64, } impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow { fn from_row(row: &sqlx::mysql::MySqlRow) -> Result { Ok(DatabaseRow { database: row.try_get::("database")?.into(), tables: { let s: Option = row.try_get("tables")?; s.and_then(|s| { if s.is_empty() { None } else { Some(s.split(',').map(std::borrow::ToOwned::to_owned).collect()) } }) .unwrap_or_default() }, users: { let s: Option = row.try_get("users")?; s.and_then(|s| { if s.is_empty() { None } else { Some(s.split(',').map(|s| s.to_owned().into()).collect()) } }) .unwrap_or_default() }, collation: row.try_get::, _>("collation")?, character_set: row.try_get::, _>("character_set")?, size_bytes: row.try_get::("size_bytes")?, }) } } pub async fn list_databases( database_names: &[MySQLDatabase], unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> ListDatabasesResponse { let mut results = BTreeMap::new(); for database_name in database_names.iter().cloned() { if let Err(err) = validate_db_or_user_request( &DbOrUser::Database(database_name.clone()), unix_user, group_denylist, ) .map_err(ListDatabasesError::ValidationError) { results.insert(database_name.clone(), Err(err)); continue; } let result = sqlx::query_as::<_, DatabaseRow>( r" SELECT CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`, GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`, GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`, MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`, MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`, CAST(IFNULL( SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`), 0 ) AS UNSIGNED INTEGER) AS `size_bytes` FROM `information_schema`.`SCHEMATA` LEFT OUTER JOIN `information_schema`.`TABLES` ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA` LEFT OUTER JOIN `mysql`.`db` ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB` WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = ? GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME` ", ) .bind(database_name.to_string()) .fetch_optional(&mut *connection) .await .map_err(|err| ListDatabasesError::MySqlError(err.to_string())) .and_then(|database| { database.map_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist), Ok) }); if let Err(err) = &result { tracing::error!("Failed to list database '{}': {:?}", &database_name, err); } // TODO: should we assert that the users are also owned by the unix_user from the request? results.insert(database_name, result); } results } pub async fn list_all_databases_for_user( unix_user: &UnixUser, connection: &mut MySqlConnection, _db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> ListAllDatabasesResponse { let result = sqlx::query_as::<_, DatabaseRow>( r" SELECT CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`, GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`, GROUP_CONCAT(DISTINCT CAST(`mysql`.`db`.`User` AS CHAR(64)) SEPARATOR ',') AS `users`, MAX(`information_schema`.`SCHEMATA`.`DEFAULT_COLLATION_NAME`) AS `collation`, MAX(`information_schema`.`SCHEMATA`.`DEFAULT_CHARACTER_SET_NAME`) AS `character_set`, CAST(IFNULL( SUM(`information_schema`.`TABLES`.`DATA_LENGTH` + `information_schema`.`TABLES`.`INDEX_LENGTH`), 0 ) AS UNSIGNED INTEGER) AS `size_bytes` FROM `information_schema`.`SCHEMATA` LEFT OUTER JOIN `information_schema`.`TABLES` ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `TABLES`.`TABLE_SCHEMA` LEFT OUTER JOIN `mysql`.`db` ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB` WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys') AND `information_schema`.`SCHEMATA`.`SCHEMA_NAME` REGEXP ? GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME` ", ) .bind(create_user_group_matching_regex(unix_user, group_denylist)) .fetch_all(connection) .await .map_err(|err| ListAllDatabasesError::MySqlError(err.to_string())); // TODO: should we assert that the users are also owned by the unix_user from the request? if let Err(err) = &result { tracing::error!( "Failed to list databases for user '{}': {:?}", unix_user.username, err ); } result }