From 01d502337dcbf6981ade446da15ff3ea3ab2a477 Mon Sep 17 00:00:00 2001
From: h7x4 <h7x4@nani.wtf>
Date: Wed, 7 Aug 2024 16:16:46 +0200
Subject: [PATCH] Don't fail on erroneus db connection closure

---
 src/cli/database_command.rs                   |  6 +++---
 .../mysql_useradm.rs                          | 20 ++++++++++---------
 src/cli/user_command.rs                       |  6 +++---
 src/core/common.rs                            |  8 ++++++++
 4 files changed, 25 insertions(+), 15 deletions(-)

diff --git a/src/cli/database_command.rs b/src/cli/database_command.rs
index 23c9fc3..0eeb58b 100644
--- a/src/cli/database_command.rs
+++ b/src/cli/database_command.rs
@@ -3,11 +3,11 @@ use clap::Parser;
 use indoc::indoc;
 use itertools::Itertools;
 use prettytable::{Cell, Row, Table};
-use sqlx::{Connection, MySqlConnection};
+use sqlx::MySqlConnection;
 
 use crate::core::{
     self,
-    common::get_current_unix_user,
+    common::{close_database_connection, get_current_unix_user},
     database_operations::{
         apply_permission_diffs, db_priv_field_human_readable_name, diff_permissions, yn,
         DatabasePrivileges, DATABASE_PRIVILEGE_FIELDS,
@@ -155,7 +155,7 @@ pub async fn handle_command(
         DatabaseCommand::EditDbPerm(args) => edit_permissions(args, &mut conn).await,
     };
 
-    conn.close().await?;
+    close_database_connection(conn).await;
 
     result
 }
diff --git a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
index 5e8b3e0..f8893c6 100644
--- a/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
+++ b/src/cli/mysql_admutils_compatibility/mysql_useradm.rs
@@ -7,7 +7,7 @@ use crate::{
         user_command,
     },
     core::{
-        common::get_current_unix_user,
+        common::{close_database_connection, get_current_unix_user},
         config::{get_config, mysql_connection_from_config, GlobalConfigArgs},
         user_operations::{
             create_database_user, delete_database_user, get_all_database_users_for_unix_user,
@@ -107,14 +107,16 @@ pub async fn main() -> anyhow::Result<()> {
                 delete_database_user(&name, &mut connection).await?;
             }
         }
-        Command::Passwd(args) => passwd(args, connection).await?,
-        Command::Show(args) => show(args, connection).await?,
+        Command::Passwd(args) => passwd(args, &mut connection).await?,
+        Command::Show(args) => show(args, &mut connection).await?,
     }
 
+    close_database_connection(connection).await;
+
     Ok(())
 }
 
-async fn passwd(args: PasswdArgs, mut connection: MySqlConnection) -> anyhow::Result<()> {
+async fn passwd(args: PasswdArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> {
     let filtered_names = filter_db_or_user_names(args.name, DbOrUser::User)?;
 
     // NOTE: this gets doubly checked during the call to `set_password_for_database_user`.
@@ -123,7 +125,7 @@ async fn passwd(args: PasswdArgs, mut connection: MySqlConnection) -> anyhow::Re
     //       have entered the password twice.
     let mut better_filtered_names = Vec::with_capacity(filtered_names.len());
     for name in filtered_names.into_iter() {
-        if !user_exists(&name, &mut connection).await? {
+        if !user_exists(&name, connection).await? {
             println!(
                 "{}: User '{}' does not exist. You must create it first.",
                 std::env::args()
@@ -138,17 +140,17 @@ async fn passwd(args: PasswdArgs, mut connection: MySqlConnection) -> anyhow::Re
 
     for name in better_filtered_names {
         let password = user_command::read_password_from_stdin_with_double_check(&name)?;
-        set_password_for_database_user(&name, &password, &mut connection).await?;
+        set_password_for_database_user(&name, &password, connection).await?;
         println!("Password updated for user '{}'.", name);
     }
 
     Ok(())
 }
 
-async fn show(args: ShowArgs, mut connection: MySqlConnection) -> anyhow::Result<()> {
+async fn show(args: ShowArgs, connection: &mut MySqlConnection) -> anyhow::Result<()> {
     let users = if args.name.is_empty() {
         let unix_user = get_current_unix_user()?;
-        get_all_database_users_for_unix_user(&unix_user, &mut connection)
+        get_all_database_users_for_unix_user(&unix_user, connection)
             .await?
             .into_iter()
             .map(|u| u.user)
@@ -158,7 +160,7 @@ async fn show(args: ShowArgs, mut connection: MySqlConnection) -> anyhow::Result
     };
 
     for user in users {
-        let password_is_set = password_is_set_for_database_user(&user, &mut connection).await?;
+        let password_is_set = password_is_set_for_database_user(&user, connection).await?;
 
         match password_is_set {
             Some(true) => println!("User '{}': password set.", user),
diff --git a/src/cli/user_command.rs b/src/cli/user_command.rs
index e9778c2..0e395d1 100644
--- a/src/cli/user_command.rs
+++ b/src/cli/user_command.rs
@@ -2,9 +2,9 @@ use std::vec;
 
 use anyhow::Context;
 use clap::Parser;
-use sqlx::{Connection, MySqlConnection};
+use sqlx::MySqlConnection;
 
-use crate::core::user_operations::validate_user_name;
+use crate::core::{common::close_database_connection, user_operations::validate_user_name};
 
 #[derive(Parser)]
 pub struct UserArgs {
@@ -67,7 +67,7 @@ pub async fn handle_command(command: UserCommand, mut conn: MySqlConnection) ->
         UserCommand::ShowUser(args) => show_users(args, &mut conn).await,
     };
 
-    conn.close().await?;
+    close_database_connection(conn).await;
 
     result
 }
diff --git a/src/core/common.rs b/src/core/common.rs
index c641b14..881cb6d 100644
--- a/src/core/common.rs
+++ b/src/core/common.rs
@@ -2,6 +2,7 @@ use anyhow::Context;
 use indoc::indoc;
 use itertools::Itertools;
 use nix::unistd::{getuid, Group, User};
+use sqlx::{Connection, MySqlConnection};
 
 #[cfg(not(target_os = "macos"))]
 use std::ffi::CString;
@@ -140,6 +141,13 @@ pub fn validate_ownership_by_user_prefix<'a>(
     Ok(prefix)
 }
 
+pub async fn close_database_connection(conn: MySqlConnection) {
+    if let Err(e) = conn.close().await.context("Failed to close connection properly") {
+        eprintln!("{}", e);
+        eprintln!("Ignoring...");
+    }
+}
+
 pub fn quote_literal(s: &str) -> String {
     format!("'{}'", s.replace('\'', r"\'"))
 }