verify database connection before starting

This commit is contained in:
2026-02-05 01:39:40 +09:00
parent 4d88409e97
commit af5710d663
4 changed files with 128 additions and 17 deletions

View File

@@ -4,25 +4,11 @@ import tomllib
from pathlib import Path
from typing import Any
from dibbler.lib.helpers import file_is_submissive_and_readable
DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml")
def default_config_path_submissive_and_readable() -> bool:
return DEFAULT_CONFIG_PATH.is_file() and any(
[
(
DEFAULT_CONFIG_PATH.stat().st_mode & 0o400
and DEFAULT_CONFIG_PATH.stat().st_uid == os.getuid()
),
(
DEFAULT_CONFIG_PATH.stat().st_mode & 0o040
and DEFAULT_CONFIG_PATH.stat().st_gid == os.getgid()
),
(DEFAULT_CONFIG_PATH.stat().st_mode & 0o004),
],
)
config: dict[str, dict[str, Any]] = {}
@@ -31,7 +17,7 @@ def load_config(config_path: Path | None = None) -> None:
if config_path is not None:
with Path(config_path).open("rb") as file:
config = tomllib.load(file)
elif default_config_path_submissive_and_readable():
elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH):
with DEFAULT_CONFIG_PATH.open("rb") as file:
config = tomllib.load(file)
else:

View File

@@ -0,0 +1,106 @@
import sys
from pathlib import Path
from sqlalchemy import Engine, create_engine, inspect, select
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.clsregistry import _ModuleMarker
from dibbler.lib.helpers import file_is_submissive_and_readable
from dibbler.models import Base
def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None:
dialect_name = getattr(engine.dialect, "name", "").lower()
if "postgres" in dialect_name:
check_postgres_ping(engine)
elif dialect_name == "sqlite":
check_sqlite_file(engine)
if verify_table_existence:
verify_tables_and_columns(engine)
def check_postgres_ping(engine: Engine) -> None:
try:
with engine.connect() as conn:
result = conn.execute(select(1))
scalar = result.scalar()
if scalar != 1 and scalar is not None:
print(
"Unexpected response from Postgres when running 'SELECT 1'",
file=sys.stderr,
)
sys.exit(1)
except (OperationalError, DBAPIError) as exc:
print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr)
sys.exit(1)
def check_sqlite_file(engine: Engine) -> None:
db_path = engine.url.database
# Don't verify in-memory databases or empty paths
if db_path in (None, "", ":memory:"):
return
db_path = db_path.removeprefix("file:").removeprefix("sqlite:")
# Strip query parameters
if "?" in db_path:
db_path = db_path.split("?", 1)[0]
path = Path(db_path)
if not path.exists():
print(f"SQLite database file does not exist: {path}", file=sys.stderr)
sys.exit(1)
if not path.is_file():
print(f"SQLite database path is not a file: {path}", file=sys.stderr)
sys.exit(1)
if not file_is_submissive_and_readable(path):
print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr)
sys.exit(1)
return
def verify_tables_and_columns(engine: Engine) -> None:
iengine = inspect(engine)
errors = False
tables = iengine.get_table_names()
for _name, klass in Base.registry._class_registry.items():
if isinstance(klass, _ModuleMarker):
continue
table = klass.__tablename__
if table in tables:
columns = [c["name"] for c in iengine.get_columns(table)]
mapper = inspect(klass)
for column_prop in mapper.attrs:
if isinstance(column_prop, RelationshipProperty):
pass
else:
for column in column_prop.columns:
if not column.key in columns:
print(
f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
else:
print(
f"Model '{klass}' declares table '{table}' which does not exist in database {engine}",
file=sys.stderr,
)
errors = True
if errors:
print("Have you remembered to run `dibbler create-db?", file=sys.stderr)
sys.exit(1)

View File

@@ -3,6 +3,7 @@ import pwd
import signal
import subprocess
from collections.abc import Callable
from pathlib import Path
from typing import Any, Literal
from sqlalchemy import and_, not_, or_
@@ -152,3 +153,13 @@ def less(string: str) -> None:
proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE)
proc.communicate(string)
signal.signal(signal.SIGINT, int_handler)
def file_is_submissive_and_readable(file: Path) -> bool:
return file.is_file() and any(
[
file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(),
file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(),
file.stat().st_mode & 0o004,
],
)

View File

@@ -6,6 +6,7 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from dibbler.conf import config_db_string, load_config
from dibbler.lib.check_db_health import check_db_health
parser = argparse.ArgumentParser()
@@ -41,6 +42,7 @@ def main() -> None:
if args.version:
from ._version import commit_id, version
print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}")
return
@@ -51,6 +53,7 @@ def main() -> None:
load_config(args.config)
engine = create_engine(config_db_string())
sql_session = Session(
engine,
expire_on_commit=False,
@@ -59,6 +62,11 @@ def main() -> None:
close_resets_only=True,
)
check_db_health(
engine,
verify_table_existence=args.subcommand != "create-db",
)
if args.subcommand == "loop":
import dibbler.subcommands.loop as loop