diff --git a/src/core/bootstrap.rs b/src/core/bootstrap.rs index cabfbac..9b9fee4 100644 --- a/src/core/bootstrap.rs +++ b/src/core/bootstrap.rs @@ -277,8 +277,23 @@ fn run_forked_server( .block_on(async { let socket = TokioUnixStream::from_std(server_socket)?; let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?; + let db_is_mariadb = { + let mut conn = db_pool.acquire().await?; + let version_row: String = sqlx::query_scalar("SELECT VERSION()") + .fetch_one(&mut *conn) + .await + .context("Failed to query MySQL version")?; + version_row.to_lowercase().contains("mariadb") + }; + let db_pool = Arc::new(RwLock::new(db_pool)); - session_handler::session_handler_with_unix_user(socket, &unix_user, db_pool).await?; + session_handler::session_handler_with_unix_user( + socket, + &unix_user, + db_pool, + db_is_mariadb, + ) + .await?; Ok(()) }); diff --git a/src/core/protocol/commands/modify_privileges.rs b/src/core/protocol/commands/modify_privileges.rs index 7d9312a..49d10e5 100644 --- a/src/core/protocol/commands/modify_privileges.rs +++ b/src/core/protocol/commands/modify_privileges.rs @@ -102,9 +102,7 @@ impl ModifyDatabasePrivilegesError { ModifyDatabasePrivilegesError::DatabaseDoesNotExist => { "database-does-not-exist".to_string() } - ModifyDatabasePrivilegesError::UserDoesNotExist => { - "user-does-not-exist".to_string() - } + ModifyDatabasePrivilegesError::UserDoesNotExist => "user-does-not-exist".to_string(), ModifyDatabasePrivilegesError::DiffDoesNotApply(err) => { format!("diff-does-not-apply/{}", err.error_type()) } diff --git a/src/server/session_handler.rs b/src/server/session_handler.rs index ac18497..55a061b 100644 --- a/src/server/session_handler.rs +++ b/src/server/session_handler.rs @@ -38,6 +38,7 @@ use crate::{ pub async fn session_handler( socket: UnixStream, db_pool: Arc>, + db_is_mariadb: bool, ) -> anyhow::Result<()> { let uid = match socket.peer_cred() { Ok(cred) => cred.uid(), @@ -84,7 +85,8 @@ pub async fn session_handler( (async move { tracing::info!("Accepted connection from user: {}", unix_user); - let result = session_handler_with_unix_user(socket, &unix_user, db_pool).await; + let result = + session_handler_with_unix_user(socket, &unix_user, db_pool, db_is_mariadb).await; tracing::info!( "Finished handling requests for connection from user: {}", @@ -101,6 +103,7 @@ pub async fn session_handler_with_unix_user( socket: UnixStream, unix_user: &UnixUser, db_pool: Arc>, + db_is_mariadb: bool, ) -> anyhow::Result<()> { let mut message_stream = create_server_to_client_message_stream(socket); @@ -123,8 +126,13 @@ pub async fn session_handler_with_unix_user( }; tracing::debug!("Successfully acquired database connection from pool"); - let result = - session_handler_with_db_connection(message_stream, unix_user, &mut db_connection).await; + let result = session_handler_with_db_connection( + message_stream, + unix_user, + &mut db_connection, + db_is_mariadb, + ) + .await; tracing::debug!("Releasing database connection back to pool"); @@ -138,6 +146,7 @@ async fn session_handler_with_db_connection( mut stream: ServerToClientMessageStream, unix_user: &UnixUser, db_connection: &mut MySqlConnection, + db_is_mariadb: bool, ) -> anyhow::Result<()> { stream.send(Response::Ready).await?; loop { @@ -180,9 +189,13 @@ async fn session_handler_with_db_connection( { Response::CompleteDatabaseName(vec![]) } else { - let result = - complete_database_name(partial_database_name, unix_user, db_connection) - .await; + let result = complete_database_name( + partial_database_name, + unix_user, + db_connection, + db_is_mariadb, + ) + .await; Response::CompleteDatabaseName(result) } } @@ -194,39 +207,54 @@ async fn session_handler_with_db_connection( { Response::CompleteUserName(vec![]) } else { - let result = - complete_user_name(partial_user_name, unix_user, db_connection).await; + let result = complete_user_name( + partial_user_name, + unix_user, + db_connection, + db_is_mariadb, + ) + .await; Response::CompleteUserName(result) } } Request::CreateDatabases(databases_names) => { - let result = create_databases(databases_names, unix_user, db_connection).await; + let result = + create_databases(databases_names, unix_user, db_connection, db_is_mariadb) + .await; Response::CreateDatabases(result) } Request::DropDatabases(databases_names) => { - let result = drop_databases(databases_names, unix_user, db_connection).await; + let result = + drop_databases(databases_names, unix_user, db_connection, db_is_mariadb).await; Response::DropDatabases(result) } Request::ListDatabases(database_names) => match database_names { Some(database_names) => { - let result = list_databases(database_names, unix_user, db_connection).await; + let result = + list_databases(database_names, unix_user, db_connection, db_is_mariadb) + .await; Response::ListDatabases(result) } None => { - let result = list_all_databases_for_user(unix_user, db_connection).await; + let result = + list_all_databases_for_user(unix_user, db_connection, db_is_mariadb).await; Response::ListAllDatabases(result) } }, Request::ListPrivileges(database_names) => match database_names { Some(database_names) => { - let privilege_data = - get_databases_privilege_data(database_names, unix_user, db_connection) - .await; + let privilege_data = get_databases_privilege_data( + database_names, + unix_user, + db_connection, + db_is_mariadb, + ) + .await; Response::ListPrivileges(privilege_data) } None => { let privilege_data = - get_all_database_privileges(unix_user, db_connection).await; + get_all_database_privileges(unix_user, db_connection, db_is_mariadb).await; Response::ListAllPrivileges(privilege_data) } }, @@ -235,41 +263,57 @@ async fn session_handler_with_db_connection( BTreeSet::from_iter(database_privilege_diffs), unix_user, db_connection, + db_is_mariadb, ) .await; Response::ModifyPrivileges(result) } Request::CreateUsers(db_users) => { - let result = create_database_users(db_users, unix_user, db_connection).await; + let result = + create_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; Response::CreateUsers(result) } Request::DropUsers(db_users) => { - let result = drop_database_users(db_users, unix_user, db_connection).await; + let result = + drop_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; Response::DropUsers(result) } Request::PasswdUser((db_user, password)) => { - let result = - set_password_for_database_user(&db_user, &password, unix_user, db_connection) - .await; + let result = set_password_for_database_user( + &db_user, + &password, + unix_user, + db_connection, + db_is_mariadb, + ) + .await; Response::SetUserPassword(result) } Request::ListUsers(db_users) => match db_users { Some(db_users) => { - let result = list_database_users(db_users, unix_user, db_connection).await; + let result = + list_database_users(db_users, unix_user, db_connection, db_is_mariadb) + .await; Response::ListUsers(result) } None => { - let result = - list_all_database_users_for_unix_user(unix_user, db_connection).await; + let result = list_all_database_users_for_unix_user( + unix_user, + db_connection, + db_is_mariadb, + ) + .await; Response::ListAllUsers(result) } }, Request::LockUsers(db_users) => { - let result = lock_database_users(db_users, unix_user, db_connection).await; + let result = + lock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; Response::LockUsers(result) } Request::UnlockUsers(db_users) => { - let result = unlock_database_users(db_users, unix_user, db_connection).await; + let result = + unlock_database_users(db_users, unix_user, db_connection, db_is_mariadb).await; Response::UnlockUsers(result) } Request::Exit => { diff --git a/src/server/sql/database_operations.rs b/src/server/sql/database_operations.rs index d5fe151..129b6f5 100644 --- a/src/server/sql/database_operations.rs +++ b/src/server/sql/database_operations.rs @@ -48,6 +48,7 @@ pub async fn complete_database_name( database_prefix: String, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> CompleteDatabaseNameResponse { let result = sqlx::query( r#" @@ -87,6 +88,7 @@ pub async fn create_databases( database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> CreateDatabasesResponse { let mut results = BTreeMap::new(); @@ -146,6 +148,7 @@ pub async fn drop_databases( database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> DropDatabasesResponse { let mut results = BTreeMap::new(); @@ -218,6 +221,7 @@ pub async fn list_databases( database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> ListDatabasesResponse { let mut results = BTreeMap::new(); @@ -268,6 +272,7 @@ pub async fn list_databases( pub async fn list_all_databases_for_user( unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> ListAllDatabasesResponse { let result = sqlx::query_as::<_, DatabaseRow>( r#" diff --git a/src/server/sql/database_privilege_operations.rs b/src/server/sql/database_privilege_operations.rs index cc5b444..c17833e 100644 --- a/src/server/sql/database_privilege_operations.rs +++ b/src/server/sql/database_privilege_operations.rs @@ -140,6 +140,7 @@ pub async fn get_databases_privilege_data( database_names: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> ListPrivilegesResponse { let mut results = BTreeMap::new(); @@ -187,6 +188,7 @@ pub async fn get_databases_privilege_data( pub async fn get_all_database_privileges( unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> ListAllPrivilegesResponse { let result = sqlx::query_as::<_, DatabasePrivilegeRow>(&format!( indoc! {r#" @@ -394,6 +396,7 @@ pub async fn apply_privilege_diffs( database_privilege_diffs: BTreeSet, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> ModifyPrivilegesResponse { let mut results: BTreeMap<(MySQLDatabase, MySQLUser), _> = BTreeMap::new(); @@ -451,10 +454,7 @@ pub async fn apply_privilege_diffs( .await .unwrap() { - results.insert( - key, - Err(ModifyDatabasePrivilegesError::UserDoesNotExist), - ); + results.insert(key, Err(ModifyDatabasePrivilegesError::UserDoesNotExist)); continue; } diff --git a/src/server/sql/user_operations.rs b/src/server/sql/user_operations.rs index 584f96a..afe728b 100644 --- a/src/server/sql/user_operations.rs +++ b/src/server/sql/user_operations.rs @@ -55,6 +55,7 @@ pub async fn complete_user_name( user_prefix: String, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> Vec { let result = sqlx::query( r#" @@ -93,6 +94,7 @@ pub async fn create_database_users( db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> CreateUsersResponse { let mut results = BTreeMap::new(); @@ -139,6 +141,7 @@ pub async fn drop_database_users( db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> DropUsersResponse { let mut results = BTreeMap::new(); @@ -186,6 +189,7 @@ pub async fn set_password_for_database_user( password: &str, unix_user: &UnixUser, connection: &mut MySqlConnection, + _db_is_mariadb: bool, ) -> SetUserPasswordResponse { if let Err(err) = validate_name(db_user) { return Err(SetPasswordError::SanitizationError(err)); @@ -224,26 +228,39 @@ pub async fn set_password_for_database_user( result } +const DATABASE_USER_LOCK_STATUS_QUERY_MARIADB: &str = r#" + SELECT COALESCE( + JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), + 'false' + ) != 'false' + FROM `mysql`.`global_priv` + WHERE `User` = ? + AND `Host` = '%' +"#; + +const DATABASE_USER_LOCK_STATUS_QUERY_MYSQL: &str = r#" + SELECT `mysql`.`user`.`account_locked` = 'Y' + FROM `mysql`.`user` + WHERE `User` = ? + AND `Host` = '%' +"#; + // NOTE: this function is unsafe because it does no input validation. async fn database_user_is_locked_unsafe( db_user: &str, connection: &mut MySqlConnection, + db_is_mariadb: bool, ) -> Result { - let result = sqlx::query( - r#" - SELECT COALESCE( - JSON_EXTRACT(`mysql`.`global_priv`.`priv`, "$.account_locked"), - 'false' - ) != 'false' - FROM `mysql`.`global_priv` - WHERE `User` = ? - AND `Host` = '%' - "#, - ) + let result = sqlx::query(if db_is_mariadb { + DATABASE_USER_LOCK_STATUS_QUERY_MARIADB + } else { + DATABASE_USER_LOCK_STATUS_QUERY_MYSQL + }) .bind(db_user) .fetch_one(connection) .await - .map(|row| row.get::(0)); + .map(|row| row.try_get(0)) + .and_then(|res| res); if let Err(err) = &result { tracing::error!( @@ -260,6 +277,7 @@ pub async fn lock_database_users( db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + db_is_mariadb: bool, ) -> LockUsersResponse { let mut results = BTreeMap::new(); @@ -286,7 +304,7 @@ pub async fn lock_database_users( } } - match database_user_is_locked_unsafe(&db_user, &mut *connection).await { + match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await { Ok(false) => {} Ok(true) => { results.insert(db_user, Err(LockUserError::UserIsAlreadyLocked)); @@ -320,6 +338,7 @@ pub async fn unlock_database_users( db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + db_is_mariadb: bool, ) -> UnlockUsersResponse { let mut results = BTreeMap::new(); @@ -346,7 +365,7 @@ pub async fn unlock_database_users( _ => {} } - match database_user_is_locked_unsafe(&db_user, &mut *connection).await { + match database_user_is_locked_unsafe(&db_user, &mut *connection, db_is_mariadb).await { Ok(false) => { results.insert(db_user, Err(UnlockUserError::UserIsAlreadyUnlocked)); continue; @@ -394,13 +413,13 @@ impl FromRow<'_, sqlx::mysql::MySqlRow> for DatabaseUser { user: try_get_with_binary_fallback(row, "User")?.into(), host: try_get_with_binary_fallback(row, "Host")?, has_password: row.try_get("has_password")?, - is_locked: row.try_get("is_locked")?, + is_locked: row.try_get("account_locked")?, databases: Vec::new(), }) } } -const DB_USER_SELECT_STATEMENT: &str = r#" +const DB_USER_SELECT_STATEMENT_MARIADB: &str = r#" SELECT `user`.`User`, `user`.`Host`, @@ -408,17 +427,27 @@ SELECT COALESCE( JSON_EXTRACT(`global_priv`.`priv`, "$.account_locked"), 'false' - ) != 'false' AS `is_locked` + ) != 'false' AS `account_locked` FROM `user` JOIN `global_priv` ON `user`.`User` = `global_priv`.`User` AND `user`.`Host` = `global_priv`.`Host` "#; +const DB_USER_SELECT_STATEMENT_MYSQL: &str = r#" +SELECT + `user`.`User`, + `user`.`Host`, + `user`.`authentication_string` != '' AS `has_password`, + `user`.`account_locked` = 'Y' AS `account_locked` +FROM `user` +"#; + pub async fn list_database_users( db_users: Vec, unix_user: &UnixUser, connection: &mut MySqlConnection, + db_is_mariadb: bool, ) -> ListUsersResponse { let mut results = BTreeMap::new(); @@ -434,7 +463,11 @@ pub async fn list_database_users( } let mut result = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `mysql`.`user`.`User` = ?"), + &(if db_is_mariadb { + DB_USER_SELECT_STATEMENT_MARIADB.to_string() + } else { + DB_USER_SELECT_STATEMENT_MYSQL.to_string() + } + "WHERE `mysql`.`user`.`User` = ?"), ) .bind(db_user.as_str()) .fetch_optional(&mut *connection) @@ -461,9 +494,14 @@ pub async fn list_database_users( pub async fn list_all_database_users_for_unix_user( unix_user: &UnixUser, connection: &mut MySqlConnection, + db_is_mariadb: bool, ) -> ListAllUsersResponse { let mut result = sqlx::query_as::<_, DatabaseUser>( - &(DB_USER_SELECT_STATEMENT.to_string() + "WHERE `user`.`User` REGEXP ?"), + &(if db_is_mariadb { + DB_USER_SELECT_STATEMENT_MARIADB.to_string() + } else { + DB_USER_SELECT_STATEMENT_MYSQL.to_string() + } + "WHERE `user`.`User` REGEXP ?"), ) .bind(create_user_group_matching_regex(unix_user)) .fetch_all(&mut *connection) diff --git a/src/server/supervisor.rs b/src/server/supervisor.rs index 97d8760..7aac28b 100644 --- a/src/server/supervisor.rs +++ b/src/server/supervisor.rs @@ -43,6 +43,7 @@ pub struct Supervisor { signal_handler_task: JoinHandle<()>, db_connection_pool: Arc>, + db_is_mariadb: Arc>, listener: Arc>, listener_task: JoinHandle>, handler_task_tracker: TaskTracker, @@ -83,6 +84,22 @@ impl Supervisor { let db_connection_pool = Arc::new(RwLock::new(create_db_connection_pool(&config.mysql).await?)); + let db_is_mariadb = { + let connection = db_connection_pool.read().await; + let version: String = sqlx::query_scalar("SELECT VERSION()") + .fetch_one(&*connection) + .await + .context("Failed to query database version")?; + + let result = version.to_lowercase().contains("mariadb"); + tracing::debug!( + "Connected to {} database server", + if result { "MariaDB" } else { "MySQL" } + ); + + Arc::new(RwLock::new(result)) + }; + let task_tracker = TaskTracker::new(); let status_notifier_task = if systemd_mode { @@ -112,6 +129,7 @@ impl Supervisor { task_tracker_clone, db_connection_pool.clone(), rx, + db_is_mariadb.clone(), )) }; @@ -123,6 +141,7 @@ impl Supervisor { shutdown_cancel_token, signal_handler_task, db_connection_pool, + db_is_mariadb, listener, listener_task, handler_task_tracker: task_tracker, @@ -165,8 +184,26 @@ impl Supervisor { async fn restart_db_connection_pool(&self) -> anyhow::Result<()> { let config = self.config.lock().await; let mut connection_pool = self.db_connection_pool.clone().write_owned().await; + let mut db_is_mariadb_lock = self.db_is_mariadb.write().await; + let new_db_pool = create_db_connection_pool(&config.mysql).await?; + let db_is_mariadb = { + let version: String = sqlx::query_scalar("SELECT VERSION()") + .fetch_one(&new_db_pool) + .await + .context("Failed to query database version")?; + + let result = version.to_lowercase().contains("mariadb"); + tracing::debug!( + "Connected to {} database server", + if result { "MariaDB" } else { "MySQL" } + ); + + result + }; + *connection_pool = new_db_pool; + *db_is_mariadb_lock = db_is_mariadb; Ok(()) } @@ -429,6 +466,7 @@ async fn listener_task( task_tracker: TaskTracker, db_pool: Arc>, mut supervisor_message_receiver: broadcast::Receiver, + db_is_mariadb: Arc>, ) -> anyhow::Result<()> { sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?; @@ -464,8 +502,9 @@ async fn listener_task( tracing::debug!("Got new connection"); let db_pool_clone = db_pool.clone(); - task_tracker.spawn(async { - match session_handler(conn, db_pool_clone).await { + let db_is_mariadb_clone = *db_is_mariadb.read().await; + task_tracker.spawn(async move { + match session_handler(conn, db_pool_clone, db_is_mariadb_clone).await { Ok(()) => {} Err(e) => { tracing::error!("Failed to run server: {}", e);