"""Authentication and authorization dependencies for protected API routes.""" from dataclasses import dataclass from datetime import datetime from typing import Annotated from uuid import UUID import hmac from fastapi import Depends, HTTPException, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from sqlalchemy.orm import Session from app.db.base import get_session from app.models.auth import UserRole from app.services.authentication import resolve_auth_session try: from fastapi import Cookie, Header except (ImportError, AttributeError): def Cookie(_default=None, **_kwargs): # type: ignore[no-untyped-def] """Compatibility fallback for environments that stub fastapi without request params.""" return None def Header(_default=None, **_kwargs): # type: ignore[no-untyped-def] """Compatibility fallback for environments that stub fastapi without request params.""" return None bearer_auth = HTTPBearer(auto_error=False) SESSION_COOKIE_NAME = "dcm_session" CSRF_COOKIE_NAME = "dcm_csrf" CSRF_HEADER_NAME = "x-csrf-token" CSRF_PROTECTED_METHODS = frozenset({"POST", "PATCH", "PUT", "DELETE"}) @dataclass(frozen=True) class AuthContext: """Carries authenticated identity and role details for one request.""" user_id: UUID username: str role: UserRole session_id: UUID expires_at: datetime def _requires_csrf_validation(method: str) -> bool: """Returns whether an HTTP method should be protected by cookie CSRF validation.""" return method.upper() in CSRF_PROTECTED_METHODS def _extract_cookie_values(request: Request, cookie_name: str) -> tuple[str, ...]: """Extracts all values for one cookie name from raw Cookie header order.""" request_headers = getattr(request, "headers", None) raw_cookie_header = request_headers.get("cookie", "") if request_headers is not None else "" if not raw_cookie_header: return () extracted_values: list[str] = [] for cookie_pair in raw_cookie_header.split(";"): normalized_pair = cookie_pair.strip() if not normalized_pair or "=" not in normalized_pair: continue key, value = normalized_pair.split("=", 1) if key.strip() != cookie_name: continue normalized_value = value.strip() if normalized_value: extracted_values.append(normalized_value) return tuple(extracted_values) def _raise_unauthorized() -> None: """Raises a 401 challenge response for missing or invalid auth sessions.""" raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired authentication session", headers={"WWW-Authenticate": "Bearer"}, ) def _raise_csrf_rejected() -> None: """Raises a forbidden response for CSRF validation failure.""" raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid CSRF token", ) def get_request_auth_context( request: Request, credentials: HTTPAuthorizationCredentials | None = Depends(bearer_auth), csrf_header: str | None = Header(None, alias=CSRF_HEADER_NAME), csrf_cookie: str | None = Cookie(None, alias=CSRF_COOKIE_NAME), session_cookie: str | None = Cookie(None, alias=SESSION_COOKIE_NAME), session: Session = Depends(get_session), ) -> AuthContext: """Authenticates auth session token and validates CSRF for cookie sessions.""" token = credentials.credentials.strip() if credentials is not None and credentials.credentials else "" using_cookie_session = False if not token: token = (session_cookie or "").strip() using_cookie_session = True if not token: _raise_unauthorized() if _requires_csrf_validation(request.method) and using_cookie_session: normalized_csrf_header = (csrf_header or "").strip() csrf_candidates = [candidate for candidate in _extract_cookie_values(request, CSRF_COOKIE_NAME) if candidate] normalized_csrf_cookie = (csrf_cookie or "").strip() if normalized_csrf_cookie and normalized_csrf_cookie not in csrf_candidates: csrf_candidates.append(normalized_csrf_cookie) if ( not csrf_candidates or not normalized_csrf_header or not any(hmac.compare_digest(candidate, normalized_csrf_header) for candidate in csrf_candidates) ): _raise_csrf_rejected() resolved_session = resolve_auth_session(session, token=token) if resolved_session is None or resolved_session.user is None: _raise_unauthorized() return AuthContext( user_id=resolved_session.user.id, username=resolved_session.user.username, role=resolved_session.user.role, session_id=resolved_session.id, expires_at=resolved_session.expires_at, ) def require_user_or_admin(context: Annotated[AuthContext, Depends(get_request_auth_context)]) -> AuthContext: """Requires any authenticated user session and returns its request identity context.""" return context def require_admin(context: Annotated[AuthContext, Depends(get_request_auth_context)]) -> AuthContext: """Requires authenticated admin role and rejects standard user sessions.""" if context.role != UserRole.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Administrator role required", ) return context