Compare commits

..

1 Commits

Author SHA1 Message Date
Oystein Kristoffer Tveit 5b7eafd7ca
WIP 2024-08-19 00:19:17 +02:00
10 changed files with 538 additions and 243 deletions

17
Cargo.lock generated
View File

@ -1010,15 +1010,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@ -1067,7 +1058,6 @@ dependencies = [
"prettytable", "prettytable",
"rand", "rand",
"ratatui", "ratatui",
"sd-notify",
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
@ -1089,7 +1079,6 @@ dependencies = [
"cfg-if", "cfg-if",
"cfg_aliases", "cfg_aliases",
"libc", "libc",
"memoffset",
] ]
[[package]] [[package]]
@ -1535,12 +1524,6 @@ dependencies = [
"untrusted", "untrusted",
] ]
[[package]]
name = "sd-notify"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4646d6f919800cd25c50edb49438a1381e2cd4833c027e75e8897981c50b8b5e"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.208" version = "1.0.208"

View File

@ -16,11 +16,10 @@ futures-util = "0.3.30"
indoc = "2.0.5" indoc = "2.0.5"
itertools = "0.13.0" itertools = "0.13.0"
log = "0.4.22" log = "0.4.22"
nix = { version = "0.29.0", features = ["fs", "process", "socket", "user"] } nix = { version = "0.29.0", features = ["fs", "process", "user"] }
prettytable = "0.10.0" prettytable = "0.10.0"
rand = "0.8.5" rand = "0.8.5"
ratatui = { version = "0.28.0", optional = true } ratatui = { version = "0.28.0", optional = true }
sd-notify = "0.4.2"
serde = "1.0.208" serde = "1.0.208"
serde_json = { version = "1.0.125", features = ["preserve_order"] } serde_json = { version = "1.0.125", features = ["preserve_order"] }
sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] } sqlx = { version = "0.8.0", features = ["runtime-tokio", "mysql", "tls-rustls"] }

View File

