diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index 53fbbe5..61dc9a5 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use sqlx::AssertSqlSafe; use sqlx::MySqlConnection; use sqlx::prelude::*; @@ -125,12 +126,15 @@ pub async fn create_databases( _ => {} } - let result = - sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str()) - .execute(&mut *connection) - .await - .map(|_| ()) - .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); + let statement = AssertSqlSafe(format!( + "CREATE DATABASE {}", + quote_identifier(&database_name) + )); + let result = sqlx::query(statement) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to create database '{}': {:?}", &database_name, err); @@ -181,12 +185,15 @@ pub async fn drop_databases( _ => {} } - let result = - sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str()) - .execute(&mut *connection) - .await - .map(|_| ()) - .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); + let statement = AssertSqlSafe(format!( + "DROP DATABASE {}", + quote_identifier(&database_name) + )); + let result = sqlx::query(statement) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| DropDatabaseError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to drop database '{}': {:?}", &database_name, err); diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index de2bbf3..b0c5027 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -18,7 +18,7 @@ use std::collections::{BTreeMap, BTreeSet}; use indoc::indoc; use itertools::Itertools; -use sqlx::{MySqlConnection, mysql::MySqlRow, prelude::*}; +use sqlx::{AssertSqlSafe, MySqlConnection, mysql::MySqlRow, prelude::*}; use crate::{ core::{ @@ -84,16 +84,17 @@ async fn unsafe_get_database_privileges( database_name: &str, connection: &mut MySqlConnection, ) -> Result, sqlx::Error> { - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + let statement = AssertSqlSafe(format!( "SELECT {} FROM `db` WHERE `Db` = ?", DATABASE_PRIVILEGE_FIELDS .iter() .map(|field| quote_identifier(field)) .join(","), - )) - .bind(database_name) - .fetch_all(connection) - .await; + )); + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement) + .bind(database_name) + .fetch_all(connection) + .await; if let Err(e) = &result { tracing::error!( @@ -113,17 +114,18 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair( user_name: &MySQLUser, connection: &mut MySqlConnection, ) -> Result, sqlx::Error> { - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( + let statement = AssertSqlSafe(format!( "SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'", DATABASE_PRIVILEGE_FIELDS .iter() .map(|field| quote_identifier(field)) .join(","), - )) - .bind(database_name.as_str()) - .bind(user_name.as_str()) - .fetch_optional(connection) - .await; + )); + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement) + .bind(database_name.as_str()) + .bind(user_name.as_str()) + .fetch_optional(connection) + .await; if let Err(e) = &result { tracing::error!( @@ -189,8 +191,8 @@ pub async fn get_databases_privilege_data( } /// TODO: make this constant -fn get_all_db_privs_query() -> String { - format!( +fn get_all_db_privs_query() -> AssertSqlSafe { + AssertSqlSafe(format!( indoc! {r" SELECT {} FROM `db` WHERE `db` IN (SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database` @@ -202,7 +204,7 @@ fn get_all_db_privs_query() -> String { .iter() .map(|field| quote_identifier(field)) .join(","), - ) + )) } /// Get all database + user + privileges pairs that are owned by the current user. @@ -212,7 +214,7 @@ pub async fn get_all_database_privileges( _db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> ListAllPrivilegesResponse { - let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query()) + let result = sqlx::query_as::<_, DatabasePrivilegeRow>(get_all_db_privs_query()) .bind(create_user_group_matching_regex(unix_user, group_denylist)) .fetch_all(connection) .await @@ -241,7 +243,10 @@ async fn unsafe_apply_privilege_diff( let question_marks = std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(","); - sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str()) + let statement = AssertSqlSafe(format!( + "INSERT INTO `db` ({tables}) VALUES ({question_marks})" + )); + sqlx::query(statement) .bind(p.db.to_string()) .bind(p.user.to_string()) .bind(yn(p.select_priv)) @@ -280,27 +285,27 @@ async fn unsafe_apply_privilege_diff( } } - sqlx::query( - format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?") - .as_str(), - ) - .bind(p.select_priv.map(change_to_yn)) - .bind(p.insert_priv.map(change_to_yn)) - .bind(p.update_priv.map(change_to_yn)) - .bind(p.delete_priv.map(change_to_yn)) - .bind(p.create_priv.map(change_to_yn)) - .bind(p.drop_priv.map(change_to_yn)) - .bind(p.alter_priv.map(change_to_yn)) - .bind(p.index_priv.map(change_to_yn)) - .bind(p.create_tmp_table_priv.map(change_to_yn)) - .bind(p.lock_tables_priv.map(change_to_yn)) - .bind(p.references_priv.map(change_to_yn)) - .bind(p.db.to_string()) - .bind(p.user.to_string()) - .bind("%") - .execute(connection) - .await - .map(|_| ()) + let statement = AssertSqlSafe(format!( + "UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?" + )); + sqlx::query(statement) + .bind(p.select_priv.map(change_to_yn)) + .bind(p.insert_priv.map(change_to_yn)) + .bind(p.update_priv.map(change_to_yn)) + .bind(p.delete_priv.map(change_to_yn)) + .bind(p.create_priv.map(change_to_yn)) + .bind(p.drop_priv.map(change_to_yn)) + .bind(p.alter_priv.map(change_to_yn)) + .bind(p.index_priv.map(change_to_yn)) + .bind(p.create_tmp_table_priv.map(change_to_yn)) + .bind(p.lock_tables_priv.map(change_to_yn)) + .bind(p.references_priv.map(change_to_yn)) + .bind(p.db.to_string()) + .bind(p.user.to_string()) + .bind("%") + .execute(connection) + .await + .map(|_| ()) } DatabasePrivilegesDiff::Deleted(p) => { sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?") diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index c296a46..49a84d7 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -1,5 +1,6 @@ use indoc::formatdoc; use itertools::Itertools; +use sqlx::AssertSqlSafe; use std::collections::BTreeMap; use serde::{Deserialize, Serialize}; @@ -126,7 +127,8 @@ pub async fn create_database_users( _ => {} } - let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str()) + let statement = AssertSqlSafe(format!("CREATE USER {}@'%'", quote_literal(&db_user),)); + let result = sqlx::query(statement) .execute(&mut *connection) .await .map(|_| ()) @@ -172,7 +174,8 @@ pub async fn drop_database_users( _ => {} } - let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str()) + let statement = AssertSqlSafe(format!("DROP USER {}@'%'", quote_literal(&db_user),)); + let result = sqlx::query(statement) .execute(&mut *connection) .await .map(|_| ()) @@ -205,18 +208,16 @@ pub async fn set_password_for_database_user( _ => {} } - let result = sqlx::query( - format!( - "ALTER USER {}@'%' IDENTIFIED BY {}", - quote_literal(db_user), - quote_literal(password).as_str(), - ) - .as_str(), - ) - .execute(&mut *connection) - .await - .map(|_| ()) - .map_err(|err| SetPasswordError::MySqlError(err.to_string())); + let statement = AssertSqlSafe(format!( + "ALTER USER {}@'%' IDENTIFIED BY {}", + quote_literal(db_user), + quote_literal(password).as_str(), + )); + let result = sqlx::query(statement) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| SetPasswordError::MySqlError(err.to_string())); if result.is_err() { tracing::error!( @@ -315,13 +316,15 @@ pub async fn lock_database_users( } } - let result = sqlx::query( - format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(), - ) - .execute(&mut *connection) - .await - .map(|_| ()) - .map_err(|err| LockUserError::MySqlError(err.to_string())); + let statement = AssertSqlSafe(format!( + "ALTER USER {}@'%' ACCOUNT LOCK", + quote_literal(&db_user), + )); + let result = sqlx::query(statement) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| LockUserError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to lock database user '{}': {:?}", &db_user, err); @@ -375,13 +378,15 @@ pub async fn unlock_database_users( _ => {} } - let result = sqlx::query( - format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(), - ) - .execute(&mut *connection) - .await - .map(|_| ()) - .map_err(|err| UnlockUserError::MySqlError(err.to_string())); + let statement = AssertSqlSafe(format!( + "ALTER USER {}@'%' ACCOUNT UNLOCK", + quote_literal(&db_user), + )); + let result = sqlx::query(statement) + .execute(&mut *connection) + .await + .map(|_| ()) + .map_err(|err| UnlockUserError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to unlock database user '{}': {:?}", &db_user, err); @@ -459,16 +464,17 @@ pub async fn list_database_users( continue; } - let mut result = sqlx::query_as::<_, DatabaseUser>( - &(if db_is_mariadb { + let statement = AssertSqlSafe( + if db_is_mariadb { DB_USER_SELECT_STATEMENT_MARIADB.to_string() } else { DB_USER_SELECT_STATEMENT_MYSQL.to_string() - } + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'"), - ) - .bind(db_user.as_str()) - .fetch_optional(&mut *connection) - .await; + } + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'", + ); + let mut result = sqlx::query_as::<_, DatabaseUser>(statement) + .bind(db_user.as_str()) + .fetch_optional(&mut *connection) + .await; if let Err(err) = &result { tracing::error!("Failed to list database user '{}': {:?}", &db_user, err); @@ -496,17 +502,18 @@ pub async fn list_all_database_users_for_unix_user( db_is_mariadb: bool, group_denylist: &GroupDenylist, ) -> ListAllUsersResponse { - let mut result = sqlx::query_as::<_, DatabaseUser>( - &(if db_is_mariadb { + let statement = AssertSqlSafe( + if db_is_mariadb { DB_USER_SELECT_STATEMENT_MARIADB.to_string() } else { DB_USER_SELECT_STATEMENT_MYSQL.to_string() - } + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'"), - ) - .bind(create_user_group_matching_regex(unix_user, group_denylist)) - .fetch_all(&mut *connection) - .await - .map_err(|err| ListAllUsersError::MySqlError(err.to_string())); + } + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'", + ); + let mut result = sqlx::query_as::<_, DatabaseUser>(statement) + .bind(create_user_group_matching_regex(unix_user, group_denylist)) + .fetch_all(&mut *connection) + .await + .map_err(|err| ListAllUsersError::MySqlError(err.to_string())); if let Err(err) = &result { tracing::error!("Failed to list all database users: {:?}", err); @@ -531,23 +538,21 @@ pub async fn set_databases_where_user_has_privileges( db_user: &mut DatabaseUser, connection: &mut MySqlConnection, ) -> Result<(), sqlx::Error> { - let database_list = sqlx::query( - formatdoc!( - r" - SELECT `Db` AS `database` - FROM `db` - WHERE `User` = ? AND `Host` = '%' AND ({}) - ", - DATABASE_PRIVILEGE_FIELDS - .iter() - .map(|field| format!("`{field}` = 'Y'")) - .join(" OR "), - ) - .as_str(), - ) - .bind(db_user.user.as_str()) - .fetch_all(&mut *connection) - .await; + let statement = AssertSqlSafe(formatdoc!( + r" + SELECT `Db` AS `database` + FROM `db` + WHERE `User` = ? AND `Host` = '%' AND ({}) + ", + DATABASE_PRIVILEGE_FIELDS + .iter() + .map(|field| format!("`{field}` = 'Y'")) + .join(" OR "), + )); + let database_list = sqlx::query(statement) + .bind(db_user.user.as_str()) + .fetch_all(&mut *connection) + .await; if let Err(err) = &database_list { tracing::error!(