clippy pedantic fix + get rid of a few unwraps
All checks were successful
Build and test / docs (push) Successful in 7m1s
Build and test / check-license (push) Successful in 57s
Build and test / check (push) Successful in 2m46s
Build and test / build (push) Successful in 3m12s
Build and test / test (push) Successful in 3m25s

This commit is contained in:
2025-12-23 13:40:46 +09:00
parent c866400b4a
commit 4c3677d6d3
51 changed files with 596 additions and 545 deletions

View File

@@ -39,13 +39,13 @@ pub fn erroneous_server_response(
) -> anyhow::Result<()> {
match response {
Some(Ok(Response::Error(e))) => {
anyhow::bail!("Server returned error: {}", e);
anyhow::bail!("Server returned error: {e}");
}
Some(Err(e)) => {
anyhow::bail!(e);
}
Some(response) => {
anyhow::bail!("Unexpected response from server: {:?}", response);
anyhow::bail!("Unexpected response from server: {response:?}");
}
None => {
anyhow::bail!("No response from server");
@@ -72,7 +72,7 @@ async fn print_authorization_owner_hint(
eprintln!(
"Note: You are allowed to manage databases and users with the following prefixes:\n{}",
response.into_iter().map(|p| format!(" - {}", p)).join("\n")
response.into_iter().map(|p| format!(" - {p}")).join("\n")
);
Ok(())

View File

@@ -14,7 +14,7 @@ use tokio_stream::StreamExt;
#[derive(Parser, Debug, Clone)]
pub struct CheckAuthArgs {
/// The MySQL database(s) or user(s) to check authorization for
/// The `MySQL` database(s) or user(s) to check authorization for
#[arg(num_args = 1.., value_name = "NAME")]
name: Vec<String>,
@@ -63,7 +63,7 @@ pub async fn check_authorization(
print_check_authorization_output_status(&result);
}
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -18,7 +18,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct CreateDbArgs {
/// The MySQL database(s) to create
/// The `MySQL` database(s) to create
#[arg(num_args = 1.., value_name = "DB_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))]
name: Vec<MySQLDatabase>,
@@ -36,7 +36,7 @@ pub async fn create_databases(
anyhow::bail!("No database names provided");
}
let message = Request::CreateDatabases(args.name.to_owned());
let message = Request::CreateDatabases(args.name.clone());
server_connection.send(message).await?;
let result = match server_connection.next().await {
@@ -57,13 +57,13 @@ pub async fn create_databases(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -22,7 +22,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct CreateUserArgs {
/// The MySQL user(s) to create
/// The `MySQL` user(s) to create
#[arg(num_args = 1.., value_name = "USER_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(prefix_completer)))]
username: Vec<MySQLUser>,
@@ -46,7 +46,7 @@ pub async fn create_users(
anyhow::bail!("No usernames provided");
}
let message = Request::CreateUsers(args.username.to_owned());
let message = Request::CreateUsers(args.username.clone());
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(anyhow::Error::from(err).context("Failed to communicate with server"));
@@ -70,20 +70,19 @@ pub async fn create_users(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
let successfully_created_users = result
.iter()
.filter_map(|(username, result)| result.as_ref().ok().map(|_| username))
.filter_map(|(username, result)| result.as_ref().ok().map(|()| username))
.collect::<Vec<_>>();
for username in successfully_created_users {
if !args.no_password
&& Confirm::new()
.with_prompt(format!(
"Do you want to set a password for user '{}'?",
username
"Do you want to set a password for user '{username}'?"
))
.default(false)
.interact()?
@@ -98,7 +97,7 @@ pub async fn create_users(
match server_connection.next().await {
Some(Ok(Response::SetUserPassword(result))) => {
print_set_password_output_status(&result, username)
print_set_password_output_status(&result, username);
}
response => return erroneous_server_response(response),
}
@@ -110,7 +109,7 @@ pub async fn create_users(
server_connection.send(Request::Exit).await?;
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -19,7 +19,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct DropDbArgs {
/// The MySQL database(s) to drop
/// The `MySQL` database(s) to drop
#[arg(num_args = 1.., value_name = "DB_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
name: Vec<MySQLDatabase>,
@@ -47,7 +47,7 @@ pub async fn drop_databases(
"Are you sure you want to drop the databases?\n\n{}\n\nThis action cannot be undone",
args.name
.iter()
.map(|d| format!("- {}", d))
.map(|d| format!("- {d}"))
.collect::<Vec<_>>()
.join("\n")
))
@@ -62,7 +62,7 @@ pub async fn drop_databases(
}
}
let message = Request::DropDatabases(args.name.to_owned());
let message = Request::DropDatabases(args.name.clone());
server_connection.send(message).await?;
let result = match server_connection.next().await {
@@ -83,13 +83,13 @@ pub async fn drop_databases(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
};
}
server_connection.send(Request::Exit).await?;
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -19,7 +19,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct DropUserArgs {
/// The MySQL user(s) to drop
/// The `MySQL` user(s) to drop
#[arg(num_args = 1.., value_name = "USER_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
username: Vec<MySQLUser>,
@@ -47,7 +47,7 @@ pub async fn drop_users(
"Are you sure you want to drop the users?\n\n{}\n\nThis action cannot be undone",
args.username
.iter()
.map(|d| format!("- {}", d))
.map(|d| format!("- {d}"))
.collect::<Vec<_>>()
.join("\n")
))
@@ -61,7 +61,7 @@ pub async fn drop_users(
}
}
let message = Request::DropUsers(args.username.to_owned());
let message = Request::DropUsers(args.username.clone());
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
@@ -86,13 +86,13 @@ pub async fn drop_users(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -67,7 +67,7 @@ pub struct EditPrivsArgs {
#[derive(Args, Debug, Clone)]
pub struct SinglePrivilegeEditArgs {
/// The MySQL database to edit privileges for
/// The `MySQL` database to edit privileges for
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
#[arg(
value_name = "DB_NAME",
@@ -76,7 +76,7 @@ pub struct SinglePrivilegeEditArgs {
)]
pub db_name: Option<MySQLDatabase>,
/// The MySQL database to edit privileges for
/// The `MySQL` database to edit privileges for
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
#[arg(value_name = "USER_NAME")]
pub user_name: Option<MySQLUser>,
@@ -212,13 +212,13 @@ pub async fn edit_database_privileges(
response => return erroneous_server_response(response),
};
let diffs: BTreeSet<DatabasePrivilegesDiff> = if !privs.is_empty() {
let privileges_to_change = parse_privilege_tables(&privs)?;
create_or_modify_privilege_rows(&existing_privilege_rows, &privileges_to_change)?
} else {
let diffs: BTreeSet<DatabasePrivilegesDiff> = if privs.is_empty() {
let privileges_to_change =
edit_privileges_with_editor(&existing_privilege_rows, use_database.as_ref())?;
diff_privileges(&existing_privilege_rows, &privileges_to_change)
} else {
let privileges_to_change = parse_privilege_tables(&privs)?;
create_or_modify_privilege_rows(&existing_privilege_rows, &privileges_to_change)?
};
let database_existence_map = databases_exist(&mut server_connection, &diffs).await?;
@@ -306,12 +306,12 @@ pub async fn edit_database_privileges(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
server_connection.send(Request::Exit).await?;
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}
@@ -328,8 +328,7 @@ fn parse_privilege_tables(
priv_edit_entry
.as_database_privileges_diff()
.context(format!(
"Failed parsing database privileges: `{}`",
priv_edit_entry
"Failed parsing database privileges: `{priv_edit_entry}`"
))
})
.collect::<anyhow::Result<BTreeSet<DatabasePrivilegeRowDiff>>>()
@@ -352,7 +351,7 @@ fn edit_privileges_with_editor(
match result {
None => Ok(privilege_data.to_vec()),
Some(result) => parse_privilege_data_from_editor_content(result)
Some(result) => parse_privilege_data_from_editor_content(&result)
.context("Could not parse privilege data from editor"),
}
}

View File

@@ -18,7 +18,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct LockUserArgs {
/// The MySQL user(s) to loc
/// The `MySQL` user(s) to loc
#[arg(num_args = 1.., value_name = "USER_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
username: Vec<MySQLUser>,
@@ -36,7 +36,7 @@ pub async fn lock_users(
anyhow::bail!("No usernames provided");
}
let message = Request::LockUsers(args.username.to_owned());
let message = Request::LockUsers(args.username.clone());
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
@@ -61,13 +61,13 @@ pub async fn lock_users(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if result.values().any(|res| res.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -21,7 +21,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct PasswdUserArgs {
/// The MySQL user whose password is to be changed
/// The `MySQL` user whose password is to be changed
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
#[arg(value_name = "USER_NAME")]
username: MySQLUser,
@@ -41,9 +41,9 @@ pub struct PasswdUserArgs {
pub fn read_password_from_stdin_with_double_check(username: &MySQLUser) -> anyhow::Result<String> {
Password::new()
.with_prompt(format!("New MySQL password for user '{}'", username))
.with_prompt(format!("New MySQL password for user '{username}'"))
.with_confirmation(
format!("Retype new MySQL password for user '{}'", username),
format!("Retype new MySQL password for user '{username}'"),
"Passwords do not match",
)
.interact()
@@ -55,7 +55,7 @@ pub async fn passwd_user(
mut server_connection: ClientToServerMessageStream,
) -> anyhow::Result<()> {
// TODO: create a "user" exists check" command
let message = Request::ListUsers(Some(vec![args.username.to_owned()]));
let message = Request::ListUsers(Some(vec![args.username.clone()]));
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
anyhow::bail!(err);
@@ -91,7 +91,7 @@ pub async fn passwd_user(
read_password_from_stdin_with_double_check(&args.username)?
};
let message = Request::PasswdUser((args.username.to_owned(), password));
let message = Request::PasswdUser((args.username.clone(), password));
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
@@ -111,7 +111,7 @@ pub async fn passwd_user(
ValidationError::AuthorizationError(_)
))
) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
server_connection.send(Request::Exit).await?;

View File

@@ -18,7 +18,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct ShowDbArgs {
/// The MySQL database(s) to show
/// The `MySQL` database(s) to show
#[arg(num_args = 0.., value_name = "DB_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
name: Vec<MySQLDatabase>,
@@ -39,7 +39,7 @@ pub async fn show_databases(
let message = if args.name.is_empty() {
Request::ListDatabases(None)
} else {
Request::ListDatabases(Some(args.name.to_owned()))
Request::ListDatabases(Some(args.name.clone()))
};
server_connection.send(message).await?;
@@ -74,13 +74,13 @@ pub async fn show_databases(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if databases.values().any(|res| res.is_err()) {
if databases.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -19,7 +19,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct ShowPrivsArgs {
/// The MySQL database(s) to show privileges for
/// The `MySQL` database(s) to show privileges for
#[arg(num_args = 0.., value_name = "DB_NAME")]
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_database_completer)))]
name: Vec<MySQLDatabase>,
@@ -42,7 +42,7 @@ pub async fn show_database_privileges(
let message = if args.name.is_empty() {
Request::ListPrivileges(None)
} else {
Request::ListPrivileges(Some(args.name.to_owned()))
Request::ListPrivileges(Some(args.name.clone()))
};
server_connection.send(message).await?;
@@ -78,13 +78,13 @@ pub async fn show_database_privileges(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if privilege_data.values().any(|res| res.is_err()) {
if privilege_data.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -18,7 +18,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct ShowUserArgs {
/// The MySQL user(s) to show
/// The `MySQL` user(s) to show
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
#[arg(num_args = 0.., value_name = "USER_NAME")]
username: Vec<MySQLUser>,
@@ -35,7 +35,7 @@ pub async fn show_users(
let message = if args.username.is_empty() {
Request::ListUsers(None)
} else {
Request::ListUsers(Some(args.username.to_owned()))
Request::ListUsers(Some(args.username.clone()))
};
if let Err(err) = server_connection.send(message).await {
@@ -73,13 +73,13 @@ pub async fn show_users(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if users.values().any(|result| result.is_err()) {
if users.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -18,7 +18,7 @@ use crate::{
#[derive(Parser, Debug, Clone)]
pub struct UnlockUserArgs {
/// The MySQL user(s) to unlock
/// The `MySQL` user(s) to unlock
#[cfg_attr(not(feature = "suid-sgid-mode"), arg(add = ArgValueCompleter::new(mysql_user_completer)))]
#[arg(num_args = 1.., value_name = "USER_NAME")]
username: Vec<MySQLUser>,
@@ -36,7 +36,7 @@ pub async fn unlock_users(
anyhow::bail!("No usernames provided");
}
let message = Request::UnlockUsers(args.username.to_owned());
let message = Request::UnlockUsers(args.username.clone());
if let Err(err) = server_connection.send(message).await {
server_connection.close().await.ok();
@@ -61,13 +61,13 @@ pub async fn unlock_users(
))
)
}) {
print_authorization_owner_hint(&mut server_connection).await?
print_authorization_owner_hint(&mut server_connection).await?;
}
}
server_connection.send(Request::Exit).await?;
if result.values().any(|result| result.is_err()) {
if result.values().any(std::result::Result::is_err) {
std::process::exit(1);
}

View File

@@ -1,11 +1,13 @@
use crate::core::types::{MySQLDatabase, MySQLUser};
#[inline]
#[must_use]
pub fn trim_db_name_to_32_chars(db_name: &MySQLDatabase) -> MySQLDatabase {
db_name.chars().take(32).collect::<String>().into()
}
#[inline]
#[must_use]
pub fn trim_user_name_to_32_chars(user_name: &MySQLUser) -> MySQLUser {
user_name.chars().take(32).collect::<String>().into()
}

View File

@@ -6,7 +6,7 @@ use crate::core::{
types::DbOrUser,
};
pub fn name_validation_error_to_error_message(db_or_user: DbOrUser) -> String {
pub fn name_validation_error_to_error_message(db_or_user: &DbOrUser) -> String {
let argv0 = std::env::args().next().unwrap_or_else(|| match db_or_user {
DbOrUser::Database(_) => "mysql-dbadm".to_string(),
DbOrUser::User(_) => "mysql-useradm".to_string(),
@@ -23,7 +23,7 @@ pub fn name_validation_error_to_error_message(db_or_user: DbOrUser) -> String {
)
}
pub fn authorization_error_message(db_or_user: DbOrUser) -> String {
pub fn authorization_error_message(db_or_user: &DbOrUser) -> String {
format!(
"You are not in charge of mysql-{}: '{}'. Skipping.",
db_or_user.lowercased_noun(),
@@ -31,7 +31,7 @@ pub fn authorization_error_message(db_or_user: DbOrUser) -> String {
)
}
pub fn handle_create_user_error(error: CreateUserError, name: &str) {
pub fn handle_create_user_error(error: &CreateUserError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-useradm".to_string());
@@ -39,22 +39,22 @@ pub fn handle_create_user_error(error: CreateUserError, name: &str) {
CreateUserError::ValidationError(ValidationError::NameValidationError(_)) => {
eprintln!(
"{}",
name_validation_error_to_error_message(DbOrUser::User(name.into()))
name_validation_error_to_error_message(&DbOrUser::User(name.into()))
);
}
CreateUserError::ValidationError(ValidationError::AuthorizationError(_)) => {
eprintln!(
"{}",
authorization_error_message(DbOrUser::User(name.into()))
authorization_error_message(&DbOrUser::User(name.into()))
);
}
CreateUserError::MySqlError(_) | CreateUserError::UserAlreadyExists => {
eprintln!("{}: Failed to create user '{}'.", argv0, name);
eprintln!("{argv0}: Failed to create user '{name}'.");
}
}
}
pub fn handle_drop_user_error(error: DropUserError, name: &str) {
pub fn handle_drop_user_error(error: &DropUserError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-useradm".to_string());
@@ -62,22 +62,22 @@ pub fn handle_drop_user_error(error: DropUserError, name: &str) {
DropUserError::ValidationError(ValidationError::NameValidationError(_)) => {
eprintln!(
"{}",
name_validation_error_to_error_message(DbOrUser::User(name.into()))
name_validation_error_to_error_message(&DbOrUser::User(name.into()))
);
}
DropUserError::ValidationError(ValidationError::AuthorizationError(_)) => {
eprintln!(
"{}",
authorization_error_message(DbOrUser::User(name.into()))
authorization_error_message(&DbOrUser::User(name.into()))
);
}
DropUserError::MySqlError(_) | DropUserError::UserDoesNotExist => {
eprintln!("{}: Failed to delete user '{}'.", argv0, name);
eprintln!("{argv0}: Failed to delete user '{name}'.");
}
}
}
pub fn handle_list_users_error(error: ListUsersError, name: &str) {
pub fn handle_list_users_error(error: &ListUsersError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-useradm".to_string());
@@ -85,30 +85,27 @@ pub fn handle_list_users_error(error: ListUsersError, name: &str) {
ListUsersError::ValidationError(ValidationError::NameValidationError(_)) => {
eprintln!(
"{}",
name_validation_error_to_error_message(DbOrUser::User(name.into()))
name_validation_error_to_error_message(&DbOrUser::User(name.into()))
);
}
ListUsersError::ValidationError(ValidationError::AuthorizationError(_)) => {
eprintln!(
"{}",
authorization_error_message(DbOrUser::User(name.into()))
authorization_error_message(&DbOrUser::User(name.into()))
);
}
ListUsersError::UserDoesNotExist => {
eprintln!(
"{}: User '{}' does not exist. You must create it first.",
argv0, name,
);
eprintln!("{argv0}: User '{name}' does not exist. You must create it first.",);
}
ListUsersError::MySqlError(_) => {
eprintln!("{}: Failed to look up password for user '{}'", argv0, name);
eprintln!("{argv0}: Failed to look up password for user '{name}'");
}
}
}
// ----------------------------------------------------------------------------
pub fn handle_create_database_error(error: CreateDatabaseError, name: &str) {
pub fn handle_create_database_error(error: &CreateDatabaseError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-dbadm".to_string());
@@ -116,26 +113,26 @@ pub fn handle_create_database_error(error: CreateDatabaseError, name: &str) {
CreateDatabaseError::ValidationError(ValidationError::NameValidationError(_)) => {
eprintln!(
"{}",
name_validation_error_to_error_message(DbOrUser::Database(name.into()))
name_validation_error_to_error_message(&DbOrUser::Database(name.into()))
);
}
CreateDatabaseError::ValidationError(ValidationError::AuthorizationError(_)) => {
eprintln!(
"{}",
authorization_error_message(DbOrUser::Database(name.into()))
authorization_error_message(&DbOrUser::Database(name.into()))
);
}
CreateDatabaseError::MySqlError(_) => {
eprintln!("{}: Cannot create database '{}'.", argv0, name);
eprintln!("{argv0}: Cannot create database '{name}'.");
}
CreateDatabaseError::DatabaseAlreadyExists => {
eprintln!("{}: Database '{}' already exists.", argv0, name);
eprintln!("{argv0}: Database '{name}' already exists.");
}
}
}
pub fn handle_drop_database_error(error: DropDatabaseError, name: &str) {
pub fn handle_drop_database_error(error: &DropDatabaseError, name: &str) {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-dbadm".to_string());
@@ -143,44 +140,41 @@ pub fn handle_drop_database_error(error: DropDatabaseError, name: &str) {
DropDatabaseError::ValidationError(ValidationError::NameValidationError(_)) => {
eprintln!(
"{}",
name_validation_error_to_error_message(DbOrUser::Database(name.into()))
name_validation_error_to_error_message(&DbOrUser::Database(name.into()))
);
}
DropDatabaseError::ValidationError(ValidationError::AuthorizationError(_)) => {
eprintln!(
"{}",
authorization_error_message(DbOrUser::Database(name.into()))
authorization_error_message(&DbOrUser::Database(name.into()))
);
}
DropDatabaseError::MySqlError(_) => {
eprintln!("{}: Cannot drop database '{}'.", argv0, name);
eprintln!("{argv0}: Cannot drop database '{name}'.");
}
DropDatabaseError::DatabaseDoesNotExist => {
eprintln!("{}: Database '{}' doesn't exist.", argv0, name);
eprintln!("{argv0}: Database '{name}' doesn't exist.");
}
}
}
pub fn format_show_database_error_message(error: ListPrivilegesError, name: &str) -> String {
pub fn format_show_database_error_message(error: &ListPrivilegesError, name: &str) -> String {
let argv0 = std::env::args()
.next()
.unwrap_or_else(|| "mysql-dbadm".to_string());
match error {
ListPrivilegesError::ValidationError(ValidationError::NameValidationError(_)) => {
name_validation_error_to_error_message(DbOrUser::Database(name.into()))
name_validation_error_to_error_message(&DbOrUser::Database(name.into()))
}
ListPrivilegesError::ValidationError(ValidationError::AuthorizationError(_)) => {
authorization_error_message(DbOrUser::Database(name.into()))
authorization_error_message(&DbOrUser::Database(name.into()))
}
ListPrivilegesError::MySqlError(err) => {
format!(
"{}: Failed to look up privileges for database '{}': {}",
argv0, name, err
)
format!("{argv0}: Failed to look up privileges for database '{name}': {err}")
}
ListPrivilegesError::DatabaseDoesNotExist => {
format!("{}: Database '{}' doesn't exist.", argv0, name)
format!("{argv0}: Database '{name}' doesn't exist.")
}
}
}

View File

@@ -1,5 +1,6 @@
use clap::{Parser, Subcommand};
use clap_complete::ArgValueCompleter;
use clap_verbosity_flag::Verbosity;
use futures_util::{SinkExt, StreamExt};
use std::os::unix::net::UnixStream as StdUnixStream;
use std::path::PathBuf;
@@ -28,7 +29,7 @@ use crate::{
},
};
const HELP_DB_PERM: &str = r#"
const HELP_DB_PERM: &str = r"
Edit permissions for the DATABASE(s). Running this command will
spawn the editor stored in the $EDITOR environment variable.
(pico will be used if the variable is unset)
@@ -49,7 +50,7 @@ The Y/N-values corresponds to the following mysql privileges:
Temp - Enables use of CREATE TEMPORARY TABLE
Lock - Enables use of LOCK TABLE
References - Enables use of REFERENCES
"#;
";
/// Create, drop or edit permissions for the DATABASE(s),
/// as determined by the COMMAND.
@@ -156,25 +157,22 @@ pub fn main() -> anyhow::Result<()> {
let args: Args = Args::parse();
if args.help_editperm {
println!("{}", HELP_DB_PERM);
println!("{HELP_DB_PERM}");
return Ok(());
}
let server_connection = bootstrap_server_connection_and_drop_privileges(
args.server_socket_path,
args.config,
Default::default(),
Verbosity::default(),
)?;
let command = match args.command {
Some(command) => command,
None => {
println!(
"Try `{} --help' for more information.",
std::env::args().next().unwrap_or("mysql-dbadm".to_string())
);
return Ok(());
}
let Some(command) = args.command else {
println!(
"Try `{} --help' for more information.",
std::env::args().next().unwrap_or("mysql-dbadm".to_string())
);
return Ok(());
};
tokio_run_command(command, server_connection)?;
@@ -194,11 +192,11 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh
while let Some(Ok(message)) = message_stream.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
anyhow::bail!("{err}");
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
eprintln!("Unexpected message from server: {message:?}");
}
}
}
@@ -245,8 +243,8 @@ async fn create_databases(
for (name, result) in result {
match result {
Ok(()) => println!("Database {} created.", name),
Err(err) => handle_create_database_error(err, &name),
Ok(()) => println!("Database {name} created."),
Err(err) => handle_create_database_error(&err, &name),
}
}
@@ -271,8 +269,8 @@ async fn drop_databases(
for (name, result) in result {
match result {
Ok(()) => println!("Database {} dropped.", name),
Err(err) => handle_drop_database_error(err, &name),
Ok(()) => println!("Database {name} dropped."),
Err(err) => handle_drop_database_error(&err, &name),
}
}
@@ -312,24 +310,21 @@ async fn show_databases(
let results: Vec<Result<(MySQLDatabase, Vec<DatabasePrivilegeRow>), String>> = match response {
Some(Ok(Response::ListPrivileges(result))) => result
.into_iter()
.map(
|(name, rows)| match rows.map(|rows| (name.to_owned(), rows)) {
Ok(rows) => Ok(rows),
Err(ListPrivilegesError::DatabaseDoesNotExist) => Ok((name, vec![])),
Err(err) => Err(format_show_database_error_message(err, &name)),
},
)
.map(|(name, rows)| match rows.map(|rows| (name.clone(), rows)) {
Ok(rows) => Ok(rows),
Err(ListPrivilegesError::DatabaseDoesNotExist) => Ok((name, vec![])),
Err(err) => Err(format_show_database_error_message(&err, &name)),
})
.collect(),
response => return erroneous_server_response(response),
};
results.into_iter().try_for_each(|result| match result {
Ok((name, rows)) => print_db_privs(&name, rows),
Err(err) => {
eprintln!("{}", err);
Ok(())
for result in results {
match result {
Ok((name, rows)) => print_db_privs(&name, rows),
Err(err) => eprintln!("{err}"),
}
})?;
}
Ok(())
}
@@ -339,7 +334,7 @@ fn yn(value: bool) -> &'static str {
if value { "Y" } else { "N" }
}
fn print_db_privs(name: &str, rows: Vec<DatabasePrivilegeRow>) -> anyhow::Result<()> {
fn print_db_privs(name: &str, rows: Vec<DatabasePrivilegeRow>) {
println!(
concat!(
"Database '{}':\n",
@@ -369,6 +364,4 @@ fn print_db_privs(name: &str, rows: Vec<DatabasePrivilegeRow>) -> anyhow::Result
);
}
}
Ok(())
}

View File

@@ -75,7 +75,7 @@ pub enum Command {
/// delete the USER(s).
Delete(DeleteArgs),
/// change the MySQL password for the USER(s).
/// change the `MySQL` password for the USER(s).
Passwd(PasswdArgs),
/// give information about the USERS(s), or, if
@@ -119,17 +119,14 @@ pub struct ShowArgs {
pub fn main() -> anyhow::Result<()> {
let args: Args = Args::parse();
let command = match args.command {
Some(command) => command,
None => {
println!(
"Try `{} --help' for more information.",
std::env::args()
.next()
.unwrap_or("mysql-useradm".to_string())
);
return Ok(());
}
let Some(command) = args.command else {
println!(
"Try `{} --help' for more information.",
std::env::args()
.next()
.unwrap_or("mysql-useradm".to_string())
);
return Ok(());
};
let server_connection = bootstrap_server_connection_and_drop_privileges(
@@ -155,11 +152,11 @@ fn tokio_run_command(command: Command, server_connection: StdUnixStream) -> anyh
while let Some(Ok(message)) = message_stream.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
anyhow::bail!("{err}");
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
eprintln!("Unexpected message from server: {message:?}");
}
}
}
@@ -191,8 +188,8 @@ async fn create_user(
for (name, result) in result {
match result {
Ok(()) => println!("User '{}' created.", name),
Err(err) => handle_create_user_error(err, &name),
Ok(()) => println!("User '{name}' created."),
Err(err) => handle_create_user_error(&err, &name),
}
}
@@ -217,8 +214,8 @@ async fn drop_users(
for (name, result) in result {
match result {
Ok(()) => println!("User '{}' deleted.", name),
Err(err) => handle_drop_user_error(err, &name),
Ok(()) => println!("User '{name}' deleted."),
Err(err) => handle_drop_user_error(&err, &name),
}
}
@@ -248,7 +245,7 @@ async fn passwd_users(
.filter_map(|(name, result)| match result {
Ok(user) => Some(user),
Err(err) => {
handle_list_users_error(err, &name);
handle_list_users_error(&err, &name);
None
}
})
@@ -256,7 +253,7 @@ async fn passwd_users(
for user in users {
let password = read_password_from_stdin_with_double_check(&user.user)?;
let message = Request::PasswdUser((user.user.to_owned(), password));
let message = Request::PasswdUser((user.user.clone(), password));
server_connection.send(message).await?;
match server_connection.next().await {
Some(Ok(Response::SetUserPassword(result))) => match result {
@@ -292,7 +289,7 @@ async fn show_users(
Some(Ok(Response::ListAllUsers(result))) => match result {
Ok(users) => users,
Err(err) => {
eprintln!("Failed to list users: {:?}", err);
eprintln!("Failed to list users: {err:?}");
return Ok(());
}
},
@@ -301,7 +298,7 @@ async fn show_users(
.filter_map(|(name, result)| match result {
Ok(user) => Some(user),
Err(err) => {
handle_list_users_error(err, &name);
handle_list_users_error(&err, &name);
None
}
})

View File

@@ -1,4 +1,9 @@
use std::{fs, path::PathBuf, sync::Arc, time::Duration};
use std::{
fs,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use anyhow::{Context, anyhow};
use clap_verbosity_flag::{InfoLevel, Verbosity};
@@ -136,7 +141,7 @@ fn connect_to_external_server(
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {e}")),
},
};
}
@@ -148,7 +153,7 @@ fn connect_to_external_server(
Err(e) => match e.kind() {
std::io::ErrorKind::NotFound => Err(anyhow::anyhow!("Socket not found")),
std::io::ErrorKind::PermissionDenied => Err(anyhow::anyhow!("Permission denied")),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {}", e)),
_ => Err(anyhow::anyhow!("Failed to connect to socket: {e}")),
},
};
}
@@ -192,10 +197,10 @@ fn bootstrap_internal_server_and_drop_privs(
}
tracing::debug!("Starting server with config at {:?}", config_path);
let socket = invoke_server_with_config(config_path)?;
let socket = invoke_server_with_config(&config_path)?;
drop_privs()?;
return Ok(socket);
};
}
let config_path = PathBuf::from(DEFAULT_CONFIG_PATH);
if fs::metadata(&config_path).is_ok() {
@@ -203,10 +208,10 @@ fn bootstrap_internal_server_and_drop_privs(
anyhow::bail!("Executable is not SUID/SGID - refusing to start internal sever");
}
tracing::debug!("Starting server with default config at {:?}", config_path);
let socket = invoke_server_with_config(config_path)?;
let socket = invoke_server_with_config(&config_path)?;
drop_privs()?;
return Ok(socket);
};
}
anyhow::bail!("No config path provided, and no default config found");
}
@@ -216,7 +221,7 @@ fn bootstrap_internal_server_and_drop_privs(
/// Fork a child process to run the server with the provided config.
/// The server will exit silently by itself when it is done, and this function
/// will only return for the client with the socket for the server.
fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStream> {
fn invoke_server_with_config(config_path: &Path) -> anyhow::Result<StdUnixStream> {
let (server_socket, client_socket) = StdUnixStream::pair()?;
let unix_user = UnixUser::from_uid(nix::unistd::getuid().as_raw())?;
@@ -228,18 +233,18 @@ fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStre
nix::unistd::ForkResult::Child => {
tracing::debug!("Running server in child process");
landlock_restrict_server(Some(config_path.as_path()))
landlock_restrict_server(Some(config_path))
.context("Failed to apply Landlock restrictions to the server process")?;
match run_forked_server(config_path, server_socket, unix_user) {
match run_forked_server(config_path, server_socket, &unix_user) {
Err(e) => Err(e),
Ok(_) => unreachable!(),
Ok(()) => unreachable!(),
}
}
}
}
/// Construct a MySQL connection pool that consists of exactly one connection.
/// Construct a `MySQL` connection pool that consists of exactly one connection.
///
/// This is used for the internal server in SUID/SGID mode, where the server session
/// only ever will get a single client.
@@ -273,11 +278,11 @@ async fn construct_single_connection_mysql_pool(
/// This function will not return, but will exit the process with a success code.
/// The function assumes that it's caller has already forked the process.
fn run_forked_server(
config_path: PathBuf,
config_path: &Path,
server_socket: StdUnixStream,
unix_user: UnixUser,
unix_user: &UnixUser,
) -> anyhow::Result<()> {
let config = ServerConfig::read_config_from_path(&config_path)
let config = ServerConfig::read_config_from_path(config_path)
.context("Failed to read server config in forked process")?;
let group_denylist = if let Some(denylist_path) = &config.authorization.group_denylist_file {
@@ -306,7 +311,7 @@ fn run_forked_server(
let db_pool = Arc::new(RwLock::new(db_pool));
session_handler::session_handler_with_unix_user(
socket,
&unix_user,
unix_user,
db_pool,
db_is_mariadb,
&group_denylist,

View File

@@ -10,13 +10,13 @@ pub const DEFAULT_CONFIG_PATH: &str = "/etc/muscl/config.toml";
pub const DEFAULT_SOCKET_PATH: &str = "/run/muscl/muscl.sock";
pub const ASCII_BANNER: &str = indoc! {
r#"
r"
__
____ ___ __ ____________/ /
/ __ `__ \/ / / / ___/ ___/ /
/ / / / / / /_/ (__ ) /__/ /
/_/ /_/ /_/\__,_/____/\___/_/
"#
"
};
pub const KIND_REGARDS: &str = concat!(
@@ -95,7 +95,7 @@ impl UnixUser {
Ok(UnixUser {
username: libc_user.name,
groups: groups.iter().map(|g| g.name.to_owned()).collect(),
groups: groups.iter().map(|g| g.name.clone()).collect(),
})
}

View File

@@ -12,6 +12,7 @@ use crate::{
},
};
#[must_use]
pub fn mysql_database_completer(current: &std::ffi::OsStr) -> Vec<CompletionCandidate> {
match tokio::runtime::Builder::new_current_thread()
.enable_all()
@@ -20,18 +21,18 @@ pub fn mysql_database_completer(current: &std::ffi::OsStr) -> Vec<CompletionCand
Ok(runtime) => match runtime.block_on(mysql_database_completer_(current)) {
Ok(completions) => completions,
Err(err) => {
eprintln!("Error getting MySQL database completions: {}", err);
eprintln!("Error getting MySQL database completions: {err}");
Vec::new()
}
},
Err(err) => {
eprintln!("Error starting Tokio runtime: {}", err);
eprintln!("Error starting Tokio runtime: {err}");
Vec::new()
}
}
}
/// Connect to the server to get MySQL database completions.
/// Connect to the server to get `MySQL` database completions.
async fn mysql_database_completer_(
current: &std::ffi::OsStr,
) -> anyhow::Result<Vec<CompletionCandidate>> {
@@ -44,11 +45,11 @@ async fn mysql_database_completer_(
while let Some(Ok(message)) = server_connection.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
anyhow::bail!("{err}");
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
eprintln!("Unexpected message from server: {message:?}");
}
}
}
@@ -62,7 +63,7 @@ async fn mysql_database_completer_(
let result = match server_connection.next().await {
Some(Ok(Response::CompleteDatabaseName(suggestions))) => suggestions,
response => return erroneous_server_response(response).map(|_| vec![]),
response => return erroneous_server_response(response).map(|()| vec![]),
};
server_connection.send(Request::Exit).await?;

View File

@@ -12,6 +12,7 @@ use crate::{
},
};
#[must_use]
pub fn mysql_user_completer(current: &std::ffi::OsStr) -> Vec<CompletionCandidate> {
match tokio::runtime::Builder::new_current_thread()
.enable_all()
@@ -20,18 +21,18 @@ pub fn mysql_user_completer(current: &std::ffi::OsStr) -> Vec<CompletionCandidat
Ok(runtime) => match runtime.block_on(mysql_user_completer_(current)) {
Ok(completions) => completions,
Err(err) => {
eprintln!("Error getting MySQL user completions: {}", err);
eprintln!("Error getting MySQL user completions: {err}");
Vec::new()
}
},
Err(err) => {
eprintln!("Error starting Tokio runtime: {}", err);
eprintln!("Error starting Tokio runtime: {err}");
Vec::new()
}
}
}
/// Connect to the server to get MySQL user completions.
/// Connect to the server to get `MySQL` user completions.
async fn mysql_user_completer_(
current: &std::ffi::OsStr,
) -> anyhow::Result<Vec<CompletionCandidate>> {
@@ -44,11 +45,11 @@ async fn mysql_user_completer_(
while let Some(Ok(message)) = server_connection.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
anyhow::bail!("{err}");
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
eprintln!("Unexpected message from server: {message:?}");
}
}
}
@@ -62,7 +63,7 @@ async fn mysql_user_completer_(
let result = match server_connection.next().await {
Some(Ok(Response::CompleteUserName(suggestions))) => suggestions,
response => return erroneous_server_response(response).map(|_| vec![]),
response => return erroneous_server_response(response).map(|()| vec![]),
};
server_connection.send(Request::Exit).await?;

View File

@@ -12,6 +12,7 @@ use crate::{
},
};
#[must_use]
pub fn prefix_completer(current: &std::ffi::OsStr) -> Vec<CompletionCandidate> {
match tokio::runtime::Builder::new_current_thread()
.enable_all()
@@ -20,18 +21,18 @@ pub fn prefix_completer(current: &std::ffi::OsStr) -> Vec<CompletionCandidate> {
Ok(runtime) => match runtime.block_on(prefix_completer_(current)) {
Ok(completions) => completions,
Err(err) => {
eprintln!("Error getting prefix completions: {}", err);
eprintln!("Error getting prefix completions: {err}");
Vec::new()
}
},
Err(err) => {
eprintln!("Error starting Tokio runtime: {}", err);
eprintln!("Error starting Tokio runtime: {err}");
Vec::new()
}
}
}
/// Connect to the server to get MySQL user completions.
/// Connect to the server to get `MySQL` user completions.
async fn prefix_completer_(_current: &std::ffi::OsStr) -> anyhow::Result<Vec<CompletionCandidate>> {
let server_connection =
bootstrap_server_connection_and_drop_privileges(None, None, Verbosity::new(0, 1))?;
@@ -42,11 +43,11 @@ async fn prefix_completer_(_current: &std::ffi::OsStr) -> anyhow::Result<Vec<Com
while let Some(Ok(message)) = server_connection.next().await {
match message {
Response::Error(err) => {
anyhow::bail!("{}", err);
anyhow::bail!("{err}");
}
Response::Ready => break,
message => {
eprintln!("Unexpected message from server: {:?}", message);
eprintln!("Unexpected message from server: {message:?}");
}
}
}
@@ -60,7 +61,7 @@ async fn prefix_completer_(_current: &std::ffi::OsStr) -> anyhow::Result<Vec<Com
let result = match server_connection.next().await {
Some(Ok(Response::ListValidNamePrefixes(prefixes))) => prefixes,
response => return erroneous_server_response(response).map(|_| vec![]),
response => return erroneous_server_response(response).map(|()| vec![]),
};
server_connection.send(Request::Exit).await?;

View File

@@ -1,5 +1,5 @@
//! This module contains some base datastructures and functionality for dealing with
//! database privileges in MySQL.
//! database privileges in `MySQL`.
use std::fmt;
@@ -49,6 +49,7 @@ pub struct DatabasePrivilegeRow {
impl DatabasePrivilegeRow {
/// Gets the value of a privilege by its name as a &str.
#[must_use]
pub fn get_privilege_by_name(&self, name: &str) -> Option<bool> {
match name {
"select_priv" => Some(self.select_priv),
@@ -83,6 +84,7 @@ impl fmt::Display for DatabasePrivilegeRow {
}
/// Converts a database privilege field name to a human-readable name.
#[must_use]
pub fn db_priv_field_human_readable_name(name: &str) -> String {
match name {
"Db" => "Database".to_owned(),
@@ -98,12 +100,13 @@ pub fn db_priv_field_human_readable_name(name: &str) -> String {
"create_tmp_table_priv" => "Temp".to_owned(),
"lock_tables_priv" => "Lock".to_owned(),
"references_priv" => "References".to_owned(),
_ => format!("Unknown({})", name),
_ => format!("Unknown({name})"),
}
}
/// Converts a database privilege field name to a single-character name.
/// (the characters from the cli privilege editor)
#[must_use]
pub fn db_priv_field_single_character_name(name: &str) -> &str {
match name {
"select_priv" => "s",

View File

@@ -51,9 +51,7 @@ impl DatabasePrivilegeEdit {
.map(|c| format!("'{c}'"))
.join(", ");
anyhow::bail!(
"Invalid character(s) in privilege edit entry: {}\n\nValid characters are: {}",
invalid_chars,
valid_characters,
"Invalid character(s) in privilege edit entry: {invalid_chars}\n\nValid characters are: {valid_characters}",
);
}
@@ -72,7 +70,7 @@ impl std::fmt::Display for DatabasePrivilegeEdit {
DatabasePrivilegeEditEntryType::Remove => write!(f, "-")?,
}
for priv_char in &self.privileges {
write!(f, "{}", priv_char)?;
write!(f, "{priv_char}")?;
}
Ok(())
@@ -99,7 +97,7 @@ impl DatabasePrivilegeEditEntry {
/// `database_name:username:[+|-]privileges`
///
/// where:
/// - database_name is the name of the database to edit privileges for
/// - `database_name` is the name of the database to edit privileges for
/// - username is the name of the user to edit privileges for
/// - privileges is a string of characters representing the privileges to add, set or remove
/// - the `+` or `-` prefix indicates whether to add or remove the privileges, if omitted the privileges are set directly
@@ -107,13 +105,13 @@ impl DatabasePrivilegeEditEntry {
pub fn parse_from_str(arg: &str) -> anyhow::Result<Self> {
let parts: Vec<&str> = arg.split(':').collect();
if parts.len() != 3 {
anyhow::bail!("Invalid privilege edit entry format: {}", arg);
anyhow::bail!("Invalid privilege edit entry format: {arg}");
}
let (database, user, user_privs) = (parts[0].to_string(), parts[1].to_string(), parts[2]);
if user.is_empty() {
anyhow::bail!("Username cannot be empty in privilege edit entry: {}", arg);
anyhow::bail!("Username cannot be empty in privilege edit entry: {arg}");
}
let privilege_edit = DatabasePrivilegeEdit::parse_from_str(user_privs)?;

View File

@@ -18,6 +18,7 @@ pub enum DatabasePrivilegeChange {
}
impl DatabasePrivilegeChange {
#[must_use]
pub fn new(p1: bool, p2: bool) -> Option<DatabasePrivilegeChange> {
match (p1, p2) {
(true, false) => Some(DatabasePrivilegeChange::YesToNo),
@@ -49,6 +50,7 @@ pub struct DatabasePrivilegeRowDiff {
impl DatabasePrivilegeRowDiff {
/// Calculates the difference between two [`DatabasePrivilegeRow`] instances.
#[must_use]
pub fn from_rows(
row1: &DatabasePrivilegeRow,
row2: &DatabasePrivilegeRow,
@@ -56,8 +58,8 @@ impl DatabasePrivilegeRowDiff {
debug_assert!(row1.db == row2.db && row1.user == row2.user);
DatabasePrivilegeRowDiff {
db: row1.db.to_owned(),
user: row1.user.to_owned(),
db: row1.db.clone(),
user: row1.user.clone(),
select_priv: DatabasePrivilegeChange::new(row1.select_priv, row2.select_priv),
insert_priv: DatabasePrivilegeChange::new(row1.insert_priv, row2.insert_priv),
update_priv: DatabasePrivilegeChange::new(row1.update_priv, row2.update_priv),
@@ -82,6 +84,7 @@ impl DatabasePrivilegeRowDiff {
}
/// Returns true if there are no changes in this diff.
#[must_use]
pub fn is_empty(&self) -> bool {
self.select_priv.is_none()
&& self.insert_priv.is_none()
@@ -113,7 +116,7 @@ impl DatabasePrivilegeRowDiff {
"create_tmp_table_priv" => Ok(self.create_tmp_table_priv),
"lock_tables_priv" => Ok(self.lock_tables_priv),
"references_priv" => Ok(self.references_priv),
_ => anyhow::bail!("Unknown privilege name: {}", privilege_name),
_ => anyhow::bail!("Unknown privilege name: {privilege_name}"),
}
}
@@ -159,7 +162,7 @@ impl DatabasePrivilegeRowDiff {
/// Removes any no-op changes from the diff, based on the original privilege row.
fn remove_noops(&mut self, from: &DatabasePrivilegeRow) {
fn new_value(
change: &Option<DatabasePrivilegeChange>,
change: Option<&DatabasePrivilegeChange>,
from_value: bool,
) -> Option<DatabasePrivilegeChange> {
change.as_ref().and_then(|c| match c {
@@ -173,22 +176,24 @@ impl DatabasePrivilegeRowDiff {
})
}
self.select_priv = new_value(&self.select_priv, from.select_priv);
self.insert_priv = new_value(&self.insert_priv, from.insert_priv);
self.update_priv = new_value(&self.update_priv, from.update_priv);
self.delete_priv = new_value(&self.delete_priv, from.delete_priv);
self.create_priv = new_value(&self.create_priv, from.create_priv);
self.drop_priv = new_value(&self.drop_priv, from.drop_priv);
self.alter_priv = new_value(&self.alter_priv, from.alter_priv);
self.index_priv = new_value(&self.index_priv, from.index_priv);
self.create_tmp_table_priv =
new_value(&self.create_tmp_table_priv, from.create_tmp_table_priv);
self.lock_tables_priv = new_value(&self.lock_tables_priv, from.lock_tables_priv);
self.references_priv = new_value(&self.references_priv, from.references_priv);
self.select_priv = new_value(self.select_priv.as_ref(), from.select_priv);
self.insert_priv = new_value(self.insert_priv.as_ref(), from.insert_priv);
self.update_priv = new_value(self.update_priv.as_ref(), from.update_priv);
self.delete_priv = new_value(self.delete_priv.as_ref(), from.delete_priv);
self.create_priv = new_value(self.create_priv.as_ref(), from.create_priv);
self.drop_priv = new_value(self.drop_priv.as_ref(), from.drop_priv);
self.alter_priv = new_value(self.alter_priv.as_ref(), from.alter_priv);
self.index_priv = new_value(self.index_priv.as_ref(), from.index_priv);
self.create_tmp_table_priv = new_value(
self.create_tmp_table_priv.as_ref(),
from.create_tmp_table_priv,
);
self.lock_tables_priv = new_value(self.lock_tables_priv.as_ref(), from.lock_tables_priv);
self.references_priv = new_value(self.references_priv.as_ref(), from.references_priv);
}
fn apply(&self, base: &mut DatabasePrivilegeRow) {
fn apply_change(change: &Option<DatabasePrivilegeChange>, target: &mut bool) {
fn apply_change(change: Option<&DatabasePrivilegeChange>, target: &mut bool) {
match change {
Some(DatabasePrivilegeChange::YesToNo) => *target = false,
Some(DatabasePrivilegeChange::NoToYes) => *target = true,
@@ -196,17 +201,20 @@ impl DatabasePrivilegeRowDiff {
}
}
apply_change(&self.select_priv, &mut base.select_priv);
apply_change(&self.insert_priv, &mut base.insert_priv);
apply_change(&self.update_priv, &mut base.update_priv);
apply_change(&self.delete_priv, &mut base.delete_priv);
apply_change(&self.create_priv, &mut base.create_priv);
apply_change(&self.drop_priv, &mut base.drop_priv);
apply_change(&self.alter_priv, &mut base.alter_priv);
apply_change(&self.index_priv, &mut base.index_priv);
apply_change(&self.create_tmp_table_priv, &mut base.create_tmp_table_priv);
apply_change(&self.lock_tables_priv, &mut base.lock_tables_priv);
apply_change(&self.references_priv, &mut base.references_priv);
apply_change(self.select_priv.as_ref(), &mut base.select_priv);
apply_change(self.insert_priv.as_ref(), &mut base.insert_priv);
apply_change(self.update_priv.as_ref(), &mut base.update_priv);
apply_change(self.delete_priv.as_ref(), &mut base.delete_priv);
apply_change(self.create_priv.as_ref(), &mut base.create_priv);
apply_change(self.drop_priv.as_ref(), &mut base.drop_priv);
apply_change(self.alter_priv.as_ref(), &mut base.alter_priv);
apply_change(self.index_priv.as_ref(), &mut base.index_priv);
apply_change(
self.create_tmp_table_priv.as_ref(),
&mut base.create_tmp_table_priv,
);
apply_change(self.lock_tables_priv.as_ref(), &mut base.lock_tables_priv);
apply_change(self.references_priv.as_ref(), &mut base.references_priv);
}
}
@@ -214,7 +222,7 @@ impl fmt::Display for DatabasePrivilegeRowDiff {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn format_change(
f: &mut fmt::Formatter<'_>,
change: &Option<DatabasePrivilegeChange>,
change: Option<DatabasePrivilegeChange>,
field_name: &str,
) -> fmt::Result {
if let Some(change) = change {
@@ -233,17 +241,17 @@ impl fmt::Display for DatabasePrivilegeRowDiff {
}
}
format_change(f, &self.select_priv, "select_priv")?;
format_change(f, &self.insert_priv, "insert_priv")?;
format_change(f, &self.update_priv, "update_priv")?;
format_change(f, &self.delete_priv, "delete_priv")?;
format_change(f, &self.create_priv, "create_priv")?;
format_change(f, &self.drop_priv, "drop_priv")?;
format_change(f, &self.alter_priv, "alter_priv")?;
format_change(f, &self.index_priv, "index_priv")?;
format_change(f, &self.create_tmp_table_priv, "create_tmp_table_priv")?;
format_change(f, &self.lock_tables_priv, "lock_tables_priv")?;
format_change(f, &self.references_priv, "references_priv")?;
format_change(f, self.select_priv, "select_priv")?;
format_change(f, self.insert_priv, "insert_priv")?;
format_change(f, self.update_priv, "update_priv")?;
format_change(f, self.delete_priv, "delete_priv")?;
format_change(f, self.create_priv, "create_priv")?;
format_change(f, self.drop_priv, "drop_priv")?;
format_change(f, self.alter_priv, "alter_priv")?;
format_change(f, self.index_priv, "index_priv")?;
format_change(f, self.create_tmp_table_priv, "create_tmp_table_priv")?;
format_change(f, self.lock_tables_priv, "lock_tables_priv")?;
format_change(f, self.references_priv, "references_priv")?;
Ok(())
}
@@ -259,6 +267,7 @@ pub enum DatabasePrivilegesDiff {
}
impl DatabasePrivilegesDiff {
#[must_use]
pub fn get_database_name(&self) -> &MySQLDatabase {
match self {
DatabasePrivilegesDiff::New(p) => &p.db,
@@ -268,6 +277,7 @@ impl DatabasePrivilegesDiff {
}
}
#[must_use]
pub fn get_user_name(&self) -> &MySQLUser {
match self {
DatabasePrivilegesDiff::New(p) => &p.user,
@@ -305,7 +315,7 @@ impl DatabasePrivilegesDiff {
}
if matches!(self, DatabasePrivilegesDiff::Noop { .. }) {
*self = other.to_owned();
other.clone_into(self);
return Ok(());
} else if matches!(other, DatabasePrivilegesDiff::Noop { .. }) {
return Ok(());
@@ -327,8 +337,8 @@ impl DatabasePrivilegesDiff {
inner_diff.mappend(modified);
if inner_diff.is_empty() {
let db = inner_diff.db.to_owned();
let user = inner_diff.user.to_owned();
let db = inner_diff.db.clone();
let user = inner_diff.user.clone();
*self = DatabasePrivilegesDiff::Noop { db, user };
}
}
@@ -352,28 +362,27 @@ pub type DatabasePrivilegeState<'a> = &'a [DatabasePrivilegeRow];
/// This function calculates the differences between two sets of database privileges.
/// It returns a set of [`DatabasePrivilegesDiff`] that can be used to display or
/// apply a set of privilege modifications to the database.
#[must_use]
pub fn diff_privileges(
from: DatabasePrivilegeState<'_>,
to: &[DatabasePrivilegeRow],
) -> BTreeSet<DatabasePrivilegesDiff> {
let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> =
HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.to_owned(), p.user.to_owned()), p)),
);
let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> = from
.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p))
.collect();
let to_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> =
HashMap::from_iter(
to.iter()
.cloned()
.map(|p| ((p.db.to_owned(), p.user.to_owned()), p)),
);
let to_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> = to
.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p))
.collect();
let mut result = BTreeSet::new();
for p in to {
if let Some(old_p) = from_lookup_table.get(&(p.db.to_owned(), p.user.to_owned())) {
if let Some(old_p) = from_lookup_table.get(&(p.db.clone(), p.user.clone())) {
let diff = DatabasePrivilegeRowDiff::from_rows(old_p, p);
if !diff.is_empty() {
result.insert(DatabasePrivilegesDiff::Modified(diff));
@@ -384,7 +393,7 @@ pub fn diff_privileges(
}
for p in from {
if !to_lookup_table.contains_key(&(p.db.to_owned(), p.user.to_owned())) {
if !to_lookup_table.contains_key(&(p.db.clone(), p.user.clone())) {
result.insert(DatabasePrivilegesDiff::Deleted(p.to_owned()));
}
}
@@ -400,17 +409,16 @@ pub fn create_or_modify_privilege_rows(
from: DatabasePrivilegeState<'_>,
to: &BTreeSet<DatabasePrivilegeRowDiff>,
) -> anyhow::Result<BTreeSet<DatabasePrivilegesDiff>> {
let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> =
HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.to_owned(), p.user.to_owned()), p)),
);
let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> = from
.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p))
.collect();
let mut result = BTreeSet::new();
for diff in to {
if let Some(old_p) = from_lookup_table.get(&(diff.db.to_owned(), diff.user.to_owned())) {
if let Some(old_p) = from_lookup_table.get(&(diff.db.clone(), diff.user.clone())) {
let mut modified_diff = diff.to_owned();
modified_diff.remove_noops(old_p);
if !modified_diff.is_empty() {
@@ -418,8 +426,8 @@ pub fn create_or_modify_privilege_rows(
}
} else {
let mut new_row = DatabasePrivilegeRow {
db: diff.db.to_owned(),
user: diff.user.to_owned(),
db: diff.db.clone(),
user: diff.user.clone(),
select_priv: false,
insert_priv: false,
update_priv: false,
@@ -450,12 +458,11 @@ pub fn reduce_privilege_diffs(
from: DatabasePrivilegeState<'_>,
to: BTreeSet<DatabasePrivilegesDiff>,
) -> anyhow::Result<BTreeSet<DatabasePrivilegesDiff>> {
let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> =
HashMap::from_iter(
from.iter()
.cloned()
.map(|p| ((p.db.to_owned(), p.user.to_owned()), p)),
);
let from_lookup_table: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegeRow> = from
.iter()
.cloned()
.map(|p| ((p.db.clone(), p.user.clone()), p))
.collect();
let mut result: HashMap<(MySQLDatabase, MySQLUser), DatabasePrivilegesDiff> = from_lookup_table
.iter()
@@ -481,19 +488,19 @@ pub fn reduce_privilege_diffs(
existing_diff.mappend(&diff)?;
}
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(diff.to_owned());
vacant_entry.insert(diff.clone());
}
}
}
for (key, diff) in result.iter_mut() {
for (key, diff) in &mut result {
if let Some(from_row) = from_lookup_table.get(key)
&& let DatabasePrivilegesDiff::Modified(modified_diff) = diff
{
modified_diff.remove_noops(from_row);
if modified_diff.is_empty() {
let db = modified_diff.db.to_owned();
let user = modified_diff.user.to_owned();
let db = modified_diff.db.clone();
let user = modified_diff.user.clone();
*diff = DatabasePrivilegesDiff::Noop { db, user };
}
}
@@ -506,6 +513,7 @@ pub fn reduce_privilege_diffs(
}
/// Renders a set of [`DatabasePrivilegesDiff`] into a human-readable formatted table.
#[must_use]
pub fn display_privilege_diffs(diffs: &BTreeSet<DatabasePrivilegesDiff>) -> String {
let mut table = Table::new();
table.set_titles(row!["Database", "User", "Privilege diff",]);

View File

@@ -13,6 +13,7 @@ use itertools::Itertools;
use std::cmp::max;
/// Generates a single row of the privileges table for the editor.
#[must_use]
pub fn format_privileges_line_for_editor(
privs: &DatabasePrivilegeRow,
database_name_len: usize,
@@ -25,6 +26,7 @@ pub fn format_privileges_line_for_editor(
"User" => format!("{:width$}", privs.user, width = username_len),
privilege => format!(
"{:width$}",
// SAFETY: unwrap is safe here because the field names are static
yn(privs.get_privilege_by_name(privilege).unwrap()),
width = db_priv_field_human_readable_name(privilege).len()
),
@@ -34,14 +36,14 @@ pub fn format_privileges_line_for_editor(
.to_string()
}
const EDITOR_COMMENT: &str = r#"
const EDITOR_COMMENT: &str = r"
# Welcome to the privilege editor.
# Each line defines what privileges a single user has on a single database.
# The first two columns respectively represent the database name and the user, and the remaining columns are the privileges.
# If the user should have a certain privilege, write 'Y', otherwise write 'N'.
#
# Lines starting with '#' are comments and will be ignored.
"#;
";
/// Generates the content for the privilege editor.
///
@@ -52,9 +54,9 @@ pub fn generate_editor_content_from_privilege_data(
unix_user: &str,
database_name: Option<&MySQLDatabase>,
) -> String {
let example_user = format!("{}_user", unix_user);
let example_user = format!("{unix_user}_user");
let example_db = database_name
.unwrap_or(&format!("{}_db", unix_user).into())
.unwrap_or(&format!("{unix_user}_db").into())
.to_string();
// NOTE: `.max()`` fails when the iterator is empty.
@@ -114,7 +116,7 @@ pub fn generate_editor_content_from_privilege_data(
EDITOR_COMMENT,
header.join(" "),
if privilege_data.is_empty() {
format!("# {}", example_line)
format!("# {example_line}")
} else {
privilege_data
.iter()
@@ -145,11 +147,8 @@ enum PrivilegeRowParseResult {
fn parse_privilege_cell_from_editor(yn: &str, name: &str) -> anyhow::Result<bool> {
let human_readable_name = db_priv_field_human_readable_name(name);
rev_yn(yn)
.ok_or_else(|| anyhow!("Expected Y or N, found {}", yn))
.context(format!(
"Could not parse '{}' privilege",
human_readable_name
))
.ok_or_else(|| anyhow!("Expected Y or N, found {yn}"))
.context(format!("Could not parse '{human_readable_name}' privilege"))
}
#[inline]
@@ -272,12 +271,12 @@ fn parse_privilege_row_from_editor(row: &str) -> PrivilegeRowParseResult {
}
pub fn parse_privilege_data_from_editor_content(
content: String,
content: &str,
) -> anyhow::Result<Vec<DatabasePrivilegeRow>> {
content
.trim()
.split('\n')
.map(|line| line.trim())
.lines()
.map(str::trim)
.enumerate()
.map(|(i, line)| {
let mut header: Vec<_> = DATABASE_PRIVILEGE_FIELDS
@@ -314,7 +313,7 @@ pub fn parse_privilege_data_from_editor_content(
PrivilegeRowParseResult::Empty => Ok(None),
}
})
.filter_map(|result| result.transpose())
.filter_map(std::result::Result::transpose)
.collect::<anyhow::Result<Vec<DatabasePrivilegeRow>>>()
}
@@ -417,7 +416,7 @@ mod tests {
let content = generate_editor_content_from_privilege_data(&permissions, "user", None);
let parsed_permissions = parse_privilege_data_from_editor_content(content).unwrap();
let parsed_permissions = parse_privilege_data_from_editor_content(&content).unwrap();
assert_eq!(permissions, parsed_permissions);
}

View File

@@ -57,10 +57,12 @@ pub fn print_check_authorization_output_status_json(output: &CheckAuthorizationR
}
impl CheckAuthorizationError {
#[must_use]
pub fn to_error_message(&self, db_or_user: &DbOrUser) -> String {
self.0.to_error_message(db_or_user.clone())
self.0.to_error_message(db_or_user)
}
#[must_use]
pub fn error_type(&self) -> String {
self.0.error_type()
}

View File

@@ -29,7 +29,7 @@ pub fn print_create_databases_output_status(output: &CreateDatabasesResponse) {
for (database_name, result) in output {
match result {
Ok(()) => {
println!("Database '{}' created successfully.", database_name);
println!("Database '{database_name}' created successfully.");
}
Err(err) => {
eprintln!("{}", err.to_error_message(database_name));
@@ -63,20 +63,22 @@ pub fn print_create_databases_output_status_json(output: &CreateDatabasesRespons
}
impl CreateDatabaseError {
#[must_use]
pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self {
CreateDatabaseError::ValidationError(err) => {
err.to_error_message(DbOrUser::Database(database_name.clone()))
err.to_error_message(&DbOrUser::Database(database_name.clone()))
}
CreateDatabaseError::DatabaseAlreadyExists => {
format!("Database {} already exists.", database_name)
format!("Database {database_name} already exists.")
}
CreateDatabaseError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
CreateDatabaseError::ValidationError(err) => err.error_type(),

View File

@@ -29,7 +29,7 @@ pub fn print_create_users_output_status(output: &CreateUsersResponse) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' created successfully.", username);
println!("User '{username}' created successfully.");
}
Err(err) => {
eprintln!("{}", err.to_error_message(username));
@@ -63,20 +63,22 @@ pub fn print_create_users_output_status_json(output: &CreateUsersResponse) {
}
impl CreateUserError {
#[must_use]
pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self {
CreateUserError::ValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
CreateUserError::UserAlreadyExists => {
format!("User '{}' already exists.", username)
format!("User '{username}' already exists.")
}
CreateUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
CreateUserError::ValidationError(err) => err.error_type(),

View File

@@ -66,20 +66,22 @@ pub fn print_drop_databases_output_status_json(output: &DropDatabasesResponse) {
}
impl DropDatabaseError {
#[must_use]
pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self {
DropDatabaseError::ValidationError(err) => {
err.to_error_message(DbOrUser::Database(database_name.clone()))
err.to_error_message(&DbOrUser::Database(database_name.clone()))
}
DropDatabaseError::DatabaseDoesNotExist => {
format!("Database {} does not exist.", database_name)
format!("Database {database_name} does not exist.")
}
DropDatabaseError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
DropDatabaseError::ValidationError(err) => err.error_type(),

View File

@@ -29,7 +29,7 @@ pub fn print_drop_users_output_status(output: &DropUsersResponse) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' dropped successfully.", username);
println!("User '{username}' dropped successfully.");
}
Err(err) => {
eprintln!("{}", err.to_error_message(username));
@@ -63,20 +63,22 @@ pub fn print_drop_users_output_status_json(output: &DropUsersResponse) {
}
impl DropUserError {
#[must_use]
pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self {
DropUserError::ValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
DropUserError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
format!("User '{username}' does not exist.")
}
DropUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
DropUserError::ValidationError(err) => err.error_type(),

View File

@@ -12,13 +12,15 @@ pub enum ListAllDatabasesError {
}
impl ListAllDatabasesError {
#[must_use]
pub fn to_error_message(&self) -> String {
match self {
ListAllDatabasesError::MySqlError(err) => format!("MySQL error: {}", err),
ListAllDatabasesError::MySqlError(err) => format!("MySQL error: {err}"),
}
}
#[allow(dead_code)]
#[must_use]
pub fn error_type(&self) -> String {
match self {
ListAllDatabasesError::MySqlError(_) => "mysql-error".to_string(),

View File

@@ -12,13 +12,15 @@ pub enum ListAllPrivilegesError {
}
impl ListAllPrivilegesError {
#[must_use]
pub fn to_error_message(&self) -> String {
match self {
ListAllPrivilegesError::MySqlError(err) => format!("MySQL error: {}", err),
ListAllPrivilegesError::MySqlError(err) => format!("MySQL error: {err}"),
}
}
#[allow(dead_code)]
#[must_use]
pub fn error_type(&self) -> String {
match self {
ListAllPrivilegesError::MySqlError(_) => "mysql-error".to_string(),

View File

@@ -12,13 +12,15 @@ pub enum ListAllUsersError {
}
impl ListAllUsersError {
#[must_use]
pub fn to_error_message(&self) -> String {
match self {
ListAllUsersError::MySqlError(err) => format!("MySQL error: {}", err),
ListAllUsersError::MySqlError(err) => format!("MySQL error: {err}"),
}
}
#[allow(dead_code)]
#[must_use]
pub fn error_type(&self) -> String {
match self {
ListAllUsersError::MySqlError(_) => "mysql-error".to_string(),

View File

@@ -113,20 +113,22 @@ pub fn print_list_databases_output_status_json(output: &ListDatabasesResponse) {
}
impl ListDatabasesError {
#[must_use]
pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self {
ListDatabasesError::ValidationError(err) => {
err.to_error_message(DbOrUser::Database(database_name.clone()))
err.to_error_message(&DbOrUser::Database(database_name.clone()))
}
ListDatabasesError::DatabaseDoesNotExist => {
format!("Database '{}' does not exist.", database_name)
format!("Database '{database_name}' does not exist.")
}
ListDatabasesError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
ListDatabasesError::ValidationError(err) => err.error_type(),

View File

@@ -65,7 +65,7 @@ pub fn print_list_privileges_output_status(output: &ListPrivilegesResponse, long
));
for (_database, rows) in final_privs_map {
for row in rows.iter() {
for row in &rows {
table.add_row(row![
row.db,
row.user,
@@ -129,20 +129,22 @@ pub enum ListPrivilegesError {
}
impl ListPrivilegesError {
#[must_use]
pub fn to_error_message(&self, database_name: &MySQLDatabase) -> String {
match self {
ListPrivilegesError::ValidationError(err) => {
err.to_error_message(DbOrUser::Database(database_name.clone()))
err.to_error_message(&DbOrUser::Database(database_name.clone()))
}
ListPrivilegesError::DatabaseDoesNotExist => {
format!("Database '{}' does not exist.", database_name)
format!("Database '{database_name}' does not exist.")
}
ListPrivilegesError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
ListPrivilegesError::ValidationError(err) => err.error_type(),

View File

@@ -97,20 +97,22 @@ pub fn print_list_users_output_status_json(output: &ListUsersResponse) {
}
impl ListUsersError {
#[must_use]
pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self {
ListUsersError::ValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
ListUsersError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
format!("User '{username}' does not exist.")
}
ListUsersError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
ListUsersError::ValidationError(err) => err.error_type(),

View File

@@ -32,7 +32,7 @@ pub fn print_lock_users_output_status(output: &LockUsersResponse) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' locked successfully.", username);
println!("User '{username}' locked successfully.");
}
Err(err) => {
eprintln!("{}", err.to_error_message(username));
@@ -66,23 +66,25 @@ pub fn print_lock_users_output_status_json(output: &LockUsersResponse) {
}
impl LockUserError {
#[must_use]
pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self {
LockUserError::ValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
LockUserError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
format!("User '{username}' does not exist.")
}
LockUserError::UserIsAlreadyLocked => {
format!("User '{}' is already locked.", username)
format!("User '{username}' is already locked.")
}
LockUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
LockUserError::ValidationError(err) => err.error_type(),

View File

@@ -53,8 +53,7 @@ pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesR
match result {
Ok(()) => {
println!(
"Privileges for user '{}' on database '{}' modified successfully.",
username, database_name
"Privileges for user '{username}' on database '{database_name}' modified successfully."
);
}
Err(err) => {
@@ -67,19 +66,20 @@ pub fn print_modify_database_privileges_output_status(output: &ModifyPrivilegesR
}
impl ModifyDatabasePrivilegesError {
#[must_use]
pub fn to_error_message(&self, database_name: &MySQLDatabase, username: &MySQLUser) -> String {
match self {
ModifyDatabasePrivilegesError::DatabaseValidationError(err) => {
err.to_error_message(DbOrUser::Database(database_name.clone()))
err.to_error_message(&DbOrUser::Database(database_name.clone()))
}
ModifyDatabasePrivilegesError::UserValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
ModifyDatabasePrivilegesError::DatabaseDoesNotExist => {
format!("Database '{}' does not exist.", database_name)
format!("Database '{database_name}' does not exist.")
}
ModifyDatabasePrivilegesError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
format!("User '{username}' does not exist.")
}
ModifyDatabasePrivilegesError::DiffDoesNotApply(diff) => {
format!(
@@ -88,12 +88,13 @@ impl ModifyDatabasePrivilegesError {
)
}
ModifyDatabasePrivilegesError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[allow(dead_code)]
#[must_use]
pub fn error_type(&self) -> String {
match self {
ModifyDatabasePrivilegesError::DatabaseValidationError(err) => {
@@ -113,29 +114,26 @@ impl ModifyDatabasePrivilegesError {
}
impl DiffDoesNotApplyError {
#[must_use]
pub fn to_error_message(&self) -> String {
match self {
DiffDoesNotApplyError::RowAlreadyExists(database_name, username) => {
format!(
"Privileges for user '{}' on database '{}' already exist.",
username, database_name
"Privileges for user '{username}' on database '{database_name}' already exist."
)
}
DiffDoesNotApplyError::RowDoesNotExist(database_name, username) => {
format!(
"Privileges for user '{}' on database '{}' do not exist.",
username, database_name
"Privileges for user '{username}' on database '{database_name}' do not exist."
)
}
DiffDoesNotApplyError::RowPrivilegeChangeDoesNotApply(diff, row) => {
format!(
"Could not apply privilege change {:?} to row {:?}",
diff, row
)
format!("Could not apply privilege change {diff:?} to row {row:?}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
DiffDoesNotApplyError::RowAlreadyExists(_, _) => "row-already-exists".to_string(),

View File

@@ -25,7 +25,7 @@ pub enum SetPasswordError {
pub fn print_set_password_output_status(output: &SetUserPasswordResponse, username: &MySQLUser) {
match output {
Ok(()) => {
println!("Password for user '{}' set successfully.", username);
println!("Password for user '{username}' set successfully.");
}
Err(err) => {
eprintln!("{}", err.to_error_message(username));
@@ -35,21 +35,23 @@ pub fn print_set_password_output_status(output: &SetUserPasswordResponse, userna
}
impl SetPasswordError {
#[must_use]
pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self {
SetPasswordError::ValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
SetPasswordError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
format!("User '{username}' does not exist.")
}
SetPasswordError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[allow(dead_code)]
#[must_use]
pub fn error_type(&self) -> String {
match self {
SetPasswordError::ValidationError(err) => err.error_type(),

View File

@@ -32,7 +32,7 @@ pub fn print_unlock_users_output_status(output: &UnlockUsersResponse) {
for (username, result) in output {
match result {
Ok(()) => {
println!("User '{}' unlocked successfully.", username);
println!("User '{username}' unlocked successfully.");
}
Err(err) => {
eprintln!("{}", err.to_error_message(username));
@@ -66,23 +66,25 @@ pub fn print_unlock_users_output_status_json(output: &UnlockUsersResponse) {
}
impl UnlockUserError {
#[must_use]
pub fn to_error_message(&self, username: &MySQLUser) -> String {
match self {
UnlockUserError::ValidationError(err) => {
err.to_error_message(DbOrUser::User(username.clone()))
err.to_error_message(&DbOrUser::User(username.clone()))
}
UnlockUserError::UserDoesNotExist => {
format!("User '{}' does not exist.", username)
format!("User '{username}' does not exist.")
}
UnlockUserError::UserIsAlreadyUnlocked => {
format!("User '{}' is already unlocked.", username)
format!("User '{username}' is already unlocked.")
}
UnlockUserError::MySqlError(err) => {
format!("MySQL error: {}", err)
format!("MySQL error: {err}")
}
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
UnlockUserError::ValidationError(err) => err.error_type(),

View File

@@ -22,7 +22,8 @@ pub enum NameValidationError {
}
impl NameValidationError {
pub fn to_error_message(self, db_or_user: DbOrUser) -> String {
#[must_use]
pub fn to_error_message(self, db_or_user: &DbOrUser) -> String {
match self {
NameValidationError::EmptyString => {
format!("{} name can not be empty.", db_or_user.capitalized_noun())
@@ -32,15 +33,16 @@ impl NameValidationError {
db_or_user.capitalized_noun()
),
NameValidationError::InvalidCharacters => format!(
indoc! {r#"
indoc! {r"
Invalid characters in {} name: '{}', only A-Z, a-z, 0-9, _ (underscore) and - (dash) are permitted.
"#},
"},
db_or_user.lowercased_noun(),
db_or_user.name(),
),
}
}
#[must_use]
pub fn error_type(&self) -> &'static str {
match self {
NameValidationError::EmptyString => "empty-string",
@@ -64,7 +66,8 @@ pub enum AuthorizationError {
}
impl AuthorizationError {
pub fn to_error_message(self, db_or_user: DbOrUser) -> String {
#[must_use]
pub fn to_error_message(self, db_or_user: &DbOrUser) -> String {
match self {
AuthorizationError::IllegalPrefix => format!(
"Illegal {} name prefix: you are not allowed to manage databases or users prefixed with '{}'",
@@ -82,6 +85,7 @@ impl AuthorizationError {
}
}
#[must_use]
pub fn error_type(&self) -> &'static str {
match self {
AuthorizationError::IllegalPrefix => "illegal-prefix",
@@ -102,7 +106,8 @@ pub enum ValidationError {
}
impl ValidationError {
pub fn to_error_message(&self, db_or_user: DbOrUser) -> String {
#[must_use]
pub fn to_error_message(&self, db_or_user: &DbOrUser) -> String {
match self {
ValidationError::NameValidationError(err) => err.to_error_message(db_or_user),
ValidationError::AuthorizationError(err) => err.to_error_message(db_or_user),
@@ -116,6 +121,7 @@ impl ValidationError {
}
}
#[must_use]
pub fn error_type(&self) -> String {
match self {
ValidationError::NameValidationError(err) => {
@@ -153,7 +159,7 @@ pub fn validate_authorization_by_unix_user(
name: &str,
user: &UnixUser,
) -> Result<(), AuthorizationError> {
let prefixes = std::iter::once(user.username.to_owned())
let prefixes = std::iter::once(user.username.clone())
.chain(user.groups.iter().cloned())
.collect::<Vec<String>>();
@@ -174,12 +180,12 @@ pub fn validate_authorization_by_prefixes(
if prefixes
.iter()
.filter(|p| name.starts_with(&(p.to_string() + "_")))
.filter(|p| name.starts_with(&((*p).clone() + "_")))
.collect::<Vec<_>>()
.is_empty()
{
return Err(AuthorizationError::IllegalPrefix);
};
}
Ok(())
}

View File

@@ -112,6 +112,7 @@ pub enum DbOrUser {
}
impl DbOrUser {
#[must_use]
pub fn lowercased_noun(&self) -> &'static str {
match self {
DbOrUser::Database(_) => "database",
@@ -119,6 +120,7 @@ impl DbOrUser {
}
}
#[must_use]
pub fn capitalized_noun(&self) -> &'static str {
match self {
DbOrUser::Database(_) => "Database",
@@ -126,6 +128,7 @@ impl DbOrUser {
}
}
#[must_use]
pub fn name(&self) -> &str {
match self {
DbOrUser::Database(db) => db.as_str(),
@@ -133,6 +136,7 @@ impl DbOrUser {
}
}
#[must_use]
pub fn prefix(&self) -> &str {
match self {
DbOrUser::Database(db) => db.split('_').next().unwrap_or("?"),

View File

@@ -42,10 +42,8 @@ pub async fn check_authorization(
/// - `gid:1001`
/// - `group:admins`
pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<GroupDenylist> {
let content = std::fs::read_to_string(denylist_path).context(format!(
"Failed to read denylist file at {:?}",
denylist_path
))?;
let content = std::fs::read_to_string(denylist_path)
.context(format!("Failed to read denylist file at {denylist_path:?}"))?;
let mut groups = HashSet::with_capacity(content.lines().count());
@@ -128,7 +126,6 @@ pub fn read_and_parse_group_denylist(denylist_path: &Path) -> anyhow::Result<Gro
line_number + 1,
err
);
continue;
}
},
_ => {

View File

@@ -44,7 +44,7 @@ impl MysqlConfig {
if let Some(password_file) = &self.password_file {
let password = fs::read_to_string(password_file)
.with_context(|| {
format!("Failed to read MySQL password file at {:?}", password_file)
format!("Failed to read MySQL password file at {password_file:?}")
})?
.trim()
.to_owned();
@@ -96,8 +96,8 @@ impl ServerConfig {
tracing::debug!("Reading config file at {:?}", config_path);
fs::read_to_string(config_path)
.context(format!("Failed to read config file at {:?}", config_path))
.context(format!("Failed to read config file at {config_path:?}"))
.and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
.context(format!("Failed to parse config file at {:?}", config_path))
.context(format!("Failed to parse config file at {config_path:?}"))
}
}

View File

@@ -78,7 +78,7 @@ pub async fn session_handler(
))
.await
.ok();
anyhow::bail!("Failed to get username from uid: {}", e);
anyhow::bail!("Failed to get username from uid: {e}");
}
};
@@ -181,10 +181,10 @@ async fn session_handler_with_db_connection(
request => request.to_owned(),
};
if request_to_display != Request::Exit {
tracing::info!("Received request: {:#?}", request_to_display);
} else {
if request_to_display == Request::Exit {
tracing::debug!("Received request: {:#?}", request_to_display);
} else {
tracing::info!("Received request: {:#?}", request_to_display);
}
let response = match request {
@@ -194,22 +194,20 @@ async fn session_handler_with_db_connection(
}
Request::ListValidNamePrefixes => {
let mut result = Vec::with_capacity(unix_user.groups.len() + 1);
result.push(unix_user.username.to_owned());
result.push(unix_user.username.clone());
for group in get_user_filtered_groups(unix_user, group_denylist) {
result.push(group.to_owned());
result.push(group.clone());
}
Response::ListValidNamePrefixes(result)
}
Request::CompleteDatabaseName(partial_database_name) => {
// TODO: more correct validation here
if !partial_database_name
if partial_database_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
Response::CompleteDatabaseName(vec![])
} else {
let result = complete_database_name(
partial_database_name,
unix_user,
@@ -219,16 +217,16 @@ async fn session_handler_with_db_connection(
)
.await;
Response::CompleteDatabaseName(result)
} else {
Response::CompleteDatabaseName(vec![])
}
}
Request::CompleteUserName(partial_user_name) => {
// TODO: more correct validation here
if !partial_user_name
if partial_user_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
Response::CompleteUserName(vec![])
} else {
let result = complete_user_name(
partial_user_name,
unix_user,
@@ -238,6 +236,8 @@ async fn session_handler_with_db_connection(
)
.await;
Response::CompleteUserName(result)
} else {
Response::CompleteUserName(vec![])
}
}
Request::CreateDatabases(databases_names) => {
@@ -262,8 +262,8 @@ async fn session_handler_with_db_connection(
.await;
Response::DropDatabases(result)
}
Request::ListDatabases(database_names) => match database_names {
Some(database_names) => {
Request::ListDatabases(database_names) => {
if let Some(database_names) = database_names {
let result = list_databases(
database_names,
unix_user,
@@ -273,8 +273,7 @@ async fn session_handler_with_db_connection(
)
.await;
Response::ListDatabases(result)
}
None => {
} else {
let result = list_all_databases_for_user(
unix_user,
db_connection,
@@ -284,9 +283,9 @@ async fn session_handler_with_db_connection(
.await;
Response::ListAllDatabases(result)
}
},
Request::ListPrivileges(database_names) => match database_names {
Some(database_names) => {
}
Request::ListPrivileges(database_names) => {
if let Some(database_names) = database_names {
let privilege_data = get_databases_privilege_data(
database_names,
unix_user,
@@ -296,8 +295,7 @@ async fn session_handler_with_db_connection(
)
.await;
Response::ListPrivileges(privilege_data)
}
None => {
} else {
let privilege_data = get_all_database_privileges(
unix_user,
db_connection,
@@ -307,7 +305,7 @@ async fn session_handler_with_db_connection(
.await;
Response::ListAllPrivileges(privilege_data)
}
},
}
Request::ModifyPrivileges(database_privilege_diffs) => {
let result = apply_privilege_diffs(
BTreeSet::from_iter(database_privilege_diffs),
@@ -353,8 +351,8 @@ async fn session_handler_with_db_connection(
.await;
Response::SetUserPassword(result)
}
Request::ListUsers(db_users) => match db_users {
Some(db_users) => {
Request::ListUsers(db_users) => {
if let Some(db_users) = db_users {
let result = list_database_users(
db_users,
unix_user,
@@ -364,8 +362,7 @@ async fn session_handler_with_db_connection(
)
.await;
Response::ListUsers(result)
}
None => {
} else {
let result = list_all_database_users_for_unix_user(
unix_user,
db_connection,
@@ -375,7 +372,7 @@ async fn session_handler_with_db_connection(
.await;
Response::ListAllUsers(result)
}
},
}
Request::LockUsers(db_users) => {
let result = lock_database_users(
db_users,

View File

@@ -3,11 +3,13 @@ pub mod database_privilege_operations;
pub mod user_operations;
#[inline]
#[must_use]
pub fn quote_literal(s: &str) -> String {
format!("'{}'", s.replace('\'', r"\'"))
}
#[inline]
#[must_use]
pub fn quote_identifier(s: &str) -> String {
format!("`{}`", s.replace('`', r"\`"))
}

View File

@@ -53,16 +53,16 @@ pub async fn complete_database_name(
group_denylist: &GroupDenylist,
) -> CompleteDatabaseNameResponse {
let result = sqlx::query(
r#"
r"
SELECT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?
AND `SCHEMA_NAME` LIKE ?
"#,
",
)
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(format!("{}%", database_prefix))
.bind(format!("{database_prefix}%"))
.fetch_all(connection)
.await;
@@ -103,21 +103,21 @@ pub async fn create_databases(
)
.map_err(CreateDatabaseError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
results.insert(database_name.clone(), Err(err));
continue;
}
match unsafe_database_exists(&database_name, &mut *connection).await {
Ok(true) => {
results.insert(
database_name.to_owned(),
database_name.clone(),
Err(CreateDatabaseError::DatabaseAlreadyExists),
);
continue;
}
Err(err) => {
results.insert(
database_name.to_owned(),
database_name.clone(),
Err(CreateDatabaseError::MySqlError(err.to_string())),
);
continue;
@@ -159,21 +159,21 @@ pub async fn drop_databases(
)
.map_err(DropDatabaseError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
results.insert(database_name.clone(), Err(err));
continue;
}
match unsafe_database_exists(&database_name, &mut *connection).await {
Ok(false) => {
results.insert(
database_name.to_owned(),
database_name.clone(),
Err(DropDatabaseError::DatabaseDoesNotExist),
);
continue;
}
Err(err) => {
results.insert(
database_name.to_owned(),
database_name.clone(),
Err(DropDatabaseError::MySqlError(err.to_string())),
);
continue;
@@ -218,7 +218,7 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseRow {
if s.is_empty() {
None
} else {
Some(s.split(',').map(|s| s.to_owned()).collect())
Some(s.split(',').map(std::borrow::ToOwned::to_owned).collect())
}
})
.unwrap_or_default()
@@ -258,12 +258,12 @@ pub async fn list_databases(
)
.map_err(ListDatabasesError::ValidationError)
{
results.insert(database_name.to_owned(), Err(err));
results.insert(database_name.clone(), Err(err));
continue;
}
let result = sqlx::query_as::<_, DatabaseRow>(
r#"
r"
SELECT
CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`,
GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`,
@@ -281,7 +281,7 @@ pub async fn list_databases(
ON `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = `mysql`.`db`.`DB`
WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` = ?
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
"#,
",
)
.bind(database_name.to_string())
@@ -289,9 +289,7 @@ pub async fn list_databases(
.await
.map_err(|err| ListDatabasesError::MySqlError(err.to_string()))
.and_then(|database| {
database
.map(Ok)
.unwrap_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist))
database.map_or_else(|| Err(ListDatabasesError::DatabaseDoesNotExist), Ok)
});
if let Err(err) = &result {
@@ -313,7 +311,7 @@ pub async fn list_all_databases_for_user(
group_denylist: &GroupDenylist,
) -> ListAllDatabasesResponse {
let result = sqlx::query_as::<_, DatabaseRow>(
r#"
r"
SELECT
CAST(`information_schema`.`SCHEMATA`.`SCHEMA_NAME` AS CHAR(64)) AS `database`,
GROUP_CONCAT(DISTINCT CAST(`information_schema`.`TABLES`.`TABLE_NAME` AS CHAR(64)) SEPARATOR ',') AS `tables`,
@@ -332,7 +330,7 @@ pub async fn list_all_databases_for_user(
WHERE `information_schema`.`SCHEMATA`.`SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `information_schema`.`SCHEMATA`.`SCHEMA_NAME` REGEXP ?
GROUP BY `information_schema`.`SCHEMATA`.`SCHEMA_NAME`
"#,
",
)
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.fetch_all(connection)

View File

@@ -50,12 +50,11 @@ use crate::{
fn get_mysql_row_priv_field(row: &MySqlRow, position: usize) -> Result<bool, sqlx::Error> {
let field = DATABASE_PRIVILEGE_FIELDS[position];
let value = row.try_get(position)?;
match rev_yn(value) {
Some(val) => Ok(val),
_ => {
tracing::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
Ok(false)
}
if let Some(val) = rev_yn(value) {
Ok(val)
} else {
tracing::warn!(r#"Invalid value for privilege "{}": '{}'"#, field, value);
Ok(false)
}
}
@@ -147,7 +146,7 @@ pub async fn get_databases_privilege_data(
) -> ListPrivilegesResponse {
let mut results = BTreeMap::new();
for database_name in database_names.iter() {
for database_name in &database_names {
if let Err(err) = validate_db_or_user_request(
&DbOrUser::Database(database_name.clone()),
unix_user,
@@ -159,15 +158,22 @@ pub async fn get_databases_privilege_data(
continue;
}
if !unsafe_database_exists(database_name, connection)
.await
.unwrap()
{
results.insert(
database_name.to_owned(),
Err(ListPrivilegesError::DatabaseDoesNotExist),
);
continue;
match unsafe_database_exists(database_name, connection).await {
Ok(false) => {
results.insert(
database_name.to_owned(),
Err(ListPrivilegesError::DatabaseDoesNotExist),
);
continue;
}
Err(e) => {
results.insert(
database_name.to_owned(),
Err(ListPrivilegesError::MySqlError(e.to_string())),
);
continue;
}
Ok(true) => {}
}
let result = unsafe_get_database_privileges(database_name, connection)
@@ -185,13 +191,13 @@ pub async fn get_databases_privilege_data(
/// TODO: make this constant
fn get_all_db_privs_query() -> String {
format!(
indoc! {r#"
indoc! {r"
SELECT {} FROM `db` WHERE `db` IN
(SELECT DISTINCT CAST(`SCHEMA_NAME` AS CHAR(64)) AS `database`
FROM `information_schema`.`SCHEMATA`
WHERE `SCHEMA_NAME` NOT IN ('information_schema', 'performance_schema', 'mysql', 'sys')
AND `SCHEMA_NAME` REGEXP ?)
"#},
"},
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| quote_identifier(field))
@@ -234,25 +240,23 @@ async fn unsafe_apply_privilege_diff(
let question_marks =
std::iter::repeat_n("?", DATABASE_PRIVILEGE_FIELDS.len()).join(",");
sqlx::query(
format!("INSERT INTO `db` ({}) VALUES ({})", tables, question_marks).as_str(),
)
.bind(p.db.to_string())
.bind(p.user.to_string())
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(connection)
.await
.map(|_| ())
sqlx::query(format!("INSERT INTO `db` ({tables}) VALUES ({question_marks})").as_str())
.bind(p.db.to_string())
.bind(p.user.to_string())
.bind(yn(p.select_priv))
.bind(yn(p.insert_priv))
.bind(yn(p.update_priv))
.bind(yn(p.delete_priv))
.bind(yn(p.create_priv))
.bind(yn(p.drop_priv))
.bind(yn(p.alter_priv))
.bind(yn(p.index_priv))
.bind(yn(p.create_tmp_table_priv))
.bind(yn(p.lock_tables_priv))
.bind(yn(p.references_priv))
.execute(connection)
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Modified(p) => {
let changes = DATABASE_PRIVILEGE_FIELDS
@@ -274,25 +278,23 @@ async fn unsafe_apply_privilege_diff(
}
}
sqlx::query(
format!("UPDATE `db` SET {} WHERE `Db` = ? AND `User` = ?", changes).as_str(),
)
.bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn))
.bind(p.delete_priv.map(change_to_yn))
.bind(p.create_priv.map(change_to_yn))
.bind(p.drop_priv.map(change_to_yn))
.bind(p.alter_priv.map(change_to_yn))
.bind(p.index_priv.map(change_to_yn))
.bind(p.create_tmp_table_priv.map(change_to_yn))
.bind(p.lock_tables_priv.map(change_to_yn))
.bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
.await
.map(|_| ())
sqlx::query(format!("UPDATE `db` SET {changes} WHERE `Db` = ? AND `User` = ?").as_str())
.bind(p.select_priv.map(change_to_yn))
.bind(p.insert_priv.map(change_to_yn))
.bind(p.update_priv.map(change_to_yn))
.bind(p.delete_priv.map(change_to_yn))
.bind(p.create_priv.map(change_to_yn))
.bind(p.drop_priv.map(change_to_yn))
.bind(p.alter_priv.map(change_to_yn))
.bind(p.index_priv.map(change_to_yn))
.bind(p.create_tmp_table_priv.map(change_to_yn))
.bind(p.lock_tables_priv.map(change_to_yn))
.bind(p.references_priv.map(change_to_yn))
.bind(p.db.to_string())
.bind(p.user.to_string())
.execute(connection)
.await
.map(|_| ())
}
DatabasePrivilegesDiff::Deleted(p) => {
sqlx::query("DELETE FROM `db` WHERE `Db` = ? AND `User` = ?")
@@ -433,23 +435,37 @@ pub async fn apply_privilege_diffs(
continue;
}
if !unsafe_database_exists(diff.get_database_name(), connection)
.await
.unwrap()
{
results.insert(
key,
Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist),
);
continue;
match unsafe_database_exists(diff.get_database_name(), connection).await {
Ok(false) => {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::DatabaseDoesNotExist),
);
continue;
}
Err(e) => {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
);
continue;
}
Ok(true) => {}
}
if !unsafe_user_exists(diff.get_user_name(), connection)
.await
.unwrap()
{
results.insert(key, Err(ModifyDatabasePrivilegesError::UserDoesNotExist));
continue;
match unsafe_user_exists(diff.get_user_name(), connection).await {
Ok(false) => {
results.insert(key, Err(ModifyDatabasePrivilegesError::UserDoesNotExist));
continue;
}
Err(e) => {
results.insert(
key,
Err(ModifyDatabasePrivilegesError::MySqlError(e.to_string())),
);
continue;
}
Ok(true) => {}
}
if let Err(err) = validate_diff(&diff, connection).await {

View File

@@ -34,13 +34,13 @@ pub(super) async fn unsafe_user_exists(
connection: &mut MySqlConnection,
) -> Result<bool, sqlx::Error> {
let result = sqlx::query(
r#"
r"
SELECT EXISTS(
SELECT 1
FROM `mysql`.`user`
WHERE `User` = ?
)
"#,
",
)
.bind(db_user)
.fetch_one(connection)
@@ -62,15 +62,15 @@ pub async fn complete_user_name(
group_denylist: &GroupDenylist,
) -> Vec<MySQLUser> {
let result = sqlx::query(
r#"
r"
SELECT `User` AS `user`
FROM `mysql`.`user`
WHERE `User` REGEXP ?
AND `User` LIKE ?
"#,
",
)
.bind(create_user_group_matching_regex(unix_user, group_denylist))
.bind(format!("{}%", user_prefix))
.bind(format!("{user_prefix}%"))
.fetch_all(connection)
.await;
@@ -236,12 +236,12 @@ const DATABASE_USER_LOCK_STATUS_QUERY_MARIADB: &str = r#"
AND `Host` = '%'
"#;
const DATABASE_USER_LOCK_STATUS_QUERY_MYSQL: &str = r#"
const DATABASE_USER_LOCK_STATUS_QUERY_MYSQL: &str = r"
SELECT `mysql`.`user`.`account_locked` = 'Y'
FROM `mysql`.`user`
WHERE `User` = ?
AND `Host` = '%'
"#;
";
// NOTE: this function is unsafe because it does no input validation.
async fn database_user_is_locked_unsafe(
@@ -430,14 +430,14 @@ JOIN `global_priv` ON
AND `user`.`Host` = `global_priv`.`Host`
"#;
const DB_USER_SELECT_STATEMENT_MYSQL: &str = r#"
const DB_USER_SELECT_STATEMENT_MYSQL: &str = r"
SELECT
`user`.`User`,
`user`.`Host`,
`user`.`authentication_string` != '' AS `has_password`,
`user`.`account_locked` = 'Y' AS `account_locked`
FROM `user`
"#;
";
pub async fn list_database_users(
db_users: Vec<MySQLUser>,
@@ -472,8 +472,10 @@ pub async fn list_database_users(
tracing::error!("Failed to list database user '{}': {:?}", &db_user, err);
}
if let Ok(Some(user)) = result.as_mut() {
append_databases_where_user_has_privileges(user, &mut *connection).await;
if let Ok(Some(user)) = result.as_mut()
&& let Err(err) = set_databases_where_user_has_privileges(user, &mut *connection).await
{
result = Err(err);
}
match result {
@@ -510,27 +512,33 @@ pub async fn list_all_database_users_for_unix_user(
if let Ok(users) = result.as_mut() {
for user in users {
append_databases_where_user_has_privileges(user, &mut *connection).await;
if let Err(mysql_error) =
set_databases_where_user_has_privileges(user, &mut *connection).await
{
return Err(ListAllUsersError::MySqlError(mysql_error.to_string()));
}
}
}
result
}
pub async fn append_databases_where_user_has_privileges(
/// This function sets the `databases` field of the given `DatabaseUser`
/// where the user has any privileges.
pub async fn set_databases_where_user_has_privileges(
db_user: &mut DatabaseUser,
connection: &mut MySqlConnection,
) {
) -> Result<(), sqlx::Error> {
let database_list = sqlx::query(
formatdoc!(
r#"
r"
SELECT `Db` AS `database`
FROM `db`
WHERE `User` = ? AND ({})
"#,
",
DATABASE_PRIVILEGE_FIELDS
.iter()
.map(|field| format!("`{}` = 'Y'", field))
.map(|field| format!("`{field}` = 'Y'"))
.join(" OR "),
)
.as_str(),
@@ -547,11 +555,11 @@ pub async fn append_databases_where_user_has_privileges(
);
}
db_user.databases = database_list
.map(|rows| {
rows.into_iter()
.map(|row| try_get_with_binary_fallback(&row, "database").unwrap())
.collect()
})
.unwrap_or_default();
db_user.databases = database_list.and_then(|rows| {
rows.into_iter()
.map(|row| try_get_with_binary_fallback(&row, "database"))
.collect::<Result<Vec<String>, sqlx::Error>>()
})?;
Ok(())
}

View File

@@ -71,21 +71,19 @@ impl Supervisor {
let config = ServerConfig::read_config_from_path(&config_path)
.context("Failed to read server configuration")?;
let group_deny_list = match &config.authorization.group_denylist_file {
Some(denylist_path) => {
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
Arc::new(RwLock::new(denylist))
}
None => {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
Arc::new(RwLock::new(GroupDenylist::new()))
}
let group_deny_list = if let Some(denylist_path) = &config.authorization.group_denylist_file
{
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
Arc::new(RwLock::new(denylist))
} else {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
Arc::new(RwLock::new(GroupDenylist::new()))
};
let mut watchdog_duration = None;
@@ -93,12 +91,13 @@ impl Supervisor {
#[cfg(target_os = "linux")]
let watchdog_task =
if systemd_mode && sd_notify::watchdog_enabled(true, &mut watchdog_micro_seconds) {
watchdog_duration = Some(Duration::from_micros(watchdog_micro_seconds));
let watchdog_duration_ = Duration::from_micros(watchdog_micro_seconds);
tracing::debug!(
"Systemd watchdog enabled with {} millisecond interval",
watchdog_micro_seconds.div_ceil(1000),
);
Some(spawn_watchdog_task(watchdog_duration.unwrap()))
watchdog_duration = Some(watchdog_duration_);
Some(spawn_watchdog_task(watchdog_duration_))
} else {
tracing::debug!("Systemd watchdog not enabled, skipping watchdog thread");
None
@@ -221,22 +220,20 @@ impl Supervisor {
let mut config = self.config.clone().lock_owned().await;
*config = new_config;
let group_deny_list = match &config.authorization.group_denylist_file {
Some(denylist_path) => {
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
let group_deny_list = if let Some(denylist_path) = &config.authorization.group_denylist_file
{
let denylist = read_and_parse_group_denylist(denylist_path)
.context("Failed to read group denylist file")?;
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
denylist
}
None => {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
GroupDenylist::new()
}
tracing::debug!(
"Loaded group denylist with {} entries from {:?}",
denylist.len(),
denylist_path
);
denylist
} else {
tracing::debug!("No group denylist file specified, proceeding without a denylist");
GroupDenylist::new()
};
let mut group_deny_list_lock = self.group_deny_list.write().await;
*group_deny_list_lock = group_deny_list;
@@ -387,7 +384,7 @@ impl Supervisor {
}
}
_ = self.shutdown_cancel_token.cancelled() => {
() = self.shutdown_cancel_token.cancelled() => {
tracing::info!("Shutting down server");
self.shutdown().await?;
break;
@@ -427,7 +424,7 @@ fn spawn_status_notifier_task(task_tracker: TaskTracker) -> JoinHandle<()> {
let count = task_tracker.len();
let message = if count > 0 {
format!("Handling {} connections", count)
format!("Handling {count} connections")
} else {
"Waiting for connections".to_string()
};
@@ -453,7 +450,7 @@ async fn create_unix_listener_with_socket_path(
tracing::info!("Listening on socket {:?}", socket_path);
match fs::remove_file(socket_path.as_path()) {
Ok(_) => {}
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => return Err(e.into()),
}
@@ -470,7 +467,7 @@ async fn create_unix_listener_with_systemd_socket() -> anyhow::Result<TokioUnixL
.next()
.context("No file descriptors received from systemd")?;
debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {}", fd);
debug_assert!(fd == 3, "Unexpected file descriptor from systemd: {fd}");
tracing::debug!(
"Received file descriptor from systemd with id: '{}', assuming socket",