Rewrite entire codebase to split into client and server
This commit was merged in pull request #55.
This commit is contained in:
177
src/core/bootstrap.rs
Normal file
177
src/core/bootstrap.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use std::{fs, path::PathBuf};
|
||||
|
||||
use anyhow::Context;
|
||||
use nix::libc::{exit, EXIT_SUCCESS};
|
||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||
use tokio::net::UnixStream as TokioUnixStream;
|
||||
|
||||
use crate::{
|
||||
core::{
|
||||
bootstrap::authenticated_unix_socket::client_authenticate,
|
||||
common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH},
|
||||
},
|
||||
server::{config::read_config_form_path, server_loop::handle_requests_for_single_session},
|
||||
};
|
||||
|
||||
pub mod authenticated_unix_socket;
|
||||
|
||||
// TODO: this function is security critical, it should be integration tested
|
||||
// in isolation.
|
||||
/// Drop privileges to the real user and group of the process.
|
||||
/// If the process is not running with elevated privileges, this function
|
||||
/// is a no-op.
|
||||
pub fn drop_privs() -> anyhow::Result<()> {
|
||||
log::debug!("Dropping privileges");
|
||||
let real_uid = nix::unistd::getuid();
|
||||
let real_gid = nix::unistd::getgid();
|
||||
|
||||
nix::unistd::setuid(real_uid).context("Failed to drop privileges")?;
|
||||
nix::unistd::setgid(real_gid).context("Failed to drop privileges")?;
|
||||
|
||||
debug_assert_eq!(nix::unistd::getuid(), real_uid);
|
||||
debug_assert_eq!(nix::unistd::getgid(), real_gid);
|
||||
|
||||
log::debug!("Privileges dropped successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This function is used to bootstrap the connection to the server.
|
||||
/// This can happen in two ways:
|
||||
/// 1. If a socket path is provided, or exists in the default location,
|
||||
/// the function will connect to the socket and authenticate with the
|
||||
/// server to ensure that the server knows the uid of the client.
|
||||
/// 2. If a config path is provided, or exists in the default location,
|
||||
/// and the config is readable, the function will assume it is either
|
||||
/// setuid or setgid, and will fork a child process to run the server
|
||||
/// with the provided config. The server will exit silently by itself
|
||||
/// when it is done, and this function will only return for the client
|
||||
/// with the socket for the server.
|
||||
/// If neither of these options are available, the function will fail.
|
||||
pub fn bootstrap_server_connection_and_drop_privileges(
|
||||
server_socket_path: Option<PathBuf>,
|
||||
config_path: Option<PathBuf>,
|
||||
) -> anyhow::Result<StdUnixStream> {
|
||||
if server_socket_path.is_some() && config_path.is_some() {
|
||||
anyhow::bail!("Cannot provide both a socket path and a config path");
|
||||
}
|
||||
|
||||
log::debug!("Starting the server connection bootstrap process");
|
||||
|
||||
let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?;
|
||||
|
||||
drop_privs()?;
|
||||
|
||||
let result: anyhow::Result<StdUnixStream> = if do_authenticate {
|
||||
tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
let mut socket = TokioUnixStream::from_std(socket)?;
|
||||
client_authenticate(&mut socket, None).await?;
|
||||
Ok(socket.into_std()?)
|
||||
})
|
||||
} else {
|
||||
Ok(socket)
|
||||
};
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Inner function for [`bootstrap_server_connection_and_drop_privileges`].
|
||||
/// See that function for more information.
|
||||
fn bootstrap_server_connection(
|
||||
socket_path: Option<PathBuf>,
|
||||
config_path: Option<PathBuf>,
|
||||
) -> anyhow::Result<(StdUnixStream, bool)> {
|
||||
// TODO: ensure this is both readable and writable
|
||||
if let Some(socket_path) = socket_path {
|
||||
log::debug!("Connecting to socket at {:?}", socket_path);
|
||||
return match StdUnixStream::connect(socket_path) {
|
||||
Ok(socket) => Ok((socket, true)),
|
||||
Err(e) => match e.kind() {
|
||||
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
|
||||
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
|
||||
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
|
||||
},
|
||||
};
|
||||
}
|
||||
if let Some(config_path) = config_path {
|
||||
// ensure config exists and is readable
|
||||
if fs::metadata(&config_path).is_err() {
|
||||
return Err(anyhow::anyhow!("Config file not found or not readable"));
|
||||
}
|
||||
|
||||
log::debug!("Starting server with config at {:?}", config_path);
|
||||
return invoke_server_with_config(config_path).map(|socket| (socket, false));
|
||||
}
|
||||
|
||||
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
|
||||
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
|
||||
Ok(socket) => Ok((socket, true)),
|
||||
Err(e) => match e.kind() {
|
||||
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
|
||||
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
|
||||
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
|
||||
if fs::metadata(&config_path).is_ok() {
|
||||
return invoke_server_with_config(config_path).map(|socket| (socket, false));
|
||||
}
|
||||
|
||||
anyhow::bail!("No socket path or config path provided, and no default socket or config found");
|
||||
}
|
||||
|
||||
// TODO: we should somehow ensure that the forked process is killed on completion,
|
||||
// just in case the client does not behave properly.
|
||||
/// Fork a child process to run the server with the provided config.
|
||||
/// The server will exit silently by itself when it is done, and this function
|
||||
/// will only return for the client with the socket for the server.
|
||||
fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStream> {
|
||||
let (server_socket, client_socket) = StdUnixStream::pair()?;
|
||||
let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?;
|
||||
|
||||
match (unsafe { nix::unistd::fork() }).context("Failed to fork")? {
|
||||
nix::unistd::ForkResult::Parent { child } => {
|
||||
log::debug!("Forked child process with PID {}", child);
|
||||
Ok(client_socket)
|
||||
}
|
||||
nix::unistd::ForkResult::Child => {
|
||||
log::debug!("Running server in child process");
|
||||
|
||||
match run_forked_server(config_path, server_socket, unix_user) {
|
||||
Err(e) => Err(e),
|
||||
Ok(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the server in the forked child process.
|
||||
/// This function will not return, but will exit the process with a success code.
|
||||
fn run_forked_server(
|
||||
config_path: PathBuf,
|
||||
server_socket: StdUnixStream,
|
||||
unix_user: UnixUser,
|
||||
) -> anyhow::Result<()> {
|
||||
let config = read_config_form_path(Some(config_path))?;
|
||||
|
||||
let result: anyhow::Result<()> = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
let socket = TokioUnixStream::from_std(server_socket)?;
|
||||
handle_requests_for_single_session(socket, &unix_user, &config).await?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
result?;
|
||||
|
||||
unsafe {
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
}
|
||||
439
src/core/bootstrap/authenticated_unix_socket.rs
Normal file
439
src/core/bootstrap/authenticated_unix_socket.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
//! This module provides a way to authenticate a client uid to a server over a Unix socket.
|
||||
//! This is needed so that the server can trust the client's uid, which it depends on to
|
||||
//! make modifications for that user in the database. It is crucial that the server can trust
|
||||
//! that the client is the unix user it claims to be.
|
||||
//!
|
||||
//! This works by having the client respond to a challenge on a socket that is verifiably owned
|
||||
//! by the client. In more detailed steps, the following should happen:
|
||||
//!
|
||||
//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644
|
||||
//! and owned by the uid of the current user.
|
||||
//! 2. The client opens a request to the server on the "normal" socket where the server is listening,
|
||||
//! In this request, the client should include the following:
|
||||
//! - The address of it's authentication socket
|
||||
//! - The uid of the user currently using the client
|
||||
//! - A challenge string that has been randomly generated
|
||||
//! 3. The server validates the following:
|
||||
//! - The address of the auth socket is valid
|
||||
//! - The owner of the auth socket address is the same as the uid
|
||||
//! 4. Server connects to the auth socket address and receives another challenge string.
|
||||
//! The server should close the connection after receiving the challenge string.
|
||||
//! 5. Server verifies that the challenge is the same as the one it originally received.
|
||||
//! It responds to the client with an "Authenticated" message if the challenge matches,
|
||||
//! or an error message if it does not.
|
||||
//! 6. Client closes the authentication socket. Normal socket is used for communication.
|
||||
//!
|
||||
//! Note that the server can at any point in the process send an error message to the client
|
||||
//! over it's initial connection, and the client should respond by closing the authentication
|
||||
//! socket, it's connection to the normal socket, and reporting the error to the user.
|
||||
//!
|
||||
//! Also note that it is essential that the client does not send any sensitive information
|
||||
//! over it's authentication socket, since it is readable by any user on the system.
|
||||
|
||||
// TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream
|
||||
|
||||
use std::os::unix::io::AsRawFd;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
|
||||
use derive_more::derive::{Display, Error};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use nix::{sys::stat, unistd::Uid};
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ClientRequest {
|
||||
Initialize {
|
||||
uid: u32,
|
||||
challenge: u64,
|
||||
auth_socket: String,
|
||||
},
|
||||
Cancel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)]
|
||||
pub enum ServerResponse {
|
||||
Authenticated,
|
||||
ChallengeDidNotMatch,
|
||||
ServerError(ServerError),
|
||||
}
|
||||
|
||||
// TODO: wrap more data into the errors
|
||||
|
||||
#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)]
|
||||
pub enum ServerError {
|
||||
InvalidRequest,
|
||||
UnableToReadPermissionsFromAuthSocket,
|
||||
CouldNotConnectToAuthSocket,
|
||||
AuthSocketClosedEarly,
|
||||
UidMismatch,
|
||||
ChallengeMismatch,
|
||||
InvalidChallenge,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Display, Error)]
|
||||
pub enum ClientError {
|
||||
UnableToConnectToServer,
|
||||
UnableToOpenAuthSocket,
|
||||
UnableToConfigureAuthSocket,
|
||||
AuthSocketClosedEarly,
|
||||
UnableToCloseAuthSocket,
|
||||
AuthenticationError,
|
||||
UnableToParseServerResponse,
|
||||
NoServerResponse,
|
||||
ServerError(ServerError),
|
||||
}
|
||||
|
||||
async fn create_auth_socket(socket_addr: &PathBuf) -> Result<UnixListener, ClientError> {
|
||||
let auth_socket =
|
||||
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
|
||||
|
||||
stat::fchmod(
|
||||
auth_socket.as_raw_fd(),
|
||||
stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP,
|
||||
)
|
||||
.map_err(|_err| ClientError::UnableToConfigureAuthSocket)?;
|
||||
|
||||
Ok(auth_socket)
|
||||
}
|
||||
|
||||
type ClientToServerStream<'a> =
|
||||
AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>;
|
||||
type ServerToClientStream<'a> =
|
||||
AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>;
|
||||
|
||||
// TODO: make the challenge constant size and use socket directly, this is overkill
|
||||
type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>;
|
||||
|
||||
// TODO: add timeout
|
||||
|
||||
// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR
|
||||
|
||||
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
|
||||
|
||||
pub async fn client_authenticate(
|
||||
normal_socket: &mut UnixStream,
|
||||
auth_socket_dir: Option<PathBuf>,
|
||||
) -> Result<(), ClientError> {
|
||||
let random_prefix: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(16)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
|
||||
|
||||
let auth_socket_address = auth_socket_dir
|
||||
.unwrap_or(std::env::temp_dir())
|
||||
.join(socket_name);
|
||||
|
||||
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
|
||||
}
|
||||
|
||||
async fn client_authenticate_with_auth_socket_address(
|
||||
normal_socket: &mut UnixStream,
|
||||
auth_socket_address: &PathBuf,
|
||||
) -> Result<(), ClientError> {
|
||||
let auth_socket = create_auth_socket(auth_socket_address).await?;
|
||||
|
||||
let result =
|
||||
client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await;
|
||||
|
||||
std::fs::remove_file(auth_socket_address)
|
||||
.map_err(|_err| ClientError::UnableToCloseAuthSocket)?;
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn client_authenticate_with_auth_socket(
|
||||
normal_socket: &mut UnixStream,
|
||||
auth_socket: UnixListener,
|
||||
auth_socket_address: &Path,
|
||||
) -> Result<(), ClientError> {
|
||||
let challenge = rand::random::<u64>();
|
||||
let uid = nix::unistd::getuid();
|
||||
|
||||
let mut normal_socket: ClientToServerStream =
|
||||
AsyncBincodeStream::from(normal_socket).for_async();
|
||||
|
||||
let challenge_replier_cancellation_token = CancellationToken::new();
|
||||
let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone();
|
||||
let challenge_replier_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
socket = auth_socket.accept() =>
|
||||
{
|
||||
match socket {
|
||||
Ok((mut conn, _addr)) => {
|
||||
let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async();
|
||||
stream.send(challenge).await.ok();
|
||||
stream.close().await.ok();
|
||||
}
|
||||
Err(_err) => return Err(ClientError::AuthSocketClosedEarly),
|
||||
}
|
||||
}
|
||||
|
||||
_ = challenge_replier_cancellation_token_clone.cancelled() => {
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let client_hello = ClientRequest::Initialize {
|
||||
uid: uid.into(),
|
||||
challenge,
|
||||
auth_socket: auth_socket_address
|
||||
.to_str()
|
||||
.ok_or(ClientError::UnableToConfigureAuthSocket)?
|
||||
.to_owned(),
|
||||
};
|
||||
|
||||
normal_socket
|
||||
.send(client_hello)
|
||||
.await
|
||||
.map_err(|err| match *err {
|
||||
bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer,
|
||||
_ => ClientError::NoServerResponse,
|
||||
})?;
|
||||
|
||||
match normal_socket.next().await {
|
||||
Some(Ok(ServerResponse::Authenticated)) => {}
|
||||
Some(Ok(ServerResponse::ChallengeDidNotMatch)) => {
|
||||
return Err(ClientError::AuthenticationError)
|
||||
}
|
||||
Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)),
|
||||
Some(Err(err)) => match *err {
|
||||
bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse),
|
||||
_ => return Err(ClientError::UnableToParseServerResponse),
|
||||
},
|
||||
None => return Err(ClientError::NoServerResponse),
|
||||
}
|
||||
|
||||
challenge_replier_cancellation_token.cancel();
|
||||
challenge_replier_handle.await.ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
macro_rules! report_server_error_and_return {
|
||||
($normal_socket:expr, $err:expr) => {{
|
||||
$normal_socket
|
||||
.send(ServerResponse::ServerError($err))
|
||||
.await
|
||||
.ok();
|
||||
return Err($err);
|
||||
}};
|
||||
}
|
||||
|
||||
pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result<Uid, ServerError> {
|
||||
_server_authenticate(normal_socket, None).await
|
||||
}
|
||||
|
||||
pub async fn _server_authenticate(
|
||||
normal_socket: &mut UnixStream,
|
||||
unix_user_uid: Option<u32>,
|
||||
) -> Result<Uid, ServerError> {
|
||||
let mut normal_socket: ServerToClientStream =
|
||||
AsyncBincodeStream::from(normal_socket).for_async();
|
||||
|
||||
let (uid, challenge, auth_socket) = match normal_socket.next().await {
|
||||
Some(Ok(ClientRequest::Initialize {
|
||||
uid,
|
||||
challenge,
|
||||
auth_socket,
|
||||
})) => (uid, challenge, auth_socket),
|
||||
// TODO: more granular errros
|
||||
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
|
||||
};
|
||||
|
||||
let auth_socket_uid = match unix_user_uid {
|
||||
Some(uid) => uid,
|
||||
None => match stat::stat(auth_socket.as_str()) {
|
||||
Ok(stat) => stat.st_uid,
|
||||
Err(_err) => report_server_error_and_return!(
|
||||
normal_socket,
|
||||
ServerError::UnableToReadPermissionsFromAuthSocket
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
if uid != auth_socket_uid {
|
||||
report_server_error_and_return!(normal_socket, ServerError::UidMismatch);
|
||||
}
|
||||
|
||||
let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await {
|
||||
Ok(socket) => socket,
|
||||
Err(_err) => {
|
||||
report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket)
|
||||
}
|
||||
};
|
||||
let mut authenticated_unix_socket: AuthStream =
|
||||
AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async();
|
||||
|
||||
let challenge_2 = match authenticated_unix_socket.next().await {
|
||||
Some(Ok(challenge)) => challenge,
|
||||
Some(Err(_)) => {
|
||||
report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge)
|
||||
}
|
||||
None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly),
|
||||
};
|
||||
|
||||
authenticated_unix_socket.close().await.ok();
|
||||
|
||||
if challenge != challenge_2 {
|
||||
normal_socket
|
||||
.send(ServerResponse::ChallengeDidNotMatch)
|
||||
.await
|
||||
.ok();
|
||||
return Err(ServerError::ChallengeMismatch);
|
||||
}
|
||||
|
||||
normal_socket.send(ServerResponse::Authenticated).await.ok();
|
||||
|
||||
Ok(Uid::from_raw(uid))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_valid_authentication() {
|
||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||
|
||||
let client_handle =
|
||||
tokio::spawn(async move { client_authenticate(&mut client, None).await });
|
||||
|
||||
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
||||
|
||||
client_handle.await.unwrap().unwrap();
|
||||
server_handle.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ensure_auth_socket_does_not_exist() {
|
||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||
|
||||
let client_handle = tokio::spawn(async move {
|
||||
client_authenticate_with_auth_socket_address(
|
||||
&mut client,
|
||||
&PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
||||
|
||||
client_handle.await.unwrap().unwrap();
|
||||
server_handle.await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_uid_mismatch() {
|
||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||
|
||||
let client_handle = tokio::spawn(async move {
|
||||
let err = client_authenticate(&mut client, None).await;
|
||||
assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch)));
|
||||
});
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
let uid: u32 = nix::unistd::getuid().into();
|
||||
let err = _server_authenticate(&mut server, Some(uid + 1)).await;
|
||||
assert_eq!(err, Err(ServerError::UidMismatch));
|
||||
});
|
||||
|
||||
client_handle.await.unwrap();
|
||||
server_handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_snooping_connection() {
|
||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
||||
|
||||
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
|
||||
let socket_path_clone = socket_path.clone();
|
||||
let client_handle = tokio::spawn(async move {
|
||||
client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await
|
||||
});
|
||||
|
||||
for i in 0..100 {
|
||||
if socket_path.exists() {
|
||||
break;
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
|
||||
if i == 99 {
|
||||
panic!("Socket not created after 1 second, assuming test failure");
|
||||
}
|
||||
}
|
||||
|
||||
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||
let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async();
|
||||
let message = snooper.next().await.unwrap().unwrap();
|
||||
|
||||
let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||
let mut other_snooper: AuthStream =
|
||||
AsyncBincodeStream::from(&mut other_snooper).for_async();
|
||||
let other_message = other_snooper.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(message, other_message);
|
||||
|
||||
let third_snooper_handle = tokio::spawn(async move {
|
||||
let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
||||
let mut third_snooper: AuthStream =
|
||||
AsyncBincodeStream::from(&mut third_snooper).for_async();
|
||||
// NOTE: Should hang
|
||||
third_snooper.send(1234).await.unwrap()
|
||||
});
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
|
||||
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
||||
|
||||
client_handle.await.unwrap().unwrap();
|
||||
server_handle.await.unwrap().unwrap();
|
||||
|
||||
third_snooper_handle.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dead_server() {
|
||||
let (mut client, server) = UnixStream::pair().unwrap();
|
||||
std::mem::drop(server);
|
||||
|
||||
let client_handle = tokio::spawn(async move {
|
||||
let err = client_authenticate(&mut client, None).await;
|
||||
assert_eq!(err, Err(ClientError::UnableToConnectToServer));
|
||||
});
|
||||
|
||||
client_handle.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_no_response_from_server() {
|
||||
let (mut client, server) = UnixStream::pair().unwrap();
|
||||
|
||||
let client_handle = tokio::spawn(async move {
|
||||
let err = client_authenticate(&mut client, None).await;
|
||||
assert_eq!(err, Err(ClientError::NoServerResponse));
|
||||
});
|
||||
|
||||
sleep(Duration::from_millis(200)).await;
|
||||
|
||||
std::mem::drop(server);
|
||||
|
||||
client_handle.await.unwrap();
|
||||
}
|
||||
|
||||
// TODO: Test challenge mismatch
|
||||
// TODO: Test invoking server with junk data
|
||||
}
|
||||
@@ -1,56 +1,32 @@
|
||||
use anyhow::Context;
|
||||
use indoc::indoc;
|
||||
use itertools::Itertools;
|
||||
use nix::unistd::{getuid, Group, User};
|
||||
use sqlx::{Connection, MySqlConnection};
|
||||
use nix::unistd::{Group as LibcGroup, User as LibcUser};
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
use std::ffi::CString;
|
||||
|
||||
/// Report the result status of a command.
|
||||
/// This is used to display a status message to the user.
|
||||
pub enum CommandStatus {
|
||||
/// The command was successful,
|
||||
/// and made modification to the database.
|
||||
SuccessfullyModified,
|
||||
pub const DEFAULT_CONFIG_PATH: &str = "/etc/mysqladm/config.toml";
|
||||
pub const DEFAULT_SOCKET_PATH: &str = "/run/mysqladm/mysqladm.sock";
|
||||
|
||||
/// The command was mostly successful,
|
||||
/// and modifications have been made to the database.
|
||||
/// However, some of the requested modifications failed.
|
||||
PartiallySuccessfullyModified,
|
||||
|
||||
/// The command was successful,
|
||||
/// but no modifications were needed.
|
||||
NoModificationsNeeded,
|
||||
|
||||
/// The command was successful,
|
||||
/// and made no modification to the database.
|
||||
NoModificationsIntended,
|
||||
|
||||
/// The command was cancelled, either through a dialog or a signal.
|
||||
/// No modifications have been made to the database.
|
||||
Cancelled,
|
||||
pub struct UnixUser {
|
||||
pub username: String,
|
||||
pub groups: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn get_current_unix_user() -> anyhow::Result<User> {
|
||||
User::from_uid(getuid())
|
||||
.context("Failed to look up your UNIX username")
|
||||
.and_then(|u| u.ok_or(anyhow::anyhow!("Failed to look up your UNIX username")))
|
||||
}
|
||||
// TODO: these functions are somewhat critical, and should have integration tests
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub fn get_unix_groups(_user: &User) -> anyhow::Result<Vec<Group>> {
|
||||
fn get_unix_groups(_user: &LibcUser) -> anyhow::Result<Vec<LibcGroup>> {
|
||||
// Return an empty list on macOS since there is no `getgrouplist` function
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> {
|
||||
fn get_unix_groups(user: &LibcUser) -> anyhow::Result<Vec<LibcGroup>> {
|
||||
let user_cstr =
|
||||
CString::new(user.name.as_bytes()).context("Failed to convert username to CStr")?;
|
||||
let groups = nix::unistd::getgrouplist(&user_cstr, user.gid)?
|
||||
.iter()
|
||||
.filter_map(|gid| match Group::from_gid(*gid) {
|
||||
.filter_map(|gid| match LibcGroup::from_gid(*gid) {
|
||||
Ok(Some(group)) => Some(group),
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
@@ -62,211 +38,32 @@ pub fn get_unix_groups(user: &User) -> anyhow::Result<Vec<Group>> {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Group>>();
|
||||
.collect::<Vec<LibcGroup>>();
|
||||
|
||||
Ok(groups)
|
||||
}
|
||||
|
||||
/// This function creates a regex that matches items (users, databases)
|
||||
/// that belong to the user or any of the user's groups.
|
||||
pub fn create_user_group_matching_regex(user: &User) -> String {
|
||||
let groups = get_unix_groups(user).unwrap_or_default();
|
||||
impl UnixUser {
|
||||
pub fn from_uid(uid: u32) -> anyhow::Result<Self> {
|
||||
let libc_uid = nix::unistd::Uid::from_raw(uid);
|
||||
let libc_user = LibcUser::from_uid(libc_uid)
|
||||
.context("Failed to look up your UNIX username")?
|
||||
.ok_or(anyhow::anyhow!("Failed to look up your UNIX username"))?;
|
||||
|
||||
if groups.is_empty() {
|
||||
format!("{}(_.+)?", user.name)
|
||||
} else {
|
||||
format!(
|
||||
"({}|{})(_.+)?",
|
||||
user.name,
|
||||
groups
|
||||
.iter()
|
||||
.map(|g| g.name.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("|")
|
||||
)
|
||||
}
|
||||
}
|
||||
let groups = get_unix_groups(&libc_user)?;
|
||||
|
||||
/// This enum is used to differentiate between database and user operations.
|
||||
/// Their output are very similar, but there are slight differences in the words used.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum DbOrUser {
|
||||
Database,
|
||||
User,
|
||||
}
|
||||
|
||||
impl DbOrUser {
|
||||
pub fn lowercased(&self) -> String {
|
||||
match self {
|
||||
DbOrUser::Database => "database".to_string(),
|
||||
DbOrUser::User => "user".to_string(),
|
||||
}
|
||||
Ok(UnixUser {
|
||||
username: libc_user.name,
|
||||
groups: groups.iter().map(|g| g.name.clone()).collect(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn capitalized(&self) -> String {
|
||||
match self {
|
||||
DbOrUser::Database => "Database".to_string(),
|
||||
DbOrUser::User => "User".to_string(),
|
||||
}
|
||||
pub fn from_enviroment() -> anyhow::Result<Self> {
|
||||
let libc_uid = nix::unistd::getuid();
|
||||
UnixUser::from_uid(libc_uid.as_raw())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum NameValidationResult {
|
||||
Valid,
|
||||
EmptyString,
|
||||
InvalidCharacters,
|
||||
TooLong,
|
||||
}
|
||||
|
||||
pub fn validate_name(name: &str) -> NameValidationResult {
|
||||
if name.is_empty() {
|
||||
NameValidationResult::EmptyString
|
||||
} else if name.len() > 64 {
|
||||
NameValidationResult::TooLong
|
||||
} else if !name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
{
|
||||
NameValidationResult::InvalidCharacters
|
||||
} else {
|
||||
NameValidationResult::Valid
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_name_or_error(name: &str, db_or_user: DbOrUser) -> anyhow::Result<()> {
|
||||
match validate_name(name) {
|
||||
NameValidationResult::Valid => Ok(()),
|
||||
NameValidationResult::EmptyString => {
|
||||
anyhow::bail!("{} name cannot be empty.", db_or_user.capitalized())
|
||||
}
|
||||
NameValidationResult::TooLong => anyhow::bail!(
|
||||
"{} is too long. Maximum length is 64 characters.",
|
||||
db_or_user.capitalized()
|
||||
),
|
||||
NameValidationResult::InvalidCharacters => anyhow::bail!(
|
||||
indoc! {r#"
|
||||
Invalid characters in {} name: '{}'
|
||||
|
||||
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
|
||||
"#},
|
||||
db_or_user.lowercased(),
|
||||
name
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum OwnerValidationResult {
|
||||
// The name is valid and matches one of the given prefixes
|
||||
Match,
|
||||
|
||||
// The name is valid, but none of the given prefixes matched the name
|
||||
NoMatch,
|
||||
|
||||
// The name is empty, which is invalid
|
||||
StringEmpty,
|
||||
|
||||
// The name is in the format "_<postfix>", which is invalid
|
||||
MissingPrefix,
|
||||
|
||||
// The name is in the format "<prefix>_", which is invalid
|
||||
MissingPostfix,
|
||||
}
|
||||
|
||||
/// Core logic for validating the ownership of a database name.
|
||||
/// This function checks if the given name matches any of the given prefixes.
|
||||
/// These prefixes will in most cases be the user's unix username and any
|
||||
/// unix groups the user is a member of.
|
||||
pub fn validate_ownership_by_prefixes(name: &str, prefixes: &[String]) -> OwnerValidationResult {
|
||||
if name.is_empty() {
|
||||
return OwnerValidationResult::StringEmpty;
|
||||
}
|
||||
|
||||
if name.starts_with('_') {
|
||||
return OwnerValidationResult::MissingPrefix;
|
||||
}
|
||||
|
||||
let (prefix, _) = match name.split_once('_') {
|
||||
Some(pair) => pair,
|
||||
None => return OwnerValidationResult::MissingPostfix,
|
||||
};
|
||||
|
||||
if prefixes.iter().any(|g| g == prefix) {
|
||||
OwnerValidationResult::Match
|
||||
} else {
|
||||
OwnerValidationResult::NoMatch
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the ownership of a database name or database user name.
|
||||
/// This function takes the name of a database or user and a unix user,
|
||||
/// for which it fetches the user's groups. It then checks if the name
|
||||
/// is prefixed with the user's username or any of the user's groups.
|
||||
pub fn validate_ownership_or_error<'a>(
|
||||
name: &'a str,
|
||||
user: &User,
|
||||
db_or_user: DbOrUser,
|
||||
) -> anyhow::Result<&'a str> {
|
||||
let user_groups = get_unix_groups(user)?;
|
||||
let prefixes = std::iter::once(user.name.clone())
|
||||
.chain(user_groups.iter().map(|g| g.name.clone()))
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
match validate_ownership_by_prefixes(name, &prefixes) {
|
||||
OwnerValidationResult::Match => Ok(name),
|
||||
OwnerValidationResult::NoMatch => {
|
||||
anyhow::bail!(
|
||||
indoc! {r#"
|
||||
Invalid {} name prefix: '{}' does not match your username or any of your groups.
|
||||
Are you sure you are allowed to create {} names with this prefix?
|
||||
|
||||
Allowed prefixes:
|
||||
- {}
|
||||
{}
|
||||
"#},
|
||||
db_or_user.lowercased(),
|
||||
name,
|
||||
db_or_user.lowercased(),
|
||||
user.name,
|
||||
user_groups
|
||||
.iter()
|
||||
.filter(|g| g.name != user.name)
|
||||
.map(|g| format!(" - {}", g.name))
|
||||
.sorted()
|
||||
.join("\n"),
|
||||
);
|
||||
}
|
||||
_ => anyhow::bail!(
|
||||
"'{}' is not a valid {} name.",
|
||||
name,
|
||||
db_or_user.lowercased()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gracefully close a MySQL connection.
|
||||
pub async fn close_database_connection(connection: MySqlConnection) {
|
||||
if let Err(e) = connection
|
||||
.close()
|
||||
.await
|
||||
.context("Failed to close connection properly")
|
||||
{
|
||||
eprintln!("{}", e);
|
||||
eprintln!("Ignoring...");
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn quote_literal(s: &str) -> String {
|
||||
format!("'{}'", s.replace('\'', r"\'"))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn quote_identifier(s: &str) -> String {
|
||||
format!("`{}`", s.replace('`', r"\`"))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn yn(b: bool) -> &'static str {
|
||||
if b {
|
||||
@@ -303,94 +100,4 @@ mod test {
|
||||
assert_eq!(rev_yn("n"), Some(false));
|
||||
assert_eq!(rev_yn("X"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quote_literal() {
|
||||
let payload = "' OR 1=1 --";
|
||||
assert_eq!(quote_literal(payload), r#"'\' OR 1=1 --'"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quote_identifier() {
|
||||
let payload = "` OR 1=1 --";
|
||||
assert_eq!(quote_identifier(payload), r#"`\` OR 1=1 --`"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_name() {
|
||||
assert_eq!(validate_name(""), NameValidationResult::EmptyString);
|
||||
assert_eq!(
|
||||
validate_name("abcdefghijklmnopqrstuvwxyz"),
|
||||
NameValidationResult::Valid
|
||||
);
|
||||
assert_eq!(
|
||||
validate_name("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
|
||||
NameValidationResult::Valid
|
||||
);
|
||||
assert_eq!(validate_name("0123456789_-"), NameValidationResult::Valid);
|
||||
|
||||
for c in "\n\t\r !@#$%^&*()+=[]{}|;:,.<>?/".chars() {
|
||||
assert_eq!(
|
||||
validate_name(&c.to_string()),
|
||||
NameValidationResult::InvalidCharacters
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(validate_name(&"a".repeat(64)), NameValidationResult::Valid);
|
||||
|
||||
assert_eq!(
|
||||
validate_name(&"a".repeat(65)),
|
||||
NameValidationResult::TooLong
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_owner_by_prefixes() {
|
||||
let prefixes = vec!["user".to_string(), "group".to_string()];
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("", &prefixes),
|
||||
OwnerValidationResult::StringEmpty
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("user", &prefixes),
|
||||
OwnerValidationResult::MissingPostfix
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("something", &prefixes),
|
||||
OwnerValidationResult::MissingPostfix
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("user-testdb", &prefixes),
|
||||
OwnerValidationResult::MissingPostfix
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("_testdb", &prefixes),
|
||||
OwnerValidationResult::MissingPrefix
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("user_testdb", &prefixes),
|
||||
OwnerValidationResult::Match
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("group_testdb", &prefixes),
|
||||
OwnerValidationResult::Match
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("group_test_db", &prefixes),
|
||||
OwnerValidationResult::Match
|
||||
);
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("group_test-db", &prefixes),
|
||||
OwnerValidationResult::Match
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
validate_ownership_by_prefixes("nonexistent_testdb", &prefixes),
|
||||
OwnerValidationResult::NoMatch
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
use std::{fs, path::PathBuf, time::Duration};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use clap::Parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{mysql::MySqlConnectOptions, ConnectOptions, MySqlConnection};
|
||||
|
||||
// NOTE: this might look empty now, and the extra wrapping for the mysql
|
||||
// config seems unnecessary, but it will be useful later when we
|
||||
// add more configuration options.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Config {
|
||||
pub mysql: MysqlConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename = "mysql")]
|
||||
pub struct MysqlConfig {
|
||||
pub host: String,
|
||||
pub port: Option<u16>,
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub timeout: Option<u64>,
|
||||
}
|
||||
|
||||
const DEFAULT_PORT: u16 = 3306;
|
||||
const DEFAULT_TIMEOUT: u64 = 2;
|
||||
|
||||
#[derive(Parser)]
|
||||
pub struct GlobalConfigArgs {
|
||||
/// Path to the configuration file.
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
value_name = "PATH",
|
||||
global = true,
|
||||
hide_short_help = true,
|
||||
default_value = "/etc/mysqladm/config.toml"
|
||||
)]
|
||||
config_file: String,
|
||||
|
||||
/// Hostname of the MySQL server.
|
||||
#[arg(long, value_name = "HOST", global = true, hide_short_help = true)]
|
||||
mysql_host: Option<String>,
|
||||
|
||||
/// Port of the MySQL server.
|
||||
#[arg(long, value_name = "PORT", global = true, hide_short_help = true)]
|
||||
mysql_port: Option<u16>,
|
||||
|
||||
/// Username to use for the MySQL connection.
|
||||
#[arg(long, value_name = "USER", global = true, hide_short_help = true)]
|
||||
mysql_user: Option<String>,
|
||||
|
||||
/// Path to a file containing the MySQL password.
|
||||
#[arg(long, value_name = "PATH", global = true, hide_short_help = true)]
|
||||
mysql_password_file: Option<String>,
|
||||
|
||||
/// Seconds to wait for the MySQL connection to be established.
|
||||
#[arg(long, value_name = "SECONDS", global = true, hide_short_help = true)]
|
||||
mysql_connect_timeout: Option<u64>,
|
||||
}
|
||||
|
||||
/// Use the arguments and whichever configuration file which might or might not
|
||||
/// be found and default values to determine the configuration for the program.
|
||||
pub fn get_config(args: GlobalConfigArgs) -> anyhow::Result<Config> {
|
||||
let config_path = PathBuf::from(args.config_file);
|
||||
|
||||
let config: Config = fs::read_to_string(&config_path)
|
||||
.context(format!(
|
||||
"Failed to read config file from {:?}",
|
||||
&config_path
|
||||
))
|
||||
.and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
|
||||
.context(format!(
|
||||
"Failed to parse config file from {:?}",
|
||||
&config_path
|
||||
))?;
|
||||
|
||||
let mysql = &config.mysql;
|
||||
|
||||
let password = if let Some(path) = args.mysql_password_file {
|
||||
fs::read_to_string(path)
|
||||
.context("Failed to read MySQL password file")
|
||||
.map(|s| s.trim().to_owned())?
|
||||
} else {
|
||||
mysql.password.to_owned()
|
||||
};
|
||||
|
||||
let mysql_config = MysqlConfig {
|
||||
host: args.mysql_host.unwrap_or(mysql.host.to_owned()),
|
||||
port: args.mysql_port.or(mysql.port),
|
||||
username: args.mysql_user.unwrap_or(mysql.username.to_owned()),
|
||||
password,
|
||||
timeout: args.mysql_connect_timeout.or(mysql.timeout),
|
||||
};
|
||||
|
||||
Ok(Config {
|
||||
mysql: mysql_config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Use the provided configuration to establish a connection to a MySQL server.
|
||||
pub async fn create_mysql_connection_from_config(
|
||||
config: MysqlConfig,
|
||||
) -> anyhow::Result<MySqlConnection> {
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
|
||||
MySqlConnectOptions::new()
|
||||
.host(&config.host)
|
||||
.username(&config.username)
|
||||
.password(&config.password)
|
||||
.port(config.port.unwrap_or(DEFAULT_PORT))
|
||||
.database("mysql")
|
||||
.connect(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection.context("Failed to connect to MySQL"),
|
||||
Err(_) => Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to MySQL"),
|
||||
}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use indoc::formatdoc;
|
||||
use itertools::Itertools;
|
||||
use nix::unistd::User;
|
||||
use sqlx::{prelude::*, MySqlConnection};
|
||||
|
||||
use crate::core::{
|
||||
common::{
|
||||
create_user_group_matching_regex, get_current_unix_user, quote_identifier,
|
||||
validate_name_or_error, validate_ownership_or_error, DbOrUser,
|
||||
},
|
||||
database_privilege_operations::DATABASE_PRIVILEGE_FIELDS,
|
||||
};
|
||||
|
||||
pub async fn create_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
|
||||
let user = get_current_unix_user()?;
|
||||
validate_database_name(name, &user)?;
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
|
||||
sqlx::query(&format!("CREATE DATABASE {}", quote_identifier(name)))
|
||||
.execute(connection)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.to_string().contains("database exists") {
|
||||
anyhow::anyhow!("Database '{}' already exists", name)
|
||||
} else {
|
||||
e.into()
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn drop_database(name: &str, connection: &mut MySqlConnection) -> anyhow::Result<()> {
|
||||
let user = get_current_unix_user()?;
|
||||
validate_database_name(name, &user)?;
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_owner_of_database_name`
|
||||
sqlx::query(&format!("DROP DATABASE {}", quote_identifier(name)))
|
||||
.execute(connection)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.to_string().contains("doesn't exist") {
|
||||
anyhow::anyhow!("Database '{}' does not exist", name)
|
||||
} else {
|
||||
e.into()
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_database_list(connection: &mut MySqlConnection) -> anyhow::Result<Vec<String>> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
let databases: Vec<String> = sqlx::query(
|
||||
r#"
|
||||
SELECT `SCHEMA_NAME` AS `database`
|
||||
FROM `information_schema`.`SCHEMATA`
|
||||
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
||||
AND `SCHEMA_NAME` REGEXP ?
|
||||
"#,
|
||||
)
|
||||
.bind(create_user_group_matching_regex(&unix_user))
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
.and_then(|row| {
|
||||
row.into_iter()
|
||||
.map(|row| row.try_get::<String, _>("database"))
|
||||
.collect::<Result<_, _>>()
|
||||
})
|
||||
.context(format!(
|
||||
"Failed to get databases for user '{}'",
|
||||
unix_user.name
|
||||
))?;
|
||||
|
||||
Ok(databases)
|
||||
}
|
||||
|
||||
pub async fn get_databases_where_user_has_privileges(
|
||||
username: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<Vec<String>> {
|
||||
let result = sqlx::query(
|
||||
formatdoc!(
|
||||
r#"
|
||||
SELECT `db` AS `database`
|
||||
FROM `db`
|
||||
WHERE `user` = ?
|
||||
AND ({})
|
||||
"#,
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| format!("`{}` = 'Y'", field))
|
||||
.join(" OR "),
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_all(connection)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|databases| databases.try_get::<String, _>("database").unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// NOTE: It is very critical that this function validates the database name
|
||||
/// properly. MySQL does not seem to allow for prepared statements, binding
|
||||
/// the database name as a parameter to the query. This means that we have
|
||||
/// to validate the database name ourselves to prevent SQL injection.
|
||||
pub fn validate_database_name(name: &str, user: &User) -> anyhow::Result<()> {
|
||||
validate_name_or_error(name, DbOrUser::Database)
|
||||
.context(format!("Invalid database name: '{}'", name))?;
|
||||
validate_ownership_or_error(name, user, DbOrUser::Database)
|
||||
.context(format!("Invalid database name: '{}'", name))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,52 +1,16 @@
|
||||
//! Database privilege operations
|
||||
//!
|
||||
//! This module contains functions for querying, modifying,
|
||||
//! displaying and comparing database privileges.
|
||||
//!
|
||||
//! A lot of the complexity comes from two core components:
|
||||
//!
|
||||
//! - The privilege editor that needs to be able to print
|
||||
//! an editable table of privileges and reparse the content
|
||||
//! after the user has made manual changes.
|
||||
//!
|
||||
//! - The comparison functionality that tells the user what
|
||||
//! changes will be made when applying a set of changes
|
||||
//! to the list of database privileges.
|
||||
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use indoc::indoc;
|
||||
use itertools::Itertools;
|
||||
use prettytable::Table;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{mysql::MySqlRow, prelude::*, MySqlConnection};
|
||||
|
||||
use crate::core::{
|
||||
common::{
|
||||
create_user_group_matching_regex, get_current_unix_user, quote_identifier, rev_yn, yn,
|
||||
},
|
||||
database_operations::validate_database_name,
|
||||
use std::{
|
||||
cmp::max,
|
||||
collections::{BTreeSet, HashMap},
|
||||
};
|
||||
|
||||
/// This is the list of fields that are used to fetch the db + user + privileges
|
||||
/// from the `db` table in the database. If you need to add or remove privilege
|
||||
/// fields, this is a good place to start.
|
||||
pub const DATABASE_PRIVILEGE_FIELDS: [&str; 13] = [
|
||||
"db",
|
||||
"user",
|
||||
"select_priv",
|
||||
"insert_priv",
|
||||
"update_priv",
|
||||
"delete_priv",
|
||||
"create_priv",
|
||||
"drop_priv",
|
||||
"alter_priv",
|
||||
"index_priv",
|
||||
"create_tmp_table_priv",
|
||||
"lock_tables_priv",
|
||||
"references_priv",
|
||||
];
|
||||
use super::common::{rev_yn, yn};
|
||||
use crate::server::sql::database_privilege_operations::{
|
||||
DatabasePrivilegeRow, DATABASE_PRIVILEGE_FIELDS,
|
||||
};
|
||||
|
||||
pub fn db_priv_field_human_readable_name(name: &str) -> String {
|
||||
match name {
|
||||
@@ -67,162 +31,24 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct represents the set of privileges for a single user on a single database.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct DatabasePrivilegeRow {
|
||||
pub db: String,
|
||||
pub user: String,
|
||||
pub select_priv: bool,
|
||||
pub insert_priv: bool,
|
||||
pub update_priv: bool,
|
||||
pub delete_priv: bool,
|
||||
pub create_priv: bool,
|
||||
pub drop_priv: bool,
|
||||
pub alter_priv: bool,
|
||||
pub index_priv: bool,
|
||||
pub create_tmp_table_priv: bool,
|
||||
pub lock_tables_priv: bool,
|
||||
pub references_priv: bool,
|
||||
}
|
||||
pub fn diff(row1: &DatabasePrivilegeRow, row2: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
|
||||
debug_assert!(row1.db == row2.db && row1.user == row2.user);
|
||||
|
||||
impl DatabasePrivilegeRow {
|
||||
pub fn empty(db: &str, user: &str) -> Self {
|
||||
Self {
|
||||
db: db.to_owned(),
|
||||
user: user.to_owned(),
|
||||
select_priv: false,
|
||||
insert_priv: false,
|
||||
update_priv: false,
|
||||
delete_priv: false,
|
||||
create_priv: false,
|
||||
drop_priv: false,
|
||||
alter_priv: false,
|
||||
index_priv: false,
|
||||
create_tmp_table_priv: false,
|
||||
lock_tables_priv: false,
|
||||
references_priv: false,
|
||||
}
|
||||
DatabasePrivilegeRowDiff {
|
||||
db: row1.db.clone(),
|
||||
user: row1.user.clone(),
|
||||
diff: DATABASE_PRIVILEGE_FIELDS
|
||||
.into_iter()
|
||||
.skip(2)
|
||||
.filter_map(|field| {
|
||||
DatabasePrivilegeChange::new(
|
||||
row1.get_privilege_by_name(field),
|
||||
row2.get_privilege_by_name(field),
|
||||
field,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
|
||||
pub fn get_privilege_by_name(&self, name: &str) -> bool {
|
||||
match name {
|
||||
"select_priv" => self.select_priv,
|
||||
"insert_priv" => self.insert_priv,
|
||||
"update_priv" => self.update_priv,
|
||||
"delete_priv" => self.delete_priv,
|
||||
"create_priv" => self.create_priv,
|
||||
"drop_priv" => self.drop_priv,
|
||||
"alter_priv" => self.alter_priv,
|
||||
"index_priv" => self.index_priv,
|
||||
"create_tmp_table_priv" => self.create_tmp_table_priv,
|
||||
"lock_tables_priv" => self.lock_tables_priv,
|
||||
"references_priv" => self.references_priv,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn diff(&self, other: &DatabasePrivilegeRow) -> DatabasePrivilegeRowDiff {
|
||||
debug_assert!(self.db == other.db && self.user == other.user);
|
||||
|
||||
DatabasePrivilegeRowDiff {
|
||||
db: self.db.clone(),
|
||||
user: self.user.clone(),
|
||||
diff: DATABASE_PRIVILEGE_FIELDS
|
||||
.into_iter()
|
||||
.skip(2)
|
||||
.filter_map(|field| {
|
||||
DatabasePrivilegeChange::new(
|
||||
self.get_privilege_by_name(field),
|
||||
other.get_privilege_by_name(field),
|
||||
field,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
|
||||
let field = DATABASE_PRIVILEGE_FIELDS[position];
|
||||
let value = row.try_get(position)?;
|
||||
match rev_yn(value) {
|
||||
Some(val) => Ok(val),
|
||||
_ => {
|
||||
log::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRow<'_, MySqlRow> for DatabasePrivilegeRow {
|
||||
fn from_row(row: &MySqlRow) -> Result<Self, sqlx::Error> {
|
||||
Ok(Self {
|
||||
db: row.try_get("db")?,
|
||||
user: row.try_get("user")?,
|
||||
select_priv: get_mysql_row_priv_field(row, 2)?,
|
||||
insert_priv: get_mysql_row_priv_field(row, 3)?,
|
||||
update_priv: get_mysql_row_priv_field(row, 4)?,
|
||||
delete_priv: get_mysql_row_priv_field(row, 5)?,
|
||||
create_priv: get_mysql_row_priv_field(row, 6)?,
|
||||
drop_priv: get_mysql_row_priv_field(row, 7)?,
|
||||
alter_priv: get_mysql_row_priv_field(row, 8)?,
|
||||
index_priv: get_mysql_row_priv_field(row, 9)?,
|
||||
create_tmp_table_priv: get_mysql_row_priv_field(row, 10)?,
|
||||
lock_tables_priv: get_mysql_row_priv_field(row, 11)?,
|
||||
references_priv: get_mysql_row_priv_field(row, 12)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all users + privileges for a single database.
|
||||
pub async fn get_database_privileges(
|
||||
database_name: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
validate_database_name(database_name, &unix_user)?;
|
||||
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
"SELECT {} FROM `db` WHERE `db` = ?",
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| quote_identifier(field))
|
||||
.join(","),
|
||||
))
|
||||
.bind(database_name)
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
.context("Failed to show database")?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Get all database + user + privileges pairs that are owned by the current user.
|
||||
pub async fn get_all_database_privileges(
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||
indoc! {r#"
|
||||
SELECT {} FROM `db` WHERE `db` IN
|
||||
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
||||
FROM `information_schema`.`SCHEMATA`
|
||||
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
|
||||
AND `SCHEMA_NAME` REGEXP ?)
|
||||
"#},
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| format!("`{field}`"))
|
||||
.join(","),
|
||||
))
|
||||
.bind(create_user_group_matching_regex(&unix_user))
|
||||
.fetch_all(connection)
|
||||
.await
|
||||
.context("Failed to show databases")?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/*************************/
|
||||
@@ -340,17 +166,23 @@ pub fn generate_editor_content_from_privilege_data(
|
||||
// editor will be the example user and example db name.
|
||||
// Hence, it's put as the fallback value, despite not really
|
||||
// being a "fallback" in the normal sense.
|
||||
let longest_username = privilege_data
|
||||
.iter()
|
||||
.map(|p| p.user.len())
|
||||
.max()
|
||||
.unwrap_or(example_user.len());
|
||||
let longest_username = max(
|
||||
privilege_data
|
||||
.iter()
|
||||
.map(|p| p.user.len())
|
||||
.max()
|
||||
.unwrap_or(example_user.len()),
|
||||
"User".len(),
|
||||
);
|
||||
|
||||
let longest_database_name = privilege_data
|
||||
.iter()
|
||||
.map(|p| p.db.len())
|
||||
.max()
|
||||
.unwrap_or(example_db.len());
|
||||
let longest_database_name = max(
|
||||
privilege_data
|
||||
.iter()
|
||||
.map(|p| p.db.len())
|
||||
.max()
|
||||
.unwrap_or(example_db.len()),
|
||||
"Database".len(),
|
||||
);
|
||||
|
||||
let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS
|
||||
.into_iter()
|
||||
@@ -578,7 +410,7 @@ pub fn parse_privilege_data_from_editor_content(
|
||||
/// instances of privilege sets for a single user on a single database.
|
||||
///
|
||||
/// The `User` and `Database` are the same for both instances.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct DatabasePrivilegeRowDiff {
|
||||
pub db: String,
|
||||
pub user: String,
|
||||
@@ -586,7 +418,7 @@ pub struct DatabasePrivilegeRowDiff {
|
||||
}
|
||||
|
||||
/// This enum represents a change for a single privilege.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub enum DatabasePrivilegeChange {
|
||||
YesToNo(String),
|
||||
NoToYes(String),
|
||||
@@ -603,13 +435,31 @@ impl DatabasePrivilegeChange {
|
||||
}
|
||||
|
||||
/// This enum encapsulates whether a [`DatabasePrivilegeRow`] was intrduced, modified or deleted.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub enum DatabasePrivilegesDiff {
|
||||
New(DatabasePrivilegeRow),
|
||||
Modified(DatabasePrivilegeRowDiff),
|
||||
Deleted(DatabasePrivilegeRow),
|
||||
}
|
||||
|
||||
impl DatabasePrivilegesDiff {
|
||||
pub fn get_database_name(&self) -> &str {
|
||||
match self {
|
||||
DatabasePrivilegesDiff::New(p) => &p.db,
|
||||
DatabasePrivilegesDiff::Modified(p) => &p.db,
|
||||
DatabasePrivilegesDiff::Deleted(p) => &p.db,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_user_name(&self) -> &str {
|
||||
match self {
|
||||
DatabasePrivilegesDiff::New(p) => &p.user,
|
||||
DatabasePrivilegesDiff::Modified(p) => &p.user,
|
||||
DatabasePrivilegesDiff::Deleted(p) => &p.user,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This function calculates the differences between two sets of database privileges.
|
||||
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
|
||||
/// apply a set of privilege modifications to the database.
|
||||
@@ -633,7 +483,7 @@ pub fn diff_privileges(
|
||||
|
||||
for p in to {
|
||||
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
|
||||
let diff = old_p.diff(p);
|
||||
let diff = diff(old_p, p);
|
||||
if !diff.diff.is_empty() {
|
||||
result.insert(DatabasePrivilegesDiff::Modified(diff));
|
||||
}
|
||||
@@ -651,72 +501,6 @@ pub fn diff_privileges(
|
||||
result
|
||||
}
|
||||
|
||||
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
||||
pub async fn apply_privilege_diffs(
|
||||
diffs: BTreeSet<DatabasePrivilegesDiff>,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
for diff in diffs {
|
||||
match diff {
|
||||
DatabasePrivilegesDiff::New(p) => {
|
||||
let tables = DATABASE_PRIVILEGE_FIELDS
|
||||
.iter()
|
||||
.map(|field| format!("`{field}`"))
|
||||
.join(",");
|
||||
|
||||
let question_marks = std::iter::repeat("?")
|
||||
.take(DATABASE_PRIVILEGE_FIELDS.len())
|
||||
.join(",");
|
||||
|
||||
sqlx::query(
|
||||
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
|
||||
)
|
||||
.bind(p.db)
|
||||
.bind(p.user)
|
||||
.bind(yn(p.select_priv))
|
||||
.bind(yn(p.insert_priv))
|
||||
.bind(yn(p.update_priv))
|
||||
.bind(yn(p.delete_priv))
|
||||
.bind(yn(p.create_priv))
|
||||
.bind(yn(p.drop_priv))
|
||||
.bind(yn(p.alter_priv))
|
||||
.bind(yn(p.index_priv))
|
||||
.bind(yn(p.create_tmp_table_priv))
|
||||
.bind(yn(p.lock_tables_priv))
|
||||
.bind(yn(p.references_priv))
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
}
|
||||
DatabasePrivilegesDiff::Modified(p) => {
|
||||
let tables = p
|
||||
.diff
|
||||
.iter()
|
||||
.map(|diff| match diff {
|
||||
DatabasePrivilegeChange::YesToNo(name) => format!("`{}` = 'N'", name),
|
||||
DatabasePrivilegeChange::NoToYes(name) => format!("`{}` = 'Y'", name),
|
||||
})
|
||||
.join(",");
|
||||
|
||||
sqlx::query(
|
||||
format!("UPDATE `db` SET {} WHERE `db` = ? AND `user` = ?", tables).as_str(),
|
||||
)
|
||||
.bind(p.db)
|
||||
.bind(p.user)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
}
|
||||
DatabasePrivilegesDiff::Deleted(p) => {
|
||||
sqlx::query("DELETE FROM `db` WHERE `db` = ? AND `user` = ?")
|
||||
.bind(p.db)
|
||||
.bind(p.user)
|
||||
.execute(&mut *connection)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
|
||||
diff.diff
|
||||
.iter()
|
||||
@@ -731,6 +515,20 @@ fn display_privilege_cell(diff: &DatabasePrivilegeRowDiff) -> String {
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
fn display_new_privileges_list(row: &DatabasePrivilegeRow) -> String {
|
||||
DATABASE_PRIVILEGE_FIELDS
|
||||
.into_iter()
|
||||
.skip(2)
|
||||
.map(|field| {
|
||||
if row.get_privilege_by_name(field) {
|
||||
format!("{}: Y", db_priv_field_human_readable_name(field))
|
||||
} else {
|
||||
format!("{}: N", db_priv_field_human_readable_name(field))
|
||||
}
|
||||
})
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
/// Displays the difference between two sets of database privileges.
|
||||
pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> String {
|
||||
let mut table = Table::new();
|
||||
@@ -741,24 +539,14 @@ pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> Stri
|
||||
table.add_row(row![
|
||||
p.db,
|
||||
p.user,
|
||||
"(New user)\n".to_string()
|
||||
+ &display_privilege_cell(
|
||||
&DatabasePrivilegeRow::empty(&p.db, &p.user).diff(p)
|
||||
)
|
||||
"(New user)\n".to_string() + &display_new_privileges_list(p)
|
||||
]);
|
||||
}
|
||||
DatabasePrivilegesDiff::Modified(p) => {
|
||||
table.add_row(row![p.db, p.user, display_privilege_cell(p),]);
|
||||
}
|
||||
DatabasePrivilegesDiff::Deleted(p) => {
|
||||
table.add_row(row![
|
||||
p.db,
|
||||
p.user,
|
||||
"(All privileges removed)\n".to_string()
|
||||
+ &display_privilege_cell(
|
||||
&p.diff(&DatabasePrivilegeRow::empty(&p.db, &p.user))
|
||||
)
|
||||
]);
|
||||
table.add_row(row![p.db, p.user, "Removed".to_string()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
5
src/core/protocol.rs
Normal file
5
src/core/protocol.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod request_response;
|
||||
pub mod server_responses;
|
||||
|
||||
pub use request_response::*;
|
||||
pub use server_responses::*;
|
||||
79
src/core/protocol/request_response.rs
Normal file
79
src/core/protocol/request_response.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::UnixStream;
|
||||
use tokio_serde::{formats::Bincode, Framed as SerdeFramed};
|
||||
use tokio_util::codec::{Framed, LengthDelimitedCodec};
|
||||
|
||||
use crate::core::{database_privileges::DatabasePrivilegesDiff, protocol::*};
|
||||
|
||||
pub type ServerToClientMessageStream = SerdeFramed<
|
||||
Framed<UnixStream, LengthDelimitedCodec>,
|
||||
Request,
|
||||
Response,
|
||||
Bincode<Request, Response>,
|
||||
>;
|
||||
|
||||
pub type ClientToServerMessageStream = SerdeFramed<
|
||||
Framed<UnixStream, LengthDelimitedCodec>,
|
||||
Response,
|
||||
Request,
|
||||
Bincode<Response, Request>,
|
||||
>;
|
||||
|
||||
pub fn create_server_to_client_message_stream(socket: UnixStream) -> ServerToClientMessageStream {
|
||||
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
|
||||
tokio_serde::Framed::new(length_delimited, Bincode::default())
|
||||
}
|
||||
|
||||
pub fn create_client_to_server_message_stream(socket: UnixStream) -> ClientToServerMessageStream {
|
||||
let length_delimited = Framed::new(socket, LengthDelimitedCodec::new());
|
||||
tokio_serde::Framed::new(length_delimited, Bincode::default())
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Request {
|
||||
CreateDatabases(Vec<String>),
|
||||
DropDatabases(Vec<String>),
|
||||
ListDatabases,
|
||||
ListPrivileges(Option<Vec<String>>),
|
||||
ModifyPrivileges(BTreeSet<DatabasePrivilegesDiff>),
|
||||
|
||||
CreateUsers(Vec<String>),
|
||||
DropUsers(Vec<String>),
|
||||
PasswdUser(String, String),
|
||||
ListUsers(Option<Vec<String>>),
|
||||
LockUsers(Vec<String>),
|
||||
UnlockUsers(Vec<String>),
|
||||
|
||||
// Commit,
|
||||
Exit,
|
||||
}
|
||||
|
||||
// TODO: include a generic "message" that will display a message to the user?
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Response {
|
||||
// Specific data for specific commands
|
||||
CreateDatabases(CreateDatabasesOutput),
|
||||
DropDatabases(DropDatabasesOutput),
|
||||
ListAllDatabases(ListAllDatabasesOutput),
|
||||
ListPrivileges(GetDatabasesPrivilegeData),
|
||||
ListAllPrivileges(GetAllDatabasesPrivilegeData),
|
||||
ModifyPrivileges(ModifyDatabasePrivilegesOutput),
|
||||
|
||||
CreateUsers(CreateUsersOutput),
|
||||
DropUsers(DropUsersOutput),
|
||||
PasswdUser(SetPasswordOutput),
|
||||
ListUsers(ListUsersOutput),
|
||||
ListAllUsers(ListAllUsersOutput),
|
||||
LockUsers(LockUsersOutput),
|
||||
UnlockUsers(UnlockUsersOutput),
|
||||
|
||||
// Generic responses
|
||||
OperationAborted,
|
||||
Error(String),
|
||||
Exit,
|
||||
}
|
||||
611
src/core/protocol/server_responses.rs
Normal file
611
src/core/protocol/server_responses.rs
Normal file
@@ -0,0 +1,611 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use indoc::indoc;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
core::{common::UnixUser, database_privileges::DatabasePrivilegeRowDiff},
|
||||
server::sql::{
|
||||
database_privilege_operations::DatabasePrivilegeRow, user_operations::DatabaseUser,
|
||||
},
|
||||
};
|
||||
|
||||
/// This enum is used to differentiate between database and user operations.
|
||||
/// Their output are very similar, but there are slight differences in the words used.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum DbOrUser {
|
||||
Database,
|
||||
User,
|
||||
}
|
||||
|
||||
impl DbOrUser {
|
||||
pub fn lowercased(&self) -> String {
|
||||
match self {
|
||||
DbOrUser::Database => "database".to_string(),
|
||||
DbOrUser::User => "user".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn capitalized(&self) -> String {
|
||||
match self {
|
||||
DbOrUser::Database => "Database".to_string(),
|
||||
DbOrUser::User => "User".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum NameValidationError {
|
||||
EmptyString,
|
||||
InvalidCharacters,
|
||||
TooLong,
|
||||
}
|
||||
|
||||
impl NameValidationError {
|
||||
pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String {
|
||||
match self {
|
||||
NameValidationError::EmptyString => {
|
||||
format!("{} name cannot be empty.", db_or_user.capitalized()).to_owned()
|
||||
}
|
||||
NameValidationError::TooLong => format!(
|
||||
"{} is too long. Maximum length is 64 characters.",
|
||||
db_or_user.capitalized()
|
||||
)
|
||||
.to_owned(),
|
||||
NameValidationError::InvalidCharacters => format!(
|
||||
indoc! {r#"
|
||||
Invalid characters in {} name: '{}'
|
||||
|
||||
Only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
|
||||
"#},
|
||||
db_or_user.lowercased(),
|
||||
name
|
||||
)
|
||||
.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OwnerValidationError {
|
||||
pub fn to_error_message(self, name: &str, db_or_user: DbOrUser) -> String {
|
||||
let user = UnixUser::from_enviroment();
|
||||
|
||||
match self {
|
||||
OwnerValidationError::NoMatch => format!(
|
||||
indoc! {r#"
|
||||
Invalid {} name prefix: '{}' does not match your username or any of your groups.
|
||||
Are you sure you are allowed to create {} names with this prefix?
|
||||
|
||||
Allowed prefixes:
|
||||
- {}
|
||||
{}
|
||||
"#},
|
||||
db_or_user.lowercased(),
|
||||
name,
|
||||
db_or_user.lowercased(),
|
||||
user.as_ref()
|
||||
.map(|u| u.username.clone())
|
||||
.unwrap_or("???".to_string()),
|
||||
user.map(|u| u.groups)
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.map(|g| format!(" - {}", g))
|
||||
.sorted()
|
||||
.join("\n"),
|
||||
)
|
||||
.to_owned(),
|
||||
|
||||
_ => format!(
|
||||
"'{}' is not a valid {} name.",
|
||||
name,
|
||||
db_or_user.lowercased()
|
||||
)
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum OwnerValidationError {
|
||||
// The name is valid, but none of the given prefixes matched the name
|
||||
NoMatch,
|
||||
|
||||
// The name is empty, which is invalid
|
||||
StringEmpty,
|
||||
|
||||
// The name is in the format "_<postfix>", which is invalid
|
||||
MissingPrefix,
|
||||
|
||||
// The name is in the format "<prefix>_", which is invalid
|
||||
MissingPostfix,
|
||||
}
|
||||
|
||||
pub type CreateDatabasesOutput = BTreeMap<String, Result<(), CreateDatabaseError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum CreateDatabaseError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
DatabaseAlreadyExists,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_create_databases_output_status(output: &CreateDatabasesOutput) {
|
||||
for (database_name, result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!("Database '{}' created successfully.", database_name);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(database_name));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl CreateDatabaseError {
|
||||
pub fn to_error_message(&self, database_name: &str) -> String {
|
||||
match self {
|
||||
CreateDatabaseError::SanitizationError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
CreateDatabaseError::OwnershipError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
CreateDatabaseError::DatabaseAlreadyExists => {
|
||||
format!("Database {} already exists.", database_name)
|
||||
}
|
||||
CreateDatabaseError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type DropDatabasesOutput = BTreeMap<String, Result<(), DropDatabaseError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum DropDatabaseError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
DatabaseDoesNotExist,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_drop_databases_output_status(output: &DropDatabasesOutput) {
|
||||
for (database_name, result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!("Database '{}' dropped successfully.", database_name);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(database_name));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl DropDatabaseError {
|
||||
pub fn to_error_message(&self, database_name: &str) -> String {
|
||||
match self {
|
||||
DropDatabaseError::SanitizationError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
DropDatabaseError::OwnershipError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
DropDatabaseError::DatabaseDoesNotExist => {
|
||||
format!("Database {} does not exist.", database_name)
|
||||
}
|
||||
DropDatabaseError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ListAllDatabasesOutput = Result<Vec<String>, ListDatabasesError>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ListDatabasesError {
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
impl ListDatabasesError {
|
||||
pub fn to_error_message(&self) -> String {
|
||||
match self {
|
||||
ListDatabasesError::MySqlError(err) => format!("MySQL error: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: merge all rows into a single collection.
|
||||
// they already contain which database they belong to.
|
||||
// no need to index by database name.
|
||||
|
||||
pub type GetDatabasesPrivilegeData =
|
||||
BTreeMap<String, Result<Vec<DatabasePrivilegeRow>, GetDatabasesPrivilegeDataError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum GetDatabasesPrivilegeDataError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
DatabaseDoesNotExist,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
impl GetDatabasesPrivilegeDataError {
|
||||
pub fn to_error_message(&self, database_name: &str) -> String {
|
||||
match self {
|
||||
GetDatabasesPrivilegeDataError::SanitizationError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
GetDatabasesPrivilegeDataError::OwnershipError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
GetDatabasesPrivilegeDataError::DatabaseDoesNotExist => {
|
||||
format!("Database '{}' does not exist.", database_name)
|
||||
}
|
||||
GetDatabasesPrivilegeDataError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type GetAllDatabasesPrivilegeData =
|
||||
Result<Vec<DatabasePrivilegeRow>, GetAllDatabasesPrivilegeDataError>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum GetAllDatabasesPrivilegeDataError {
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
impl GetAllDatabasesPrivilegeDataError {
|
||||
pub fn to_error_message(&self) -> String {
|
||||
match self {
|
||||
GetAllDatabasesPrivilegeDataError::MySqlError(err) => format!("MySQL error: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ModifyDatabasePrivilegesOutput =
|
||||
BTreeMap<(String, String), Result<(), ModifyDatabasePrivilegesError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ModifyDatabasePrivilegesError {
|
||||
DatabaseSanitizationError(NameValidationError),
|
||||
DatabaseOwnershipError(OwnerValidationError),
|
||||
UserSanitizationError(NameValidationError),
|
||||
UserOwnershipError(OwnerValidationError),
|
||||
DatabaseDoesNotExist,
|
||||
DiffDoesNotApply(DiffDoesNotApplyError),
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum DiffDoesNotApplyError {
|
||||
RowAlreadyExists(String, String),
|
||||
RowDoesNotExist(String, String),
|
||||
RowPrivilegeChangeDoesNotApply(DatabasePrivilegeRowDiff, DatabasePrivilegeRow),
|
||||
}
|
||||
|
||||
pub fn print_modify_database_privileges_output_status(output: &ModifyDatabasePrivilegesOutput) {
|
||||
for ((database_name, username), result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!(
|
||||
"Privileges for user '{}' on database '{}' modified successfully.",
|
||||
username, database_name
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(database_name, username));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl ModifyDatabasePrivilegesError {
|
||||
pub fn to_error_message(&self, database_name: &str, username: &str) -> String {
|
||||
match self {
|
||||
ModifyDatabasePrivilegesError::DatabaseSanitizationError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
ModifyDatabasePrivilegesError::DatabaseOwnershipError(err) => {
|
||||
err.to_error_message(database_name, DbOrUser::Database)
|
||||
}
|
||||
ModifyDatabasePrivilegesError::UserSanitizationError(err) => {
|
||||
err.to_error_message(username, DbOrUser::User)
|
||||
}
|
||||
ModifyDatabasePrivilegesError::UserOwnershipError(err) => {
|
||||
err.to_error_message(username, DbOrUser::User)
|
||||
}
|
||||
ModifyDatabasePrivilegesError::DatabaseDoesNotExist => {
|
||||
format!("Database '{}' does not exist.", database_name)
|
||||
}
|
||||
ModifyDatabasePrivilegesError::DiffDoesNotApply(diff) => {
|
||||
format!(
|
||||
"Could not apply privilege change:\n{}",
|
||||
diff.to_error_message()
|
||||
)
|
||||
}
|
||||
ModifyDatabasePrivilegesError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DiffDoesNotApplyError {
|
||||
pub fn to_error_message(&self) -> String {
|
||||
match self {
|
||||
DiffDoesNotApplyError::RowAlreadyExists(database_name, username) => {
|
||||
format!(
|
||||
"Privileges for user '{}' on database '{}' already exist.",
|
||||
username, database_name
|
||||
)
|
||||
}
|
||||
DiffDoesNotApplyError::RowDoesNotExist(database_name, username) => {
|
||||
format!(
|
||||
"Privileges for user '{}' on database '{}' do not exist.",
|
||||
username, database_name
|
||||
)
|
||||
}
|
||||
DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(diff, row) => {
|
||||
format!(
|
||||
"Could not apply privilege change {:?} to row {:?}",
|
||||
diff, row
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type CreateUsersOutput = BTreeMap<String, Result<(), CreateUserError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum CreateUserError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
UserAlreadyExists,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_create_users_output_status(output: &CreateUsersOutput) {
|
||||
for (username, result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!("User '{}' created successfully.", username);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(username));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl CreateUserError {
|
||||
pub fn to_error_message(&self, username: &str) -> String {
|
||||
match self {
|
||||
CreateUserError::SanitizationError(err) => {
|
||||
err.to_error_message(username, DbOrUser::User)
|
||||
}
|
||||
CreateUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
CreateUserError::UserAlreadyExists => {
|
||||
format!("User '{}' already exists.", username)
|
||||
}
|
||||
CreateUserError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type DropUsersOutput = BTreeMap<String, Result<(), DropUserError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum DropUserError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
UserDoesNotExist,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_drop_users_output_status(output: &DropUsersOutput) {
|
||||
for (username, result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!("User '{}' dropped successfully.", username);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(username));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl DropUserError {
|
||||
pub fn to_error_message(&self, username: &str) -> String {
|
||||
match self {
|
||||
DropUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
DropUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
DropUserError::UserDoesNotExist => {
|
||||
format!("User '{}' does not exist.", username)
|
||||
}
|
||||
DropUserError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SetPasswordOutput = Result<(), SetPasswordError>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum SetPasswordError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
UserDoesNotExist,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_set_password_output_status(output: &SetPasswordOutput, username: &str) {
|
||||
match output {
|
||||
Ok(()) => {
|
||||
println!("Password for user '{}' set successfully.", username);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(username));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SetPasswordError {
|
||||
pub fn to_error_message(&self, username: &str) -> String {
|
||||
match self {
|
||||
SetPasswordError::SanitizationError(err) => {
|
||||
err.to_error_message(username, DbOrUser::User)
|
||||
}
|
||||
SetPasswordError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
SetPasswordError::UserDoesNotExist => {
|
||||
format!("User '{}' does not exist.", username)
|
||||
}
|
||||
SetPasswordError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type LockUsersOutput = BTreeMap<String, Result<(), LockUserError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum LockUserError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
UserDoesNotExist,
|
||||
UserIsAlreadyLocked,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_lock_users_output_status(output: &LockUsersOutput) {
|
||||
for (username, result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!("User '{}' locked successfully.", username);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(username));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl LockUserError {
|
||||
pub fn to_error_message(&self, username: &str) -> String {
|
||||
match self {
|
||||
LockUserError::SanitizationError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
LockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
LockUserError::UserDoesNotExist => {
|
||||
format!("User '{}' does not exist.", username)
|
||||
}
|
||||
LockUserError::UserIsAlreadyLocked => {
|
||||
format!("User '{}' is already locked.", username)
|
||||
}
|
||||
LockUserError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type UnlockUsersOutput = BTreeMap<String, Result<(), UnlockUserError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum UnlockUserError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
UserDoesNotExist,
|
||||
UserIsAlreadyUnlocked,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
pub fn print_unlock_users_output_status(output: &UnlockUsersOutput) {
|
||||
for (username, result) in output {
|
||||
match result {
|
||||
Ok(()) => {
|
||||
println!("User '{}' unlocked successfully.", username);
|
||||
}
|
||||
Err(err) => {
|
||||
println!("{}", err.to_error_message(username));
|
||||
println!("Skipping...");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
impl UnlockUserError {
|
||||
pub fn to_error_message(&self, username: &str) -> String {
|
||||
match self {
|
||||
UnlockUserError::SanitizationError(err) => {
|
||||
err.to_error_message(username, DbOrUser::User)
|
||||
}
|
||||
UnlockUserError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
UnlockUserError::UserDoesNotExist => {
|
||||
format!("User '{}' does not exist.", username)
|
||||
}
|
||||
UnlockUserError::UserIsAlreadyUnlocked => {
|
||||
format!("User '{}' is already unlocked.", username)
|
||||
}
|
||||
UnlockUserError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ListUsersOutput = BTreeMap<String, Result<DatabaseUser, ListUsersError>>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ListUsersError {
|
||||
SanitizationError(NameValidationError),
|
||||
OwnershipError(OwnerValidationError),
|
||||
UserDoesNotExist,
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
impl ListUsersError {
|
||||
pub fn to_error_message(&self, username: &str) -> String {
|
||||
match self {
|
||||
ListUsersError::SanitizationError(err) => {
|
||||
err.to_error_message(username, DbOrUser::User)
|
||||
}
|
||||
ListUsersError::OwnershipError(err) => err.to_error_message(username, DbOrUser::User),
|
||||
ListUsersError::UserDoesNotExist => {
|
||||
format!("User '{}' does not exist.", username)
|
||||
}
|
||||
ListUsersError::MySqlError(err) => {
|
||||
format!("MySQL error: {}", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ListAllUsersOutput = Result<Vec<DatabaseUser>, ListAllUsersError>;
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ListAllUsersError {
|
||||
MySqlError(String),
|
||||
}
|
||||
|
||||
impl ListAllUsersError {
|
||||
pub fn to_error_message(&self) -> String {
|
||||
match self {
|
||||
ListAllUsersError::MySqlError(err) => format!("MySQL error: {}", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use nix::unistd::User;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{prelude::*, MySqlConnection};
|
||||
|
||||
use crate::core::common::{
|
||||
create_user_group_matching_regex, get_current_unix_user, quote_literal, validate_name_or_error,
|
||||
validate_ownership_or_error, DbOrUser,
|
||||
};
|
||||
|
||||
pub async fn user_exists(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
let user_exists = sqlx::query(
|
||||
r#"
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM `mysql`.`user`
|
||||
WHERE `User` = ?
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(db_user)
|
||||
.fetch_one(connection)
|
||||
.await?
|
||||
.get::<bool, _>(0);
|
||||
|
||||
Ok(user_exists)
|
||||
}
|
||||
|
||||
pub async fn create_database_user(
|
||||
db_user: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
if user_exists(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' already exists", db_user);
|
||||
}
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
||||
sqlx::query(format!("CREATE USER {}@'%'", quote_literal(db_user),).as_str())
|
||||
.execute(connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_database_user(
|
||||
db_user: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
if !user_exists(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' does not exist", db_user);
|
||||
}
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
||||
sqlx::query(format!("DROP USER {}@'%'", quote_literal(db_user),).as_str())
|
||||
.execute(connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_password_for_database_user(
|
||||
db_user: &str,
|
||||
password: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let unix_user = crate::core::common::get_current_unix_user()?;
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
if !user_exists(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' does not exist", db_user);
|
||||
}
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
||||
sqlx::query(
|
||||
format!(
|
||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
||||
quote_literal(db_user),
|
||||
quote_literal(password).as_str()
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
.execute(connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn user_is_locked(db_user: &str, connection: &mut MySqlConnection) -> anyhow::Result<bool> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
if !user_exists(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' does not exist", db_user);
|
||||
}
|
||||
|
||||
let is_locked = sqlx::query(
|
||||
r#"
|
||||
SELECT JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked") = 'true'
|
||||
FROM `mysql`.`global_priv`
|
||||
WHERE `User` = ?
|
||||
AND `Host` = '%'
|
||||
"#,
|
||||
)
|
||||
.bind(db_user)
|
||||
.fetch_one(connection)
|
||||
.await?
|
||||
.get::<bool, _>(0);
|
||||
|
||||
Ok(is_locked)
|
||||
}
|
||||
|
||||
pub async fn lock_database_user(
|
||||
db_user: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
if !user_exists(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' does not exist", db_user);
|
||||
}
|
||||
|
||||
if user_is_locked(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' is already locked", db_user);
|
||||
}
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
||||
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT LOCK", quote_literal(db_user),).as_str())
|
||||
.execute(connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn unlock_database_user(
|
||||
db_user: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<()> {
|
||||
let unix_user = get_current_unix_user()?;
|
||||
|
||||
validate_user_name(db_user, &unix_user)?;
|
||||
|
||||
if !user_exists(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' does not exist", db_user);
|
||||
}
|
||||
|
||||
if !user_is_locked(db_user, connection).await? {
|
||||
anyhow::bail!("User '{}' is already unlocked", db_user);
|
||||
}
|
||||
|
||||
// NOTE: see the note about SQL injections in `validate_ownership_of_user_name`
|
||||
sqlx::query(format!("ALTER USER {}@'%' ACCOUNT UNLOCK", quote_literal(db_user),).as_str())
|
||||
.execute(connection)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This struct contains information about a database user.
|
||||
/// This can be extended if we need more information in the future.
|
||||
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
|
||||
pub struct DatabaseUser {
|
||||
#[sqlx(rename = "User")]
|
||||
pub user: String,
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[serde(skip)]
|
||||
#[sqlx(rename = "Host")]
|
||||
pub host: String,
|
||||
|
||||
#[sqlx(rename = "has_password")]
|
||||
pub has_password: bool,
|
||||
|
||||
#[sqlx(rename = "is_locked")]
|
||||
pub is_locked: bool,
|
||||
}
|
||||
|
||||
const DB_USER_SELECT_STATEMENT: &str = r#"
|
||||
SELECT
|
||||
`mysql`.`user`.`User`,
|
||||
`mysql`.`user`.`Host`,
|
||||
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
|
||||
COALESCE(
|
||||
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
||||
'false'
|
||||
) != 'false' AS `is_locked`
|
||||
FROM `mysql`.`user`
|
||||
JOIN `mysql`.`global_priv` ON
|
||||
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
|
||||
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
|
||||
"#;
|
||||
|
||||
/// This function fetches all database users that have a prefix matching the
|
||||
/// unix username and group names of the given unix user.
|
||||
pub async fn get_all_database_users_for_unix_user(
|
||||
unix_user: &User,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<Vec<DatabaseUser>> {
|
||||
let users = sqlx::query_as::<_, DatabaseUser>(
|
||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
|
||||
)
|
||||
.bind(create_user_group_matching_regex(unix_user))
|
||||
.fetch_all(connection)
|
||||
.await?;
|
||||
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
/// This function fetches a database user if it exists.
|
||||
pub async fn get_database_user_for_user(
|
||||
username: &str,
|
||||
connection: &mut MySqlConnection,
|
||||
) -> anyhow::Result<Option<DatabaseUser>> {
|
||||
let user = sqlx::query_as::<_, DatabaseUser>(
|
||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"),
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_optional(connection)
|
||||
.await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// NOTE: It is very critical that this function validates the database name
|
||||
/// properly. MySQL does not seem to allow for prepared statements, binding
|
||||
/// the database name as a parameter to the query. This means that we have
|
||||
/// to validate the database name ourselves to prevent SQL injection.
|
||||
pub fn validate_user_name(name: &str, user: &User) -> anyhow::Result<()> {
|
||||
validate_name_or_error(name, DbOrUser::User)
|
||||
.context(format!("Invalid username: '{}'", name))?;
|
||||
validate_ownership_or_error(name, user, DbOrUser::User)
|
||||
.context(format!("Invalid username: '{}'", name))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user