From af5710d663e27e1d8aba9e62e2e64c8b432c3ea3 Mon Sep 17 00:00:00 2001 From: h7x4 Date: Thu, 5 Feb 2026 01:39:40 +0900 Subject: [PATCH] verify database connection before starting --- dibbler/conf.py | 20 +------ dibbler/lib/check_db_health.py | 106 +++++++++++++++++++++++++++++++++ dibbler/lib/helpers.py | 11 ++++ dibbler/main.py | 8 +++ 4 files changed, 128 insertions(+), 17 deletions(-) create mode 100644 dibbler/lib/check_db_health.py diff --git a/dibbler/conf.py b/dibbler/conf.py index b5ad3a5..1c6f295 100644 --- a/dibbler/conf.py +++ b/dibbler/conf.py @@ -4,25 +4,11 @@ import tomllib from pathlib import Path from typing import Any +from dibbler.lib.helpers import file_is_submissive_and_readable + DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml") -def default_config_path_submissive_and_readable() -> bool: - return DEFAULT_CONFIG_PATH.is_file() and any( - [ - ( - DEFAULT_CONFIG_PATH.stat().st_mode & 0o400 - and DEFAULT_CONFIG_PATH.stat().st_uid == os.getuid() - ), - ( - DEFAULT_CONFIG_PATH.stat().st_mode & 0o040 - and DEFAULT_CONFIG_PATH.stat().st_gid == os.getgid() - ), - (DEFAULT_CONFIG_PATH.stat().st_mode & 0o004), - ], - ) - - config: dict[str, dict[str, Any]] = {} @@ -31,7 +17,7 @@ def load_config(config_path: Path | None = None) -> None: if config_path is not None: with Path(config_path).open("rb") as file: config = tomllib.load(file) - elif default_config_path_submissive_and_readable(): + elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH): with DEFAULT_CONFIG_PATH.open("rb") as file: config = tomllib.load(file) else: diff --git a/dibbler/lib/check_db_health.py b/dibbler/lib/check_db_health.py new file mode 100644 index 0000000..20a3c92 --- /dev/null +++ b/dibbler/lib/check_db_health.py @@ -0,0 +1,106 @@ +import sys +from pathlib import Path + +from sqlalchemy import Engine, create_engine, inspect, select +from sqlalchemy.exc import DBAPIError, OperationalError +from sqlalchemy.orm import RelationshipProperty +from sqlalchemy.orm.clsregistry import _ModuleMarker + +from dibbler.lib.helpers import file_is_submissive_and_readable +from dibbler.models import Base + + +def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None: + dialect_name = getattr(engine.dialect, "name", "").lower() + + if "postgres" in dialect_name: + check_postgres_ping(engine) + + elif dialect_name == "sqlite": + check_sqlite_file(engine) + + if verify_table_existence: + verify_tables_and_columns(engine) + + +def check_postgres_ping(engine: Engine) -> None: + try: + with engine.connect() as conn: + result = conn.execute(select(1)) + scalar = result.scalar() + if scalar != 1 and scalar is not None: + print( + "Unexpected response from Postgres when running 'SELECT 1'", + file=sys.stderr, + ) + sys.exit(1) + except (OperationalError, DBAPIError) as exc: + print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr) + sys.exit(1) + + +def check_sqlite_file(engine: Engine) -> None: + db_path = engine.url.database + + # Don't verify in-memory databases or empty paths + if db_path in (None, "", ":memory:"): + return + + db_path = db_path.removeprefix("file:").removeprefix("sqlite:") + + # Strip query parameters + if "?" in db_path: + db_path = db_path.split("?", 1)[0] + + path = Path(db_path) + + if not path.exists(): + print(f"SQLite database file does not exist: {path}", file=sys.stderr) + sys.exit(1) + + if not path.is_file(): + print(f"SQLite database path is not a file: {path}", file=sys.stderr) + sys.exit(1) + + if not file_is_submissive_and_readable(path): + print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr) + sys.exit(1) + + return + + +def verify_tables_and_columns(engine: Engine) -> None: + iengine = inspect(engine) + errors = False + tables = iengine.get_table_names() + + for _name, klass in Base.registry._class_registry.items(): + if isinstance(klass, _ModuleMarker): + continue + + table = klass.__tablename__ + if table in tables: + columns = [c["name"] for c in iengine.get_columns(table)] + mapper = inspect(klass) + + for column_prop in mapper.attrs: + if isinstance(column_prop, RelationshipProperty): + pass + else: + for column in column_prop.columns: + if not column.key in columns: + print( + f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}", + file=sys.stderr, + ) + errors = True + else: + print( + f"Model '{klass}' declares table '{table}' which does not exist in database {engine}", + file=sys.stderr, + ) + errors = True + + if errors: + print("Have you remembered to run `dibbler create-db?", file=sys.stderr) + sys.exit(1) diff --git a/dibbler/lib/helpers.py b/dibbler/lib/helpers.py index 2f0b9fb..3cbb2c6 100644 --- a/dibbler/lib/helpers.py +++ b/dibbler/lib/helpers.py @@ -3,6 +3,7 @@ import pwd import signal import subprocess from collections.abc import Callable +from pathlib import Path from typing import Any, Literal from sqlalchemy import and_, not_, or_ @@ -152,3 +153,13 @@ def less(string: str) -> None: proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE) proc.communicate(string) signal.signal(signal.SIGINT, int_handler) + + +def file_is_submissive_and_readable(file: Path) -> bool: + return file.is_file() and any( + [ + file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(), + file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(), + file.stat().st_mode & 0o004, + ], + ) diff --git a/dibbler/main.py b/dibbler/main.py index c7b91ea..5dc7841 100644 --- a/dibbler/main.py +++ b/dibbler/main.py @@ -6,6 +6,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session from dibbler.conf import config_db_string, load_config +from dibbler.lib.check_db_health import check_db_health parser = argparse.ArgumentParser() @@ -41,6 +42,7 @@ def main() -> None: if args.version: from ._version import commit_id, version + print(f"Dibbler version {version}, commit {commit_id if commit_id else ''}") return @@ -51,6 +53,7 @@ def main() -> None: load_config(args.config) engine = create_engine(config_db_string()) + sql_session = Session( engine, expire_on_commit=False, @@ -59,6 +62,11 @@ def main() -> None: close_resets_only=True, ) + check_db_health( + engine, + verify_table_existence=args.subcommand != "create-db", + ) + if args.subcommand == "loop": import dibbler.subcommands.loop as loop