server: log once per request, add session ids
This commit is contained in:
@@ -22,7 +22,7 @@ use crate::{
|
||||
authorization::read_and_parse_group_denylist,
|
||||
config::{MysqlConfig, ServerConfig},
|
||||
landlock::landlock_restrict_server,
|
||||
session_handler,
|
||||
session_handler::{self, SessionId},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -308,9 +308,11 @@ fn run_forked_server(
|
||||
version_row.to_lowercase().contains("mariadb")
|
||||
};
|
||||
|
||||
let session_id = SessionId::new(0);
|
||||
let db_pool = Arc::new(RwLock::new(db_pool));
|
||||
session_handler::session_handler_with_unix_user(
|
||||
socket,
|
||||
session_id,
|
||||
unix_user,
|
||||
db_pool,
|
||||
db_is_mariadb,
|
||||
|
||||
@@ -24,6 +24,7 @@ pub const KIND_REGARDS: &str = concat!(
|
||||
"If you experience any bugs or turbulence, please give us a heads up :)",
|
||||
);
|
||||
|
||||
/// TODO: store and display UID
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnixUser {
|
||||
pub username: String,
|
||||
|
||||
@@ -17,8 +17,6 @@ mod modify_privileges;
|
||||
mod passwd_user;
|
||||
mod unlock_users;
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
pub use check_authorization::*;
|
||||
pub use complete_database_name::*;
|
||||
pub use complete_user_name::*;
|
||||
@@ -38,6 +36,9 @@ pub use modify_privileges::*;
|
||||
pub use passwd_user::*;
|
||||
pub use unlock_users::*;
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::UnixStream;
|
||||
use tokio_serde::{Framed as SerdeFramed, formats::Bincode};
|
||||
@@ -109,6 +110,7 @@ pub enum Request {
|
||||
}
|
||||
|
||||
impl Request {
|
||||
/// Get the command name associated with this request.
|
||||
pub fn command_name(&self) -> &str {
|
||||
match self {
|
||||
Request::CheckAuthorization(_) => "check-authorization",
|
||||
@@ -130,6 +132,43 @@ impl Request {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a short summary string representing this request for logging purposes.
|
||||
pub fn log_summary(&self) -> String {
|
||||
match self {
|
||||
Request::CheckAuthorization(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
|
||||
Request::CreateDatabases(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::DropDatabases(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::ListDatabases(req) => format!(
|
||||
"{}{}",
|
||||
self.command_name(),
|
||||
req.as_ref()
|
||||
.map_or("".to_string(), |r| format!("({})", r.len()))
|
||||
),
|
||||
Request::ListPrivileges(req) => format!(
|
||||
"{}{}",
|
||||
self.command_name(),
|
||||
req.as_ref()
|
||||
.map_or("".to_string(), |r| format!("({})", r.len()))
|
||||
),
|
||||
Request::ModifyPrivileges(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
|
||||
Request::CreateUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::DropUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::ListUsers(req) => format!(
|
||||
"{}{}",
|
||||
self.command_name(),
|
||||
req.as_ref()
|
||||
.map_or("".to_string(), |r| format!("({})", r.len()))
|
||||
),
|
||||
Request::LockUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
Request::UnlockUsers(req) => format!("{}({})", self.command_name(), req.len()),
|
||||
|
||||
_ => self.command_name().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the set of users affected by this request.
|
||||
pub fn affected_users(&self) -> BTreeSet<MySQLUser> {
|
||||
match self {
|
||||
Request::CheckAuthorization(_) => Default::default(),
|
||||
@@ -158,6 +197,7 @@ impl Request {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the set of databases affected by this request.
|
||||
pub fn affected_databases(&self) -> BTreeSet<MySQLDatabase> {
|
||||
match self {
|
||||
Request::CheckAuthorization(_) => Default::default(),
|
||||
@@ -219,3 +259,95 @@ pub enum Response {
|
||||
Ready,
|
||||
Error(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ResponseOkStatus {
|
||||
Success,
|
||||
PartialSuccess(usize, usize), // succeeded, total
|
||||
Error,
|
||||
}
|
||||
|
||||
impl ResponseOkStatus {
|
||||
pub fn from_counts(total: usize, succeeded: usize) -> Self {
|
||||
if succeeded == total {
|
||||
ResponseOkStatus::Success
|
||||
} else if succeeded == 0 {
|
||||
ResponseOkStatus::Error
|
||||
} else {
|
||||
ResponseOkStatus::PartialSuccess(succeeded, total)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bool(is_ok: bool) -> Self {
|
||||
if is_ok {
|
||||
ResponseOkStatus::Success
|
||||
} else {
|
||||
ResponseOkStatus::Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ResponseOkStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
ResponseOkStatus::Success => write!(f, "OK"),
|
||||
ResponseOkStatus::PartialSuccess(succeeded, total) => {
|
||||
write!(f, "PARTIAL_OK({}/{})", succeeded, total)
|
||||
}
|
||||
ResponseOkStatus::Error => write!(f, "ERR"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn ok_status(&self) -> ResponseOkStatus {
|
||||
match self {
|
||||
Response::CheckAuthorization(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
|
||||
Response::ListValidNamePrefixes(_) => ResponseOkStatus::Success,
|
||||
Response::CompleteDatabaseName(_) => ResponseOkStatus::Success,
|
||||
Response::CompleteUserName(_) => ResponseOkStatus::Success,
|
||||
|
||||
Response::CreateDatabases(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::DropDatabases(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListDatabases(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListAllDatabases(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::ListPrivileges(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListAllPrivileges(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::ModifyPrivileges(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
|
||||
Response::CreateUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::DropUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::SetUserPassword(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::ListUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::ListAllUsers(res) => ResponseOkStatus::from_bool(res.is_ok()),
|
||||
Response::LockUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
Response::UnlockUsers(res) => {
|
||||
ResponseOkStatus::from_counts(res.len(), res.values().filter(|v| v.is_ok()).count())
|
||||
}
|
||||
|
||||
Response::Ready => ResponseOkStatus::Success,
|
||||
Response::Error(_) => ResponseOkStatus::Error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,23 +13,19 @@ use crate::core::{
|
||||
};
|
||||
|
||||
pub async fn check_authorization(
|
||||
dbs_or_users: Vec<DbOrUser>,
|
||||
dbs_or_users: &[DbOrUser],
|
||||
unix_user: &UnixUser,
|
||||
group_denylist: &GroupDenylist,
|
||||
) -> std::collections::BTreeMap<DbOrUser, Result<(), CheckAuthorizationError>> {
|
||||
let mut results = std::collections::BTreeMap::new();
|
||||
|
||||
for db_or_user in dbs_or_users {
|
||||
if let Err(err) = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
|
||||
.map_err(CheckAuthorizationError)
|
||||
{
|
||||
results.insert(db_or_user.clone(), Err(err));
|
||||
continue;
|
||||
}
|
||||
results.insert(db_or_user.clone(), Ok(()));
|
||||
}
|
||||
|
||||
results
|
||||
dbs_or_users
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|db_or_user| {
|
||||
let result = validate_db_or_user_request(&db_or_user, unix_user, group_denylist)
|
||||
.map_err(CheckAuthorizationError);
|
||||
(db_or_user, result)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Reads and parses a group denylist file, returning a set of GUIDs
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::BTreeSet, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use indoc::concatdoc;
|
||||
@@ -35,10 +35,24 @@ use crate::{
|
||||
},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct SessionId(u64);
|
||||
|
||||
impl SessionId {
|
||||
pub fn new(id: u64) -> Self {
|
||||
SessionId(id)
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> u64 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: don't use database connection unless necessary.
|
||||
|
||||
pub async fn session_handler(
|
||||
socket: UnixStream,
|
||||
session_id: SessionId,
|
||||
db_pool: Arc<RwLock<MySqlPool>>,
|
||||
db_is_mariadb: bool,
|
||||
group_denylist: &GroupDenylist,
|
||||
@@ -83,13 +97,18 @@ pub async fn session_handler(
|
||||
}
|
||||
};
|
||||
|
||||
let span = tracing::info_span!("user_session", user = %unix_user);
|
||||
let span = tracing::info_span!(
|
||||
"user_session",
|
||||
session_id = session_id.inner(),
|
||||
user = %unix_user,
|
||||
);
|
||||
|
||||
(async move {
|
||||
tracing::info!("Accepted connection from user: {}", unix_user);
|
||||
tracing::debug!("Accepted connection from user: {}", unix_user);
|
||||
|
||||
let result = session_handler_with_unix_user(
|
||||
socket,
|
||||
session_id,
|
||||
&unix_user,
|
||||
db_pool,
|
||||
db_is_mariadb,
|
||||
@@ -97,7 +116,7 @@ pub async fn session_handler(
|
||||
)
|
||||
.await;
|
||||
|
||||
tracing::info!(
|
||||
tracing::debug!(
|
||||
"Finished handling requests for connection from user: {}",
|
||||
unix_user,
|
||||
);
|
||||
@@ -110,6 +129,7 @@ pub async fn session_handler(
|
||||
|
||||
pub async fn session_handler_with_unix_user(
|
||||
socket: UnixStream,
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
db_pool: Arc<RwLock<MySqlPool>>,
|
||||
db_is_mariadb: bool,
|
||||
@@ -138,6 +158,7 @@ pub async fn session_handler_with_unix_user(
|
||||
|
||||
let result = session_handler_with_db_connection(
|
||||
message_stream,
|
||||
session_id,
|
||||
unix_user,
|
||||
&mut db_connection,
|
||||
db_is_mariadb,
|
||||
@@ -155,6 +176,7 @@ pub async fn session_handler_with_unix_user(
|
||||
|
||||
async fn session_handler_with_db_connection(
|
||||
mut stream: ServerToClientMessageStream,
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
db_connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -178,6 +200,7 @@ async fn session_handler_with_db_connection(
|
||||
|
||||
if !handle_request(
|
||||
request,
|
||||
session_id,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
@@ -199,6 +222,7 @@ async fn session_handler_with_db_connection(
|
||||
/// If the function returns `true`, the session should continue.
|
||||
async fn handle_request(
|
||||
request: Request,
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
db_connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -207,11 +231,11 @@ async fn handle_request(
|
||||
) -> anyhow::Result<bool> {
|
||||
match &request {
|
||||
Request::Exit => tracing::debug!("Received request: {:#?}", request),
|
||||
Request::PasswdUser((db_user, _)) => tracing::info!(
|
||||
Request::PasswdUser((db_user, _)) => tracing::debug!(
|
||||
"Received request: {:#?}",
|
||||
Request::PasswdUser((db_user.to_owned(), "<REDACTED>".to_string()))
|
||||
),
|
||||
request => tracing::info!("Received request: {:#?}", request),
|
||||
request => tracing::debug!("Request:\n{}", serde_json::to_string_pretty(request)?),
|
||||
}
|
||||
|
||||
let affected_dbs = request.affected_databases();
|
||||
@@ -231,7 +255,7 @@ async fn handle_request(
|
||||
}
|
||||
|
||||
let response = match request {
|
||||
Request::CheckAuthorization(dbs_or_users) => {
|
||||
Request::CheckAuthorization(ref dbs_or_users) => {
|
||||
let result = check_authorization(dbs_or_users, unix_user, group_denylist).await;
|
||||
Response::CheckAuthorization(result)
|
||||
}
|
||||
@@ -245,7 +269,7 @@ async fn handle_request(
|
||||
|
||||
Response::ListValidNamePrefixes(result)
|
||||
}
|
||||
Request::CompleteDatabaseName(partial_database_name) => {
|
||||
Request::CompleteDatabaseName(ref partial_database_name) => {
|
||||
// TODO: more correct validation here
|
||||
if partial_database_name
|
||||
.chars()
|
||||
@@ -264,7 +288,7 @@ async fn handle_request(
|
||||
Response::CompleteDatabaseName(vec![])
|
||||
}
|
||||
}
|
||||
Request::CompleteUserName(partial_user_name) => {
|
||||
Request::CompleteUserName(ref partial_user_name) => {
|
||||
// TODO: more correct validation here
|
||||
if partial_user_name
|
||||
.chars()
|
||||
@@ -283,7 +307,7 @@ async fn handle_request(
|
||||
Response::CompleteUserName(vec![])
|
||||
}
|
||||
}
|
||||
Request::CreateDatabases(databases_names) => {
|
||||
Request::CreateDatabases(ref databases_names) => {
|
||||
let result = create_databases(
|
||||
databases_names,
|
||||
unix_user,
|
||||
@@ -294,7 +318,7 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::CreateDatabases(result)
|
||||
}
|
||||
Request::DropDatabases(databases_names) => {
|
||||
Request::DropDatabases(ref databases_names) => {
|
||||
let result = drop_databases(
|
||||
databases_names,
|
||||
unix_user,
|
||||
@@ -305,7 +329,7 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::DropDatabases(result)
|
||||
}
|
||||
Request::ListDatabases(database_names) => {
|
||||
Request::ListDatabases(ref database_names) => {
|
||||
if let Some(database_names) = database_names {
|
||||
let result = list_databases(
|
||||
database_names,
|
||||
@@ -327,7 +351,7 @@ async fn handle_request(
|
||||
Response::ListAllDatabases(result)
|
||||
}
|
||||
}
|
||||
Request::ListPrivileges(database_names) => {
|
||||
Request::ListPrivileges(ref database_names) => {
|
||||
if let Some(database_names) = database_names {
|
||||
let privilege_data = get_databases_privilege_data(
|
||||
database_names,
|
||||
@@ -349,9 +373,9 @@ async fn handle_request(
|
||||
Response::ListAllPrivileges(privilege_data)
|
||||
}
|
||||
}
|
||||
Request::ModifyPrivileges(database_privilege_diffs) => {
|
||||
Request::ModifyPrivileges(ref database_privilege_diffs) => {
|
||||
let result = apply_privilege_diffs(
|
||||
BTreeSet::from_iter(database_privilege_diffs),
|
||||
database_privilege_diffs,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
@@ -360,7 +384,7 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::ModifyPrivileges(result)
|
||||
}
|
||||
Request::CreateUsers(db_users) => {
|
||||
Request::CreateUsers(ref db_users) => {
|
||||
let result = create_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
@@ -371,7 +395,7 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::CreateUsers(result)
|
||||
}
|
||||
Request::DropUsers(db_users) => {
|
||||
Request::DropUsers(ref db_users) => {
|
||||
let result = drop_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
@@ -382,10 +406,10 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::DropUsers(result)
|
||||
}
|
||||
Request::PasswdUser((db_user, password)) => {
|
||||
Request::PasswdUser((ref db_user, ref password)) => {
|
||||
let result = set_password_for_database_user(
|
||||
&db_user,
|
||||
&password,
|
||||
db_user,
|
||||
password,
|
||||
unix_user,
|
||||
db_connection,
|
||||
db_is_mariadb,
|
||||
@@ -394,7 +418,7 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::SetUserPassword(result)
|
||||
}
|
||||
Request::ListUsers(db_users) => {
|
||||
Request::ListUsers(ref db_users) => {
|
||||
if let Some(db_users) = db_users {
|
||||
let result = list_database_users(
|
||||
db_users,
|
||||
@@ -416,7 +440,7 @@ async fn handle_request(
|
||||
Response::ListAllUsers(result)
|
||||
}
|
||||
}
|
||||
Request::LockUsers(db_users) => {
|
||||
Request::LockUsers(ref db_users) => {
|
||||
let result = lock_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
@@ -427,7 +451,7 @@ async fn handle_request(
|
||||
.await;
|
||||
Response::LockUsers(result)
|
||||
}
|
||||
Request::UnlockUsers(db_users) => {
|
||||
Request::UnlockUsers(ref db_users) => {
|
||||
let result = unlock_database_users(
|
||||
db_users,
|
||||
unix_user,
|
||||
@@ -449,7 +473,12 @@ async fn handle_request(
|
||||
}
|
||||
response => response,
|
||||
};
|
||||
tracing::debug!("Response: {:#?}", response_to_display);
|
||||
tracing::debug!(
|
||||
"Response:\n{}",
|
||||
serde_json::to_string_pretty(&response_to_display)?
|
||||
);
|
||||
|
||||
log_request(session_id, unix_user, &request, &response);
|
||||
|
||||
stream.send(response).await?;
|
||||
stream.flush().await?;
|
||||
@@ -457,3 +486,18 @@ async fn handle_request(
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Log a summary of the request and its result.
|
||||
fn log_request(
|
||||
session_id: SessionId,
|
||||
unix_user: &UnixUser,
|
||||
request: &Request,
|
||||
response: &Response,
|
||||
) {
|
||||
tracing::info!(
|
||||
"[{}|session:{}|user:{unix_user}] {}",
|
||||
response.ok_status(),
|
||||
session_id.inner(),
|
||||
request.log_summary(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ pub(super) async fn unsafe_database_exists(
|
||||
}
|
||||
|
||||
pub async fn complete_database_name(
|
||||
database_prefix: String,
|
||||
database_prefix: &str,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -87,7 +87,7 @@ pub async fn complete_database_name(
|
||||
}
|
||||
|
||||
pub async fn create_databases(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -95,7 +95,7 @@ pub async fn create_databases(
|
||||
) -> CreateDatabasesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
unix_user,
|
||||
@@ -143,7 +143,7 @@ pub async fn create_databases(
|
||||
}
|
||||
|
||||
pub async fn drop_databases(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -151,7 +151,7 @@ pub async fn drop_databases(
|
||||
) -> DropDatabasesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
unix_user,
|
||||
@@ -242,7 +242,7 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
|
||||
}
|
||||
|
||||
pub async fn list_databases(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -250,7 +250,7 @@ pub async fn list_databases(
|
||||
) -> ListDatabasesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
unix_user,
|
||||
|
||||
@@ -138,7 +138,7 @@ pub async fn unsafe_get_database_privileges_for_db_user_pair(
|
||||
}
|
||||
|
||||
pub async fn get_databases_privilege_data(
|
||||
database_names: Vec<MySQLDatabase>,
|
||||
database_names: &[MySQLDatabase],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -146,19 +146,19 @@ pub async fn get_databases_privilege_data(
|
||||
) -> ListPrivilegesResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for database_name in &database_names {
|
||||
for database_name in database_names.iter().cloned() {
|
||||
if let Err(err) = validate_db_or_user_request(
|
||||
&DbOrUser::Database(database_name.clone()),
|
||||
&DbOrUser::Database(database_name.to_owned()),
|
||||
unix_user,
|
||||
group_denylist,
|
||||
)
|
||||
.map_err(ListPrivilegesError::ValidationError)
|
||||
{
|
||||
results.insert(database_name.to_owned(), Err(err));
|
||||
results.insert(database_name, Err(err));
|
||||
continue;
|
||||
}
|
||||
|
||||
match unsafe_database_exists(database_name, connection).await {
|
||||
match unsafe_database_exists(&database_name, connection).await {
|
||||
Ok(false) => {
|
||||
results.insert(
|
||||
database_name.to_owned(),
|
||||
@@ -176,7 +176,7 @@ pub async fn get_databases_privilege_data(
|
||||
Ok(true) => {}
|
||||
}
|
||||
|
||||
let result = unsafe_get_database_privileges(database_name, connection)
|
||||
let result = unsafe_get_database_privileges(&database_name, connection)
|
||||
.await
|
||||
.map_err(|e| ListPrivilegesError::MySqlError(e.to_string()));
|
||||
|
||||
@@ -400,7 +400,7 @@ async fn validate_diff(
|
||||
|
||||
/// Uses the result of [`diff_privileges`] to modify privileges in the database.
|
||||
pub async fn apply_privilege_diffs(
|
||||
database_privilege_diffs: BTreeSet<DatabasePrivilegesDiff>,
|
||||
database_privilege_diffs: &BTreeSet<DatabasePrivilegesDiff>,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -468,12 +468,12 @@ pub async fn apply_privilege_diffs(
|
||||
Ok(true) => {}
|
||||
}
|
||||
|
||||
if let Err(err) = validate_diff(&diff, connection).await {
|
||||
if let Err(err) = validate_diff(diff, connection).await {
|
||||
results.insert(key, Err(err));
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = unsafe_apply_privilege_diff(&diff, connection)
|
||||
let result = unsafe_apply_privilege_diff(diff, connection)
|
||||
.await
|
||||
.map_err(|e| ModifyDatabasePrivilegesError::MySqlError(e.to_string()));
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ pub(super) async fn unsafe_user_exists(
|
||||
}
|
||||
|
||||
pub async fn complete_user_name(
|
||||
user_prefix: String,
|
||||
user_prefix: &str,
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -95,7 +95,7 @@ pub async fn complete_user_name(
|
||||
}
|
||||
|
||||
pub async fn create_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -103,7 +103,7 @@ pub async fn create_database_users(
|
||||
) -> CreateUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(CreateUserError::ValidationError)
|
||||
@@ -141,7 +141,7 @@ pub async fn create_database_users(
|
||||
}
|
||||
|
||||
pub async fn drop_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
_db_is_mariadb: bool,
|
||||
@@ -149,7 +149,7 @@ pub async fn drop_database_users(
|
||||
) -> DropUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(DropUserError::ValidationError)
|
||||
@@ -272,7 +272,7 @@ async fn database_user_is_locked_unsafe(
|
||||
}
|
||||
|
||||
pub async fn lock_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -280,7 +280,7 @@ pub async fn lock_database_users(
|
||||
) -> LockUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(LockUserError::ValidationError)
|
||||
@@ -332,7 +332,7 @@ pub async fn lock_database_users(
|
||||
}
|
||||
|
||||
pub async fn unlock_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -340,7 +340,7 @@ pub async fn unlock_database_users(
|
||||
) -> UnlockUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(UnlockUserError::ValidationError)
|
||||
@@ -440,7 +440,7 @@ FROM `user`
|
||||
";
|
||||
|
||||
pub async fn list_database_users(
|
||||
db_users: Vec<MySQLUser>,
|
||||
db_users: &[MySQLUser],
|
||||
unix_user: &UnixUser,
|
||||
connection: &mut MySqlConnection,
|
||||
db_is_mariadb: bool,
|
||||
@@ -448,7 +448,7 @@ pub async fn list_database_users(
|
||||
) -> ListUsersResponse {
|
||||
let mut results = BTreeMap::new();
|
||||
|
||||
for db_user in db_users {
|
||||
for db_user in db_users.iter().cloned() {
|
||||
if let Err(err) =
|
||||
validate_db_or_user_request(&DbOrUser::User(db_user.clone()), unix_user, group_denylist)
|
||||
.map_err(ListUsersError::ValidationError)
|
||||
|
||||
@@ -2,7 +2,10 @@ use std::{
|
||||
fs,
|
||||
os::{fd::FromRawFd, unix::net::UnixListener as StdUnixListener},
|
||||
path::PathBuf,
|
||||
sync::Arc,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicU64, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
@@ -22,7 +25,7 @@ use crate::{
|
||||
server::{
|
||||
authorization::read_and_parse_group_denylist,
|
||||
config::{MysqlConfig, ServerConfig},
|
||||
session_handler::session_handler,
|
||||
session_handler::{SessionId, session_handler},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -548,6 +551,8 @@ async fn listener_task(
|
||||
#[cfg(target_os = "linux")]
|
||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
|
||||
|
||||
let connection_counter = AtomicU64::new(0);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
@@ -577,14 +582,20 @@ async fn listener_task(
|
||||
} => {
|
||||
match accept_result {
|
||||
Ok((conn, _addr)) => {
|
||||
tracing::debug!("Got new connection");
|
||||
|
||||
connection_counter.fetch_add(1, Ordering::Relaxed);
|
||||
let conn_id = connection_counter.load(Ordering::Relaxed);
|
||||
|
||||
tracing::debug!("Got new connection, assigned session ID {}", conn_id);
|
||||
|
||||
let session_id = SessionId::new(conn_id);
|
||||
let db_pool_clone = db_pool.clone();
|
||||
let db_is_mariadb_clone = *db_is_mariadb.read().await;
|
||||
let group_denylist_arc_clone = group_denylist.clone();
|
||||
task_tracker.spawn(async move {
|
||||
match session_handler(
|
||||
conn,
|
||||
session_id,
|
||||
db_pool_clone,
|
||||
db_is_mariadb_clone,
|
||||
&*group_denylist_arc_clone.read().await,
|
||||
|
||||
Reference in New Issue
Block a user