from __future__ import annotations from contextlib import contextmanager from pathlib import Path from typing import Iterator from sqlalchemy import create_engine, event, text from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import get_settings class Base(DeclarativeBase): pass def _engine_url() -> str: url = get_settings().app.database_url if url.startswith("sqlite:///") and not url.startswith("sqlite:////"): db_path = url.removeprefix("sqlite:///") if db_path and db_path != ":memory:": Path(db_path).parent.mkdir(parents=True, exist_ok=True) return url def _connect_args(url: str) -> dict: if url.startswith("sqlite:"): return {"timeout": 60} return {} engine_url = _engine_url() engine = create_engine(engine_url, future=True, pool_pre_ping=True, connect_args=_connect_args(engine_url)) @event.listens_for(engine, "connect") def _configure_sqlite(dbapi_connection, connection_record) -> None: if engine.url.get_backend_name() != "sqlite": return cursor = dbapi_connection.cursor() cursor.execute("PRAGMA busy_timeout=60000") cursor.execute("PRAGMA journal_mode=WAL") cursor.close() SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True) def init_db() -> None: from app import models # noqa: F401 Base.metadata.create_all(bind=engine) def get_db() -> Iterator[Session]: with SessionLocal() as session: yield session @contextmanager def session_scope() -> Iterator[Session]: with SessionLocal() as session: try: yield session session.commit() except Exception: session.rollback() raise def database_ok() -> bool: try: with engine.connect() as conn: conn.execute(text("select 1")) return True except Exception: return False