verify database connection before starting
This commit is contained in:
@@ -4,25 +4,11 @@ import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from dibbler.lib.helpers import file_is_submissive_and_readable
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path("/etc/dibbler/dibbler.toml")
|
||||
|
||||
|
||||
def default_config_path_submissive_and_readable() -> bool:
|
||||
return DEFAULT_CONFIG_PATH.is_file() and any(
|
||||
[
|
||||
(
|
||||
DEFAULT_CONFIG_PATH.stat().st_mode & 0o400
|
||||
and DEFAULT_CONFIG_PATH.stat().st_uid == os.getuid()
|
||||
),
|
||||
(
|
||||
DEFAULT_CONFIG_PATH.stat().st_mode & 0o040
|
||||
and DEFAULT_CONFIG_PATH.stat().st_gid == os.getgid()
|
||||
),
|
||||
(DEFAULT_CONFIG_PATH.stat().st_mode & 0o004),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
config: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
@@ -31,7 +17,7 @@ def load_config(config_path: Path | None = None) -> None:
|
||||
if config_path is not None:
|
||||
with Path(config_path).open("rb") as file:
|
||||
config = tomllib.load(file)
|
||||
elif default_config_path_submissive_and_readable():
|
||||
elif file_is_submissive_and_readable(DEFAULT_CONFIG_PATH):
|
||||
with DEFAULT_CONFIG_PATH.open("rb") as file:
|
||||
config = tomllib.load(file)
|
||||
else:
|
||||
|
||||
106
dibbler/lib/check_db_health.py
Normal file
106
dibbler/lib/check_db_health.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import Engine, create_engine, inspect, select
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
from sqlalchemy.orm import RelationshipProperty
|
||||
from sqlalchemy.orm.clsregistry import _ModuleMarker
|
||||
|
||||
from dibbler.lib.helpers import file_is_submissive_and_readable
|
||||
from dibbler.models import Base
|
||||
|
||||
|
||||
def check_db_health(engine: Engine, verify_table_existence: bool = False) -> None:
|
||||
dialect_name = getattr(engine.dialect, "name", "").lower()
|
||||
|
||||
if "postgres" in dialect_name:
|
||||
check_postgres_ping(engine)
|
||||
|
||||
elif dialect_name == "sqlite":
|
||||
check_sqlite_file(engine)
|
||||
|
||||
if verify_table_existence:
|
||||
verify_tables_and_columns(engine)
|
||||
|
||||
|
||||
def check_postgres_ping(engine: Engine) -> None:
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(select(1))
|
||||
scalar = result.scalar()
|
||||
if scalar != 1 and scalar is not None:
|
||||
print(
|
||||
"Unexpected response from Postgres when running 'SELECT 1'",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
except (OperationalError, DBAPIError) as exc:
|
||||
print(f"Failed to connect to Postgres database: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_sqlite_file(engine: Engine) -> None:
|
||||
db_path = engine.url.database
|
||||
|
||||
# Don't verify in-memory databases or empty paths
|
||||
if db_path in (None, "", ":memory:"):
|
||||
return
|
||||
|
||||
db_path = db_path.removeprefix("file:").removeprefix("sqlite:")
|
||||
|
||||
# Strip query parameters
|
||||
if "?" in db_path:
|
||||
db_path = db_path.split("?", 1)[0]
|
||||
|
||||
path = Path(db_path)
|
||||
|
||||
if not path.exists():
|
||||
print(f"SQLite database file does not exist: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not path.is_file():
|
||||
print(f"SQLite database path is not a file: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not file_is_submissive_and_readable(path):
|
||||
print(f"SQLite database file is not submissive and readable: {path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def verify_tables_and_columns(engine: Engine) -> None:
|
||||
iengine = inspect(engine)
|
||||
errors = False
|
||||
tables = iengine.get_table_names()
|
||||
|
||||
for _name, klass in Base.registry._class_registry.items():
|
||||
if isinstance(klass, _ModuleMarker):
|
||||
continue
|
||||
|
||||
table = klass.__tablename__
|
||||
if table in tables:
|
||||
columns = [c["name"] for c in iengine.get_columns(table)]
|
||||
mapper = inspect(klass)
|
||||
|
||||
for column_prop in mapper.attrs:
|
||||
if isinstance(column_prop, RelationshipProperty):
|
||||
pass
|
||||
else:
|
||||
for column in column_prop.columns:
|
||||
if not column.key in columns:
|
||||
print(
|
||||
f"Model '{klass}' declares column '{column.key}' which does not exist in database {engine}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
errors = True
|
||||
else:
|
||||
print(
|
||||
f"Model '{klass}' declares table '{table}' which does not exist in database {engine}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
errors = True
|
||||
|
||||
if errors:
|
||||
print("Have you remembered to run `dibbler create-db?", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
@@ -3,6 +3,7 @@ import pwd
|
||||
import signal
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import and_, not_, or_
|
||||
@@ -152,3 +153,13 @@ def less(string: str) -> None:
|
||||
proc = subprocess.Popen("less", env=env, encoding="utf-8", stdin=subprocess.PIPE)
|
||||
proc.communicate(string)
|
||||
signal.signal(signal.SIGINT, int_handler)
|
||||
|
||||
|
||||
def file_is_submissive_and_readable(file: Path) -> bool:
|
||||
return file.is_file() and any(
|
||||
[
|
||||
file.stat().st_mode & 0o400 and file.stat().st_uid == os.getuid(),
|
||||
file.stat().st_mode & 0o040 and file.stat().st_gid == os.getgid(),
|
||||
file.stat().st_mode & 0o004,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dibbler.conf import config_db_string, load_config
|
||||
from dibbler.lib.check_db_health import check_db_health
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -41,6 +42,7 @@ def main() -> None:
|
||||
|
||||
if args.version:
|
||||
from ._version import commit_id, version
|
||||
|
||||
print(f"Dibbler version {version}, commit {commit_id if commit_id else '<unknown>'}")
|
||||
return
|
||||
|
||||
@@ -51,6 +53,7 @@ def main() -> None:
|
||||
load_config(args.config)
|
||||
|
||||
engine = create_engine(config_db_string())
|
||||
|
||||
sql_session = Session(
|
||||
engine,
|
||||
expire_on_commit=False,
|
||||
@@ -59,6 +62,11 @@ def main() -> None:
|
||||
close_resets_only=True,
|
||||
)
|
||||
|
||||
check_db_health(
|
||||
engine,
|
||||
verify_table_existence=args.subcommand != "create-db",
|
||||
)
|
||||
|
||||
if args.subcommand == "loop":
|
||||
import dibbler.subcommands.loop as loop
|
||||
|
||||
|
||||
Reference in New Issue
Block a user