59 lines
1.4 KiB
Python
59 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Iterator
|
|
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
|
|
|
from app.config import get_settings
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
def _engine_url() -> str:
|
|
url = get_settings().app.database_url
|
|
if url.startswith("sqlite:///") and not url.startswith("sqlite:////"):
|
|
db_path = url.removeprefix("sqlite:///")
|
|
if db_path and db_path != ":memory:":
|
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
return url
|
|
|
|
|
|
engine = create_engine(_engine_url(), future=True, pool_pre_ping=True)
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
|
|
|
|
|
|
def init_db() -> None:
|
|
from app import models # noqa: F401
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
def get_db() -> Iterator[Session]:
|
|
with SessionLocal() as session:
|
|
yield session
|
|
|
|
|
|
@contextmanager
|
|
def session_scope() -> Iterator[Session]:
|
|
with SessionLocal() as session:
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
|
|
|
|
def database_ok() -> bool:
|
|
try:
|
|
with engine.connect() as conn:
|
|
conn.execute(text("select 1"))
|
|
return True
|
|
except Exception:
|
|
return False
|