server/sql: fixes for new sqlx crate version
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use sqlx::AssertSqlSafe;
|
||||
use sqlx::MySqlConnection;
|
||||
use sqlx::prelude::*;
|
||||
|
||||
@@ -125,12 +126,15 @@ pub async fn create_databases(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result =
|
||||
sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"CREATE DATABASE {}",
|
||||
quote_identifier(&database_name)
|
||||
));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("Failed to create database '{}': {:?}", &database_name, err);
|
||||
@@ -181,12 +185,15 @@ pub async fn drop_databases(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result =
|
||||
sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"DROP DATABASE {}",
|
||||
quote_identifier(&database_name)
|
||||
));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("Failed to drop database '{}': {:?}", &database_name, err);
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use indoc::indoc;
|
||||
use itertools::Itertools;
|
||||
use sqlx::{MySqlConnection, mysql::MySqlRow, prelude::*};
|
||||
use sqlx::{AssertSqlSafe, MySqlConnection, mysql::MySqlRow, prelude::*};
|
||||
|
||||
use crate::{
|
||||
core::{
|
||||
@@ -84,16 +84,17 @@ async fn unsafe_get_database_privileges(
|
||||
database_name: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"SELECT {} FROM `db` WHERE `Db` = ?",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
))
|
||||
.bind(database_name)
|
||||
.fetch_all(connection)
|
||||
.await;
|
||||
));
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement)
|
||||
.bind(database_name)
|
||||
.fetch_all(connection)
|
||||
.await;
|
||||
|
||||
if let Err(e) = &result {
|
||||
tracing::error!(
|
||||
@@ -113,17 +114,18 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
|
||||
user_name: &MySQLUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"SELECT {} FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = '%'",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
))
|
||||
.bind(database_name.as_str())
|
||||
.bind(user_name.as_str())
|
||||
.fetch_optional(connection)
|
||||
.await;
|
||||
));
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(statement)
|
||||
.bind(database_name.as_str())
|
||||
.bind(user_name.as_str())
|
||||
.fetch_optional(connection)
|
||||
.await;
|
||||
|
||||
if let Err(e) = &result {
|
||||
tracing::error!(
|
||||
@@ -189,8 +191,8 @@ pub async fn get_databases_privilege_data(
|
||||
}
|
||||
|
||||
/// TODO: make this constant
|
||||
fn get_all_db_privs_query() -> String {
|
||||
format!(
|
||||
fn get_all_db_privs_query() -> AssertSqlSafe<String> {
|
||||
AssertSqlSafe(format!(
|
||||
indoc! {r"
|
||||
SELECT {} FROM `db` WHERE `db` IN
|
||||
(SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
|
||||
@@ -202,7 +204,7 @@ fn get_all_db_privs_query() -> String {
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
/// Get all database + user + privileges pairs that are owned by the current user.
|
||||
@@ -212,7 +214,7 @@ pub async fn get_all_database_privileges(
|
||||
_db_is_mariadb: bool,
|
||||
group_denylist: &GroupDenylist,
|
||||
) -> ListAllPrivilegesResponse {
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&get_all_db_privs_query())
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(get_all_db_privs_query())
|
||||
.bind(create_user_group_matching_regex(unix_user, group_denylist))
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
@@ -241,7 +243,10 @@ async fn unsafe_apply_privilege_diff(
|
||||
let question_marks =
|
||||
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len() + 1).join(",");
|
||||
|
||||
sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str())
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"INSERT INTO `db` ({tables}) VALUES ({question_marks})"
|
||||
));
|
||||
sqlx::query(statement)
|
||||
.bind(p.db.to_string())
|
||||
.bind(p.user.to_string())
|
||||
.bind(yn(p.select_priv))
|
||||
@@ -280,27 +285,27 @@ async fn unsafe_apply_privilege_diff(
|
||||
}
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
|
||||
.as_str(),
|
||||
)
|
||||
.bind(p.select_priv.map(change_to_yn))
|
||||
.bind(p.insert_priv.map(change_to_yn))
|
||||
.bind(p.update_priv.map(change_to_yn))
|
||||
.bind(p.delete_priv.map(change_to_yn))
|
||||
.bind(p.create_priv.map(change_to_yn))
|
||||
.bind(p.drop_priv.map(change_to_yn))
|
||||
.bind(p.alter_priv.map(change_to_yn))
|
||||
.bind(p.index_priv.map(change_to_yn))
|
||||
.bind(p.create_tmp_table_priv.map(change_to_yn))
|
||||
.bind(p.lock_tables_priv.map(change_to_yn))
|
||||
.bind(p.references_priv.map(change_to_yn))
|
||||
.bind(p.db.to_string())
|
||||
.bind(p.user.to_string())
|
||||
.bind("%")
|
||||
.execute(connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ? AND `Host` = ?"
|
||||
));
|
||||
sqlx::query(statement)
|
||||
.bind(p.select_priv.map(change_to_yn))
|
||||
.bind(p.insert_priv.map(change_to_yn))
|
||||
.bind(p.update_priv.map(change_to_yn))
|
||||
.bind(p.delete_priv.map(change_to_yn))
|
||||
.bind(p.create_priv.map(change_to_yn))
|
||||
.bind(p.drop_priv.map(change_to_yn))
|
||||
.bind(p.alter_priv.map(change_to_yn))
|
||||
.bind(p.index_priv.map(change_to_yn))
|
||||
.bind(p.create_tmp_table_priv.map(change_to_yn))
|
||||
.bind(p.lock_tables_priv.map(change_to_yn))
|
||||
.bind(p.references_priv.map(change_to_yn))
|
||||
.bind(p.db.to_string())
|
||||
.bind(p.user.to_string())
|
||||
.bind("%")
|
||||
.execute(connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
DatabasePrivilegesDiff::Deleted(p) => {
|
||||
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ? AND `Host` = ?")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use indoc::formatdoc;
|
||||
use itertools::Itertools;
|
||||
use sqlx::AssertSqlSafe;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -126,7 +127,8 @@ pub async fn create_database_users(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str())
|
||||
let statement = AssertSqlSafe(format!("CREATE USER {}@'%'", quote_literal(&db_user),));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
@@ -172,7 +174,8 @@ pub async fn drop_database_users(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str())
|
||||
let statement = AssertSqlSafe(format!("DROP USER {}@'%'", quote_literal(&db_user),));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
@@ -205,18 +208,16 @@ pub async fn set_password_for_database_user(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
format!(
|
||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
||||
quote_literal(db_user),
|
||||
quote_literal(password).as_str(),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()));
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
||||
quote_literal(db_user),
|
||||
quote_literal(password).as_str(),
|
||||
));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()));
|
||||
|
||||
if result.is_err() {
|
||||
tracing::error!(
|
||||
@@ -315,13 +316,15 @@ pub async fn lock_database_users(
|
||||
}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| LockUserError::MySqlError(err.to_string()));
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"ALTER USER {}@'%' ACCOUNT LOCK",
|
||||
quote_literal(&db_user),
|
||||
));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| LockUserError::MySqlError(err.to_string()));
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("Failed to lock database user '{}': {:?}", &db_user, err);
|
||||
@@ -375,13 +378,15 @@ pub async fn unlock_database_users(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
|
||||
let statement = AssertSqlSafe(format!(
|
||||
"ALTER USER {}@'%' ACCOUNT UNLOCK",
|
||||
quote_literal(&db_user),
|
||||
));
|
||||
let result = sqlx::query(statement)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("Failed to unlock database user '{}': {:?}", &db_user, err);
|
||||
@@ -459,16 +464,17 @@ pub async fn list_database_users(
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut result = sqlx::query_as::<_, DatabaseUser>(
|
||||
&(if db_is_mariadb {
|
||||
let statement = AssertSqlSafe(
|
||||
if db_is_mariadb {
|
||||
DB_USER_SELECT_STATEMENT_MARIADB.to_string()
|
||||
} else {
|
||||
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
|
||||
} + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'"),
|
||||
)
|
||||
.bind(db_user.as_str())
|
||||
.fetch_optional(&mut *connection)
|
||||
.await;
|
||||
} + "WHERE `mysql`.`user`.`User` = ? AND `mysql`.`user`.`Host` = '%'",
|
||||
);
|
||||
let mut result = sqlx::query_as::<_, DatabaseUser>(statement)
|
||||
.bind(db_user.as_str())
|
||||
.fetch_optional(&mut *connection)
|
||||
.await;
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("Failed to list database user '{}': {:?}", &db_user, err);
|
||||
@@ -496,17 +502,18 @@ pub async fn list_all_database_users_for_unix_user(
|
||||
db_is_mariadb: bool,
|
||||
group_denylist: &GroupDenylist,
|
||||
) -> ListAllUsersResponse {
|
||||
let mut result = sqlx::query_as::<_, DatabaseUser>(
|
||||
&(if db_is_mariadb {
|
||||
let statement = AssertSqlSafe(
|
||||
if db_is_mariadb {
|
||||
DB_USER_SELECT_STATEMENT_MARIADB.to_string()
|
||||
} else {
|
||||
DB_USER_SELECT_STATEMENT_MYSQL.to_string()
|
||||
} + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'"),
|
||||
)
|
||||
.bind(create_user_group_matching_regex(unix_user, group_denylist))
|
||||
.fetch_all(&mut *connection)
|
||||
.await
|
||||
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
|
||||
} + "WHERE `user`.`User` REGEXP ? AND `user`.`Host` = '%'",
|
||||
);
|
||||
let mut result = sqlx::query_as::<_, DatabaseUser>(statement)
|
||||
.bind(create_user_group_matching_regex(unix_user, group_denylist))
|
||||
.fetch_all(&mut *connection)
|
||||
.await
|
||||
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
|
||||
|
||||
if let Err(err) = &result {
|
||||
tracing::error!("Failed to list all database users: {:?}", err);
|
||||
@@ -531,23 +538,21 @@ pub async fn set_databases_where_user_has_privileges(
|
||||
db_user: &mut DatabaseUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let database_list = sqlx::query(
|
||||
formatdoc!(
|
||||
r"
|
||||
SELECT `Db` AS `database`
|
||||
FROM `db`
|
||||
WHERE `User` = ? AND `Host` = '%' AND ({})
|
||||
",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| format!("`{field}` = 'Y'"))
|
||||
.join(" OR "),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.bind(db_user.user.as_str())
|
||||
.fetch_all(&mut *connection)
|
||||
.await;
|
||||
let statement = AssertSqlSafe(formatdoc!(
|
||||
r"
|
||||
SELECT `Db` AS `database`
|
||||
FROM `db`
|
||||
WHERE `User` = ? AND `Host` = '%' AND ({})
|
||||
",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| format!("`{field}` = 'Y'"))
|
||||
.join(" OR "),
|
||||
));
|
||||
let database_list = sqlx::query(statement)
|
||||
.bind(db_user.user.as_str())
|
||||
.fetch_all(&mut *connection)
|
||||
.await;
|
||||
|
||||
if let Err(err) = &database_list {
|
||||
tracing::error!(
|
||||
|
||||
Reference in New Issue
Block a user