76 lines
1.9 KiB
Python
76 lines
1.9 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Iterator
|
|
|
|
from sqlalchemy import create_engine, event, text
|
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
|
|
|
from app.config import get_settings
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
def _engine_url() -> str:
|
|
url = get_settings().app.database_url
|
|
if url.startswith("sqlite:///") and not url.startswith("sqlite:////"):
|
|
db_path = url.removeprefix("sqlite:///")
|
|
if db_path and db_path != ":memory:":
|
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
return url
|
|
|
|
|
|
def _connect_args(url: str) -> dict:
|
|
if url.startswith("sqlite:"):
|
|
return {"timeout": 60}
|
|
return {}
|
|
|
|
|
|
engine_url = _engine_url()
|
|
engine = create_engine(engine_url, future=True, pool_pre_ping=True, connect_args=_connect_args(engine_url))
|
|
|
|
|
|
@event.listens_for(engine, "connect")
|
|
def _configure_sqlite(dbapi_connection, connection_record) -> None:
|
|
if engine.url.get_backend_name() != "sqlite":
|
|
return
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("PRAGMA busy_timeout=60000")
|
|
cursor.execute("PRAGMA journal_mode=WAL")
|
|
cursor.close()
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
|
|
|
|
|
|
def init_db() -> None:
|
|
from app import models # noqa: F401
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
def get_db() -> Iterator[Session]:
|
|
with SessionLocal() as session:
|
|
yield session
|
|
|
|
|
|
@contextmanager
|
|
def session_scope() -> Iterator[Session]:
|
|
with SessionLocal() as session:
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
|
|
|
|
def database_ok() -> bool:
|
|
try:
|
|
with engine.connect() as conn:
|
|
conn.execute(text("select 1"))
|
|
return True
|
|
except Exception:
|
|
return False
|