Compare commits
8 Commits
5b7eafd7ca
...
d554280741
Author | SHA1 | Date |
---|---|---|
Oystein Kristoffer Tveit | d554280741 | |
Oystein Kristoffer Tveit | cd0b2c3e6d | |
Oystein Kristoffer Tveit | 93469a6e84 | |
Oystein Kristoffer Tveit | e4da639d5c | |
Oystein Kristoffer Tveit | daa8e069d3 | |
Oystein Kristoffer Tveit | 86b5b47f1e | |
Oystein Kristoffer Tveit | 9d88c95f33 | |
Oystein Kristoffer Tveit | 53f19b3d05 |
|
@ -1010,6 +1010,15 @@ version = "2.7.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.9.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "minimal-lexical"
|
name = "minimal-lexical"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -1058,6 +1067,7 @@ dependencies = [
|
||||||
"prettytable",
|
"prettytable",
|
||||||
"rand",
|
"rand",
|
||||||
"ratatui",
|
"ratatui",
|
||||||
|
"sd-notify",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
|
@ -1079,6 +1089,7 @@ dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"cfg_aliases",
|
"cfg_aliases",
|
||||||
"libc",
|
"libc",
|
||||||
|
"memoffset",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1524,6 +1535,12 @@ dependencies = [
|
||||||
"untrusted",
|
"untrusted",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sd-notify"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4646d6f919800cd25c50edb49438a1381e2cd4833c027e75e8897981c50b8b5e"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde"
|
name = "serde"
|
||||||
version = "1.0.208"
|
version = "1.0.208"
|
||||||
|
|
|
@ -16,10 +16,11 @@ futures-util = "0.3.30"
|
||||||
indoc = "2.0.5"
|
indoc = "2.0.5"
|
||||||
itertools = "0.13.0"
|
itertools = "0.13.0"
|
||||||
log = "0.4.22"
|
log = "0.4.22"
|
||||||
nix = { version = "0.29.0", features = ["fs", "process", "user"] }
|
nix = { version = "0.29.0", features = ["fs", "process", "socket", "user"] }
|
||||||
prettytable = "0.10.0"
|
prettytable = "0.10.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
ratatui = { version = "0.28.0", optional = true }
|
ratatui = { version = "0.28.0", optional = true }
|
||||||
|
sd-notify = "0.4.2"
|
||||||
serde = "1.0.208"
|
serde = "1.0.208"
|
||||||
serde_json = { version = "1.0.125", features = ["preserve_order"] }
|
serde_json = { version = "1.0.125", features = ["preserve_order"] }
|
||||||
sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] }
|
sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] }
|
||||||
|
|
|
@ -6,15 +6,10 @@ use std::os::unix::net::UnixStream as StdUnixStream;
|
||||||
use tokio::net::UnixStream as TokioUnixStream;
|
use tokio::net::UnixStream as TokioUnixStream;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
core::{
|
core::common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH},
|
||||||
bootstrap::authenticated_unix_socket::client_authenticate,
|
|
||||||
common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH},
|
|
||||||
},
|
|
||||||
server::{config::read_config_form_path, server_loop::handle_requests_for_single_session},
|
server::{config::read_config_form_path, server_loop::handle_requests_for_single_session},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod authenticated_unix_socket;
|
|
||||||
|
|
||||||
// TODO: this function is security critical, it should be integration tested
|
// TODO: this function is security critical, it should be integration tested
|
||||||
// in isolation.
|
// in isolation.
|
||||||
/// Drop privileges to the real user and group of the process.
|
/// Drop privileges to the real user and group of the process.
|
||||||
|
@ -57,25 +52,11 @@ pub fn bootstrap_server_connection_and_drop_privileges(
|
||||||
|
|
||||||
log::debug!("Starting the server connection bootstrap process");
|
log::debug!("Starting the server connection bootstrap process");
|
||||||
|
|
||||||
let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?;
|
let socket = bootstrap_server_connection(server_socket_path, config_path)?;
|
||||||
|
|
||||||
drop_privs()?;
|
drop_privs()?;
|
||||||
|
|
||||||
let result: anyhow::Result<StdUnixStream> = if do_authenticate {
|
|
||||||
tokio::runtime::Builder::new_current_thread()
|
|
||||||
.enable_all()
|
|
||||||
.build()
|
|
||||||
.unwrap()
|
|
||||||
.block_on(async {
|
|
||||||
let mut socket = TokioUnixStream::from_std(socket)?;
|
|
||||||
client_authenticate(&mut socket, None).await?;
|
|
||||||
Ok(socket.into_std()?)
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
Ok(socket)
|
Ok(socket)
|
||||||
};
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inner function for [`bootstrap_server_connection_and_drop_privileges`].
|
/// Inner function for [`bootstrap_server_connection_and_drop_privileges`].
|
||||||
|
@ -83,12 +64,12 @@ pub fn bootstrap_server_connection_and_drop_privileges(
|
||||||
fn bootstrap_server_connection(
|
fn bootstrap_server_connection(
|
||||||
socket_path: Option<PathBuf>,
|
socket_path: Option<PathBuf>,
|
||||||
config_path: Option<PathBuf>,
|
config_path: Option<PathBuf>,
|
||||||
) -> anyhow::Result<(StdUnixStream, bool)> {
|
) -> anyhow::Result<StdUnixStream> {
|
||||||
// TODO: ensure this is both readable and writable
|
// TODO: ensure this is both readable and writable
|
||||||
if let Some(socket_path) = socket_path {
|
if let Some(socket_path) = socket_path {
|
||||||
log::debug!("Connecting to socket at {:?}", socket_path);
|
log::debug!("Connecting to socket at {:?}", socket_path);
|
||||||
return match StdUnixStream::connect(socket_path) {
|
return match StdUnixStream::connect(socket_path) {
|
||||||
Ok(socket) => Ok((socket, true)),
|
Ok(socket) => Ok(socket),
|
||||||
Err(e) => match e.kind() {
|
Err(e) => match e.kind() {
|
||||||
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
|
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
|
||||||
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
|
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
|
||||||
|
@ -103,12 +84,13 @@ fn bootstrap_server_connection(
|
||||||
}
|
}
|
||||||
|
|
||||||
log::debug!("Starting server with config at {:?}", config_path);
|
log::debug!("Starting server with config at {:?}", config_path);
|
||||||
return invoke_server_with_config(config_path).map(|socket| (socket, false));
|
return invoke_server_with_config(config_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
|
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
|
||||||
|
log::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
|
||||||
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
|
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
|
||||||
Ok(socket) => Ok((socket, true)),
|
Ok(socket) => Ok(socket),
|
||||||
Err(e) => match e.kind() {
|
Err(e) => match e.kind() {
|
||||||
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
|
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
|
||||||
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
|
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
|
||||||
|
@ -119,7 +101,8 @@ fn bootstrap_server_connection(
|
||||||
|
|
||||||
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
|
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
|
||||||
if fs::metadata(&config_path).is_ok() {
|
if fs::metadata(&config_path).is_ok() {
|
||||||
return invoke_server_with_config(config_path).map(|socket| (socket, false));
|
log::debug!("Starting server with default config at {:?}", config_path);
|
||||||
|
return invoke_server_with_config(config_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::bail!("No socket path or config path provided, and no default socket or config found");
|
anyhow::bail!("No socket path or config path provided, and no default socket or config found");
|
||||||
|
|
|
@ -1,439 +0,0 @@
|
||||||
//! This module provides a way to authenticate a client uid to a server over a Unix socket.
|
|
||||||
//! This is needed so that the server can trust the client's uid, which it depends on to
|
|
||||||
//! make modifications for that user in the database. It is crucial that the server can trust
|
|
||||||
//! that the client is the unix user it claims to be.
|
|
||||||
//!
|
|
||||||
//! This works by having the client respond to a challenge on a socket that is verifiably owned
|
|
||||||
//! by the client. In more detailed steps, the following should happen:
|
|
||||||
//!
|
|
||||||
//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644
|
|
||||||
//! and owned by the uid of the current user.
|
|
||||||
//! 2. The client opens a request to the server on the "normal" socket where the server is listening,
|
|
||||||
//! In this request, the client should include the following:
|
|
||||||
//! - The address of it's authentication socket
|
|
||||||
//! - The uid of the user currently using the client
|
|
||||||
//! - A challenge string that has been randomly generated
|
|
||||||
//! 3. The server validates the following:
|
|
||||||
//! - The address of the auth socket is valid
|
|
||||||
//! - The owner of the auth socket address is the same as the uid
|
|
||||||
//! 4. Server connects to the auth socket address and receives another challenge string.
|
|
||||||
//! The server should close the connection after receiving the challenge string.
|
|
||||||
//! 5. Server verifies that the challenge is the same as the one it originally received.
|
|
||||||
//! It responds to the client with an "Authenticated" message if the challenge matches,
|
|
||||||
//! or an error message if it does not.
|
|
||||||
//! 6. Client closes the authentication socket. Normal socket is used for communication.
|
|
||||||
//!
|
|
||||||
//! Note that the server can at any point in the process send an error message to the client
|
|
||||||
//! over it's initial connection, and the client should respond by closing the authentication
|
|
||||||
//! socket, it's connection to the normal socket, and reporting the error to the user.
|
|
||||||
//!
|
|
||||||
//! Also note that it is essential that the client does not send any sensitive information
|
|
||||||
//! over it's authentication socket, since it is readable by any user on the system.
|
|
||||||
|
|
||||||
// TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream
|
|
||||||
|
|
||||||
use std::os::unix::io::AsRawFd;
|
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
|
|
||||||
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
|
|
||||||
use derive_more::derive::{Display, Error};
|
|
||||||
use futures::{SinkExt, StreamExt};
|
|
||||||
use nix::{sys::stat, unistd::Uid};
|
|
||||||
use rand::distributions::Alphanumeric;
|
|
||||||
use rand::Rng;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tokio::net::{UnixListener, UnixStream};
|
|
||||||
use tokio_util::sync::CancellationToken;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
|
||||||
pub enum ClientRequest {
|
|
||||||
Initialize {
|
|
||||||
uid: u32,
|
|
||||||
challenge: u64,
|
|
||||||
auth_socket: String,
|
|
||||||
},
|
|
||||||
Cancel,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)]
|
|
||||||
pub enum ServerResponse {
|
|
||||||
Authenticated,
|
|
||||||
ChallengeDidNotMatch,
|
|
||||||
ServerError(ServerError),
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: wrap more data into the errors
|
|
||||||
|
|
||||||
#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)]
|
|
||||||
pub enum ServerError {
|
|
||||||
InvalidRequest,
|
|
||||||
UnableToReadPermissionsFromAuthSocket,
|
|
||||||
CouldNotConnectToAuthSocket,
|
|
||||||
AuthSocketClosedEarly,
|
|
||||||
UidMismatch,
|
|
||||||
ChallengeMismatch,
|
|
||||||
InvalidChallenge,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Display, Error)]
|
|
||||||
pub enum ClientError {
|
|
||||||
UnableToConnectToServer,
|
|
||||||
UnableToOpenAuthSocket,
|
|
||||||
UnableToConfigureAuthSocket,
|
|
||||||
AuthSocketClosedEarly,
|
|
||||||
UnableToCloseAuthSocket,
|
|
||||||
AuthenticationError,
|
|
||||||
UnableToParseServerResponse,
|
|
||||||
NoServerResponse,
|
|
||||||
ServerError(ServerError),
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_auth_socket(socket_addr: &PathBuf) -> Result<UnixListener, ClientError> {
|
|
||||||
let auth_socket =
|
|
||||||
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
|
|
||||||
|
|
||||||
stat::fchmod(
|
|
||||||
auth_socket.as_raw_fd(),
|
|
||||||
stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP,
|
|
||||||
)
|
|
||||||
.map_err(|_err| ClientError::UnableToConfigureAuthSocket)?;
|
|
||||||
|
|
||||||
Ok(auth_socket)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClientToServerStream<'a> =
|
|
||||||
AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>;
|
|
||||||
type ServerToClientStream<'a> =
|
|
||||||
AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>;
|
|
||||||
|
|
||||||
// TODO: make the challenge constant size and use socket directly, this is overkill
|
|
||||||
type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>;
|
|
||||||
|
|
||||||
// TODO: add timeout
|
|
||||||
|
|
||||||
// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR
|
|
||||||
|
|
||||||
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
|
|
||||||
|
|
||||||
pub async fn client_authenticate(
|
|
||||||
normal_socket: &mut UnixStream,
|
|
||||||
auth_socket_dir: Option<PathBuf>,
|
|
||||||
) -> Result<(), ClientError> {
|
|
||||||
let random_prefix: String = rand::thread_rng()
|
|
||||||
.sample_iter(&Alphanumeric)
|
|
||||||
.take(16)
|
|
||||||
.map(char::from)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
|
|
||||||
|
|
||||||
let auth_socket_address = auth_socket_dir
|
|
||||||
.unwrap_or(std::env::temp_dir())
|
|
||||||
.join(socket_name);
|
|
||||||
|
|
||||||
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn client_authenticate_with_auth_socket_address(
|
|
||||||
normal_socket: &mut UnixStream,
|
|
||||||
auth_socket_address: &PathBuf,
|
|
||||||
) -> Result<(), ClientError> {
|
|
||||||
let auth_socket = create_auth_socket(auth_socket_address).await?;
|
|
||||||
|
|
||||||
let result =
|
|
||||||
client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await;
|
|
||||||
|
|
||||||
std::fs::remove_file(auth_socket_address)
|
|
||||||
.map_err(|_err| ClientError::UnableToCloseAuthSocket)?;
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn client_authenticate_with_auth_socket(
|
|
||||||
normal_socket: &mut UnixStream,
|
|
||||||
auth_socket: UnixListener,
|
|
||||||
auth_socket_address: &Path,
|
|
||||||
) -> Result<(), ClientError> {
|
|
||||||
let challenge = rand::random::<u64>();
|
|
||||||
let uid = nix::unistd::getuid();
|
|
||||||
|
|
||||||
let mut normal_socket: ClientToServerStream =
|
|
||||||
AsyncBincodeStream::from(normal_socket).for_async();
|
|
||||||
|
|
||||||
let challenge_replier_cancellation_token = CancellationToken::new();
|
|
||||||
let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone();
|
|
||||||
let challenge_replier_handle = tokio::spawn(async move {
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
socket = auth_socket.accept() =>
|
|
||||||
{
|
|
||||||
match socket {
|
|
||||||
Ok((mut conn, _addr)) => {
|
|
||||||
let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async();
|
|
||||||
stream.send(challenge).await.ok();
|
|
||||||
stream.close().await.ok();
|
|
||||||
}
|
|
||||||
Err(_err) => return Err(ClientError::AuthSocketClosedEarly),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = challenge_replier_cancellation_token_clone.cancelled() => {
|
|
||||||
break Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let client_hello = ClientRequest::Initialize {
|
|
||||||
uid: uid.into(),
|
|
||||||
challenge,
|
|
||||||
auth_socket: auth_socket_address
|
|
||||||
.to_str()
|
|
||||||
.ok_or(ClientError::UnableToConfigureAuthSocket)?
|
|
||||||
.to_owned(),
|
|
||||||
};
|
|
||||||
|
|
||||||
normal_socket
|
|
||||||
.send(client_hello)
|
|
||||||
.await
|
|
||||||
.map_err(|err| match *err {
|
|
||||||
bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer,
|
|
||||||
_ => ClientError::NoServerResponse,
|
|
||||||
})?;
|
|
||||||
|
|
||||||
match normal_socket.next().await {
|
|
||||||
Some(Ok(ServerResponse::Authenticated)) => {}
|
|
||||||
Some(Ok(ServerResponse::ChallengeDidNotMatch)) => {
|
|
||||||
return Err(ClientError::AuthenticationError)
|
|
||||||
}
|
|
||||||
Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)),
|
|
||||||
Some(Err(err)) => match *err {
|
|
||||||
bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse),
|
|
||||||
_ => return Err(ClientError::UnableToParseServerResponse),
|
|
||||||
},
|
|
||||||
None => return Err(ClientError::NoServerResponse),
|
|
||||||
}
|
|
||||||
|
|
||||||
challenge_replier_cancellation_token.cancel();
|
|
||||||
challenge_replier_handle.await.ok();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! report_server_error_and_return {
|
|
||||||
($normal_socket:expr, $err:expr) => {{
|
|
||||||
$normal_socket
|
|
||||||
.send(ServerResponse::ServerError($err))
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
return Err($err);
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result<Uid, ServerError> {
|
|
||||||
_server_authenticate(normal_socket, None).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn _server_authenticate(
|
|
||||||
normal_socket: &mut UnixStream,
|
|
||||||
unix_user_uid: Option<u32>,
|
|
||||||
) -> Result<Uid, ServerError> {
|
|
||||||
let mut normal_socket: ServerToClientStream =
|
|
||||||
AsyncBincodeStream::from(normal_socket).for_async();
|
|
||||||
|
|
||||||
let (uid, challenge, auth_socket) = match normal_socket.next().await {
|
|
||||||
Some(Ok(ClientRequest::Initialize {
|
|
||||||
uid,
|
|
||||||
challenge,
|
|
||||||
auth_socket,
|
|
||||||
})) => (uid, challenge, auth_socket),
|
|
||||||
// TODO: more granular errros
|
|
||||||
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
|
|
||||||
};
|
|
||||||
|
|
||||||
let auth_socket_uid = match unix_user_uid {
|
|
||||||
Some(uid) => uid,
|
|
||||||
None => match stat::stat(auth_socket.as_str()) {
|
|
||||||
Ok(stat) => stat.st_uid,
|
|
||||||
Err(_err) => report_server_error_and_return!(
|
|
||||||
normal_socket,
|
|
||||||
ServerError::UnableToReadPermissionsFromAuthSocket
|
|
||||||
),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
if uid != auth_socket_uid {
|
|
||||||
report_server_error_and_return!(normal_socket, ServerError::UidMismatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await {
|
|
||||||
Ok(socket) => socket,
|
|
||||||
Err(_err) => {
|
|
||||||
report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let mut authenticated_unix_socket: AuthStream =
|
|
||||||
AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async();
|
|
||||||
|
|
||||||
let challenge_2 = match authenticated_unix_socket.next().await {
|
|
||||||
Some(Ok(challenge)) => challenge,
|
|
||||||
Some(Err(_)) => {
|
|
||||||
report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge)
|
|
||||||
}
|
|
||||||
None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly),
|
|
||||||
};
|
|
||||||
|
|
||||||
authenticated_unix_socket.close().await.ok();
|
|
||||||
|
|
||||||
if challenge != challenge_2 {
|
|
||||||
normal_socket
|
|
||||||
.send(ServerResponse::ChallengeDidNotMatch)
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
return Err(ServerError::ChallengeMismatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
normal_socket.send(ServerResponse::Authenticated).await.ok();
|
|
||||||
|
|
||||||
Ok(Uid::from_raw(uid))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::time::sleep;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_valid_authentication() {
|
|
||||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
|
||||||
|
|
||||||
let client_handle =
|
|
||||||
tokio::spawn(async move { client_authenticate(&mut client, None).await });
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
|
||||||
|
|
||||||
client_handle.await.unwrap().unwrap();
|
|
||||||
server_handle.await.unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_ensure_auth_socket_does_not_exist() {
|
|
||||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
|
||||||
|
|
||||||
let client_handle = tokio::spawn(async move {
|
|
||||||
client_authenticate_with_auth_socket_address(
|
|
||||||
&mut client,
|
|
||||||
&PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
});
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
|
||||||
|
|
||||||
client_handle.await.unwrap().unwrap();
|
|
||||||
server_handle.await.unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_uid_mismatch() {
|
|
||||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
|
||||||
|
|
||||||
let client_handle = tokio::spawn(async move {
|
|
||||||
let err = client_authenticate(&mut client, None).await;
|
|
||||||
assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch)));
|
|
||||||
});
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move {
|
|
||||||
let uid: u32 = nix::unistd::getuid().into();
|
|
||||||
let err = _server_authenticate(&mut server, Some(uid + 1)).await;
|
|
||||||
assert_eq!(err, Err(ServerError::UidMismatch));
|
|
||||||
});
|
|
||||||
|
|
||||||
client_handle.await.unwrap();
|
|
||||||
server_handle.await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_snooping_connection() {
|
|
||||||
let (mut client, mut server) = UnixStream::pair().unwrap();
|
|
||||||
|
|
||||||
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
|
|
||||||
let socket_path_clone = socket_path.clone();
|
|
||||||
let client_handle = tokio::spawn(async move {
|
|
||||||
client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await
|
|
||||||
});
|
|
||||||
|
|
||||||
for i in 0..100 {
|
|
||||||
if socket_path.exists() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
sleep(Duration::from_millis(10)).await;
|
|
||||||
|
|
||||||
if i == 99 {
|
|
||||||
panic!("Socket not created after 1 second, assuming test failure");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
|
||||||
let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async();
|
|
||||||
let message = snooper.next().await.unwrap().unwrap();
|
|
||||||
|
|
||||||
let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
|
||||||
let mut other_snooper: AuthStream =
|
|
||||||
AsyncBincodeStream::from(&mut other_snooper).for_async();
|
|
||||||
let other_message = other_snooper.next().await.unwrap().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(message, other_message);
|
|
||||||
|
|
||||||
let third_snooper_handle = tokio::spawn(async move {
|
|
||||||
let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
|
|
||||||
let mut third_snooper: AuthStream =
|
|
||||||
AsyncBincodeStream::from(&mut third_snooper).for_async();
|
|
||||||
// NOTE: Should hang
|
|
||||||
third_snooper.send(1234).await.unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
sleep(Duration::from_millis(10)).await;
|
|
||||||
|
|
||||||
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
|
|
||||||
|
|
||||||
client_handle.await.unwrap().unwrap();
|
|
||||||
server_handle.await.unwrap().unwrap();
|
|
||||||
|
|
||||||
third_snooper_handle.abort();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_dead_server() {
|
|
||||||
let (mut client, server) = UnixStream::pair().unwrap();
|
|
||||||
std::mem::drop(server);
|
|
||||||
|
|
||||||
let client_handle = tokio::spawn(async move {
|
|
||||||
let err = client_authenticate(&mut client, None).await;
|
|
||||||
assert_eq!(err, Err(ClientError::UnableToConnectToServer));
|
|
||||||
});
|
|
||||||
|
|
||||||
client_handle.await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_no_response_from_server() {
|
|
||||||
let (mut client, server) = UnixStream::pair().unwrap();
|
|
||||||
|
|
||||||
let client_handle = tokio::spawn(async move {
|
|
||||||
let err = client_authenticate(&mut client, None).await;
|
|
||||||
assert_eq!(err, Err(ClientError::NoServerResponse));
|
|
||||||
});
|
|
||||||
|
|
||||||
sleep(Duration::from_millis(200)).await;
|
|
||||||
|
|
||||||
std::mem::drop(server);
|
|
||||||
|
|
||||||
client_handle.await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Test challenge mismatch
|
|
||||||
// TODO: Test invoking server with junk data
|
|
||||||
}
|
|
|
@ -7,7 +7,6 @@ use clap::Parser;
|
||||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||||
use tokio::net::UnixStream as TokioUnixStream;
|
use tokio::net::UnixStream as TokioUnixStream;
|
||||||
|
|
||||||
use crate::core::bootstrap::authenticated_unix_socket;
|
|
||||||
use crate::core::common::UnixUser;
|
use crate::core::common::UnixUser;
|
||||||
use crate::server::config::read_config_from_path_with_arg_overrides;
|
use crate::server::config::read_config_from_path_with_arg_overrides;
|
||||||
use crate::server::server_loop::listen_for_incoming_connections;
|
use crate::server::server_loop::listen_for_incoming_connections;
|
||||||
|
@ -41,10 +40,6 @@ pub async fn handle_command(
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
|
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
|
||||||
|
|
||||||
// if let Err(e) = &result {
|
|
||||||
// eprintln!("{}", e);
|
|
||||||
// }
|
|
||||||
|
|
||||||
match args.subcmd {
|
match args.subcmd {
|
||||||
ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
|
ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
|
||||||
ServerCommand::SocketActivate => socket_activate(config).await,
|
ServerCommand::SocketActivate => socket_activate(config).await,
|
||||||
|
@ -53,23 +48,26 @@ pub async fn handle_command(
|
||||||
|
|
||||||
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
|
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
|
||||||
// TODO: allow getting socket path from other socket activation sources
|
// TODO: allow getting socket path from other socket activation sources
|
||||||
let mut conn = get_socket_from_systemd().await?;
|
let conn = get_socket_from_systemd().await?;
|
||||||
let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?;
|
let uid = conn.peer_cred()?.uid();
|
||||||
let unix_user = UnixUser::from_uid(uid.into())?;
|
let unix_user = UnixUser::from_uid(uid)?;
|
||||||
|
|
||||||
|
log::info!("Accepted connection from {}", unix_user.username);
|
||||||
|
|
||||||
|
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
|
||||||
|
|
||||||
handle_requests_for_single_session(conn, &unix_user, &config).await?;
|
handle_requests_for_single_session(conn, &unix_user, &config).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
|
async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
|
||||||
let fd = std::env::var("LISTEN_FDS")
|
let fd = sd_notify::listen_fds()
|
||||||
.context("LISTEN_FDS not set, not running under systemd?")?
|
.context("Failed to get file descriptors from systemd")?
|
||||||
.parse::<i32>()
|
.next()
|
||||||
.context("Failed to parse LISTEN_FDS")?;
|
.context("No file descriptors received from systemd")?;
|
||||||
|
|
||||||
if fd != 1 {
|
log::debug!("Received file descriptor from systemd: {}", fd);
|
||||||
return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd));
|
|
||||||
}
|
|
||||||
|
|
||||||
let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
|
let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
|
||||||
let socket = TokioUnixStream::from_std(std_unix_stream)?;
|
let socket = TokioUnixStream::from_std(std_unix_stream)?;
|
||||||
|
|
|
@ -83,6 +83,8 @@ pub fn read_config_from_path_with_arg_overrides(
|
||||||
pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
|
pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
|
||||||
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
|
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
|
||||||
|
|
||||||
|
log::debug!("Reading config from {:?}", &config_path);
|
||||||
|
|
||||||
fs::read_to_string(&config_path)
|
fs::read_to_string(&config_path)
|
||||||
.context(format!(
|
.context(format!(
|
||||||
"Failed to read config file from {:?}",
|
"Failed to read config file from {:?}",
|
||||||
|
@ -99,6 +101,13 @@ pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<Ser
|
||||||
pub async fn create_mysql_connection_from_config(
|
pub async fn create_mysql_connection_from_config(
|
||||||
config: &MysqlConfig,
|
config: &MysqlConfig,
|
||||||
) -> anyhow::Result<MySqlConnection> {
|
) -> anyhow::Result<MySqlConnection> {
|
||||||
|
let mut display_config = config.clone();
|
||||||
|
"<REDACTED>".clone_into(&mut display_config.password);
|
||||||
|
log::debug!(
|
||||||
|
"Connecting to MySQL server with parameters: {:#?}",
|
||||||
|
display_config
|
||||||
|
);
|
||||||
|
|
||||||
match tokio::time::timeout(
|
match tokio::time::timeout(
|
||||||
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
|
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
|
||||||
MySqlConnectOptions::new()
|
MySqlConnectOptions::new()
|
||||||
|
|
|
@ -9,7 +9,6 @@ use sqlx::MySqlConnection;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
core::{
|
core::{
|
||||||
bootstrap::authenticated_unix_socket,
|
|
||||||
common::{UnixUser, DEFAULT_SOCKET_PATH},
|
common::{UnixUser, DEFAULT_SOCKET_PATH},
|
||||||
protocol::request_response::{
|
protocol::request_response::{
|
||||||
create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream,
|
create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream,
|
||||||
|
@ -44,11 +43,12 @@ pub async fn listen_for_incoming_connections(
|
||||||
|
|
||||||
let parent_directory = socket_path.parent().unwrap();
|
let parent_directory = socket_path.parent().unwrap();
|
||||||
if !parent_directory.exists() {
|
if !parent_directory.exists() {
|
||||||
println!("Creating directory {:?}", parent_directory);
|
log::debug!("Creating directory {:?}", parent_directory);
|
||||||
fs::create_dir_all(parent_directory)?;
|
fs::create_dir_all(parent_directory)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Listening on {:?}", socket_path);
|
log::info!("Listening on socket {:?}", socket_path);
|
||||||
|
|
||||||
match fs::remove_file(socket_path.as_path()) {
|
match fs::remove_file(socket_path.as_path()) {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
|
||||||
|
@ -57,16 +57,13 @@ pub async fn listen_for_incoming_connections(
|
||||||
|
|
||||||
let listener = UnixListener::bind(socket_path)?;
|
let listener = UnixListener::bind(socket_path)?;
|
||||||
|
|
||||||
|
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
|
||||||
|
|
||||||
while let Ok((mut conn, _addr)) = listener.accept().await {
|
while let Ok((mut conn, _addr)) = listener.accept().await {
|
||||||
let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await {
|
let uid = conn.peer_cred()?.uid();
|
||||||
Ok(uid) => uid,
|
log::trace!("Accepted connection from uid {}", uid);
|
||||||
Err(e) => {
|
|
||||||
eprintln!("Failed to authenticate client: {}", e);
|
let unix_user = match UnixUser::from_uid(uid) {
|
||||||
conn.shutdown().await?;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let unix_user = match UnixUser::from_uid(uid.into()) {
|
|
||||||
Ok(user) => user,
|
Ok(user) => user,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Failed to get UnixUser from uid: {}", e);
|
eprintln!("Failed to get UnixUser from uid: {}", e);
|
||||||
|
@ -74,6 +71,9 @@ pub async fn listen_for_incoming_connections(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
log::info!("Accepted connection from {}", unix_user.username);
|
||||||
|
|
||||||
match handle_requests_for_single_session(conn, &unix_user, &config).await {
|
match handle_requests_for_single_session(conn, &unix_user, &config).await {
|
||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -92,6 +92,7 @@ pub async fn handle_requests_for_single_session(
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let message_stream = create_server_to_client_message_stream(socket);
|
let message_stream = create_server_to_client_message_stream(socket);
|
||||||
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?;
|
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?;
|
||||||
|
log::debug!("Successfully connected to database");
|
||||||
|
|
||||||
let result = handle_requests_for_single_session_with_db_connection(
|
let result = handle_requests_for_single_session_with_db_connection(
|
||||||
message_stream,
|
message_stream,
|
||||||
|
@ -128,6 +129,8 @@ pub async fn handle_requests_for_single_session_with_db_connection(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
log::trace!("Received request: {:?}", request);
|
||||||
|
|
||||||
match request {
|
match request {
|
||||||
Request::CreateDatabases(databases_names) => {
|
Request::CreateDatabases(databases_names) => {
|
||||||
let result = create_databases(databases_names, unix_user, db_connection).await;
|
let result = create_databases(databases_names, unix_user, db_connection).await;
|
||||||
|
|
|
@ -26,9 +26,17 @@ pub(super) async fn unsafe_database_exists(
|
||||||
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
|
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
|
||||||
.bind(database_name)
|
.bind(database_name)
|
||||||
.fetch_optional(connection)
|
.fetch_optional(connection)
|
||||||
.await?;
|
.await;
|
||||||
|
|
||||||
Ok(result.is_some())
|
if let Err(err) = &result {
|
||||||
|
log::error!(
|
||||||
|
"Failed to check if database '{}' exists: {:?}",
|
||||||
|
&database_name,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result?.is_some())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_databases(
|
pub async fn create_databases(
|
||||||
|
@ -80,6 +88,10 @@ pub async fn create_databases(
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
|
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to create database '{}': {:?}", &database_name, err);
|
||||||
|
}
|
||||||
|
|
||||||
results.insert(database_name, result);
|
results.insert(database_name, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +147,10 @@ pub async fn drop_databases(
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
|
.map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to drop database '{}': {:?}", &database_name, err);
|
||||||
|
}
|
||||||
|
|
||||||
results.insert(database_name, result);
|
results.insert(database_name, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,7 +161,7 @@ pub async fn list_databases_for_user(
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<Vec<String>, ListDatabasesError> {
|
) -> Result<Vec<String>, ListDatabasesError> {
|
||||||
sqlx::query(
|
let result = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT `SCHEMA_NAME` AS `database`
|
SELECT `SCHEMA_NAME` AS `database`
|
||||||
FROM `information_schema`.`SCHEMATA`
|
FROM `information_schema`.`SCHEMATA`
|
||||||
|
@ -161,5 +177,15 @@ pub async fn list_databases_for_user(
|
||||||
.map(|row| row.try_get::<String, _>("database"))
|
.map(|row| row.try_get::<String, _>("database"))
|
||||||
.collect::<Result<Vec<String>, sqlx::Error>>()
|
.collect::<Result<Vec<String>, sqlx::Error>>()
|
||||||
})
|
})
|
||||||
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
|
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!(
|
||||||
|
"Failed to list databases for user '{}': {:?}",
|
||||||
|
unix_user.username,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,7 +136,7 @@ async fn unsafe_get_database_privileges(
|
||||||
database_name: &str,
|
database_name: &str,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
|
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
|
||||||
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||||
"SELECT {} FROM `db` WHERE `db` = ?",
|
"SELECT {} FROM `db` WHERE `db` = ?",
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -145,7 +145,17 @@ async fn unsafe_get_database_privileges(
|
||||||
))
|
))
|
||||||
.bind(database_name)
|
.bind(database_name)
|
||||||
.fetch_all(connection)
|
.fetch_all(connection)
|
||||||
.await
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = &result {
|
||||||
|
log::error!(
|
||||||
|
"Failed to get database privileges for '{}': {}",
|
||||||
|
&database_name,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: this function is unsafe because it does no input validation.
|
// NOTE: this function is unsafe because it does no input validation.
|
||||||
|
@ -155,7 +165,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
|
||||||
user_name: &str,
|
user_name: &str,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
|
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
|
||||||
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||||
"SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?",
|
"SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?",
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -165,7 +175,18 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
|
||||||
.bind(database_name)
|
.bind(database_name)
|
||||||
.bind(user_name)
|
.bind(user_name)
|
||||||
.fetch_optional(connection)
|
.fetch_optional(connection)
|
||||||
.await
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = &result {
|
||||||
|
log::error!(
|
||||||
|
"Failed to get database privileges for '{}.{}': {}",
|
||||||
|
&database_name,
|
||||||
|
&user_name,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_databases_privilege_data(
|
pub async fn get_databases_privilege_data(
|
||||||
|
@ -220,7 +241,7 @@ pub async fn get_all_database_privileges(
|
||||||
unix_user: &UnixUser,
|
unix_user: &UnixUser,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> GetAllDatabasesPrivilegeData {
|
) -> GetAllDatabasesPrivilegeData {
|
||||||
sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
|
||||||
indoc! {r#"
|
indoc! {r#"
|
||||||
SELECT {} FROM `db` WHERE `db` IN
|
SELECT {} FROM `db` WHERE `db` IN
|
||||||
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
(SELECT DISTINCT `SCHEMA_NAME` AS `database`
|
||||||
|
@ -236,14 +257,20 @@ pub async fn get_all_database_privileges(
|
||||||
.bind(create_user_group_matching_regex(unix_user))
|
.bind(create_user_group_matching_regex(unix_user))
|
||||||
.fetch_all(connection)
|
.fetch_all(connection)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()))
|
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()));
|
||||||
|
|
||||||
|
if let Err(e) = &result {
|
||||||
|
log::error!("Failed to get all database privileges: {:?}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn unsafe_apply_privilege_diff(
|
async fn unsafe_apply_privilege_diff(
|
||||||
database_privilege_diff: &DatabasePrivilegesDiff,
|
database_privilege_diff: &DatabasePrivilegesDiff,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<(), sqlx::Error> {
|
) -> Result<(), sqlx::Error> {
|
||||||
match database_privilege_diff {
|
let result = match database_privilege_diff {
|
||||||
DatabasePrivilegesDiff::New(p) => {
|
DatabasePrivilegesDiff::New(p) => {
|
||||||
let tables = DATABASE_PRIVILEGE_FIELDS
|
let tables = DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -305,7 +332,13 @@ async fn unsafe_apply_privilege_diff(
|
||||||
.await
|
.await
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = &result {
|
||||||
|
log::error!("Failed to apply database privilege diff: {}", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn validate_diff(
|
async fn validate_diff(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
use indoc::formatdoc;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use indoc::formatdoc;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ async fn unsafe_user_exists(
|
||||||
db_user: &str,
|
db_user: &str,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<bool, sqlx::Error> {
|
) -> Result<bool, sqlx::Error> {
|
||||||
sqlx::query(
|
let result = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT EXISTS(
|
SELECT EXISTS(
|
||||||
SELECT 1
|
SELECT 1
|
||||||
|
@ -41,7 +41,13 @@ async fn unsafe_user_exists(
|
||||||
.bind(db_user)
|
.bind(db_user)
|
||||||
.fetch_one(connection)
|
.fetch_one(connection)
|
||||||
.await
|
.await
|
||||||
.map(|row| row.get::<bool, _>(0))
|
.map(|row| row.get::<bool, _>(0));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to check if database user exists: {:?}", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_database_users(
|
pub async fn create_database_users(
|
||||||
|
@ -80,6 +86,10 @@ pub async fn create_database_users(
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| CreateUserError::MySqlError(err.to_string()));
|
.map_err(|err| CreateUserError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to create database user '{}': {:?}", &db_user, err);
|
||||||
|
}
|
||||||
|
|
||||||
results.insert(db_user, result);
|
results.insert(db_user, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,6 +132,10 @@ pub async fn drop_database_users(
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| DropUserError::MySqlError(err.to_string()));
|
.map_err(|err| DropUserError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to drop database user '{}': {:?}", &db_user, err);
|
||||||
|
}
|
||||||
|
|
||||||
results.insert(db_user, result);
|
results.insert(db_user, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,7 +162,7 @@ pub async fn set_password_for_database_user(
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlx::query(
|
let result = sqlx::query(
|
||||||
format!(
|
format!(
|
||||||
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
"ALTER USER {}@'%' IDENTIFIED BY {}",
|
||||||
quote_literal(db_user),
|
quote_literal(db_user),
|
||||||
|
@ -159,7 +173,17 @@ pub async fn set_password_for_database_user(
|
||||||
.execute(&mut *connection)
|
.execute(&mut *connection)
|
||||||
.await
|
.await
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| SetPasswordError::MySqlError(err.to_string()))
|
.map_err(|err| SetPasswordError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!(
|
||||||
|
"Failed to set password for database user '{}': {:?}",
|
||||||
|
&db_user,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: this function is unsafe because it does no input validation.
|
// NOTE: this function is unsafe because it does no input validation.
|
||||||
|
@ -167,7 +191,7 @@ async fn database_user_is_locked_unsafe(
|
||||||
db_user: &str,
|
db_user: &str,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> Result<bool, sqlx::Error> {
|
) -> Result<bool, sqlx::Error> {
|
||||||
sqlx::query(
|
let result = sqlx::query(
|
||||||
r#"
|
r#"
|
||||||
SELECT COALESCE(
|
SELECT COALESCE(
|
||||||
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
||||||
|
@ -181,7 +205,17 @@ async fn database_user_is_locked_unsafe(
|
||||||
.bind(db_user)
|
.bind(db_user)
|
||||||
.fetch_one(connection)
|
.fetch_one(connection)
|
||||||
.await
|
.await
|
||||||
.map(|row| row.get::<bool, _>(0))
|
.map(|row| row.get::<bool, _>(0));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!(
|
||||||
|
"Failed to check if database user is locked '{}': {:?}",
|
||||||
|
&db_user,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn lock_database_users(
|
pub async fn lock_database_users(
|
||||||
|
@ -234,6 +268,10 @@ pub async fn lock_database_users(
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| LockUserError::MySqlError(err.to_string()));
|
.map_err(|err| LockUserError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to lock database user '{}': {:?}", &db_user, err);
|
||||||
|
}
|
||||||
|
|
||||||
results.insert(db_user, result);
|
results.insert(db_user, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,6 +328,10 @@ pub async fn unlock_database_users(
|
||||||
.map(|_| ())
|
.map(|_| ())
|
||||||
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
|
.map_err(|err| UnlockUserError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to unlock database user '{}': {:?}", &db_user, err);
|
||||||
|
}
|
||||||
|
|
||||||
results.insert(db_user, result);
|
results.insert(db_user, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,39 +340,55 @@ pub async fn unlock_database_users(
|
||||||
|
|
||||||
/// This struct contains information about a database user.
|
/// This struct contains information about a database user.
|
||||||
/// This can be extended if we need more information in the future.
|
/// This can be extended if we need more information in the future.
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)]
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
pub struct DatabaseUser {
|
pub struct DatabaseUser {
|
||||||
#[sqlx(rename = "User")]
|
|
||||||
pub user: String,
|
pub user: String,
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
#[sqlx(rename = "Host")]
|
|
||||||
pub host: String,
|
pub host: String,
|
||||||
|
|
||||||
#[sqlx(rename = "has_password")]
|
|
||||||
pub has_password: bool,
|
pub has_password: bool,
|
||||||
|
|
||||||
#[sqlx(rename = "is_locked")]
|
|
||||||
pub is_locked: bool,
|
pub is_locked: bool,
|
||||||
|
|
||||||
#[sqlx(skip)]
|
|
||||||
pub databases: Vec<String>,
|
pub databases: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Some mysql versions with some collations mark some columns as binary fields,
|
||||||
|
/// which in the current version of sqlx is not parsable as string.
|
||||||
|
/// See: https://github.com/launchbadge/sqlx/issues/3387
|
||||||
|
#[inline]
|
||||||
|
fn try_get_with_binary_fallback(
|
||||||
|
row: &sqlx::mysql::MySqlRow,
|
||||||
|
column: &str,
|
||||||
|
) -> Result<String, sqlx::Error> {
|
||||||
|
row.try_get(column).or_else(|_| {
|
||||||
|
row.try_get::<Vec<u8>, _>(column)
|
||||||
|
.map(|v| String::from_utf8_lossy(&v).to_string())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
|
||||||
|
fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
|
||||||
|
Ok(Self {
|
||||||
|
user: try_get_with_binary_fallback(row, "User")?,
|
||||||
|
host: try_get_with_binary_fallback(row, "Host")?,
|
||||||
|
has_password: row.try_get("has_password")?,
|
||||||
|
is_locked: row.try_get("is_locked")?,
|
||||||
|
databases: Vec::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const DB_USER_SELECT_STATEMENT: &str = r#"
|
const DB_USER_SELECT_STATEMENT: &str = r#"
|
||||||
SELECT
|
SELECT
|
||||||
`mysql`.`user`.`User`,
|
`user`.`User`,
|
||||||
`mysql`.`user`.`Host`,
|
`user`.`Host`,
|
||||||
`mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
|
`user`.`Password` != '' OR `user`.`authentication_string` != '' AS `has_password`,
|
||||||
COALESCE(
|
COALESCE(
|
||||||
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
|
JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"),
|
||||||
'false'
|
'false'
|
||||||
) != 'false' AS `is_locked`
|
) != 'false' AS `is_locked`
|
||||||
FROM `mysql`.`user`
|
FROM `user`
|
||||||
JOIN `mysql`.`global_priv` ON
|
JOIN `global_priv` ON
|
||||||
`mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
|
`user`.`User` = `global_priv`.`User`
|
||||||
AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
|
AND `user`.`Host` = `global_priv`.`Host`
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
pub async fn list_database_users(
|
pub async fn list_database_users(
|
||||||
|
@ -358,6 +416,10 @@ pub async fn list_database_users(
|
||||||
.fetch_optional(&mut *connection)
|
.fetch_optional(&mut *connection)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to list database user '{}': {:?}", &db_user, err);
|
||||||
|
}
|
||||||
|
|
||||||
if let Ok(Some(user)) = result.as_mut() {
|
if let Ok(Some(user)) = result.as_mut() {
|
||||||
append_databases_where_user_has_privileges(user, &mut *connection).await;
|
append_databases_where_user_has_privileges(user, &mut *connection).await;
|
||||||
}
|
}
|
||||||
|
@ -377,13 +439,17 @@ pub async fn list_all_database_users_for_unix_user(
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) -> ListAllUsersOutput {
|
) -> ListAllUsersOutput {
|
||||||
let mut result = sqlx::query_as::<_, DatabaseUser>(
|
let mut result = sqlx::query_as::<_, DatabaseUser>(
|
||||||
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
|
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `user`.`User` REGEXP ?"),
|
||||||
)
|
)
|
||||||
.bind(create_user_group_matching_regex(unix_user))
|
.bind(create_user_group_matching_regex(unix_user))
|
||||||
.fetch_all(&mut *connection)
|
.fetch_all(&mut *connection)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
|
.map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
|
||||||
|
|
||||||
|
if let Err(err) = &result {
|
||||||
|
log::error!("Failed to list all database users: {:?}", err);
|
||||||
|
}
|
||||||
|
|
||||||
if let Ok(users) = result.as_mut() {
|
if let Ok(users) = result.as_mut() {
|
||||||
for user in users {
|
for user in users {
|
||||||
append_databases_where_user_has_privileges(user, &mut *connection).await;
|
append_databases_where_user_has_privileges(user, &mut *connection).await;
|
||||||
|
@ -394,15 +460,15 @@ pub async fn list_all_database_users_for_unix_user(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn append_databases_where_user_has_privileges(
|
pub async fn append_databases_where_user_has_privileges(
|
||||||
database_user: &mut DatabaseUser,
|
db_user: &mut DatabaseUser,
|
||||||
connection: &mut MySqlConnection,
|
connection: &mut MySqlConnection,
|
||||||
) {
|
) {
|
||||||
let database_list = sqlx::query(
|
let database_list = sqlx::query(
|
||||||
formatdoc!(
|
formatdoc!(
|
||||||
r#"
|
r#"
|
||||||
SELECT `db` AS `database`
|
SELECT `Db` AS `database`
|
||||||
FROM `db`
|
FROM `db`
|
||||||
WHERE `user` = ? AND ({})
|
WHERE `User` = ? AND ({})
|
||||||
"#,
|
"#,
|
||||||
DATABASE_PRIVILEGE_FIELDS
|
DATABASE_PRIVILEGE_FIELDS
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -411,14 +477,22 @@ pub async fn append_databases_where_user_has_privileges(
|
||||||
)
|
)
|
||||||
.as_str(),
|
.as_str(),
|
||||||
)
|
)
|
||||||
.bind(database_user.user.clone())
|
.bind(db_user.user.clone())
|
||||||
.fetch_all(&mut *connection)
|
.fetch_all(&mut *connection)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
database_user.databases = database_list
|
if let Err(err) = &database_list {
|
||||||
|
log::error!(
|
||||||
|
"Failed to list databases for user '{}': {:?}",
|
||||||
|
&db_user.user,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
db_user.databases = database_list
|
||||||
.map(|rows| {
|
.map(|rows| {
|
||||||
rows.into_iter()
|
rows.into_iter()
|
||||||
.map(|row| row.get::<String, _>("database"))
|
.map(|row| try_get_with_binary_fallback(&row, "database").unwrap())
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
Loading…
Reference in New Issue