Rewrite entire codebase to split into client and server
This commit is contained in:
Cargo.lockCargo.toml
src
cli.rs
cli
common.rsdatabase_command.rsmysql_admutils_compatibility.rs
core.rsmysql_admutils_compatibility
user_command.rscore
bootstrap.rs
main.rsserver.rsbootstrap
common.rsdatabase_operations.rsdatabase_privileges.rsprotocol.rsprotocol
user_operations.rsserver
77
src/server/command.rs
Normal file
77
src/server/command.rs
Normal file
@ -0,0 +1,77 @@
|
||||
use std::os::fd::FromRawFd;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
|
||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||
use tokio::net::UnixStream as TokioUnixStream;
|
||||
|
||||
use crate::core::bootstrap::authenticated_unix_socket;
|
||||
use crate::core::common::UnixUser;
|
||||
use crate::server::config::read_config_from_path_with_arg_overrides;
|
||||
use crate::server::server_loop::listen_for_incoming_connections;
|
||||
use crate::server::{
|
||||
config::{ServerConfig, ServerConfigArgs},
|
||||
server_loop::handle_requests_for_single_session,
|
||||
};
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
pub struct ServerArgs {
|
||||
#[command(subcommand)]
|
||||
subcmd: ServerCommand,
|
||||
|
||||
#[command(flatten)]
|
||||
config_overrides: ServerConfigArgs,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
pub enum ServerCommand {
|
||||
#[command()]
|
||||
Listen,
|
||||
|
||||
#[command()]
|
||||
SocketActivate,
|
||||
}
|
||||
|
||||
pub async fn handle_command(
|
||||
socket_path: Option<PathBuf>,
|
||||
config_path: Option<PathBuf>,
|
||||
args: ServerArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
|
||||
|
||||
// if let Err(e) = &result {
|
||||
// eprintln!("{}", e);
|
||||
// }
|
||||
|
||||
match args.subcmd {
|
||||
ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
|
||||
ServerCommand::SocketActivate => socket_activate(config).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
|
||||
// TODO: allow getting socket path from other socket activation sources
|
||||
let mut conn = get_socket_from_systemd().await?;
|
||||
let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?;
|
||||
let unix_user = UnixUser::from_uid(uid.into())?;
|
||||
handle_requests_for_single_session(conn, &unix_user, &config).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
|
||||
let fd = std::env::var("LISTEN_FDS")
|
||||
.context("LISTEN_FDS not set, not running under systemd?")?
|
||||
.parse::<i32>()
|
||||
.context("Failed to parse LISTEN_FDS")?;
|
||||
|
||||
if fd != 1 {
|
||||
return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd));
|
||||
}
|
||||
|
||||
let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
|
||||
let socket = TokioUnixStream::from_std(std_unix_stream)?;
|
||||
Ok(socket)
|
||||
}
|
11
src/server/common.rs
Normal file
11
src/server/common.rs
Normal file
@ -0,0 +1,11 @@
|
||||
use crate::core::common::UnixUser;
|
||||
|
||||
/// This function creates a regex that matches items (users, databases)
|
||||
/// that belong to the user or any of the user's groups.
|
||||
pub fn create_user_group_matching_regex(user: &UnixUser) -> String {
|
||||
if user.groups.is_empty() {
|
||||
format!("{}(_.+)?", user.username)
|
||||
} else {
|
||||
format!("({}|{})(_.+)?", user.username, user.groups.join("|"))
|
||||
}
|
||||
}
|
117
src/server/config.rs
Normal file
117
src/server/config.rs
Normal file
@ -0,0 +1,117 @@
|
||||
use std::{fs, path::PathBuf, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use clap::Parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection};
|
||||
|
||||
use crate::core::common::DEFAULT_CONFIG_PATH;
|
||||
|
||||
pub const DEFAULT_PORT: u16 = 3306;
|
||||
pub const DEFAULT_TIMEOUT: u64 = 2;
|
||||
|
||||
// NOTE: this might look empty now, and the extra wrapping for the mysql
|
||||
// config seems unnecessary, but it will be useful later when we
|
||||
// add more configuration options.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ServerConfig {
|
||||
pub mysql: MysqlConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename = "mysql")]
|
||||
pub struct MysqlConfig {
|
||||
pub host: String,
|
||||
pub port: Option<u16>,
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
pub struct ServerConfigArgs {
|
||||
/// Hostname of the MySQL server.
|
||||
#[arg(long, value_name = "HOST", global = true)]
|
||||
mysql_host: Option<String>,
|
||||
|
||||
/// Port of the MySQL server.
|
||||
#[arg(long, value_name = "PORT", global = true)]
|
||||
mysql_port: Option<u16>,
|
||||
|
||||
/// Username to use for the MySQL connection.
|
||||
#[arg(long, value_name = "USER", global = true)]
|
||||
mysql_user: Option<String>,
|
||||
|
||||
/// Path to a file containing the MySQL password.
|
||||
#[arg(long, value_name = "PATH", global = true)]
|
||||
mysql_password_file: Option<String>,
|
||||
|
||||
/// Seconds to wait for the MySQL connection to be established.
|
||||
#[arg(long, value_name = "SECONDS", global = true)]
|
||||
mysql_connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
/// Use the arguments and whichever configuration file which might or might not
|
||||
/// be found and default values to determine the configuration for the program.
|
||||
pub fn read_config_from_path_with_arg_overrides(
|
||||
config_path: Option<PathBuf>,
|
||||
args: ServerConfigArgs,
|
||||
) -> anyhow::Result<ServerConfig> {
|
||||
let config = read_config_form_path(config_path)?;
|
||||
|
||||
let mysql = &config.mysql;
|
||||
|
||||
let password = if let Some(path) = args.mysql_password_file {
|
||||
fs::read_to_string(path)
|
||||
.context("Failed to read MySQL password file")
|
||||
.map(|s| s.trim().to_owned())?
|
||||
} else {
|
||||
mysql.password.to_owned()
|
||||
};
|
||||
|
||||
Ok(ServerConfig {
|
||||
mysql: MysqlConfig {
|
||||
host: args.mysql_host.unwrap_or(mysql.host.to_owned()),
|
||||
port: args.mysql_port.or(mysql.port),
|
||||
username: args.mysql_user.unwrap_or(mysql.username.to_owned()),
|
||||
password,
|
||||
timeout: args.mysql_connect_timeout.or(mysql.timeout),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
|
||||
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
|
||||
|
||||
fs::read_to_string(&config_path)
|
||||
.context(format!(
|
||||
"Failed to read config file from {:?}",
|
||||
&config_path
|
||||
))
|
||||
.and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
|
||||
.context(format!(
|
||||
"Failed to parse config file from {:?}",
|
||||
&config_path
|
||||
))
|
||||
}
|
||||
|
||||
/// Use the provided configuration to establish a connection to a MySQL server.
|
||||
pub async fn create_mysql_connection_from_config(
|
||||
config: &MysqlConfig,
|
||||
) -> anyhow::Result<MySqlConnection> {
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
|
||||
MySqlConnectOptions::new()
|
||||
.host(&config.host)
|
||||
.username(&config.username)
|
||||
.password(&config.password)
|
||||
.port(config.port.unwrap_or(DEFAULT_PORT))
|
||||
.database("mysql")
|
||||
.connect(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection.context("Failed to connect to MySQL"),
|
||||
Err(_) => Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to MySQL"),
|
||||
}
|
||||
}
|
158
src/server/input_sanitization.rs
Normal file
158
src/server/input_sanitization.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use crate::core::{
|
||||
common::UnixUser,
|
||||
protocol::server_responses::{NameValidationError, OwnerValidationError},
|
||||
};
|
||||
|
||||
const MAX_NAME_LENGTH: usize = 64;
|
||||
|
||||
pub fn validate_name(name: &str) -> Result<(), NameValidationError> {
|
||||
if name.is_empty() {
|
||||
Err(NameValidationError::EmptyString)
|
||||
} else if name.len() > MAX_NAME_LENGTH {
|
||||
Err(NameValidationError::TooLong)
|
||||
} else if !name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
Err(NameValidationError::InvalidCharacters)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_ownership_by_unix_user(
|
||||
name: &str,
|
||||
user: &UnixUser,
|
||||
) -> Result<(), OwnerValidationError> {
|
||||
let prefixes = std::iter::once(user.username.clone())
|
||||
.chain(user.groups.iter().cloned())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
validate_ownership_by_prefixes(name, &prefixes)
|
||||
}
|
||||
|
||||
/// Core logic for validating the ownership of a database name.
|
||||
/// This function checks if the given name matches any of the given prefixes.
|
||||
/// These prefixes will in most cases be the user's unix username and any
|
||||
/// unix groups the user is a member of.
|
||||
pub fn validate_ownership_by_prefixes(
|
||||
name: &str,
|
||||
prefixes: &[String],
|
||||
) -> Result<(), OwnerValidationError> {
|
||||
if name.is_empty() {
|
||||
return Err(OwnerValidationError::StringEmpty);
|
||||
}
|
||||
|
||||
if name.starts_with('_') {
|
||||
return Err(OwnerValidationError::MissingPrefix);
|
||||
}
|
||||
|
||||
let (prefix, _) = match name.split_once('_') {
|
||||
Some(pair) => pair,
|
||||
None => return Err(OwnerValidationError::MissingPostfix),
|
||||
};
|
||||
|
||||
if !prefixes.iter().any(|g| g == prefix) {
|
||||
return Err(OwnerValidationError::NoMatch);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn quote_literal(s: &str) -> String {
|
||||
format!("'{}'", s.replace('\'', r"\'"))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn quote_identifier(s: &str) -> String {
|
||||
format!("`{}`", s.replace('`', r"\`"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_quote_literal() {
|
||||
let payload = "' OR 1=1 --";
|
||||
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quote_identifier() {
|
||||
let payload = "` OR 1=1 --";
|
||||
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_name() {
|
||||
assert_eq!(validate_name(""), Err(NameValidationError::EmptyString));
|
||||
assert_eq!(validate_name("abcdefghijklmnopqrstuvwxyz"), Ok(()));
|
||||
assert_eq!(validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"), Ok(()));
|
||||
assert_eq!(validate_name("0123456789_-"), Ok(()));
|
||||
|
||||
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
|
||||
assert_eq!(
|
||||
validate_name(&c.to_string()),
|
||||
Err(NameValidationError::InvalidCharacters)
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(validate_name(&"a".repeat(MAX_NAME_LENGTH)), Ok(()));
|
||||
|
||||
assert_eq!(
|
||||
validate_name(&"a".repeat(MAX_NAME_LENGTH + 1)),
|
||||
Err(NameValidationError::TooLong)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_owner_by_prefixes() {
|
||||
let prefixes = vec!["user".to_string(), "group".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("", &prefixes),
|
||||
Err(OwnerValidationError::StringEmpty)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("user", &prefixes),
|
||||
Err(OwnerValidationError::MissingPostfix)
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("something", &prefixes),
|
||||
Err(OwnerValidationError::MissingPostfix)
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("user-testdb", &prefixes),
|
||||
Err(OwnerValidationError::MissingPostfix)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("_testdb", &prefixes),
|
||||
Err(OwnerValidationError::MissingPrefix)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("user_testdb", &prefixes),
|
||||
Ok(())
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("group_testdb", &prefixes),
|
||||
Ok(())
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("group_test_db", &prefixes),
|
||||
Ok(())
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("group_test-db", &prefixes),
|
||||
Ok(())
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
|
||||
Err(OwnerValidationError::NoMatch)
|
||||
);
|
||||
}
|
||||
}
|
229
src/server/server_loop.rs
Normal file
229
src/server/server_loop.rs
Normal file
@ -0,0 +1,229 @@
|
||||
use std::{collections::BTreeSet, fs, path::PathBuf};
|
||||
|
||||
use anyhow::Context;
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
|
||||
use sqlx::prelude::*;
|
||||
use sqlx::MySqlConnection;
|
||||
|
||||
use crate::{
|
||||
core::{
|
||||
bootstrap::authenticated_unix_socket,
|
||||
common::{UnixUser, DEFAULT_SOCKET_PATH},
|
||||
protocol::request_response::{
|
||||
create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream,
|
||||
},
|
||||
},
|
||||
server::{
|
||||
config::{create_mysql_connection_from_config, ServerConfig},
|
||||
sql::{
|
||||
database_operations::{create_databases, drop_databases, list_databases_for_user},
|
||||
database_privilege_operations::{
|
||||
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
|
||||
},
|
||||
user_operations::{
|
||||
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
|
||||
list_database_users, lock_database_users, set_password_for_database_user,
|
||||
unlock_database_users,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
// TODO: consider using a connection pool
|
||||
|
||||
// TODO: use tracing for login, so we can scope the log messages per incoming connection
|
||||
|
||||
pub async fn listen_for_incoming_connections(
|
||||
socket_path: Option<PathBuf>,
|
||||
config: ServerConfig,
|
||||
// db_connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let socket_path = socket_path.unwrap_or(PathBuf::from(DEFAULT_SOCKET_PATH));
|
||||
|
||||
let parent_directory = socket_path.parent().unwrap();
|
||||
if !parent_directory.exists() {
|
||||
println!("Creating directory {:?}", parent_directory);
|
||||
fs::create_dir_all(parent_directory)?;
|
||||
}
|
||||
|
||||
println!("Listening on {:?}", socket_path);
|
||||
match fs::remove_file(socket_path.as_path()) {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
|
||||
let listener = UnixListener::bind(socket_path)?;
|
||||
|
||||
while let Ok((mut conn, _addr)) = listener.accept().await {
|
||||
let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await {
|
||||
Ok(uid) => uid,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to authenticate client: {}", e);
|
||||
conn.shutdown().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let unix_user = match UnixUser::from_uid(uid.into()) {
|
||||
Ok(user) => user,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to get UnixUser from uid: {}", e);
|
||||
conn.shutdown().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match handle_requests_for_single_session(conn, &unix_user, &config).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to run server: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_requests_for_single_session(
|
||||
socket: UnixStream,
|
||||
unix_user: &UnixUser,
|
||||
config: &ServerConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
let message_stream = create_server_to_client_message_stream(socket);
|
||||
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?;
|
||||
|
||||
let result = handle_requests_for_single_session_with_db_connection(
|
||||
message_stream,
|
||||
unix_user,
|
||||
&mut db_connection,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = db_connection
|
||||
.close()
|
||||
.await
|
||||
.context("Failed to close connection properly")
|
||||
{
|
||||
eprintln!("{}", e);
|
||||
eprintln!("Ignoring...");
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// TODO: ensure proper db_connection hygiene for functions that invoke
|
||||
// this function
|
||||
|
||||
pub async fn handle_requests_for_single_session_with_db_connection(
|
||||
mut stream: ServerToClientMessageStream,
|
||||
unix_user: &UnixUser,
|
||||
db_connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
// TODO: better error handling
|
||||
let request = match stream.next().await {
|
||||
Some(Ok(request)) => request,
|
||||
Some(Err(e)) => return Err(e.into()),
|
||||
None => {
|
||||
log::warn!("Client disconnected without sending an exit message");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
match request {
|
||||
Request::CreateDatabases(databases_names) => {
|
||||
let result = create_databases(databases_names, unix_user, db_connection).await;
|
||||
stream.send(Response::CreateDatabases(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::DropDatabases(databases_names) => {
|
||||
let result = drop_databases(databases_names, unix_user, db_connection).await;
|
||||
stream.send(Response::DropDatabases(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::ListDatabases => {
|
||||
let result = list_databases_for_user(unix_user, db_connection).await;
|
||||
stream.send(Response::ListAllDatabases(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::ListPrivileges(database_names) => {
|
||||
let response = match database_names {
|
||||
Some(database_names) => {
|
||||
let privilege_data =
|
||||
get_databases_privilege_data(database_names, unix_user, db_connection)
|
||||
.await;
|
||||
Response::ListPrivileges(privilege_data)
|
||||
}
|
||||
None => {
|
||||
let privilege_data =
|
||||
get_all_database_privileges(unix_user, db_connection).await;
|
||||
Response::ListAllPrivileges(privilege_data)
|
||||
}
|
||||
};
|
||||
|
||||
stream.send(response).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::ModifyPrivileges(database_privilege_diffs) => {
|
||||
let result = apply_privilege_diffs(
|
||||
BTreeSet::from_iter(database_privilege_diffs),
|
||||
unix_user,
|
||||
db_connection,
|
||||
)
|
||||
.await;
|
||||
stream.send(Response::ModifyPrivileges(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::CreateUsers(db_users) => {
|
||||
let result = create_database_users(db_users, unix_user, db_connection).await;
|
||||
stream.send(Response::CreateUsers(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::DropUsers(db_users) => {
|
||||
let result = drop_database_users(db_users, unix_user, db_connection).await;
|
||||
stream.send(Response::DropUsers(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::PasswdUser(db_user, password) => {
|
||||
let result =
|
||||
set_password_for_database_user(&db_user, &password, unix_user, db_connection)
|
||||
.await;
|
||||
stream.send(Response::PasswdUser(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::ListUsers(db_users) => {
|
||||
let response = match db_users {
|
||||
Some(db_users) => {
|
||||
let result = list_database_users(db_users, unix_user, db_connection).await;
|
||||
Response::ListUsers(result)
|
||||
}
|
||||
None => {
|
||||
let result =
|
||||
list_all_database_users_for_unix_user(unix_user, db_connection).await;
|
||||
Response::ListAllUsers(result)
|
||||
}
|
||||
};
|
||||
stream.send(response).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::LockUsers(db_users) => {
|
||||
let result = lock_database_users(db_users, unix_user, db_connection).await;
|
||||
stream.send(Response::LockUsers(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::UnlockUsers(db_users) => {
|
||||
let result = unlock_database_users(db_users, unix_user, db_connection).await;
|
||||
stream.send(Response::UnlockUsers(result)).await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
Request::Exit => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
3
src/server/sql.rs
Normal file
3
src/server/sql.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod database_operations;
|
||||
pub mod database_privilege_operations;
|
||||
pub mod user_operations;
|
165
src/server/sql/database_operations.rs
Normal file
165
src/server/sql/database_operations.rs
Normal file
@ -0,0 +1,165 @@
|
||||
use crate::{
|
||||
core::{
|
||||
common::UnixUser,
|
||||
protocol::{
|
||||
CreateDatabaseError, CreateDatabasesOutput, DropDatabaseError, DropDatabasesOutput,
|
||||
ListDatabasesError,
|
||||
},
|
||||
},
|
||||
server::{
|
||||
common::create_user_group_matching_regex,
|
||||
input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
|
||||
},
|
||||
};
|
||||
|
||||
use sqlx::prelude::*;
|
||||
|
||||
use sqlx::MySqlConnection;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
// NOTE: this function is unsafe because it does no input validation.
|
||||
pub(super) async fn unsafe_database_exists(
|
||||
database_name: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
let result =
|
||||
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
|
||||
.bind(database_name)
|
||||
.fetch_optional(connection)
|
||||
.await?;
|
||||
|
||||
Ok(result.is_some())
|
||||
}
|
||||
|
||||
pub async fn create_databases(
|
||||
database_names: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> CreateDatabasesOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
if let Err(err) = validate_name(&database_name) {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(CreateDatabaseError::SanitizationError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(CreateDatabaseError::OwnershipError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_database_exists(&database_name, &mut *connection).await {
|
||||
Ok(true) => {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(CreateDatabaseError::DatabaseAlreadyExists),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(CreateDatabaseError::MySqlError(err.to_string())),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result =
|
||||
sqlx::query(format!("CREATE DATABASE {}", quote_identifier(&database_name)).as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
|
||||
|
||||
results.insert(database_name, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn drop_databases(
|
||||
database_names: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> DropDatabasesOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
if let Err(err) = validate_name(&database_name) {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(DropDatabaseError::SanitizationError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&database_name, unix_user) {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(DropDatabaseError::OwnershipError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_database_exists(&database_name, &mut *connection).await {
|
||||
Ok(false) => {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(DropDatabaseError::DatabaseDoesNotExist),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(DropDatabaseError::MySqlError(err.to_string())),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result =
|
||||
sqlx::query(format!("DROP DATABASE {}", quote_identifier(&database_name)).as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
|
||||
|
||||
results.insert(database_name, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn list_databases_for_user(
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<Vec<String>, ListDatabasesError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
SELECT `SCHEMA_NAME` AS `database`
|
||||
FROM `information_schema`.`SCHEMATA`
|
||||
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
||||
AND `SCHEMA_NAME` REGEXP ?
|
||||
"#,
|
||||
)
|
||||
.bind(create_user_group_matching_regex(unix_user))
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
.and_then(|rows| {
|
||||
rows.into_iter()
|
||||
.map(|row| row.try_get::<String, _>("database"))
|
||||
.collect::<Result<Vec<String>, sqlx::Error>>()
|
||||
})
|
||||
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
|
||||
}
|
452
src/server/sql/database_privilege_operations.rs
Normal file
452
src/server/sql/database_privilege_operations.rs
Normal file
@ -0,0 +1,452 @@
|
||||
// TODO: fix comment
|
||||
//! Database privilege operations
|
||||
//!
|
||||
//! This module contains functions for querying, modifying,
|
||||
//! displaying and comparing database privileges.
|
||||
//!
|
||||
//! A lot of the complexity comes from two core components:
|
||||
//!
|
||||
//! - The privilege editor that needs to be able to print
|
||||
//! an editable table of privileges and reparse the content
|
||||
//! after the user has made manual changes.
|
||||
//!
|
||||
//! - The comparison functionality that tells the user what
|
||||
//! changes will be made when applying a set of changes
|
||||
//! to the list of database privileges.
|
||||
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use indoc::indoc;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
|
||||
|
||||
use crate::{
|
||||
core::{
|
||||
common::{rev_yn, yn, UnixUser},
|
||||
database_privileges::{DatabasePrivilegeChange, DatabasePrivilegesDiff},
|
||||
protocol::{
|
||||
DiffDoesNotApplyError, GetAllDatabasesPrivilegeData, GetAllDatabasesPrivilegeDataError,
|
||||
GetDatabasesPrivilegeData, GetDatabasesPrivilegeDataError,
|
||||
ModifyDatabasePrivilegesError, ModifyDatabasePrivilegesOutput,
|
||||
},
|
||||
},
|
||||
server::{
|
||||
common::create_user_group_matching_regex,
|
||||
input_sanitization::{quote_identifier, validate_name, validate_ownership_by_unix_user},
|
||||
sql::database_operations::unsafe_database_exists,
|
||||
},
|
||||
};
|
||||
|
||||
/// This is the list of fields that are used to fetch the db + user + privileges
|
||||
/// from the `db` table in the database. If you need to add or remove privilege
|
||||
/// fields, this is a good place to start.
|
||||
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
|
||||
"db",
|
||||
"user",
|
||||
"select_priv",
|
||||
"insert_priv",
|
||||
"update_priv",
|
||||
"delete_priv",
|
||||
"create_priv",
|
||||
"drop_priv",
|
||||
"alter_priv",
|
||||
"index_priv",
|
||||
"create_tmp_table_priv",
|
||||
"lock_tables_priv",
|
||||
"references_priv",
|
||||
];
|
||||
|
||||
// NOTE: ord is needed for BTreeSet to accept the type, but it
|
||||
// doesn't have any natural implementation semantics.
|
||||
|
||||
/// This struct represents the set of privileges for a single user on a single database.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct DatabasePrivilegeRow {
|
||||
pub db: String,
|
||||
pub user: String,
|
||||
pub select_priv: bool,
|
||||
pub insert_priv: bool,
|
||||
pub update_priv: bool,
|
||||
pub delete_priv: bool,
|
||||
pub create_priv: bool,
|
||||
pub drop_priv: bool,
|
||||
pub alter_priv: bool,
|
||||
pub index_priv: bool,
|
||||
pub create_tmp_table_priv: bool,
|
||||
pub lock_tables_priv: bool,
|
||||
pub references_priv: bool,
|
||||
}
|
||||
|
||||
impl DatabasePrivilegeRow {
|
||||
pub fn get_privilege_by_name(&self, name: &str) -> bool {
|
||||
match name {
|
||||
"select_priv" => self.select_priv,
|
||||
"insert_priv" => self.insert_priv,
|
||||
"update_priv" => self.update_priv,
|
||||
"delete_priv" => self.delete_priv,
|
||||
"create_priv" => self.create_priv,
|
||||
"drop_priv" => self.drop_priv,
|
||||
"alter_priv" => self.alter_priv,
|
||||
"index_priv" => self.index_priv,
|
||||
"create_tmp_table_priv" => self.create_tmp_table_priv,
|
||||
"lock_tables_priv" => self.lock_tables_priv,
|
||||
"references_priv" => self.references_priv,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
|
||||
let field = DATABASE_PRIVILEGE_FIELDS[position];
|
||||
let value = row.try_get(position)?;
|
||||
match rev_yn(value) {
|
||||
Some(val) => Ok(val),
|
||||
_ => {
|
||||
log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
|
||||
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
|
||||
Ok(Self {
|
||||
db: row.try_get("db")?,
|
||||
user: row.try_get("user")?,
|
||||
select_priv: get_mysql_row_priv_field(row, 2)?,
|
||||
insert_priv: get_mysql_row_priv_field(row, 3)?,
|
||||
update_priv: get_mysql_row_priv_field(row, 4)?,
|
||||
delete_priv: get_mysql_row_priv_field(row, 5)?,
|
||||
create_priv: get_mysql_row_priv_field(row, 6)?,
|
||||
drop_priv: get_mysql_row_priv_field(row, 7)?,
|
||||
alter_priv: get_mysql_row_priv_field(row, 8)?,
|
||||
index_priv: get_mysql_row_priv_field(row, 9)?,
|
||||
create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
|
||||
lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
|
||||
references_priv: get_mysql_row_priv_field(row, 12)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: this function is unsafe because it does no input validation.
|
||||
/// Get all users + privileges for a single database.
|
||||
async fn unsafe_get_database_privileges(
|
||||
database_name: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
"SELECT {} FROM `db` WHERE `db` = ?",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
))
|
||||
.bind(database_name)
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
}
|
||||
|
||||
// NOTE: this function is unsafe because it does no input validation.
|
||||
/// Get all users + privileges for a single database-user pair.
|
||||
pub async fn unsafe_get_database_privileges_for_db_user_pair(
|
||||
database_name: &str,
|
||||
user_name: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
"SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
))
|
||||
.bind(database_name)
|
||||
.bind(user_name)
|
||||
.fetch_optional(connection)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_databases_privilege_data(
|
||||
database_names: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> GetDatabasesPrivilegeData {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names.iter() {
|
||||
if let Err(err) = validate_name(database_name) {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(GetDatabasesPrivilegeDataError::SanitizationError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(database_name, unix_user) {
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(GetDatabasesPrivilegeDataError::OwnershipError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !unsafe_database_exists(database_name, connection)
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
results.insert(
|
||||
database_name.clone(),
|
||||
Err(GetDatabasesPrivilegeDataError::DatabaseDoesNotExist),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = unsafe_get_database_privileges(database_name, connection)
|
||||
.await
|
||||
.map_err(|e| GetDatabasesPrivilegeDataError::MySqlError(e.to_string()));
|
||||
|
||||
results.insert(database_name.clone(), result);
|
||||
}
|
||||
|
||||
debug_assert!(database_names.len() == results.len());
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Get all database + user + privileges pairs that are owned by the current user.
|
||||
pub async fn get_all_database_privileges(
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> GetAllDatabasesPrivilegeData {
|
||||
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
indoc! {r#"
|
||||
SELECT {} FROM `db` WHERE `db` IN
|
||||
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
||||
FROM `information_schema`.`SCHEMATA`
|
||||
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
||||
AND `SCHEMA_NAME` REGEXP ?)
|
||||
"#},
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
))
|
||||
.bind(create_user_group_matching_regex(unix_user))
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()))
|
||||
}
|
||||
|
||||
async fn unsafe_apply_privilege_diff(
|
||||
database_privilege_diff: &DatabasePrivilegesDiff,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
match database_privilege_diff {
|
||||
DatabasePrivilegesDiff::New(p) => {
|
||||
let tables = DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(",");
|
||||
|
||||
let question_marks = std::iter::repeat("?")
|
||||
.take(DATABASE_PRIVILEGE_FIELDS.len())
|
||||
.join(",");
|
||||
|
||||
sqlx::query(
|
||||
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
|
||||
)
|
||||
.bind(p.db.to_string())
|
||||
.bind(p.user.to_string())
|
||||
.bind(yn(p.select_priv))
|
||||
.bind(yn(p.insert_priv))
|
||||
.bind(yn(p.update_priv))
|
||||
.bind(yn(p.delete_priv))
|
||||
.bind(yn(p.create_priv))
|
||||
.bind(yn(p.drop_priv))
|
||||
.bind(yn(p.alter_priv))
|
||||
.bind(yn(p.index_priv))
|
||||
.bind(yn(p.create_tmp_table_priv))
|
||||
.bind(yn(p.lock_tables_priv))
|
||||
.bind(yn(p.references_priv))
|
||||
.execute(connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
DatabasePrivilegesDiff::Modified(p) => {
|
||||
let changes = p
|
||||
.diff
|
||||
.iter()
|
||||
.map(|diff| match diff {
|
||||
DatabasePrivilegeChange::YesToNo(name) => {
|
||||
format!("{} = 'N'", quote_identifier(name))
|
||||
}
|
||||
DatabasePrivilegeChange::NoToYes(name) => {
|
||||
format!("{} = 'Y'", quote_identifier(name))
|
||||
}
|
||||
})
|
||||
.join(",");
|
||||
|
||||
sqlx::query(
|
||||
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", changes).as_str(),
|
||||
)
|
||||
.bind(p.db.to_string())
|
||||
.bind(p.user.to_string())
|
||||
.execute(connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
DatabasePrivilegesDiff::Deleted(p) => {
|
||||
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
|
||||
.bind(p.db.to_string())
|
||||
.bind(p.user.to_string())
|
||||
.execute(connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn validate_diff(
|
||||
diff: &DatabasePrivilegesDiff,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<(), ModifyDatabasePrivilegesError> {
|
||||
let privilege_row = unsafe_get_database_privileges_for_db_user_pair(
|
||||
diff.get_database_name(),
|
||||
diff.get_user_name(),
|
||||
connection,
|
||||
)
|
||||
.await;
|
||||
|
||||
let privilege_row = match privilege_row {
|
||||
Ok(privilege_row) => privilege_row,
|
||||
Err(e) => return Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
|
||||
};
|
||||
|
||||
let result = match diff {
|
||||
DatabasePrivilegesDiff::New(_) => {
|
||||
if privilege_row.is_some() {
|
||||
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
|
||||
DiffDoesNotApplyError::RowAlreadyExists(
|
||||
diff.get_user_name().to_string(),
|
||||
diff.get_database_name().to_string(),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
DatabasePrivilegesDiff::Modified(_) if privilege_row.is_none() => {
|
||||
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
|
||||
DiffDoesNotApplyError::RowDoesNotExist(
|
||||
diff.get_user_name().to_string(),
|
||||
diff.get_database_name().to_string(),
|
||||
),
|
||||
))
|
||||
}
|
||||
DatabasePrivilegesDiff::Modified(row_diff) => {
|
||||
let row = privilege_row.unwrap();
|
||||
|
||||
let error_exists = row_diff.diff.iter().any(|change| match change {
|
||||
DatabasePrivilegeChange::YesToNo(name) => !row.get_privilege_by_name(name),
|
||||
DatabasePrivilegeChange::NoToYes(name) => row.get_privilege_by_name(name),
|
||||
});
|
||||
|
||||
if error_exists {
|
||||
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
|
||||
DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(row_diff.clone(), row),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
DatabasePrivilegesDiff::Deleted(_) => {
|
||||
if privilege_row.is_none() {
|
||||
Err(ModifyDatabasePrivilegesError::DiffDoesNotApply(
|
||||
DiffDoesNotApplyError::RowDoesNotExist(
|
||||
diff.get_user_name().to_string(),
|
||||
diff.get_database_name().to_string(),
|
||||
),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
||||
pub async fn apply_privilege_diffs(
|
||||
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> ModifyDatabasePrivilegesOutput {
|
||||
let mut results: BTreeMap<(String, String), _> = BTreeMap::new();
|
||||
|
||||
for diff in database_privilege_diffs {
|
||||
let key = (
|
||||
diff.get_database_name().to_string(),
|
||||
diff.get_user_name().to_string(),
|
||||
);
|
||||
if let Err(err) = validate_name(diff.get_database_name()) {
|
||||
results.insert(
|
||||
key,
|
||||
Err(ModifyDatabasePrivilegesError::DatabaseSanitizationError(
|
||||
err,
|
||||
)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(diff.get_database_name(), unix_user) {
|
||||
results.insert(
|
||||
key,
|
||||
Err(ModifyDatabasePrivilegesError::DatabaseOwnershipError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_name(diff.get_user_name()) {
|
||||
results.insert(
|
||||
key,
|
||||
Err(ModifyDatabasePrivilegesError::UserSanitizationError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(diff.get_user_name(), unix_user) {
|
||||
results.insert(
|
||||
key,
|
||||
Err(ModifyDatabasePrivilegesError::UserOwnershipError(err)),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !unsafe_database_exists(diff.get_database_name(), connection)
|
||||
.await
|
||||
.unwrap()
|
||||
{
|
||||
results.insert(
|
||||
key,
|
||||
Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_diff(&diff, connection).await {
|
||||
results.insert(key, Err(err));
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = unsafe_apply_privilege_diff(&diff, connection)
|
||||
.await
|
||||
.map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));
|
||||
|
||||
results.insert(key, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
375
src/server/sql/user_operations.rs
Normal file
375
src/server/sql/user_operations.rs
Normal file
@ -0,0 +1,375 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use sqlx::prelude::*;
|
||||
use sqlx::MySqlConnection;
|
||||
|
||||
use crate::{
|
||||
core::{
|
||||
common::UnixUser,
|
||||
protocol::{
|
||||
CreateUserError, CreateUsersOutput, DropUserError, DropUsersOutput, ListAllUsersError,
|
||||
ListAllUsersOutput, ListUsersError, ListUsersOutput, LockUserError, LockUsersOutput,
|
||||
SetPasswordError, SetPasswordOutput, UnlockUserError, UnlockUsersOutput,
|
||||
},
|
||||
},
|
||||
server::{
|
||||
common::create_user_group_matching_regex,
|
||||
input_sanitization::{quote_literal, validate_name, validate_ownership_by_unix_user},
|
||||
},
|
||||
};
|
||||
|
||||
// NOTE: this function is unsafe because it does no input validation.
|
||||
async fn unsafe_user_exists(
|
||||
db_user: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM `mysql`.`user`
|
||||
WHERE `User` = ?
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(db_user)
|
||||
.fetch_one(connection)
|
||||
.await
|
||||
.map(|row| row.get::<bool, _>(0))
|
||||
}
|
||||
|
||||
pub async fn create_database_users(
|
||||
db_users: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> CreateUsersOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
if let Err(err) = validate_name(&db_user) {
|
||||
results.insert(db_user, Err(CreateUserError::SanitizationError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
|
||||
results.insert(db_user, Err(CreateUserError::OwnershipError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_user_exists(&db_user, &mut *connection).await {
|
||||
Ok(true) => {
|
||||
results.insert(db_user, Err(CreateUserError::UserAlreadyExists));
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(db_user, Err(CreateUserError::MySqlError(err.to_string())));
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(format!("CREATE USER {}@'%'", quote_literal(&db_user),).as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| CreateUserError::MySqlError(err.to_string()));
|
||||
|
||||
results.insert(db_user, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn drop_database_users(
|
||||
db_users: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> DropUsersOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
if let Err(err) = validate_name(&db_user) {
|
||||
results.insert(db_user, Err(DropUserError::SanitizationError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
|
||||
results.insert(db_user, Err(DropUserError::OwnershipError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_user_exists(&db_user, &mut *connection).await {
|
||||
Ok(false) => {
|
||||
results.insert(db_user, Err(DropUserError::UserDoesNotExist));
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(db_user, Err(DropUserError::MySqlError(err.to_string())));
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(format!("DROP USER {}@'%'", quote_literal(&db_user),).as_str())
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| DropUserError::MySqlError(err.to_string()));
|
||||
|
||||
results.insert(db_user, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn set_password_for_database_user(
|
||||
db_user: &str,
|
||||
password: &str,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> SetPasswordOutput {
|
||||
if let Err(err) = validate_name(db_user) {
|
||||
return Err(SetPasswordError::SanitizationError(err));
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(db_user, unix_user) {
|
||||
return Err(SetPasswordError::OwnershipError(err));
|
||||
}
|
||||
|
||||
match unsafe_user_exists(db_user, &mut *connection).await {
|
||||
Ok(false) => return Err(SetPasswordError::UserDoesNotExist),
|
||||
Err(err) => return Err(SetPasswordError::MySqlError(err.to_string())),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
sqlx::query(
|
||||
format!(
|
||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
||||
quote_literal(db_user),
|
||||
quote_literal(password).as_str()
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
|
||||
}
|
||||
|
||||
// NOTE: this function is unsafe because it does no input validation.
|
||||
async fn database_user_is_locked_unsafe(
|
||||
db_user: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> Result<bool, sqlx::Error> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
SELECT COALESCE(
|
||||
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
||||
'false'
|
||||
) != 'false'
|
||||
FROM `mysql`.`global_priv`
|
||||
WHERE `User` = ?
|
||||
AND `Host` = '%'
|
||||
"#,
|
||||
)
|
||||
.bind(db_user)
|
||||
.fetch_one(connection)
|
||||
.await
|
||||
.map(|row| row.get::<bool, _>(0))
|
||||
}
|
||||
|
||||
pub async fn lock_database_users(
|
||||
db_users: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> LockUsersOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
if let Err(err) = validate_name(&db_user) {
|
||||
results.insert(db_user, Err(LockUserError::SanitizationError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
|
||||
results.insert(db_user, Err(LockUserError::OwnershipError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_user_exists(&db_user, &mut *connection).await {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
results.insert(db_user, Err(LockUserError::UserDoesNotExist));
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
match database_user_is_locked_unsafe(&db_user, &mut *connection).await {
|
||||
Ok(false) => {}
|
||||
Ok(true) => {
|
||||
results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked));
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(db_user, Err(LockUserError::MySqlError(err.to_string())));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(&db_user),).as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| LockUserError::MySqlError(err.to_string()));
|
||||
|
||||
results.insert(db_user, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn unlock_database_users(
|
||||
db_users: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> UnlockUsersOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
if let Err(err) = validate_name(&db_user) {
|
||||
results.insert(db_user, Err(UnlockUserError::SanitizationError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
|
||||
results.insert(db_user, Err(UnlockUserError::OwnershipError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_user_exists(&db_user, &mut *connection).await {
|
||||
Ok(false) => {
|
||||
results.insert(db_user, Err(UnlockUserError::UserDoesNotExist));
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
match database_user_is_locked_unsafe(&db_user, &mut *connection).await {
|
||||
Ok(false) => {
|
||||
results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked));
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
results.insert(db_user, Err(UnlockUserError::MySqlError(err.to_string())));
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query(
|
||||
format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(&db_user),).as_str(),
|
||||
)
|
||||
.execute(&mut *connection)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
|
||||
|
||||
results.insert(db_user, result);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// This struct contains information about a database user.
|
||||
/// This can be extended if we need more information in the future.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)]
|
||||
pub struct DatabaseUser {
|
||||
#[sqlx(rename = "User")]
|
||||
pub user: String,
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[serde(skip)]
|
||||
#[sqlx(rename = "Host")]
|
||||
pub host: String,
|
||||
|
||||
#[sqlx(rename = "has_password")]
|
||||
pub has_password: bool,
|
||||
|
||||
#[sqlx(rename = "is_locked")]
|
||||
pub is_locked: bool,
|
||||
}
|
||||
|
||||
const DB_USER_SELECT_STATEMENT: &str = r#"
|
||||
SELECT
|
||||
`mysql`.`user`.`User`,
|
||||
`mysql`.`user`.`Host`,
|
||||
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
|
||||
COALESCE(
|
||||
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
||||
'false'
|
||||
) != 'false' AS `is_locked`
|
||||
FROM `mysql`.`user`
|
||||
JOIN `mysql`.`global_priv` ON
|
||||
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
|
||||
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
|
||||
"#;
|
||||
|
||||
pub async fn list_database_users(
|
||||
db_users: Vec<String>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> ListUsersOutput {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
if let Err(err) = validate_name(&db_user) {
|
||||
results.insert(db_user, Err(ListUsersError::SanitizationError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(err) = validate_ownership_by_unix_user(&db_user, unix_user) {
|
||||
results.insert(db_user, Err(ListUsersError::OwnershipError(err)));
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = sqlx::query_as::<_, DatabaseUser>(
|
||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
|
||||
)
|
||||
.bind(&db_user)
|
||||
.fetch_optional(&mut *connection)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(user)) => results.insert(db_user, Ok(user)),
|
||||
Ok(None) => results.insert(db_user, Err(ListUsersError::UserDoesNotExist)),
|
||||
Err(err) => results.insert(db_user, Err(ListUsersError::MySqlError(err.to_string()))),
|
||||
};
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
pub async fn list_all_database_users_for_unix_user(
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> ListAllUsersOutput {
|
||||
sqlx::query_as::<_, DatabaseUser>(
|
||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
|
||||
)
|
||||
.bind(create_user_group_matching_regex(unix_user))
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()))
|
||||
}
|
Reference in New Issue
Block a user