Files
ledgerdock/backend/app/api/auth.py

170 lines
6.0 KiB
Python

"""Authentication and authorization dependencies for protected API routes."""
from dataclasses import dataclass
from datetime import datetime
from typing import Annotated
from uuid import UUID
import hmac
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.db.base import get_session
from app.models.auth import UserRole
from app.services.authentication import resolve_auth_session
try:
from fastapi import Cookie, Header
except (ImportError, AttributeError):
def Cookie(_default=None, **_kwargs): # type: ignore[no-untyped-def]
"""Compatibility fallback for environments that stub fastapi without request params."""
return None
def Header(_default=None, **_kwargs): # type: ignore[no-untyped-def]
"""Compatibility fallback for environments that stub fastapi without request params."""
return None
bearer_auth = HTTPBearer(auto_error=False)
SESSION_COOKIE_NAME = "dcm_session"
CSRF_COOKIE_NAME = "dcm_csrf"
CSRF_HEADER_NAME = "x-csrf-token"
CSRF_PROTECTED_METHODS = frozenset({"POST", "PATCH", "PUT", "DELETE"})
@dataclass(frozen=True)
class AuthContext:
"""Carries authenticated identity and role details for one request."""
user_id: UUID
username: str
role: UserRole
session_id: UUID
expires_at: datetime
def _requires_csrf_validation(method: str) -> bool:
"""Returns whether an HTTP method should be protected by cookie CSRF validation."""
return method.upper() in CSRF_PROTECTED_METHODS
def _extract_cookie_values(request: Request, cookie_name: str) -> tuple[str, ...]:
"""Extracts all values for one cookie name from raw Cookie header order."""
request_headers = getattr(request, "headers", None)
raw_cookie_header = request_headers.get("cookie", "") if request_headers is not None else ""
if not raw_cookie_header:
return ()
extracted_values: list[str] = []
for cookie_pair in raw_cookie_header.split(";"):
normalized_pair = cookie_pair.strip()
if not normalized_pair or "=" not in normalized_pair:
continue
key, value = normalized_pair.split("=", 1)
if key.strip() != cookie_name:
continue
normalized_value = value.strip()
if normalized_value:
extracted_values.append(normalized_value)
return tuple(extracted_values)
def _raise_unauthorized() -> None:
"""Raises a 401 challenge response for missing or invalid auth sessions."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired authentication session",
headers={"WWW-Authenticate": "Bearer"},
)
def _raise_csrf_rejected() -> None:
"""Raises a forbidden response for CSRF validation failure."""
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid CSRF token",
)
def get_request_auth_context(
request: Request,
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_auth),
csrf_header: str | None = Header(None, alias=CSRF_HEADER_NAME),
csrf_cookie: str | None = Cookie(None, alias=CSRF_COOKIE_NAME),
session_cookie: str | None = Cookie(None, alias=SESSION_COOKIE_NAME),
session: Session = Depends(get_session),
) -> AuthContext:
"""Authenticates auth session token and validates CSRF for cookie sessions."""
token = credentials.credentials.strip() if credentials is not None and credentials.credentials else ""
using_cookie_session = False
session_candidates: list[str] = []
if not token:
using_cookie_session = True
session_candidates = [candidate for candidate in _extract_cookie_values(request, SESSION_COOKIE_NAME) if candidate]
normalized_session_cookie = (session_cookie or "").strip()
if normalized_session_cookie and normalized_session_cookie not in session_candidates:
session_candidates.append(normalized_session_cookie)
if not session_candidates:
_raise_unauthorized()
if _requires_csrf_validation(request.method) and using_cookie_session:
normalized_csrf_header = (csrf_header or "").strip()
csrf_candidates = [candidate for candidate in _extract_cookie_values(request, CSRF_COOKIE_NAME) if candidate]
normalized_csrf_cookie = (csrf_cookie or "").strip()
if normalized_csrf_cookie and normalized_csrf_cookie not in csrf_candidates:
csrf_candidates.append(normalized_csrf_cookie)
if (
not csrf_candidates
or not normalized_csrf_header
or not any(hmac.compare_digest(candidate, normalized_csrf_header) for candidate in csrf_candidates)
):
_raise_csrf_rejected()
resolved_session = None
if token:
resolved_session = resolve_auth_session(session, token=token)
else:
for candidate in session_candidates:
resolved_session = resolve_auth_session(session, token=candidate)
if resolved_session is not None and resolved_session.user is not None:
break
if resolved_session is None or resolved_session.user is None:
_raise_unauthorized()
return AuthContext(
user_id=resolved_session.user.id,
username=resolved_session.user.username,
role=resolved_session.user.role,
session_id=resolved_session.id,
expires_at=resolved_session.expires_at,
)
def require_user_or_admin(context: Annotated[AuthContext, Depends(get_request_auth_context)]) -> AuthContext:
"""Requires any authenticated user session and returns its request identity context."""
return context
def require_admin(context: Annotated[AuthContext, Depends(get_request_auth_context)]) -> AuthContext:
"""Requires authenticated admin role and rejects standard user sessions."""
if context.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Administrator role required",
)
return context