@ -6,10 +6,15 @@ use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream; use tokio::net::UnixStream as TokioUnixStream;
use crate::{ use crate::{
core::common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH}, core::{
bootstrap::authenticated_unix_socket::client_authenticate,
common::{UnixUser, DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH},
},
server::{config::read_config_form_path, server_loop::handle_requests_for_single_session}, server::{config::read_config_form_path, server_loop::handle_requests_for_single_session},
}; };
pub mod authenticated_unix_socket;
// TODO: this function is security critical, it should be integration tested // TODO: this function is security critical, it should be integration tested
// in isolation. // in isolation.
/// Drop privileges to the real user and group of the process. /// Drop privileges to the real user and group of the process.
@ -52,11 +57,25 @@ pub fn bootstrap_server_connection_and_drop_privileges(
log::debug!("Starting the server connection bootstrap process"); log::debug!("Starting the server connection bootstrap process");
let socket = bootstrap_server_connection(server_socket_path, config_path)?; let (socket, do_authenticate) = bootstrap_server_connection(server_socket_path, config_path)?;
drop_privs()?; drop_privs()?;
Ok(socket) let result: anyhow::Result<StdUnixStream> = if do_authenticate {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
let mut socket = TokioUnixStream::from_std(socket)?;
client_authenticate(&mut socket, None).await?;
Ok(socket.into_std()?)
})
} else {
Ok(socket)
};
result
} }
/// Inner function for [`bootstrap_server_connection_and_drop_privileges`]. /// Inner function for [`bootstrap_server_connection_and_drop_privileges`].
@ -64,12 +83,12 @@ pub fn bootstrap_server_connection_and_drop_privileges(
fn bootstrap_server_connection( fn bootstrap_server_connection(
socket_path: Option<PathBuf>, socket_path: Option<PathBuf>,
config_path: Option<PathBuf>, config_path: Option<PathBuf>,
) -> anyhow::Result<StdUnixStream> { ) -> anyhow::Result<(StdUnixStream, bool)> {
// TODO: ensure this is both readable and writable // TODO: ensure this is both readable and writable
if let Some(socket_path) = socket_path { if let Some(socket_path) = socket_path {
log::debug!("Connecting to socket at {:?}", socket_path); log::debug!("Connecting to socket at {:?}", socket_path);
return match StdUnixStream::connect(socket_path) { return match StdUnixStream::connect(socket_path) {
Ok(socket) => Ok(socket), Ok(socket) => Ok((socket, true)),
Err(e) => match e.kind() { Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
@ -84,13 +103,12 @@ fn bootstrap_server_connection(
} }
log::debug!("Starting server with config at {:?}", config_path); log::debug!("Starting server with config at {:?}", config_path);
return invoke_server_with_config(config_path); return invoke_server_with_config(config_path).map(|socket| (socket, false));
} }
if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() { if fs::metadata(DEFAULT_SOCKET_PATH).is_ok() {
log::debug!("Connecting to default socket at {:?}", DEFAULT_SOCKET_PATH);
return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) { return match StdUnixStream::connect(DEFAULT_SOCKET_PATH) {
Ok(socket) => Ok(socket), Ok(socket) => Ok((socket, true)),
Err(e) => match e.kind() { Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")), std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")), std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
@ -101,8 +119,7 @@ fn bootstrap_server_connection(
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH); let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
if fs::metadata(&config_path).is_ok() { if fs::metadata(&config_path).is_ok() {
log::debug!("Starting server with default config at {:?}", config_path); return invoke_server_with_config(config_path).map(|socket| (socket, false));
return invoke_server_with_config(config_path);
} }
anyhow::bail!("No socket path or config path provided, and no default socket or config found"); anyhow::bail!("No socket path or config path provided, and no default socket or config found");

View File

@ -0,0 +1,439 @@
//! This module provides a way to authenticate a client uid to a server over a Unix socket.
//! This is needed so that the server can trust the client's uid, which it depends on to
//! make modifications for that user in the database. It is crucial that the server can trust
//! that the client is the unix user it claims to be.
//!
//! This works by having the client respond to a challenge on a socket that is verifiably owned
//! by the client. In more detailed steps, the following should happen:
//!
//! 1. Before initializing it's request, the client should open an "authentication" socket with permissions 644
//! and owned by the uid of the current user.
//! 2. The client opens a request to the server on the "normal" socket where the server is listening,
//! In this request, the client should include the following:
//! - The address of it's authentication socket
//! - The uid of the user currently using the client
//! - A challenge string that has been randomly generated
//! 3. The server validates the following:
//! - The address of the auth socket is valid
//! - The owner of the auth socket address is the same as the uid
//! 4. Server connects to the auth socket address and receives another challenge string.
//! The server should close the connection after receiving the challenge string.
//! 5. Server verifies that the challenge is the same as the one it originally received.
//! It responds to the client with an "Authenticated" message if the challenge matches,
//! or an error message if it does not.
//! 6. Client closes the authentication socket. Normal socket is used for communication.
//!
//! Note that the server can at any point in the process send an error message to the client
//! over it's initial connection, and the client should respond by closing the authentication
//! socket, it's connection to the normal socket, and reporting the error to the user.
//!
//! Also note that it is essential that the client does not send any sensitive information
//! over it's authentication socket, since it is readable by any user on the system.
// TODO: rewrite this so that it can be used with a normal std::os::unix::net::UnixStream
use std::os::unix::io::AsRawFd;
use std::path::{Path, PathBuf};
use async_bincode::{tokio::AsyncBincodeStream, AsyncDestination};
use derive_more::derive::{Display, Error};
use futures::{SinkExt, StreamExt};
use nix::{sys::stat, unistd::Uid};
use rand::distributions::Alphanumeric;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tokio::net::{UnixListener, UnixStream};
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ClientRequest {
Initialize {
uid: u32,
challenge: u64,
auth_socket: String,
},
Cancel,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Display, Error)]
pub enum ServerResponse {
Authenticated,
ChallengeDidNotMatch,
ServerError(ServerError),
}
// TODO: wrap more data into the errors
#[derive(Debug, Display, PartialEq, Serialize, Deserialize, Clone, Error)]
pub enum ServerError {
InvalidRequest,
UnableToReadPermissionsFromAuthSocket,
CouldNotConnectToAuthSocket,
AuthSocketClosedEarly,
UidMismatch,
ChallengeMismatch,
InvalidChallenge,
}
#[derive(Debug, PartialEq, Display, Error)]
pub enum ClientError {
UnableToConnectToServer,
UnableToOpenAuthSocket,
UnableToConfigureAuthSocket,
AuthSocketClosedEarly,
UnableToCloseAuthSocket,
AuthenticationError,
UnableToParseServerResponse,
NoServerResponse,
ServerError(ServerError),
}
async fn create_auth_socket(socket_addr: &PathBuf) -> Result<UnixListener, ClientError> {
let auth_socket =
UnixListener::bind(socket_addr).map_err(|_err| ClientError::UnableToOpenAuthSocket)?;
stat::fchmod(
auth_socket.as_raw_fd(),
stat::Mode::S_IRUSR | stat::Mode::S_IWUSR | stat::Mode::S_IRGRP,
)
.map_err(|_err| ClientError::UnableToConfigureAuthSocket)?;
Ok(auth_socket)
}
type ClientToServerStream<'a> =
AsyncBincodeStream<&'a mut UnixStream, ServerResponse, ClientRequest, AsyncDestination>;
type ServerToClientStream<'a> =
AsyncBincodeStream<&'a mut UnixStream, ClientRequest, ServerResponse, AsyncDestination>;
// TODO: make the challenge constant size and use socket directly, this is overkill
type AuthStream<'a> = AsyncBincodeStream<&'a mut UnixStream, u64, u64, AsyncDestination>;
// TODO: add timeout
// TODO: respect $XDG_RUNTIME_DIR and $TMPDIR
const AUTH_SOCKET_NAME: &str = "mysqladm-rs-cli-auth.sock";
pub async fn client_authenticate(
normal_socket: &mut UnixStream,
auth_socket_dir: Option<PathBuf>,
) -> Result<(), ClientError> {
let random_prefix: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
let socket_name = format!("{}-{}", random_prefix, AUTH_SOCKET_NAME);
let auth_socket_address = auth_socket_dir
.unwrap_or(std::env::temp_dir())
.join(socket_name);
client_authenticate_with_auth_socket_address(normal_socket, &auth_socket_address).await
}
async fn client_authenticate_with_auth_socket_address(
normal_socket: &mut UnixStream,
auth_socket_address: &PathBuf,
) -> Result<(), ClientError> {
let auth_socket = create_auth_socket(auth_socket_address).await?;
let result =
client_authenticate_with_auth_socket(normal_socket, auth_socket, auth_socket_address).await;
std::fs::remove_file(auth_socket_address)
.map_err(|_err| ClientError::UnableToCloseAuthSocket)?;
result
}
async fn client_authenticate_with_auth_socket(
normal_socket: &mut UnixStream,
auth_socket: UnixListener,
auth_socket_address: &Path,
) -> Result<(), ClientError> {
let challenge = rand::random::<u64>();
let uid = nix::unistd::getuid();
let mut normal_socket: ClientToServerStream =
AsyncBincodeStream::from(normal_socket).for_async();
let challenge_replier_cancellation_token = CancellationToken::new();
let challenge_replier_cancellation_token_clone = challenge_replier_cancellation_token.clone();
let challenge_replier_handle = tokio::spawn(async move {
loop {
tokio::select! {
socket = auth_socket.accept() =>
{
match socket {
Ok((mut conn, _addr)) => {
let mut stream: AuthStream = AsyncBincodeStream::from(&mut conn).for_async();
stream.send(challenge).await.ok();
stream.close().await.ok();
}
Err(_err) => return Err(ClientError::AuthSocketClosedEarly),
}
}
_ = challenge_replier_cancellation_token_clone.cancelled() => {
break Ok(());
}
}
}
});
let client_hello = ClientRequest::Initialize {
uid: uid.into(),
challenge,
auth_socket: auth_socket_address
.to_str()
.ok_or(ClientError::UnableToConfigureAuthSocket)?
.to_owned(),
};
normal_socket
.send(client_hello)
.await
.map_err(|err| match *err {
bincode::ErrorKind::Io(_err) => ClientError::UnableToConnectToServer,
_ => ClientError::NoServerResponse,
})?;
match normal_socket.next().await {
Some(Ok(ServerResponse::Authenticated)) => {}
Some(Ok(ServerResponse::ChallengeDidNotMatch)) => {
return Err(ClientError::AuthenticationError)
}
Some(Ok(ServerResponse::ServerError(err))) => return Err(ClientError::ServerError(err)),
Some(Err(err)) => match *err {
bincode::ErrorKind::Io(_err) => return Err(ClientError::NoServerResponse),
_ => return Err(ClientError::UnableToParseServerResponse),
},
None => return Err(ClientError::NoServerResponse),
}
challenge_replier_cancellation_token.cancel();
challenge_replier_handle.await.ok();
Ok(())
}
macro_rules! report_server_error_and_return {
($normal_socket:expr, $err:expr) => {{
$normal_socket
.send(ServerResponse::ServerError($err))
.await
.ok();
return Err($err);
}};
}
pub async fn server_authenticate(normal_socket: &mut UnixStream) -> Result<Uid, ServerError> {
_server_authenticate(normal_socket, None).await
}
pub async fn _server_authenticate(
normal_socket: &mut UnixStream,
unix_user_uid: Option<u32>,
) -> Result<Uid, ServerError> {
let mut normal_socket: ServerToClientStream =
AsyncBincodeStream::from(normal_socket).for_async();
let (uid, challenge, auth_socket) = match normal_socket.next().await {
Some(Ok(ClientRequest::Initialize {
uid,
challenge,
auth_socket,
})) => (uid, challenge, auth_socket),
// TODO: more granular errros
_ => report_server_error_and_return!(normal_socket, ServerError::InvalidRequest),
};
let auth_socket_uid = match unix_user_uid {
Some(uid) => uid,
None => match stat::stat(auth_socket.as_str()) {
Ok(stat) => stat.st_uid,
Err(_err) => report_server_error_and_return!(
normal_socket,
ServerError::UnableToReadPermissionsFromAuthSocket
),
},
};
if uid != auth_socket_uid {
report_server_error_and_return!(normal_socket, ServerError::UidMismatch);
}
let mut authenticated_unix_socket = match UnixStream::connect(auth_socket).await {
Ok(socket) => socket,
Err(_err) => {
report_server_error_and_return!(normal_socket, ServerError::CouldNotConnectToAuthSocket)
}
};
let mut authenticated_unix_socket: AuthStream =
AsyncBincodeStream::from(&mut authenticated_unix_socket).for_async();
let challenge_2 = match authenticated_unix_socket.next().await {
Some(Ok(challenge)) => challenge,
Some(Err(_)) => {
report_server_error_and_return!(normal_socket, ServerError::InvalidChallenge)
}
None => report_server_error_and_return!(normal_socket, ServerError::AuthSocketClosedEarly),
};
authenticated_unix_socket.close().await.ok();
if challenge != challenge_2 {
normal_socket
.send(ServerResponse::ChallengeDidNotMatch)
.await
.ok();
return Err(ServerError::ChallengeMismatch);
}
normal_socket.send(ServerResponse::Authenticated).await.ok();
Ok(Uid::from_raw(uid))
}
#[cfg(test)]
mod test {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
#[tokio::test]
async fn test_valid_authentication() {
let (mut client, mut server) = UnixStream::pair().unwrap();
let client_handle =
tokio::spawn(async move { client_authenticate(&mut client, None).await });
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
client_handle.await.unwrap().unwrap();
server_handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_ensure_auth_socket_does_not_exist() {
let (mut client, mut server) = UnixStream::pair().unwrap();
let client_handle = tokio::spawn(async move {
client_authenticate_with_auth_socket_address(
&mut client,
&PathBuf::from("/tmp/test_auth_socket_does_not_exist.sock"),
)
.await
});
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
client_handle.await.unwrap().unwrap();
server_handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_uid_mismatch() {
let (mut client, mut server) = UnixStream::pair().unwrap();
let client_handle = tokio::spawn(async move {
let err = client_authenticate(&mut client, None).await;
assert_eq!(err, Err(ClientError::ServerError(ServerError::UidMismatch)));
});
let server_handle = tokio::spawn(async move {
let uid: u32 = nix::unistd::getuid().into();
let err = _server_authenticate(&mut server, Some(uid + 1)).await;
assert_eq!(err, Err(ServerError::UidMismatch));
});
client_handle.await.unwrap();
server_handle.await.unwrap();
}
#[tokio::test]
async fn test_snooping_connection() {
let (mut client, mut server) = UnixStream::pair().unwrap();
let socket_path = std::env::temp_dir().join("socket_to_snoop.sock");
let socket_path_clone = socket_path.clone();
let client_handle = tokio::spawn(async move {
client_authenticate_with_auth_socket_address(&mut client, &socket_path_clone).await
});
for i in 0..100 {
if socket_path.exists() {
break;
}
sleep(Duration::from_millis(10)).await;
if i == 99 {
panic!("Socket not created after 1 second, assuming test failure");
}
}
let mut snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
let mut snooper: AuthStream = AsyncBincodeStream::from(&mut snooper).for_async();
let message = snooper.next().await.unwrap().unwrap();
let mut other_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
let mut other_snooper: AuthStream =
AsyncBincodeStream::from(&mut other_snooper).for_async();
let other_message = other_snooper.next().await.unwrap().unwrap();
assert_eq!(message, other_message);
let third_snooper_handle = tokio::spawn(async move {
let mut third_snooper = UnixStream::connect(socket_path.clone()).await.unwrap();
let mut third_snooper: AuthStream =
AsyncBincodeStream::from(&mut third_snooper).for_async();
// NOTE: Should hang
third_snooper.send(1234).await.unwrap()
});
sleep(Duration::from_millis(10)).await;
let server_handle = tokio::spawn(async move { server_authenticate(&mut server).await });
client_handle.await.unwrap().unwrap();
server_handle.await.unwrap().unwrap();
third_snooper_handle.abort();
}
#[tokio::test]
async fn test_dead_server() {
let (mut client, server) = UnixStream::pair().unwrap();
std::mem::drop(server);
let client_handle = tokio::spawn(async move {
let err = client_authenticate(&mut client, None).await;
assert_eq!(err, Err(ClientError::UnableToConnectToServer));
});
client_handle.await.unwrap();
}
#[tokio::test]
async fn test_no_response_from_server() {
let (mut client, server) = UnixStream::pair().unwrap();
let client_handle = tokio::spawn(async move {
let err = client_authenticate(&mut client, None).await;
assert_eq!(err, Err(ClientError::NoServerResponse));
});
sleep(Duration::from_millis(200)).await;
std::mem::drop(server);
client_handle.await.unwrap();
}
// TODO: Test challenge mismatch
// TODO: Test invoking server with junk data
}

View File

@ -7,6 +7,7 @@ use clap::Parser;
use std::os::unix::net::UnixStream as StdUnixStream; use std::os::unix::net::UnixStream as StdUnixStream;
use tokio::net::UnixStream as TokioUnixStream; use tokio::net::UnixStream as TokioUnixStream;
use crate::core::bootstrap::authenticated_unix_socket;
use crate::core::common::UnixUser; use crate::core::common::UnixUser;
use crate::server::config::read_config_from_path_with_arg_overrides; use crate::server::config::read_config_from_path_with_arg_overrides;
use crate::server::server_loop::listen_for_incoming_connections; use crate::server::server_loop::listen_for_incoming_connections;
@ -40,6 +41,10 @@ pub async fn handle_command(
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?; let config = read_config_from_path_with_arg_overrides(config_path, args.config_overrides)?;
// if let Err(e) = &result {
// eprintln!("{}", e);
// }
match args.subcmd { match args.subcmd {
ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await, ServerCommand::Listen => listen_for_incoming_connections(socket_path, config).await,
ServerCommand::SocketActivate => socket_activate(config).await, ServerCommand::SocketActivate => socket_activate(config).await,
@ -48,26 +53,23 @@ pub async fn handle_command(
async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> { async fn socket_activate(config: ServerConfig) -> anyhow::Result<()> {
// TODO: allow getting socket path from other socket activation sources // TODO: allow getting socket path from other socket activation sources
let conn = get_socket_from_systemd().await?; let mut conn = get_socket_from_systemd().await?;
let uid = conn.peer_cred()?.uid(); let uid = authenticated_unix_socket::server_authenticate(&mut conn).await?;
let unix_user = UnixUser::from_uid(uid)?; let unix_user = UnixUser::from_uid(uid.into())?;
log::info!("Accepted connection from {}", unix_user.username);
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
handle_requests_for_single_session(conn, &unix_user, &config).await?; handle_requests_for_single_session(conn, &unix_user, &config).await?;
Ok(()) Ok(())
} }
async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> { async fn get_socket_from_systemd() -> anyhow::Result<TokioUnixStream> {
let fd = sd_notify::listen_fds() let fd = std::env::var("LISTEN_FDS")
.context("Failed to get file descriptors from systemd")? .context("LISTEN_FDS not set, not running under systemd?")?
.next() .parse::<i32>()
.context("No file descriptors received from systemd")?; .context("Failed to parse LISTEN_FDS")?;
log::debug!("Received file descriptor from systemd: {}", fd); if fd != 1 {
return Err(anyhow::anyhow!("Unexpected LISTEN_FDS value: {}", fd));
}
let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) }; let std_unix_stream = unsafe { StdUnixStream::from_raw_fd(fd) };
let socket = TokioUnixStream::from_std(std_unix_stream)?; let socket = TokioUnixStream::from_std(std_unix_stream)?;

View File

@ -83,8 +83,6 @@ pub fn read_config_from_path_with_arg_overrides(
pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> { pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<ServerConfig> {
let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH)); let config_path = config_path.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_PATH));
log::debug!("Reading config from {:?}", &config_path);
fs::read_to_string(&config_path) fs::read_to_string(&config_path)
.context(format!( .context(format!(
"Failed to read config file from {:?}", "Failed to read config file from {:?}",
@ -101,13 +99,6 @@ pub fn read_config_form_path(config_path: Option<PathBuf>) -> anyhow::Result<Ser
pub async fn create_mysql_connection_from_config( pub async fn create_mysql_connection_from_config(
config: &MysqlConfig, config: &MysqlConfig,
) -> anyhow::Result<MySqlConnection> { ) -> anyhow::Result<MySqlConnection> {
let mut display_config = config.clone();
"<REDACTED>".clone_into(&mut display_config.password);
log::debug!(
"Connecting to MySQL server with parameters: {:#?}",
display_config
);
match tokio::time::timeout( match tokio::time::timeout(
Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)), Duration::from_secs(config.timeout.unwrap_or(DEFAULT_TIMEOUT)),
MySqlConnectOptions::new() MySqlConnectOptions::new()

View File

@ -9,6 +9,7 @@ use sqlx::MySqlConnection;
use crate::{ use crate::{
core::{ core::{
bootstrap::authenticated_unix_socket,
common::{UnixUser, DEFAULT_SOCKET_PATH}, common::{UnixUser, DEFAULT_SOCKET_PATH},
protocol::request_response::{ protocol::request_response::{
create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream, Request, Response, ServerToClientMessageStream,
@ -43,12 +44,11 @@ pub async fn listen_for_incoming_connections(
let parent_directory = socket_path.parent().unwrap(); let parent_directory = socket_path.parent().unwrap();
if !parent_directory.exists() { if !parent_directory.exists() {
log::debug!("Creating directory {:?}", parent_directory); println!("Creating directory {:?}", parent_directory);
fs::create_dir_all(parent_directory)?; fs::create_dir_all(parent_directory)?;
} }
log::info!("Listening on socket {:?}", socket_path); println!("Listening on {:?}", socket_path);
match fs::remove_file(socket_path.as_path()) { match fs::remove_file(socket_path.as_path()) {
Ok(_) => {} Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
@ -57,13 +57,16 @@ pub async fn listen_for_incoming_connections(
let listener = UnixListener::bind(socket_path)?; let listener = UnixListener::bind(socket_path)?;
sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).ok();
while let Ok((mut conn, _addr)) = listener.accept().await { while let Ok((mut conn, _addr)) = listener.accept().await {
let uid = conn.peer_cred()?.uid(); let uid = match authenticated_unix_socket::server_authenticate(&mut conn).await {
log::trace!("Accepted connection from uid {}", uid); Ok(uid) => uid,
Err(e) => {
let unix_user = match UnixUser::from_uid(uid) { eprintln!("Failed to authenticate client: {}", e);
conn.shutdown().await?;
continue;
}
};
let unix_user = match UnixUser::from_uid(uid.into()) {
Ok(user) => user, Ok(user) => user,
Err(e) => { Err(e) => {
eprintln!("Failed to get UnixUser from uid: {}", e); eprintln!("Failed to get UnixUser from uid: {}", e);
@ -71,9 +74,6 @@ pub async fn listen_for_incoming_connections(
continue; continue;
} }
}; };
log::info!("Accepted connection from {}", unix_user.username);
match handle_requests_for_single_session(conn, &unix_user, &config).await { match handle_requests_for_single_session(conn, &unix_user, &config).await {
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
@ -92,7 +92,6 @@ pub async fn handle_requests_for_single_session(
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let message_stream = create_server_to_client_message_stream(socket); let message_stream = create_server_to_client_message_stream(socket);
let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?; let mut db_connection = create_mysql_connection_from_config(&config.mysql).await?;
log::debug!("Successfully connected to database");
let result = handle_requests_for_single_session_with_db_connection( let result = handle_requests_for_single_session_with_db_connection(
message_stream, message_stream,
@ -129,8 +128,6 @@ pub async fn handle_requests_for_single_session_with_db_connection(
} }
}; };
log::trace!("Received request: {:?}", request);
match request { match request {
Request::CreateDatabases(databases_names) => { Request::CreateDatabases(databases_names) => {
let result = create_databases(databases_names, unix_user, db_connection).await; let result = create_databases(databases_names, unix_user, db_connection).await;

View File

@ -26,17 +26,9 @@ pub(super) async fn unsafe_database_exists(
sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?") sqlx::query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?")
.bind(database_name) .bind(database_name)
.fetch_optional(connection) .fetch_optional(connection)
.await; .await?;
if let Err(err) = &result { Ok(result.is_some())
log::error!(
"Failed to check if database '{}' exists: {:?}",
&database_name,
err
);
}
Ok(result?.is_some())
} }
pub async fn create_databases( pub async fn create_databases(
@ -88,10 +80,6 @@ pub async fn create_databases(
.map(|_| ()) .map(|_| ())
.map_err(|err| CreateDatabaseError::MySqlError(err.to_string())); .map_err(|err| CreateDatabaseError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to create database '{}': {:?}", &database_name, err);
}
results.insert(database_name, result); results.insert(database_name, result);
} }
@ -147,10 +135,6 @@ pub async fn drop_databases(
.map(|_| ()) .map(|_| ())
.map_err(|err| DropDatabaseError::MySqlError(err.to_string())); .map_err(|err| DropDatabaseError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to drop database '{}': {:?}", &database_name, err);
}
results.insert(database_name, result); results.insert(database_name, result);
} }
@ -161,7 +145,7 @@ pub async fn list_databases_for_user(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Vec<String>, ListDatabasesError> { ) -> Result<Vec<String>, ListDatabasesError> {
let result = sqlx::query( sqlx::query(
r#" r#"
SELECT `SCHEMA_NAME` AS `database` SELECT `SCHEMA_NAME` AS `database`
FROM `information_schema`.`SCHEMATA` FROM `information_schema`.`SCHEMATA`
@ -177,15 +161,5 @@ pub async fn list_databases_for_user(
.map(|row| row.try_get::<String, _>("database")) .map(|row| row.try_get::<String, _>("database"))
.collect::<Result<Vec<String>, sqlx::Error>>() .collect::<Result<Vec<String>, sqlx::Error>>()
}) })
.map_err(|err| ListDatabasesError::MySqlError(err.to_string())); .map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
if let Err(err) = &result {
log::error!(
"Failed to list databases for user '{}': {:?}",
unix_user.username,
err
);
}
result
} }

View File

@ -136,7 +136,7 @@ async fn unsafe_get_database_privileges(
database_name: &str, database_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Vec<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ?", "SELECT {} FROM `db` WHERE `db` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
@ -145,17 +145,7 @@ async fn unsafe_get_database_privileges(
)) ))
.bind(database_name) .bind(database_name)
.fetch_all(connection) .fetch_all(connection)
.await; .await
if let Err(e) = &result {
log::error!(
"Failed to get database privileges for '{}': {}",
&database_name,
e
);
}
result
} }
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
@ -165,7 +155,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
user_name: &str, user_name: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> { ) -> Result<Option<DatabasePrivilegeRow>, sqlx::Error> {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
"SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?", "SELECT {} FROM `db` WHERE `db` = ? AND `user` = ?",
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
@ -175,18 +165,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
.bind(database_name) .bind(database_name)
.bind(user_name) .bind(user_name)
.fetch_optional(connection) .fetch_optional(connection)
.await; .await
if let Err(e) = &result {
log::error!(
"Failed to get database privileges for '{}.{}': {}",
&database_name,
&user_name,
e
);
}
result
} }
pub async fn get_databases_privilege_data( pub async fn get_databases_privilege_data(
@ -241,7 +220,7 @@ pub async fn get_all_database_privileges(
unix_user: &UnixUser, unix_user: &UnixUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> GetAllDatabasesPrivilegeData { ) -> GetAllDatabasesPrivilegeData {
let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( sqlx::query_as::<_, DatabasePrivilegeRow>(&format!(
indoc! {r#" indoc! {r#"
SELECT {} FROM `db` WHERE `db` IN SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT `SCHEMA_NAME` AS `database` (SELECT DISTINCT `SCHEMA_NAME` AS `database`
@ -257,20 +236,14 @@ pub async fn get_all_database_privileges(
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user))
.fetch_all(connection) .fetch_all(connection)
.await .await
.map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string())); .map_err(|e| GetAllDatabasesPrivilegeDataError::MySqlError(e.to_string()))
if let Err(e) = &result {
log::error!("Failed to get all database privileges: {:?}", e);
}
result
} }
async fn unsafe_apply_privilege_diff( async fn unsafe_apply_privilege_diff(
database_privilege_diff: &DatabasePrivilegesDiff, database_privilege_diff: &DatabasePrivilegesDiff,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<(), sqlx::Error> { ) -> Result<(), sqlx::Error> {
let result = match database_privilege_diff { match database_privilege_diff {
DatabasePrivilegesDiff::New(p) => { DatabasePrivilegesDiff::New(p) => {
let tables = DATABASE_PRIVILEGE_FIELDS let tables = DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
@ -332,13 +305,7 @@ async fn unsafe_apply_privilege_diff(
.await .await
.map(|_| ()) .map(|_| ())
} }
};
if let Err(e) = &result {
log::error!("Failed to apply database privilege diff: {}", e);
} }
result
} }
async fn validate_diff( async fn validate_diff(

View File

@ -1,6 +1,6 @@
use indoc::formatdoc;
use itertools::Itertools; use itertools::Itertools;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use indoc::formatdoc;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -29,7 +29,7 @@ async fn unsafe_user_exists(
db_user: &str, db_user: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> { ) -> Result<bool, sqlx::Error> {
let result = sqlx::query( sqlx::query(
r#" r#"
SELECT EXISTS( SELECT EXISTS(
SELECT 1 SELECT 1
@ -41,13 +41,7 @@ async fn unsafe_user_exists(
.bind(db_user) .bind(db_user)
.fetch_one(connection) .fetch_one(connection)
.await .await
.map(|row| row.get::<bool, _>(0)); .map(|row| row.get::<bool, _>(0))
if let Err(err) = &result {
log::error!("Failed to check if database user exists: {:?}", err);
}
result
} }
pub async fn create_database_users( pub async fn create_database_users(
@ -86,10 +80,6 @@ pub async fn create_database_users(
.map(|_| ()) .map(|_| ())
.map_err(|err| CreateUserError::MySqlError(err.to_string())); .map_err(|err| CreateUserError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to create database user '{}': {:?}", &db_user, err);
}
results.insert(db_user, result); results.insert(db_user, result);
} }
@ -132,10 +122,6 @@ pub async fn drop_database_users(
.map(|_| ()) .map(|_| ())
.map_err(|err| DropUserError::MySqlError(err.to_string())); .map_err(|err| DropUserError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to drop database user '{}': {:?}", &db_user, err);
}
results.insert(db_user, result); results.insert(db_user, result);
} }
@ -162,7 +148,7 @@ pub async fn set_password_for_database_user(
_ => {} _ => {}
} }
let result = sqlx::query( sqlx::query(
format!( format!(
"ALTER USER {}@'%' IDENTIFIED BY {}", "ALTER USER {}@'%' IDENTIFIED BY {}",
quote_literal(db_user), quote_literal(db_user),
@ -173,17 +159,7 @@ pub async fn set_password_for_database_user(
.execute(&mut *connection) .execute(&mut *connection)
.await .await
.map(|_| ()) .map(|_| ())
.map_err(|err| SetPasswordError::MySqlError(err.to_string())); .map_err(|err| SetPasswordError::MySqlError(err.to_string()))
if let Err(err) = &result {
log::error!(
"Failed to set password for database user '{}': {:?}",
&db_user,
err
);
}
result
} }
// NOTE: this function is unsafe because it does no input validation. // NOTE: this function is unsafe because it does no input validation.
@ -191,7 +167,7 @@ async fn database_user_is_locked_unsafe(
db_user: &str, db_user: &str,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> { ) -> Result<bool, sqlx::Error> {
let result = sqlx::query( sqlx::query(
r#" r#"
SELECT COALESCE( SELECT COALESCE(
JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
@ -205,17 +181,7 @@ async fn database_user_is_locked_unsafe(
.bind(db_user) .bind(db_user)
.fetch_one(connection) .fetch_one(connection)
.await .await
.map(|row| row.get::<bool, _>(0)); .map(|row| row.get::<bool, _>(0))
if let Err(err) = &result {
log::error!(
"Failed to check if database user is locked '{}': {:?}",
&db_user,
err
);
}
result
} }
pub async fn lock_database_users( pub async fn lock_database_users(
@ -268,10 +234,6 @@ pub async fn lock_database_users(
.map(|_| ()) .map(|_| ())
.map_err(|err| LockUserError::MySqlError(err.to_string())); .map_err(|err| LockUserError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to lock database user '{}': {:?}", &db_user, err);
}
results.insert(db_user, result); results.insert(db_user, result);
} }
@ -328,10 +290,6 @@ pub async fn unlock_database_users(
.map(|_| ()) .map(|_| ())
.map_err(|err| UnlockUserError::MySqlError(err.to_string())); .map_err(|err| UnlockUserError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to unlock database user '{}': {:?}", &db_user, err);
}
results.insert(db_user, result); results.insert(db_user, result);
} }
@ -340,55 +298,39 @@ pub async fn unlock_database_users(
/// This struct contains information about a database user. /// This struct contains information about a database user.
/// This can be extended if we need more information in the future. /// This can be extended if we need more information in the future.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, FromRow)]
pub struct DatabaseUser { pub struct DatabaseUser {
#[sqlx(rename = "User")]
pub user: String, pub user: String,
#[allow(dead_code)]
#[serde(skip)] #[serde(skip)]
#[sqlx(rename = "Host")]
pub host: String, pub host: String,
#[sqlx(rename = "has_password")]
pub has_password: bool, pub has_password: bool,
#[sqlx(rename = "is_locked")]
pub is_locked: bool, pub is_locked: bool,
#[sqlx(skip)]
pub databases: Vec<String>, pub databases: Vec<String>,
} }
/// Some mysql versions with some collations mark some columns as binary fields,
/// which in the current version of sqlx is not parsable as string.
/// See: https://github.com/launchbadge/sqlx/issues/3387
#[inline]
fn try_get_with_binary_fallback(
row: &sqlx::mysql::MySqlRow,
column: &str,
) -> Result<String, sqlx::Error> {
row.try_get(column).or_else(|_| {
row.try_get::<Vec<u8>, _>(column)
.map(|v| String::from_utf8_lossy(&v).to_string())
})
}
impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser {
fn from_row(row: &sqlx::mysql::MySqlRow) -> Result<Self, sqlx::Error> {
Ok(Self {
user: try_get_with_binary_fallback(row, "User")?,
host: try_get_with_binary_fallback(row, "Host")?,
has_password: row.try_get("has_password")?,
is_locked: row.try_get("is_locked")?,
databases: Vec::new(),
})
}
}
const DB_USER_SELECT_STATEMENT: &str = r#" const DB_USER_SELECT_STATEMENT: &str = r#"
SELECT SELECT
`user`.`User`, `mysql`.`user`.`User`,
`user`.`Host`, `mysql`.`user`.`Host`,
`user`.`Password` != '' OR `user`.`authentication_string` != '' AS `has_password`, `mysql`.`user`.`Password` != '' OR `mysql`.`user`.`authentication_string` != '' AS `has_password`,
COALESCE( COALESCE(
JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"), JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"),
'false' 'false'
) != 'false' AS `is_locked` ) != 'false' AS `is_locked`
FROM `user` FROM `mysql`.`user`
JOIN `global_priv` ON JOIN `mysql`.`global_priv` ON
`user`.`User` = `global_priv`.`User` `mysql`.`user`.`User` = `mysql`.`global_priv`.`User`
AND `user`.`Host` = `global_priv`.`Host` AND `mysql`.`user`.`Host` = `mysql`.`global_priv`.`Host`
"#; "#;
pub async fn list_database_users( pub async fn list_database_users(
@ -416,10 +358,6 @@ pub async fn list_database_users(
.fetch_optional(&mut *connection) .fetch_optional(&mut *connection)
.await; .await;
if let Err(err) = &result {
log::error!("Failed to list database user '{}': {:?}", &db_user, err);
}
if let Ok(Some(user)) = result.as_mut() { if let Ok(Some(user)) = result.as_mut() {
append_databases_where_user_has_privileges(user, &mut *connection).await; append_databases_where_user_has_privileges(user, &mut *connection).await;
} }
@ -439,17 +377,13 @@ pub async fn list_all_database_users_for_unix_user(
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) -> ListAllUsersOutput { ) -> ListAllUsersOutput {
let mut result = sqlx::query_as::<_, DatabaseUser>( let mut result = sqlx::query_as::<_, DatabaseUser>(
&(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `user`.`User` REGEXP ?"), &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` REGEXP ?"),
) )
.bind(create_user_group_matching_regex(unix_user)) .bind(create_user_group_matching_regex(unix_user))
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await .await
.map_err(|err| ListAllUsersError::MySqlError(err.to_string())); .map_err(|err| ListAllUsersError::MySqlError(err.to_string()));
if let Err(err) = &result {
log::error!("Failed to list all database users: {:?}", err);
}
if let Ok(users) = result.as_mut() { if let Ok(users) = result.as_mut() {
for user in users { for user in users {
append_databases_where_user_has_privileges(user, &mut *connection).await; append_databases_where_user_has_privileges(user, &mut *connection).await;
@ -460,15 +394,15 @@ pub async fn list_all_database_users_for_unix_user(
} }
pub async fn append_databases_where_user_has_privileges( pub async fn append_databases_where_user_has_privileges(
db_user: &mut DatabaseUser, database_user: &mut DatabaseUser,
connection: &mut MySqlConnection, connection: &mut MySqlConnection,
) { ) {
let database_list = sqlx::query( let database_list = sqlx::query(
formatdoc!( formatdoc!(
r#" r#"
SELECT `Db` AS `database` SELECT `db` AS `database`
FROM `db` FROM `db`
WHERE `User` = ? AND ({}) WHERE `user` = ? AND ({})
"#, "#,
DATABASE_PRIVILEGE_FIELDS DATABASE_PRIVILEGE_FIELDS
.iter() .iter()
@ -477,22 +411,14 @@ pub async fn append_databases_where_user_has_privileges(
) )
.as_str(), .as_str(),
) )
.bind(db_user.user.clone()) .bind(database_user.user.clone())
.fetch_all(&mut *connection) .fetch_all(&mut *connection)
.await; .await;
if let Err(err) = &database_list { database_user.databases = database_list
log::error!(
"Failed to list databases for user '{}': {:?}",
&db_user.user,
err
);
}
db_user.databases = database_list
.map(|rows| { .map(|rows| {
rows.into_iter() rows.into_iter()
.map(|row| try_get_with_binary_fallback(&row, "database").unwrap()) .map(|row| row.get::<String, _>("database"))
.collect() .collect()
}) })
.unwrap_or_default(); .unwrap_or_default();