Files
DMARC-Sentinel/app/db.py
T

76 lines
1.9 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from app.config import get_settings
class Base(DeclarativeBase):
pass
def _engine_url() -> str:
url = get_settings().app.database_url
if url.startswith("sqlite:///") and not url.startswith("sqlite:////"):
db_path = url.removeprefix("sqlite:///")
if db_path and db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
return url
def _connect_args(url: str) -> dict:
if url.startswith("sqlite:"):
return {"timeout": 60}
return {}
engine_url = _engine_url()
engine = create_engine(engine_url, future=True, pool_pre_ping=True, connect_args=_connect_args(engine_url))
@event.listens_for(engine, "connect")
def _configure_sqlite(dbapi_connection, connection_record) -> None:
if engine.url.get_backend_name() != "sqlite":
return
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA busy_timeout=60000")
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
def init_db() -> None:
from app import models # noqa: F401
Base.metadata.create_all(bind=engine)
def get_db() -> Iterator[Session]:
with SessionLocal() as session:
yield session
@contextmanager
def session_scope() -> Iterator[Session]:
with SessionLocal() as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def database_ok() -> bool:
try:
with engine.connect() as conn:
conn.execute(text("select 1"))
return True
except Exception:
return False