verify database connection before starting

This commit is contained in:
2026-02-05 01:39:40 +09:00
parent 4d88409e97
commit af5710d663
4 changed files with 128 additions and 17 deletions

View File

@@ -4,25 +4,11 @@ import tomllib
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from dibbler.lib.helpers import file_is_submissive_and_readable
DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml") DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml")
def default_config_path_submissive_and_readable() -> bool:
return DEFAULT_CONFIG_PATH.is_file() and any(
[
(
DEFAULT_CONFIG_PATH.stat().st_mode & 0o400
and DEFAULT_CONFIG_PATH.stat().st_uid == os.getuid()
),
(
DEFAULT_CONFIG_PATH.stat().st_mode & 0o040
and DEFAULT_CONFIG_PATH.stat().st_gid == os.getgid()
),
(DEFAULT_CONFIG_PATH.stat().st_mode & 0o004),
],
)
config: dict[str, dict[str, Any]] = {} config: dict[str, dict[str, Any]] = {}
@@ -31,7 +17,7 @@ def load_config(config_path: Path | None = None) -> None:
if config_path is not None: if config_path is not None:
with Path(config_path).open("rb") as file: with Path(config_path).open("rb") as file:
config = tomllib.load(file) config = tomllib.load(file)
elif default_config_path_submissive_and_readable(): elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH):
with DEFAULT_CONFIG_PATH.open("rb") as file: with DEFAULT_CONFIG_PATH.open("rb") as file:
config = tomllib.load(file) config = tomllib.load(file)
else: else:

View File

@@ -0,0 +1,106 @@
import sys
from pathlib import Path
from sqlalchemy import Engine, create_engine, inspect, select
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.clsregistry import _ModuleMarker
from dibbler.lib.helpers import file_is_submissive_and_readable
from dibbler.models import Base
def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None:
dialect_name = getattr(engine.dialect, "name", "").lower()
if "postgres" in dialect_name:
check_postgres_ping(engine)
elif dialect_name == "sqlite":
check_sqlite_file(engine)
if verify_table_existence:
verify_tables_and_columns(engine)
def check_postgres_ping(engine: Engine) -> None:
try:
with engine.connect() as conn:
result = conn.execute(select(1))
scalar = result.scalar()
if scalar != 1 and scalar is not None:
print(
"Unexpected response from Postgres when running 'SELECT 1'",
file=sys.stderr,
)
sys.exit(1)
except (OperationalError, DBAPIError) as exc:
print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr)
sys.exit(1)
def check_sqlite_file(engine: Engine) -> None:
db_path = engine.url.database
# Don't verify in-memory databases or empty paths
if db_path in (None, "", ":memory:"):
return
db_path = db_path.removeprefix("file:").removeprefix("sqlite:")
# Strip query parameters
if "?" in db_path:
db_path = db_path.split("?", 1)[0]
path = Path(db_path)
if not path.exists():
print(f"SQLite database file does not exist: {path}", file=sys.stderr)
sys.exit(1)
if not path.is_file():
print(f"SQLite database path is not a file: {path}", file=sys.stderr)
sys.exit(1)
if not file_is_submissive_and_readable(path):
print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr)
sys.exit(1)
return
def verify_tables_and_columns(engine: Engine) -> None:
iengine = inspect(engine)
errors = False
tables = iengine.get_table_names()
for _name, klass in Base.registry._class_registry.items():
if isinstance(klass, _ModuleMarker):
continue
table = klass.__tablename__
if table in tables:
columns = [c["name"] for c in iengine.get_columns(table)]
mapper = inspect(klass)
for column_prop in mapper.attrs:
if isinstance(column_prop, RelationshipProperty):
pass
else:
for column in column_prop.columns:
if not column.key in columns:
print(
f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
else:
print(
f"Model '{klass}' declares table '{table}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
if errors:
print("Have you remembered to run `dibbler create-db?", file=sys.stderr)
sys.exit(1)

View File

@@ -3,6 +3,7 @@ import pwd
import signal import signal
import subprocess import subprocess
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
from sqlalchemy import and_, not_, or_ from sqlalchemy import and_, not_, or_
@@ -152,3 +153,13 @@ def less(string: str) -> None:
proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE) proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE)
proc.communicate(string) proc.communicate(string)
signal.signal(signal.SIGINT, int_handler) signal.signal(signal.SIGINT, int_handler)
def file_is_submissive_and_readable(file: Path) -> bool:
return file.is_file() and any(
[
file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(),
file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(),
file.stat().st_mode & 0o004,
],
)

View File

@@ -6,6 +6,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dibbler.conf import config_db_string, load_config from dibbler.conf import config_db_string, load_config
from dibbler.lib.check_db_health import check_db_health
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -41,6 +42,7 @@ def main() -> None:
if args.version: if args.version:
from ._version import commit_id, version from ._version import commit_id, version
print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}") print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}")
return return
@@ -51,6 +53,7 @@ def main() -> None:
load_config(args.config) load_config(args.config)
engine = create_engine(config_db_string()) engine = create_engine(config_db_string())
sql_session = Session( sql_session = Session(
engine, engine,
expire_on_commit=False, expire_on_commit=False,
@@ -59,6 +62,11 @@ def main() -> None:
close_resets_only=True, close_resets_only=True,
) )
check_db_health(
engine,
verify_table_existence=args.subcommand != "create-db",
)
if args.subcommand == "loop": if args.subcommand == "loop":
import dibbler.subcommands.loop as loop import dibbler.subcommands.loop as loop