server: make use of database connection pool
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
use std::{fs, path::PathBuf};
|
||||
use std::{fs, path::PathBuf, time::Duration};
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use clap_verbosity_flag::Verbosity;
|
||||
use nix::libc::{EXIT_SUCCESS, exit};
|
||||
use sqlx::mysql::MySqlPoolOptions;
|
||||
use std::os::unix::net::UnixStream as StdUnixStream;
|
||||
use tokio::net::UnixStream as TokioUnixStream;
|
||||
|
||||
@@ -10,7 +11,10 @@ use crate::{
|
||||
core::common::{
|
||||
DEFAULT_CONFIG_PATH, DEFAULT_SOCKET_PATH, UnixUser, executable_is_suid_or_sgid,
|
||||
},
|
||||
server::{config::read_config_from_path, session_handler},
|
||||
server::{
|
||||
config::{MysqlConfig, read_config_from_path},
|
||||
session_handler,
|
||||
},
|
||||
};
|
||||
|
||||
/// Determine whether we will make a connection to an external server
|
||||
@@ -208,6 +212,31 @@ fn invoke_server_with_config(config_path: PathBuf) -> anyhow::Result<StdUnixStre
|
||||
}
|
||||
}
|
||||
|
||||
async fn construct_single_connection_mysql_pool(
|
||||
config: &MysqlConfig,
|
||||
) -> anyhow::Result<sqlx::MySqlPool> {
|
||||
let mysql_config = config.as_mysql_connect_options()?;
|
||||
|
||||
let pool_opts = MySqlPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.min_connections(1);
|
||||
|
||||
config.log_connection_notice();
|
||||
|
||||
let pool = match tokio::time::timeout(
|
||||
Duration::from_secs(config.timeout),
|
||||
pool_opts.connect_with(mysql_config),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection.context("Failed to connect to the database"),
|
||||
Err(_) => Err(anyhow!("Timed out after {} seconds", config.timeout))
|
||||
.context("Failed to connect to the database"),
|
||||
}?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// Run the server in the forked child process.
|
||||
/// This function will not return, but will exit the process with a success code.
|
||||
fn run_forked_server(
|
||||
@@ -223,7 +252,8 @@ fn run_forked_server(
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
let socket = TokioUnixStream::from_std(server_socket)?;
|
||||
session_handler::session_handler(socket, &unix_user, &config).await?;
|
||||
let db_pool = construct_single_connection_mysql_pool(&config.mysql).await?;
|
||||
session_handler::session_handler(socket, &unix_user, db_pool).await?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::{fs, path::PathBuf, time::Duration};
|
||||
use std::{fs, path::PathBuf};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{ConnectOptions, MySqlConnection, mysql::MySqlConnectOptions};
|
||||
use sqlx::{ConnectOptions, mysql::MySqlConnectOptions};
|
||||
|
||||
use crate::core::common::DEFAULT_CONFIG_PATH;
|
||||
|
||||
@@ -37,6 +37,45 @@ pub struct MysqlConfig {
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
impl MysqlConfig {
|
||||
pub fn as_mysql_connect_options(&self) -> anyhow::Result<MySqlConnectOptions> {
|
||||
let mut options = MySqlConnectOptions::new()
|
||||
.database("mysql")
|
||||
.log_statements(log::LevelFilter::Trace);
|
||||
|
||||
if let Some(username) = &self.username {
|
||||
options = options.username(username);
|
||||
}
|
||||
|
||||
if let Some(password) = &self.password {
|
||||
options = options.password(password);
|
||||
}
|
||||
|
||||
if let Some(socket_path) = &self.socket_path {
|
||||
options = options.socket(socket_path);
|
||||
} else if let Some(host) = &self.host {
|
||||
options = options.host(host);
|
||||
options = options.port(self.port);
|
||||
} else {
|
||||
anyhow::bail!("No MySQL host or socket path provided");
|
||||
}
|
||||
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
pub fn log_connection_notice(&self) {
|
||||
let mut display_config = self.to_owned();
|
||||
display_config.password = display_config
|
||||
.password
|
||||
.as_ref()
|
||||
.map(|_| "<REDACTED>".to_owned());
|
||||
log::debug!(
|
||||
"Connecting to MySQL server with parameters: {:#?}",
|
||||
display_config
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
pub struct ServerConfigArgs {
|
||||
/// Path where the server socket should be created.
|
||||
@@ -128,50 +167,3 @@ pub fn read_config_from_path(config_path: Option<PathBuf>) -> anyhow::Result<Ser
|
||||
.and_then(|c| toml::from_str(&c).context("Failed to parse config file"))
|
||||
.context(format!("Failed to parse config file at {:?}", &config_path))
|
||||
}
|
||||
|
||||
fn log_config(config: &MysqlConfig) {
|
||||
let mut display_config = config.to_owned();
|
||||
display_config.password = display_config
|
||||
.password
|
||||
.as_ref()
|
||||
.map(|_| "<REDACTED>".to_owned());
|
||||
log::debug!(
|
||||
"Connecting to MySQL server with parameters: {:#?}",
|
||||
display_config
|
||||
);
|
||||
}
|
||||
|
||||
/// Use the provided configuration to establish a connection to a MySQL server.
|
||||
pub async fn create_mysql_connection_from_config(
|
||||
config: &MysqlConfig,
|
||||
) -> anyhow::Result<MySqlConnection> {
|
||||
log_config(config);
|
||||
|
||||
let mut mysql_options = MySqlConnectOptions::new()
|
||||
.database("mysql")
|
||||
.log_statements(log::LevelFilter::Trace);
|
||||
|
||||
if let Some(username) = &config.username {
|
||||
mysql_options = mysql_options.username(username);
|
||||
}
|
||||
|
||||
if let Some(password) = &config.password {
|
||||
mysql_options = mysql_options.password(password);
|
||||
}
|
||||
|
||||
if let Some(socket_path) = &config.socket_path {
|
||||
mysql_options = mysql_options.socket(socket_path);
|
||||
} else if let Some(host) = &config.host {
|
||||
mysql_options = mysql_options.host(host);
|
||||
mysql_options = mysql_options.port(config.port);
|
||||
} else {
|
||||
anyhow::bail!("No MySQL host or socket path provided");
|
||||
}
|
||||
|
||||
match tokio::time::timeout(Duration::from_secs(config.timeout), mysql_options.connect()).await {
|
||||
Ok(connection) => connection.context("Failed to connect to the database"),
|
||||
Err(_) => {
|
||||
Err(anyhow!("Timed out after 2 seconds")).context("Failed to connect to the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,32 +2,28 @@ use std::collections::BTreeSet;
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use indoc::concatdoc;
|
||||
use sqlx::{MySql, MySqlConnection, MySqlPool, pool::PoolConnection};
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
use sqlx::MySqlConnection;
|
||||
use sqlx::prelude::*;
|
||||
|
||||
use crate::core::protocol::SetPasswordError;
|
||||
use crate::server::sql::database_operations::list_databases;
|
||||
use crate::{
|
||||
core::{
|
||||
common::UnixUser,
|
||||
protocol::{
|
||||
Request, Response, ServerToClientMessageStream, create_server_to_client_message_stream,
|
||||
Request, Response, ServerToClientMessageStream, SetPasswordError,
|
||||
create_server_to_client_message_stream,
|
||||
},
|
||||
},
|
||||
server::{
|
||||
config::{ServerConfig, create_mysql_connection_from_config},
|
||||
sql::{
|
||||
database_operations::{create_databases, drop_databases, list_all_databases_for_user},
|
||||
database_privilege_operations::{
|
||||
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
|
||||
},
|
||||
user_operations::{
|
||||
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
|
||||
list_database_users, lock_database_users, set_password_for_database_user,
|
||||
unlock_database_users,
|
||||
},
|
||||
server::sql::{
|
||||
database_operations::{
|
||||
create_databases, drop_databases, list_all_databases_for_user, list_databases,
|
||||
},
|
||||
database_privilege_operations::{
|
||||
apply_privilege_diffs, get_all_database_privileges, get_databases_privilege_data,
|
||||
},
|
||||
user_operations::{
|
||||
create_database_users, drop_database_users, list_all_database_users_for_unix_user,
|
||||
list_database_users, lock_database_users, set_password_for_database_user,
|
||||
unlock_database_users,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -37,13 +33,13 @@ use crate::{
|
||||
pub async fn session_handler(
|
||||
socket: UnixStream,
|
||||
unix_user: &UnixUser,
|
||||
config: &ServerConfig,
|
||||
db_pool: MySqlPool,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut message_stream = create_server_to_client_message_stream(socket);
|
||||
|
||||
log::debug!("Opening connection to database");
|
||||
|
||||
let mut db_connection = match create_mysql_connection_from_config(&config.mysql).await {
|
||||
let mut db_connection = match db_pool.acquire().await {
|
||||
Ok(connection) => connection,
|
||||
Err(err) => {
|
||||
message_stream
|
||||
@@ -56,28 +52,10 @@ pub async fn session_handler(
|
||||
))
|
||||
.await?;
|
||||
message_stream.flush().await?;
|
||||
return Err(err);
|
||||
return Err(err.into());
|
||||
}
|
||||
};
|
||||
|
||||
log::debug!("Verifying that database connection is valid");
|
||||
|
||||
if let Err(e) = db_connection.ping().await {
|
||||
log::error!("Failed to ping database: {}", e);
|
||||
message_stream
|
||||
.send(Response::Error(
|
||||
(concatdoc! {
|
||||
"Server failed to connect to database\n",
|
||||
"Please check the server logs or contact the system administrators"
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
.await?;
|
||||
message_stream.flush().await?;
|
||||
close_or_ignore_db_connection(db_connection).await;
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
log::debug!("Successfully connected to database");
|
||||
|
||||
let result =
|
||||
@@ -215,7 +193,7 @@ async fn session_handler_with_db_connection(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close_or_ignore_db_connection(db_connection: MySqlConnection) {
|
||||
async fn close_or_ignore_db_connection(db_connection: PoolConnection<MySql>) {
|
||||
if let Err(e) = db_connection.close().await {
|
||||
log::error!("Failed to close database connection: {}", e);
|
||||
log::error!("{}", e);
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::{
|
||||
use anyhow::{Context, anyhow};
|
||||
use futures_util::SinkExt;
|
||||
use indoc::concatdoc;
|
||||
use sqlx::{MySqlPool, mysql::MySqlConnectOptions, prelude::*};
|
||||
use sqlx::MySqlPool;
|
||||
use tokio::{net::UnixListener as TokioUnixListener, task::JoinHandle, time::interval};
|
||||
use tokio_util::task::TaskTracker;
|
||||
// use tokio_util::sync::CancellationToken;
|
||||
@@ -80,11 +80,10 @@ impl Supervisor {
|
||||
|
||||
let listener_task = {
|
||||
let connection_counter = connection_counter.clone();
|
||||
let config_clone = config.clone();
|
||||
tokio::spawn(spawn_listener_task(
|
||||
listener,
|
||||
config_clone,
|
||||
connection_counter,
|
||||
db_connection_pool.clone(),
|
||||
))
|
||||
};
|
||||
|
||||
@@ -192,30 +191,13 @@ async fn create_unix_listener_with_systemd_socket() -> anyhow::Result<TokioUnixL
|
||||
}
|
||||
|
||||
async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySqlPool> {
|
||||
let mut mysql_options = MySqlConnectOptions::new()
|
||||
.database("mysql")
|
||||
.log_statements(log::LevelFilter::Trace);
|
||||
let mysql_config = config.as_mysql_connect_options()?;
|
||||
|
||||
if let Some(username) = config.username.as_ref() {
|
||||
mysql_options = mysql_options.username(username);
|
||||
}
|
||||
|
||||
if let Some(password) = config.password.as_ref() {
|
||||
mysql_options = mysql_options.password(password);
|
||||
}
|
||||
|
||||
if let Some(socket_path) = config.socket_path.as_ref() {
|
||||
mysql_options = mysql_options.socket(socket_path);
|
||||
} else if let Some(host) = config.host.as_ref() {
|
||||
mysql_options = mysql_options.host(host);
|
||||
mysql_options = mysql_options.port(config.port);
|
||||
} else {
|
||||
anyhow::bail!("No MySQL host or socket path provided");
|
||||
}
|
||||
config.log_connection_notice();
|
||||
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(config.timeout),
|
||||
MySqlPool::connect_with(mysql_options),
|
||||
MySqlPool::connect_with(mysql_config),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -253,8 +235,8 @@ async fn create_db_connection_pool(config: &MysqlConfig) -> anyhow::Result<MySql
|
||||
|
||||
async fn spawn_listener_task(
|
||||
listener: TokioUnixListener,
|
||||
config: ServerConfig,
|
||||
connection_counter: Arc<()>,
|
||||
db_pool: MySqlPool,
|
||||
) -> anyhow::Result<()> {
|
||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])?;
|
||||
|
||||
@@ -305,7 +287,7 @@ async fn spawn_listener_task(
|
||||
|
||||
log::info!("Accepted connection from UNIX user: {}", unix_user.username);
|
||||
|
||||
match session_handler(conn, &unix_user, &config).await {
|
||||
match session_handler(conn, &unix_user, db_pool.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
log::error!("Failed to run server: {}", e);
|
||||
|
||||
Reference in New Issue
Block a user