Harden auth and security controls with session auth and docs
This commit is contained in:
@@ -3,11 +3,17 @@ DATABASE_URL=postgresql+psycopg://dcm:dcm@db:5432/dcm
|
||||
REDIS_URL=redis://:replace-with-redis-password@redis:6379/0
|
||||
REDIS_SECURITY_MODE=auto
|
||||
REDIS_TLS_MODE=auto
|
||||
ALLOW_DEVELOPMENT_ANONYMOUS_USER_ACCESS=true
|
||||
STORAGE_ROOT=/data/storage
|
||||
ADMIN_API_TOKEN=replace-with-random-admin-token
|
||||
USER_API_TOKEN=replace-with-random-user-token
|
||||
AUTH_BOOTSTRAP_ADMIN_USERNAME=admin
|
||||
AUTH_BOOTSTRAP_ADMIN_PASSWORD=replace-with-random-admin-password
|
||||
AUTH_BOOTSTRAP_USER_USERNAME=user
|
||||
AUTH_BOOTSTRAP_USER_PASSWORD=replace-with-random-user-password
|
||||
APP_SETTINGS_ENCRYPTION_KEY=replace-with-random-settings-encryption-key
|
||||
PROCESSING_LOG_STORE_MODEL_IO_TEXT=false
|
||||
PROCESSING_LOG_STORE_PAYLOAD_TEXT=false
|
||||
CONTENT_EXPORT_MAX_DOCUMENTS=250
|
||||
CONTENT_EXPORT_MAX_TOTAL_BYTES=52428800
|
||||
CONTENT_EXPORT_RATE_LIMIT_PER_MINUTE=6
|
||||
MAX_UPLOAD_FILES_PER_REQUEST=50
|
||||
MAX_UPLOAD_FILE_SIZE_BYTES=26214400
|
||||
MAX_UPLOAD_REQUEST_SIZE_BYTES=104857600
|
||||
@@ -31,3 +37,4 @@ TYPESENSE_PORT=8108
|
||||
TYPESENSE_API_KEY=replace-with-random-typesense-api-key
|
||||
TYPESENSE_COLLECTION_NAME=documents
|
||||
PUBLIC_BASE_URL=http://localhost:8000
|
||||
CORS_ALLOW_CREDENTIALS=false
|
||||
|
||||
@@ -1,95 +1,81 @@
|
||||
"""Token-based authentication and authorization dependencies for privileged API routes."""
|
||||
"""Authentication and authorization dependencies for protected API routes."""
|
||||
|
||||
import hmac
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import Settings, get_settings
|
||||
from app.db.base import get_session
|
||||
from app.models.auth import UserRole
|
||||
from app.services.authentication import resolve_auth_session
|
||||
|
||||
|
||||
bearer_auth = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class AuthRole:
|
||||
"""Declares supported authorization roles for privileged API operations."""
|
||||
@dataclass(frozen=True)
|
||||
class AuthContext:
|
||||
"""Carries authenticated identity and role details for one request."""
|
||||
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
user_id: UUID
|
||||
username: str
|
||||
role: UserRole
|
||||
session_id: UUID
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def _raise_unauthorized() -> None:
|
||||
"""Raises an HTTP 401 response with bearer authentication challenge headers."""
|
||||
"""Raises a 401 challenge response for missing or invalid bearer sessions."""
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API token",
|
||||
detail="Invalid or expired authentication session",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def _configured_admin_token(settings: Settings) -> str:
|
||||
"""Returns required admin token or raises configuration error when unset."""
|
||||
|
||||
token = settings.admin_api_token.strip()
|
||||
if token:
|
||||
return token
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Admin API token is not configured",
|
||||
)
|
||||
|
||||
|
||||
def _resolve_token_role(token: str, settings: Settings) -> str:
|
||||
"""Resolves role from a bearer token using constant-time comparisons."""
|
||||
|
||||
admin_token = _configured_admin_token(settings)
|
||||
if hmac.compare_digest(token, admin_token):
|
||||
return AuthRole.ADMIN
|
||||
|
||||
user_token = settings.user_api_token.strip()
|
||||
if user_token and hmac.compare_digest(token, user_token):
|
||||
return AuthRole.USER
|
||||
|
||||
_raise_unauthorized()
|
||||
|
||||
|
||||
def get_request_role(
|
||||
def get_request_auth_context(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(bearer_auth)],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
) -> str:
|
||||
"""Authenticates request token and returns its authorization role.
|
||||
|
||||
Development environments can optionally allow tokenless user access for non-admin routes to
|
||||
preserve local workflow compatibility while production remains token-enforced.
|
||||
"""
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> AuthContext:
|
||||
"""Authenticates bearer session token and returns role-bound request identity context."""
|
||||
|
||||
if credentials is None:
|
||||
if settings.allow_development_anonymous_user_access and settings.app_env.strip().lower() in {"development", "dev"}:
|
||||
return AuthRole.USER
|
||||
_raise_unauthorized()
|
||||
|
||||
token = credentials.credentials.strip()
|
||||
if not token:
|
||||
if settings.allow_development_anonymous_user_access and settings.app_env.strip().lower() in {"development", "dev"}:
|
||||
return AuthRole.USER
|
||||
_raise_unauthorized()
|
||||
return _resolve_token_role(token=token, settings=settings)
|
||||
|
||||
resolved_session = resolve_auth_session(session, token=token)
|
||||
if resolved_session is None or resolved_session.user is None:
|
||||
_raise_unauthorized()
|
||||
|
||||
return AuthContext(
|
||||
user_id=resolved_session.user.id,
|
||||
username=resolved_session.user.username,
|
||||
role=resolved_session.user.role,
|
||||
session_id=resolved_session.id,
|
||||
expires_at=resolved_session.expires_at,
|
||||
)
|
||||
|
||||
|
||||
def require_user_or_admin(role: Annotated[str, Depends(get_request_role)]) -> str:
|
||||
"""Requires a valid user or admin token and returns resolved role."""
|
||||
def require_user_or_admin(context: Annotated[AuthContext, Depends(get_request_auth_context)]) -> AuthContext:
|
||||
"""Requires any authenticated user session and returns its request identity context."""
|
||||
|
||||
return role
|
||||
return context
|
||||
|
||||
|
||||
def require_admin(role: Annotated[str, Depends(get_request_role)]) -> str:
|
||||
"""Requires admin role and rejects requests authenticated as regular users."""
|
||||
def require_admin(context: Annotated[AuthContext, Depends(get_request_auth_context)]) -> AuthContext:
|
||||
"""Requires authenticated admin role and rejects standard user sessions."""
|
||||
|
||||
if role != AuthRole.ADMIN:
|
||||
if context.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin token required",
|
||||
detail="Administrator role required",
|
||||
)
|
||||
return role
|
||||
return context
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.api.auth import require_admin, require_user_or_admin
|
||||
from app.api.auth import require_admin
|
||||
from app.api.routes_auth import router as auth_router
|
||||
from app.api.routes_documents import router as documents_router
|
||||
from app.api.routes_health import router as health_router
|
||||
from app.api.routes_processing_logs import router as processing_logs_router
|
||||
@@ -12,11 +13,11 @@ from app.api.routes_settings import router as settings_router
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(health_router)
|
||||
api_router.include_router(auth_router)
|
||||
api_router.include_router(
|
||||
documents_router,
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
dependencies=[Depends(require_user_or_admin)],
|
||||
)
|
||||
api_router.include_router(
|
||||
processing_logs_router,
|
||||
@@ -28,7 +29,6 @@ api_router.include_router(
|
||||
search_router,
|
||||
prefix="/search",
|
||||
tags=["search"],
|
||||
dependencies=[Depends(require_user_or_admin)],
|
||||
)
|
||||
api_router.include_router(
|
||||
settings_router,
|
||||
|
||||
94
backend/app/api/routes_auth.py
Normal file
94
backend/app/api/routes_auth.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Authentication endpoints for credential login, session introspection, and logout."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.auth import AuthContext, require_user_or_admin
|
||||
from app.db.base import get_session
|
||||
from app.schemas.auth import (
|
||||
AuthLoginRequest,
|
||||
AuthLoginResponse,
|
||||
AuthLogoutResponse,
|
||||
AuthSessionResponse,
|
||||
AuthUserResponse,
|
||||
)
|
||||
from app.services.authentication import authenticate_user, issue_user_session, revoke_auth_session
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _request_ip_address(request: Request) -> str | None:
|
||||
"""Returns best-effort client IP extracted from the request transport context."""
|
||||
|
||||
return request.client.host if request.client is not None else None
|
||||
|
||||
|
||||
def _request_user_agent(request: Request) -> str | None:
|
||||
"""Returns best-effort user-agent metadata for created auth sessions."""
|
||||
|
||||
user_agent = request.headers.get("user-agent", "").strip()
|
||||
return user_agent[:512] if user_agent else None
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthLoginResponse)
|
||||
def login(
|
||||
payload: AuthLoginRequest,
|
||||
request: Request,
|
||||
session: Session = Depends(get_session),
|
||||
) -> AuthLoginResponse:
|
||||
"""Authenticates username and password and returns an issued bearer session token."""
|
||||
|
||||
user = authenticate_user(
|
||||
session,
|
||||
username=payload.username,
|
||||
password=payload.password,
|
||||
)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
issued_session = issue_user_session(
|
||||
session,
|
||||
user=user,
|
||||
user_agent=_request_user_agent(request),
|
||||
ip_address=_request_ip_address(request),
|
||||
)
|
||||
session.commit()
|
||||
return AuthLoginResponse(
|
||||
access_token=issued_session.token,
|
||||
expires_at=issued_session.expires_at,
|
||||
user=AuthUserResponse.model_validate(user),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=AuthSessionResponse)
|
||||
def me(context: AuthContext = Depends(require_user_or_admin)) -> AuthSessionResponse:
|
||||
"""Returns current authenticated session identity and expiration metadata."""
|
||||
|
||||
return AuthSessionResponse(
|
||||
expires_at=context.expires_at,
|
||||
user=AuthUserResponse(
|
||||
id=context.user_id,
|
||||
username=context.username,
|
||||
role=context.role,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=AuthLogoutResponse)
|
||||
def logout(
|
||||
context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> AuthLogoutResponse:
|
||||
"""Revokes current bearer session token and confirms logout state."""
|
||||
|
||||
revoked = revoke_auth_session(
|
||||
session,
|
||||
session_id=context.session_id,
|
||||
)
|
||||
if revoked:
|
||||
session.commit()
|
||||
return AuthLogoutResponse(revoked=revoked)
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Authenticated document CRUD, lifecycle, metadata, file access, and content export endpoints."""
|
||||
|
||||
import io
|
||||
import re
|
||||
import tempfile
|
||||
import unicodedata
|
||||
import zipfile
|
||||
from datetime import datetime, time
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated, BinaryIO, Iterator, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
@@ -14,8 +14,10 @@ from fastapi.responses import FileResponse, Response, StreamingResponse
|
||||
from sqlalchemy import or_, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.auth import AuthContext, require_user_or_admin
|
||||
from app.core.config import get_settings, is_inline_preview_mime_type_safe
|
||||
from app.db.base import get_session
|
||||
from app.models.auth import UserRole
|
||||
from app.models.document import Document, DocumentStatus
|
||||
from app.schemas.documents import (
|
||||
ContentExportRequest,
|
||||
@@ -30,6 +32,7 @@ from app.services.app_settings import read_predefined_paths_settings, read_prede
|
||||
from app.services.extractor import sniff_mime
|
||||
from app.services.handwriting_style import delete_many_handwriting_style_documents
|
||||
from app.services.processing_logs import log_processing_event, set_processing_log_autocommit
|
||||
from app.services.rate_limiter import increment_rate_limit
|
||||
from app.services.storage import absolute_path, compute_sha256, store_bytes
|
||||
from app.services.typesense_index import delete_many_documents_index, upsert_document_index
|
||||
from app.worker.queue import get_processing_queue
|
||||
@@ -39,6 +42,59 @@ router = APIRouter()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def _scope_document_statement_for_auth_context(statement, auth_context: AuthContext):
|
||||
"""Restricts document statements to caller-owned rows for non-admin users."""
|
||||
|
||||
if auth_context.role == UserRole.ADMIN:
|
||||
return statement
|
||||
return statement.where(Document.owner_user_id == auth_context.user_id)
|
||||
|
||||
|
||||
def _ensure_document_access(document: Document, auth_context: AuthContext) -> None:
|
||||
"""Enforces owner-level access for non-admin users and raises not-found on violations."""
|
||||
|
||||
if auth_context.role == UserRole.ADMIN:
|
||||
return
|
||||
if document.owner_user_id != auth_context.user_id:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
|
||||
def _stream_binary_file_chunks(handle: BinaryIO, *, chunk_bytes: int) -> Iterator[bytes]:
|
||||
"""Streams binary file-like content in bounded chunks and closes handle after completion."""
|
||||
|
||||
try:
|
||||
while True:
|
||||
chunk = handle.read(chunk_bytes)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
handle.close()
|
||||
|
||||
|
||||
def _enforce_content_export_rate_limit(auth_context: AuthContext) -> None:
|
||||
"""Applies per-user fixed-window rate limiting for markdown export requests."""
|
||||
|
||||
try:
|
||||
current_count, limit = increment_rate_limit(
|
||||
scope="content-md-export",
|
||||
subject=str(auth_context.user_id),
|
||||
limit=settings.content_export_rate_limit_per_minute,
|
||||
window_seconds=60,
|
||||
)
|
||||
except RuntimeError as error:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Rate limiter backend unavailable",
|
||||
) from error
|
||||
|
||||
if limit > 0 and current_count > limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Export rate limit exceeded ({limit} requests per minute)",
|
||||
)
|
||||
|
||||
|
||||
def _parse_csv(value: str | None) -> list[str]:
|
||||
"""Parses comma-separated query values into a normalized non-empty list."""
|
||||
|
||||
@@ -296,6 +352,7 @@ def list_documents(
|
||||
type_filter: str | None = Query(default=None),
|
||||
processed_from: str | None = Query(default=None),
|
||||
processed_to: str | None = Query(default=None),
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> DocumentsListResponse:
|
||||
"""Returns paginated documents ordered by newest upload timestamp."""
|
||||
@@ -305,6 +362,7 @@ def list_documents(
|
||||
include_trashed=include_trashed,
|
||||
path_prefix=path_prefix,
|
||||
)
|
||||
base_statement = _scope_document_statement_for_auth_context(base_statement, auth_context)
|
||||
base_statement = _apply_discovery_filters(
|
||||
base_statement,
|
||||
path_filter=path_filter,
|
||||
@@ -326,11 +384,13 @@ def list_documents(
|
||||
@router.get("/tags")
|
||||
def list_tags(
|
||||
include_trashed: bool = Query(default=False),
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> dict[str, list[str]]:
|
||||
"""Returns distinct tags currently assigned across all matching documents."""
|
||||
|
||||
statement = select(Document.tags)
|
||||
statement = _scope_document_statement_for_auth_context(statement, auth_context)
|
||||
if not include_trashed:
|
||||
statement = statement.where(Document.status != DocumentStatus.TRASHED)
|
||||
|
||||
@@ -348,11 +408,13 @@ def list_tags(
|
||||
@router.get("/paths")
|
||||
def list_paths(
|
||||
include_trashed: bool = Query(default=False),
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> dict[str, list[str]]:
|
||||
"""Returns distinct logical paths currently assigned across all matching documents."""
|
||||
|
||||
statement = select(Document.logical_path)
|
||||
statement = _scope_document_statement_for_auth_context(statement, auth_context)
|
||||
if not include_trashed:
|
||||
statement = statement.where(Document.status != DocumentStatus.TRASHED)
|
||||
|
||||
@@ -370,11 +432,13 @@ def list_paths(
|
||||
@router.get("/types")
|
||||
def list_types(
|
||||
include_trashed: bool = Query(default=False),
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> dict[str, list[str]]:
|
||||
"""Returns distinct document type values from extension, MIME, and image text type."""
|
||||
|
||||
statement = select(Document.extension, Document.mime_type, Document.image_text_type)
|
||||
statement = _scope_document_statement_for_auth_context(statement, auth_context)
|
||||
if not include_trashed:
|
||||
statement = statement.where(Document.status != DocumentStatus.TRASHED)
|
||||
rows = session.execute(statement).all()
|
||||
@@ -390,16 +454,20 @@ def list_types(
|
||||
@router.post("/content-md/export")
|
||||
def export_contents_markdown(
|
||||
payload: ContentExportRequest,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
"""Exports extracted contents for selected documents as individual markdown files in a ZIP archive."""
|
||||
|
||||
_enforce_content_export_rate_limit(auth_context)
|
||||
|
||||
has_document_ids = len(payload.document_ids) > 0
|
||||
has_path_prefix = bool(payload.path_prefix and payload.path_prefix.strip())
|
||||
if not has_document_ids and not has_path_prefix:
|
||||
raise HTTPException(status_code=400, detail="Provide document_ids or path_prefix for export")
|
||||
|
||||
statement = select(Document)
|
||||
statement = _scope_document_statement_for_auth_context(statement, auth_context)
|
||||
if has_document_ids:
|
||||
statement = statement.where(Document.id.in_(payload.document_ids))
|
||||
if has_path_prefix:
|
||||
@@ -409,37 +477,82 @@ def export_contents_markdown(
|
||||
elif not payload.include_trashed:
|
||||
statement = statement.where(Document.status != DocumentStatus.TRASHED)
|
||||
|
||||
documents = session.execute(statement.order_by(Document.logical_path.asc(), Document.created_at.asc())).scalars().all()
|
||||
max_documents = max(1, int(settings.content_export_max_documents))
|
||||
ordered_statement = statement.order_by(Document.logical_path.asc(), Document.created_at.asc()).limit(max_documents + 1)
|
||||
documents = session.execute(ordered_statement).scalars().all()
|
||||
if len(documents) > max_documents:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Export exceeds maximum document count ({len(documents)} > {max_documents})",
|
||||
)
|
||||
if not documents:
|
||||
raise HTTPException(status_code=404, detail="No matching documents found for export")
|
||||
|
||||
archive_buffer = io.BytesIO()
|
||||
max_total_bytes = max(1, int(settings.content_export_max_total_bytes))
|
||||
max_spool_memory = max(64 * 1024, int(settings.content_export_spool_max_memory_bytes))
|
||||
archive_file = tempfile.SpooledTemporaryFile(max_size=max_spool_memory, mode="w+b")
|
||||
total_export_bytes = 0
|
||||
used_entries: set[str] = set()
|
||||
with zipfile.ZipFile(archive_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
|
||||
for document in documents:
|
||||
entry_name = _zip_entry_name(document, used_entries)
|
||||
archive.writestr(entry_name, _markdown_for_document(document))
|
||||
try:
|
||||
with zipfile.ZipFile(archive_file, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
|
||||
for document in documents:
|
||||
markdown_bytes = _markdown_for_document(document).encode("utf-8")
|
||||
total_export_bytes += len(markdown_bytes)
|
||||
if total_export_bytes > max_total_bytes:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=(
|
||||
"Export exceeds total markdown size limit "
|
||||
f"({total_export_bytes} > {max_total_bytes} bytes)"
|
||||
),
|
||||
)
|
||||
entry_name = _zip_entry_name(document, used_entries)
|
||||
archive.writestr(entry_name, markdown_bytes)
|
||||
archive_file.seek(0)
|
||||
except Exception:
|
||||
archive_file.close()
|
||||
raise
|
||||
|
||||
archive_buffer.seek(0)
|
||||
chunk_bytes = max(4 * 1024, int(settings.content_export_stream_chunk_bytes))
|
||||
headers = {"Content-Disposition": 'attachment; filename="document-contents-md.zip"'}
|
||||
return StreamingResponse(archive_buffer, media_type="application/zip", headers=headers)
|
||||
return StreamingResponse(
|
||||
_stream_binary_file_chunks(archive_file, chunk_bytes=chunk_bytes),
|
||||
media_type="application/zip",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{document_id}", response_model=DocumentDetailResponse)
|
||||
def get_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentDetailResponse:
|
||||
def get_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> DocumentDetailResponse:
|
||||
"""Returns one document by unique identifier."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
return DocumentDetailResponse.model_validate(document)
|
||||
|
||||
|
||||
@router.get("/{document_id}/download")
|
||||
def download_document(document_id: UUID, session: Session = Depends(get_session)) -> FileResponse:
|
||||
def download_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> FileResponse:
|
||||
"""Downloads original document bytes for the requested document identifier."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
file_path = absolute_path(document.stored_relative_path)
|
||||
@@ -447,10 +560,18 @@ def download_document(document_id: UUID, session: Session = Depends(get_session)
|
||||
|
||||
|
||||
@router.get("/{document_id}/preview")
|
||||
def preview_document(document_id: UUID, session: Session = Depends(get_session)) -> FileResponse:
|
||||
def preview_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> FileResponse:
|
||||
"""Streams trusted-safe MIME types inline and forces attachment for active script-capable types."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
@@ -467,10 +588,18 @@ def preview_document(document_id: UUID, session: Session = Depends(get_session))
|
||||
|
||||
|
||||
@router.get("/{document_id}/thumbnail")
|
||||
def thumbnail_document(document_id: UUID, session: Session = Depends(get_session)) -> FileResponse:
|
||||
def thumbnail_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> FileResponse:
|
||||
"""Returns a generated thumbnail image for dashboard card previews."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
@@ -485,10 +614,18 @@ def thumbnail_document(document_id: UUID, session: Session = Depends(get_session
|
||||
|
||||
|
||||
@router.get("/{document_id}/content-md")
|
||||
def download_document_content_markdown(document_id: UUID, session: Session = Depends(get_session)) -> Response:
|
||||
def download_document_content_markdown(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Downloads extracted content for one document as a markdown file."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
@@ -505,6 +642,7 @@ async def upload_documents(
|
||||
logical_path: Annotated[str, Form()] = "Inbox",
|
||||
tags: Annotated[str | None, Form()] = None,
|
||||
conflict_mode: Annotated[Literal["ask", "replace", "duplicate"], Form()] = "ask",
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> UploadResponse:
|
||||
"""Uploads files, records metadata, and enqueues asynchronous extraction tasks."""
|
||||
@@ -562,7 +700,11 @@ async def upload_documents(
|
||||
}
|
||||
)
|
||||
|
||||
existing = session.execute(select(Document).where(Document.sha256 == sha256)).scalar_one_or_none()
|
||||
existing_statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.sha256 == sha256),
|
||||
auth_context,
|
||||
)
|
||||
existing = session.execute(existing_statement).scalar_one_or_none()
|
||||
if existing and conflict_mode == "ask":
|
||||
log_processing_event(
|
||||
session=session,
|
||||
@@ -589,9 +731,11 @@ async def upload_documents(
|
||||
return UploadResponse(uploaded=[], conflicts=conflicts)
|
||||
|
||||
for prepared in prepared_uploads:
|
||||
existing = session.execute(
|
||||
select(Document).where(Document.sha256 == str(prepared["sha256"]))
|
||||
).scalar_one_or_none()
|
||||
existing_statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.sha256 == str(prepared["sha256"])),
|
||||
auth_context,
|
||||
)
|
||||
existing = session.execute(existing_statement).scalar_one_or_none()
|
||||
replaces_document_id = existing.id if existing and conflict_mode == "replace" else None
|
||||
|
||||
stored_relative_path = store_bytes(str(prepared["filename"]), bytes(prepared["data"]))
|
||||
@@ -606,6 +750,7 @@ async def upload_documents(
|
||||
size_bytes=len(bytes(prepared["data"])),
|
||||
logical_path=logical_path,
|
||||
tags=list(normalized_tags),
|
||||
owner_user_id=auth_context.user_id,
|
||||
replaces_document_id=replaces_document_id,
|
||||
metadata_json={"upload": "web"},
|
||||
)
|
||||
@@ -637,11 +782,16 @@ async def upload_documents(
|
||||
def update_document(
|
||||
document_id: UUID,
|
||||
payload: DocumentUpdateRequest,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> DocumentResponse:
|
||||
"""Updates document metadata and refreshes semantic index representation."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
@@ -663,10 +813,18 @@ def update_document(
|
||||
|
||||
|
||||
@router.post("/{document_id}/trash", response_model=DocumentResponse)
|
||||
def trash_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentResponse:
|
||||
def trash_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> DocumentResponse:
|
||||
"""Marks a document as trashed without deleting files from storage."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
@@ -687,10 +845,18 @@ def trash_document(document_id: UUID, session: Session = Depends(get_session)) -
|
||||
|
||||
|
||||
@router.post("/{document_id}/restore", response_model=DocumentResponse)
|
||||
def restore_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentResponse:
|
||||
def restore_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> DocumentResponse:
|
||||
"""Restores a trashed document to its previous lifecycle status."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
@@ -712,16 +878,27 @@ def restore_document(document_id: UUID, session: Session = Depends(get_session))
|
||||
|
||||
|
||||
@router.delete("/{document_id}")
|
||||
def delete_document(document_id: UUID, session: Session = Depends(get_session)) -> dict[str, int]:
|
||||
def delete_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> dict[str, int]:
|
||||
"""Permanently deletes a document and all descendant archive members including stored files."""
|
||||
|
||||
root = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
root_statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
root = session.execute(root_statement).scalar_one_or_none()
|
||||
if root is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
if root.status != DocumentStatus.TRASHED:
|
||||
raise HTTPException(status_code=400, detail="Move document to trash before permanent deletion")
|
||||
|
||||
document_tree = _collect_document_tree(session=session, root_document_id=document_id)
|
||||
if auth_context.role != UserRole.ADMIN:
|
||||
for _, document in document_tree:
|
||||
_ensure_document_access(document, auth_context)
|
||||
document_ids = [document.id for _, document in document_tree]
|
||||
try:
|
||||
delete_many_documents_index([str(current_id) for current_id in document_ids])
|
||||
@@ -752,10 +929,18 @@ def delete_document(document_id: UUID, session: Session = Depends(get_session))
|
||||
|
||||
|
||||
@router.post("/{document_id}/reprocess", response_model=DocumentResponse)
|
||||
def reprocess_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentResponse:
|
||||
def reprocess_document(
|
||||
document_id: UUID,
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> DocumentResponse:
|
||||
"""Re-enqueues a document for extraction and suggestion processing."""
|
||||
|
||||
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
|
||||
statement = _scope_document_statement_for_auth_context(
|
||||
select(Document).where(Document.id == document_id),
|
||||
auth_context,
|
||||
)
|
||||
document = session.execute(statement).scalar_one_or_none()
|
||||
if document is None:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
if document.status == DocumentStatus.TRASHED:
|
||||
|
||||
@@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import Text, cast, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.routes_documents import _apply_discovery_filters
|
||||
from app.api.auth import AuthContext, require_user_or_admin
|
||||
from app.api.routes_documents import _apply_discovery_filters, _scope_document_statement_for_auth_context
|
||||
from app.db.base import get_session
|
||||
from app.models.document import Document, DocumentStatus
|
||||
from app.schemas.documents import DocumentResponse, SearchResponse
|
||||
@@ -25,6 +26,7 @@ def search_documents(
|
||||
type_filter: str | None = Query(default=None),
|
||||
processed_from: str | None = Query(default=None),
|
||||
processed_to: str | None = Query(default=None),
|
||||
auth_context: AuthContext = Depends(require_user_or_admin),
|
||||
session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
"""Searches documents using PostgreSQL full-text ranking plus metadata matching."""
|
||||
@@ -50,6 +52,7 @@ def search_documents(
|
||||
)
|
||||
|
||||
statement = select(Document).where(search_filter)
|
||||
statement = _scope_document_statement_for_auth_context(statement, auth_context)
|
||||
if only_trashed:
|
||||
statement = statement.where(Document.status == DocumentStatus.TRASHED)
|
||||
elif not include_trashed:
|
||||
@@ -67,6 +70,7 @@ def search_documents(
|
||||
items = session.execute(statement).scalars().all()
|
||||
|
||||
count_statement = select(func.count(Document.id)).where(search_filter)
|
||||
count_statement = _scope_document_statement_for_auth_context(count_statement, auth_context)
|
||||
if only_trashed:
|
||||
count_statement = count_statement.where(Document.status == DocumentStatus.TRASHED)
|
||||
elif not include_trashed:
|
||||
|
||||
@@ -21,12 +21,24 @@ class Settings(BaseSettings):
|
||||
redis_url: str = "redis://redis:6379/0"
|
||||
redis_security_mode: str = "auto"
|
||||
redis_tls_mode: str = "auto"
|
||||
allow_development_anonymous_user_access: bool = True
|
||||
auth_bootstrap_admin_username: str = "admin"
|
||||
auth_bootstrap_admin_password: str = ""
|
||||
auth_bootstrap_user_username: str = ""
|
||||
auth_bootstrap_user_password: str = ""
|
||||
auth_session_ttl_minutes: int = 720
|
||||
auth_password_pbkdf2_iterations: int = 390000
|
||||
auth_session_token_bytes: int = 32
|
||||
auth_session_pepper: str = ""
|
||||
storage_root: Path = Path("/data/storage")
|
||||
upload_chunk_size: int = 4 * 1024 * 1024
|
||||
max_upload_files_per_request: int = 50
|
||||
max_upload_file_size_bytes: int = 25 * 1024 * 1024
|
||||
max_upload_request_size_bytes: int = 100 * 1024 * 1024
|
||||
content_export_max_documents: int = 250
|
||||
content_export_max_total_bytes: int = 50 * 1024 * 1024
|
||||
content_export_rate_limit_per_minute: int = 6
|
||||
content_export_stream_chunk_bytes: int = 256 * 1024
|
||||
content_export_spool_max_memory_bytes: int = 2 * 1024 * 1024
|
||||
max_zip_members: int = 250
|
||||
max_zip_depth: int = 2
|
||||
max_zip_descendants_per_root: int = 1000
|
||||
@@ -34,8 +46,6 @@ class Settings(BaseSettings):
|
||||
max_zip_total_uncompressed_bytes: int = 150 * 1024 * 1024
|
||||
max_zip_compression_ratio: float = 120.0
|
||||
max_text_length: int = 500_000
|
||||
admin_api_token: str = ""
|
||||
user_api_token: str = ""
|
||||
provider_base_url_allowlist: list[str] = Field(default_factory=lambda: ["api.openai.com"])
|
||||
provider_base_url_allow_http: bool = False
|
||||
provider_base_url_allow_private_network: bool = False
|
||||
@@ -43,6 +53,8 @@ class Settings(BaseSettings):
|
||||
processing_log_max_unbound_entries: int = 400
|
||||
processing_log_max_payload_chars: int = 4096
|
||||
processing_log_max_text_chars: int = 12000
|
||||
processing_log_store_model_io_text: bool = False
|
||||
processing_log_store_payload_text: bool = False
|
||||
default_openai_base_url: str = "https://api.openai.com/v1"
|
||||
default_openai_model: str = "gpt-4.1-mini"
|
||||
default_openai_timeout_seconds: int = 45
|
||||
@@ -60,6 +72,7 @@ class Settings(BaseSettings):
|
||||
typesense_num_retries: int = 0
|
||||
public_base_url: str = "http://localhost:8000"
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173", "http://localhost:3000"])
|
||||
cors_allow_credentials: bool = False
|
||||
|
||||
|
||||
LOCAL_HOSTNAME_SUFFIXES = (".local", ".internal", ".home.arpa")
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.api.router import api_router
|
||||
from app.core.config import get_settings
|
||||
from app.db.base import init_db
|
||||
from app.services.app_settings import ensure_app_settings
|
||||
from app.services.authentication import ensure_bootstrap_users
|
||||
from app.services.handwriting_style import ensure_handwriting_style_collection
|
||||
from app.services.storage import ensure_storage
|
||||
from app.services.typesense_index import ensure_typesense_collection
|
||||
@@ -18,7 +19,6 @@ from app.services.typesense_index import ensure_typesense_collection
|
||||
settings = get_settings()
|
||||
UPLOAD_ENDPOINT_PATH = "/api/v1/documents/upload"
|
||||
UPLOAD_ENDPOINT_METHOD = "POST"
|
||||
CORS_HTTP_ORIGIN_REGEX = r"^https?://[^/]+$"
|
||||
|
||||
|
||||
def _is_upload_size_guard_target(request: Request) -> bool:
|
||||
@@ -35,11 +35,11 @@ def create_app() -> FastAPI:
|
||||
"""Builds and configures the FastAPI application instance."""
|
||||
|
||||
app = FastAPI(title="DCM DMS API", version="0.1.0")
|
||||
allowed_origins = [origin.strip() for origin in settings.cors_origins if isinstance(origin, str) and origin.strip()]
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_origin_regex=CORS_HTTP_ORIGIN_REGEX,
|
||||
allow_credentials=True,
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=bool(getattr(settings, "cors_allow_credentials", False)),
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
@@ -82,8 +82,9 @@ def create_app() -> FastAPI:
|
||||
"""Initializes storage directories and database schema on service startup."""
|
||||
|
||||
ensure_storage()
|
||||
ensure_app_settings()
|
||||
init_db()
|
||||
ensure_bootstrap_users()
|
||||
ensure_app_settings()
|
||||
try:
|
||||
ensure_typesense_collection()
|
||||
except Exception:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Model exports for ORM metadata discovery."""
|
||||
|
||||
from app.models.auth import AppUser, AuthSession, UserRole
|
||||
from app.models.document import Document, DocumentStatus
|
||||
from app.models.processing_log import ProcessingLogEntry
|
||||
|
||||
__all__ = ["Document", "DocumentStatus", "ProcessingLogEntry"]
|
||||
__all__ = ["AppUser", "AuthSession", "Document", "DocumentStatus", "ProcessingLogEntry", "UserRole"]
|
||||
|
||||
66
backend/app/models/auth.py
Normal file
66
backend/app/models/auth.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Data models for authenticated users and issued API sessions."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Enum as SqlEnum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.base import Base
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""Declares authorization roles used for API route access control."""
|
||||
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
|
||||
class AppUser(Base):
|
||||
"""Stores one authenticatable user account with role-bound authorization."""
|
||||
|
||||
__tablename__ = "app_users"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
username: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
role: Mapped[UserRole] = mapped_column(SqlEnum(UserRole), nullable=False, default=UserRole.USER)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
sessions: Mapped[list["AuthSession"]] = relationship(
|
||||
"AuthSession",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class AuthSession(Base):
|
||||
"""Stores one issued bearer session token for a specific authenticated user."""
|
||||
|
||||
__tablename__ = "auth_sessions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("app_users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
token_hash: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, index=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
user_agent: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
user: Mapped[AppUser] = relationship("AppUser", back_populates="sessions")
|
||||
@@ -38,6 +38,12 @@ class Document(Base):
|
||||
suggested_path: Mapped[str | None] = mapped_column(String(1024), nullable=True)
|
||||
tags: Mapped[list[str]] = mapped_column(ARRAY(String), nullable=False, default=list)
|
||||
suggested_tags: Mapped[list[str]] = mapped_column(ARRAY(String), nullable=False, default=list)
|
||||
owner_user_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("app_users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
metadata_json: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict)
|
||||
extracted_text: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
image_text_type: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
@@ -63,3 +69,4 @@ class Document(Base):
|
||||
foreign_keys=[parent_document_id],
|
||||
post_update=True,
|
||||
)
|
||||
owner_user: Mapped["AppUser | None"] = relationship("AppUser", foreign_keys=[owner_user_id], post_update=True)
|
||||
|
||||
48
backend/app/schemas/auth.py
Normal file
48
backend/app/schemas/auth.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Pydantic schemas for authentication and session API payloads."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.models.auth import UserRole
|
||||
|
||||
|
||||
class AuthLoginRequest(BaseModel):
|
||||
"""Represents credential input used to create one authenticated API session."""
|
||||
|
||||
username: str = Field(min_length=1, max_length=128)
|
||||
password: str = Field(min_length=1, max_length=256)
|
||||
|
||||
|
||||
class AuthUserResponse(BaseModel):
|
||||
"""Represents one authenticated user identity and authorization role."""
|
||||
|
||||
id: UUID
|
||||
username: str
|
||||
role: UserRole
|
||||
|
||||
class Config:
|
||||
"""Enables ORM object parsing for SQLAlchemy model instances."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AuthSessionResponse(BaseModel):
|
||||
"""Represents active session metadata for one authenticated user."""
|
||||
|
||||
user: AuthUserResponse
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class AuthLoginResponse(AuthSessionResponse):
|
||||
"""Represents one newly issued bearer token and associated user context."""
|
||||
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class AuthLogoutResponse(BaseModel):
|
||||
"""Represents logout outcome after current session revocation attempt."""
|
||||
|
||||
revoked: bool
|
||||
@@ -11,6 +11,14 @@ import secrets
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
except Exception: # pragma: no cover - dependency failures are surfaced at runtime usage.
|
||||
Fernet = None # type: ignore[assignment]
|
||||
|
||||
class InvalidToken(Exception):
|
||||
"""Fallback InvalidToken type used when cryptography dependency import fails."""
|
||||
|
||||
from app.core.config import get_settings, normalize_and_validate_provider_base_url
|
||||
|
||||
|
||||
@@ -63,12 +71,13 @@ DEFAULT_ROUTING_PROMPT = (
|
||||
"Confidence must be between 0 and 1."
|
||||
)
|
||||
|
||||
PROVIDER_API_KEY_CIPHERTEXT_PREFIX = "enc-v1"
|
||||
PROVIDER_API_KEY_CIPHERTEXT_PREFIX = "enc-v2"
|
||||
PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX = "enc-v1"
|
||||
PROVIDER_API_KEY_KEYFILE_NAME = ".settings-api-key"
|
||||
PROVIDER_API_KEY_STREAM_CONTEXT = b"dcm-provider-api-key-stream"
|
||||
PROVIDER_API_KEY_AUTH_CONTEXT = b"dcm-provider-api-key-auth"
|
||||
PROVIDER_API_KEY_NONCE_BYTES = 16
|
||||
PROVIDER_API_KEY_TAG_BYTES = 32
|
||||
PROVIDER_API_KEY_LEGACY_STREAM_CONTEXT = b"dcm-provider-api-key-stream"
|
||||
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT = b"dcm-provider-api-key-auth"
|
||||
PROVIDER_API_KEY_LEGACY_NONCE_BYTES = 16
|
||||
PROVIDER_API_KEY_LEGACY_TAG_BYTES = 32
|
||||
|
||||
|
||||
def _settings_api_key_path() -> Path:
|
||||
@@ -128,14 +137,14 @@ def _derive_provider_api_key_key() -> bytes:
|
||||
return generated
|
||||
|
||||
|
||||
def _xor_bytes(left: bytes, right: bytes) -> bytes:
|
||||
"""Applies byte-wise XOR for equal-length byte sequences."""
|
||||
def _legacy_xor_bytes(left: bytes, right: bytes) -> bytes:
|
||||
"""Applies byte-wise XOR for equal-length byte sequences used by legacy ciphertext migration."""
|
||||
|
||||
return bytes(first ^ second for first, second in zip(left, right))
|
||||
|
||||
|
||||
def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) -> bytes:
|
||||
"""Derives deterministic stream bytes from HMAC-SHA256 blocks for payload masking."""
|
||||
def _legacy_derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) -> bytes:
|
||||
"""Derives legacy deterministic stream bytes from HMAC-SHA256 blocks for migration reads."""
|
||||
|
||||
stream = bytearray()
|
||||
counter = 0
|
||||
@@ -143,7 +152,7 @@ def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) ->
|
||||
counter_bytes = counter.to_bytes(4, "big")
|
||||
block = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_STREAM_CONTEXT + nonce + counter_bytes,
|
||||
PROVIDER_API_KEY_LEGACY_STREAM_CONTEXT + nonce + counter_bytes,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
stream.extend(block)
|
||||
@@ -151,6 +160,33 @@ def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) ->
|
||||
return bytes(stream[:length])
|
||||
|
||||
|
||||
def _provider_key_fernet(master_key: bytes) -> Fernet:
|
||||
"""Builds Fernet instance from 32-byte symmetric key material."""
|
||||
|
||||
if Fernet is None:
|
||||
raise AppSettingsValidationError("cryptography dependency is not available")
|
||||
fernet_key = base64.urlsafe_b64encode(master_key[:32])
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def _encrypt_provider_api_key_fallback(value: str) -> str:
|
||||
"""Encrypts provider keys with legacy HMAC stream construction when cryptography is unavailable."""
|
||||
|
||||
plaintext = value.encode("utf-8")
|
||||
master_key = _derive_provider_api_key_key()
|
||||
nonce = secrets.token_bytes(PROVIDER_API_KEY_LEGACY_NONCE_BYTES)
|
||||
keystream = _legacy_derive_stream_cipher_bytes(master_key, nonce, len(plaintext))
|
||||
ciphertext = _legacy_xor_bytes(plaintext, keystream)
|
||||
tag = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
payload = nonce + ciphertext + tag
|
||||
encoded = _urlsafe_b64encode_no_padding(payload)
|
||||
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{encoded}"
|
||||
|
||||
|
||||
def _encrypt_provider_api_key(value: str) -> str:
|
||||
"""Encrypts one provider API key for at-rest JSON persistence."""
|
||||
|
||||
@@ -158,19 +194,52 @@ def _encrypt_provider_api_key(value: str) -> str:
|
||||
if not normalized:
|
||||
return ""
|
||||
|
||||
plaintext = normalized.encode("utf-8")
|
||||
if Fernet is None:
|
||||
return _encrypt_provider_api_key_fallback(normalized)
|
||||
master_key = _derive_provider_api_key_key()
|
||||
nonce = secrets.token_bytes(PROVIDER_API_KEY_NONCE_BYTES)
|
||||
keystream = _derive_stream_cipher_bytes(master_key, nonce, len(plaintext))
|
||||
ciphertext = _xor_bytes(plaintext, keystream)
|
||||
tag = hmac.new(
|
||||
token = _provider_key_fernet(master_key).encrypt(normalized.encode("utf-8")).decode("ascii")
|
||||
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{token}"
|
||||
|
||||
|
||||
def _decrypt_provider_api_key_legacy_payload(encoded_payload: str) -> str:
|
||||
"""Decrypts legacy stream-cipher payload bytes used for migration and fallback reads."""
|
||||
|
||||
if not encoded_payload:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is missing payload bytes")
|
||||
try:
|
||||
payload = _urlsafe_b64decode_no_padding(encoded_payload)
|
||||
except (binascii.Error, ValueError) as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is not valid base64") from error
|
||||
|
||||
minimum_length = PROVIDER_API_KEY_LEGACY_NONCE_BYTES + PROVIDER_API_KEY_LEGACY_TAG_BYTES
|
||||
if len(payload) < minimum_length:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext payload is truncated")
|
||||
|
||||
nonce = payload[:PROVIDER_API_KEY_LEGACY_NONCE_BYTES]
|
||||
ciphertext = payload[PROVIDER_API_KEY_LEGACY_NONCE_BYTES:-PROVIDER_API_KEY_LEGACY_TAG_BYTES]
|
||||
received_tag = payload[-PROVIDER_API_KEY_LEGACY_TAG_BYTES:]
|
||||
master_key = _derive_provider_api_key_key()
|
||||
expected_tag = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
payload = nonce + ciphertext + tag
|
||||
encoded = _urlsafe_b64encode_no_padding(payload)
|
||||
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{encoded}"
|
||||
if not hmac.compare_digest(received_tag, expected_tag):
|
||||
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed")
|
||||
|
||||
keystream = _legacy_derive_stream_cipher_bytes(master_key, nonce, len(ciphertext))
|
||||
plaintext = _legacy_xor_bytes(ciphertext, keystream)
|
||||
try:
|
||||
return plaintext.decode("utf-8").strip()
|
||||
except UnicodeDecodeError as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is not valid UTF-8") from error
|
||||
|
||||
|
||||
def _decrypt_provider_api_key_legacy(value: str) -> str:
|
||||
"""Decrypts legacy `enc-v1` payloads to support non-breaking key migration."""
|
||||
|
||||
encoded_payload = value.split(":", 1)[1]
|
||||
return _decrypt_provider_api_key_legacy_payload(encoded_payload)
|
||||
|
||||
|
||||
def _decrypt_provider_api_key(value: str) -> str:
|
||||
@@ -179,35 +248,23 @@ def _decrypt_provider_api_key(value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return ""
|
||||
if not normalized.startswith(f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:"):
|
||||
if not normalized.startswith(f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:") and not normalized.startswith(
|
||||
f"{PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX}:"
|
||||
):
|
||||
return normalized
|
||||
|
||||
encoded_payload = normalized.split(":", 1)[1]
|
||||
if not encoded_payload:
|
||||
if normalized.startswith(f"{PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX}:"):
|
||||
return _decrypt_provider_api_key_legacy(normalized)
|
||||
|
||||
token = normalized.split(":", 1)[1].strip()
|
||||
if not token:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is missing payload bytes")
|
||||
if Fernet is None:
|
||||
return _decrypt_provider_api_key_legacy_payload(token)
|
||||
try:
|
||||
payload = _urlsafe_b64decode_no_padding(encoded_payload)
|
||||
except (binascii.Error, ValueError) as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is not valid base64") from error
|
||||
|
||||
minimum_length = PROVIDER_API_KEY_NONCE_BYTES + PROVIDER_API_KEY_TAG_BYTES
|
||||
if len(payload) < minimum_length:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext payload is truncated")
|
||||
|
||||
nonce = payload[:PROVIDER_API_KEY_NONCE_BYTES]
|
||||
ciphertext = payload[PROVIDER_API_KEY_NONCE_BYTES:-PROVIDER_API_KEY_TAG_BYTES]
|
||||
received_tag = payload[-PROVIDER_API_KEY_TAG_BYTES:]
|
||||
master_key = _derive_provider_api_key_key()
|
||||
expected_tag = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
if not hmac.compare_digest(received_tag, expected_tag):
|
||||
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed")
|
||||
|
||||
keystream = _derive_stream_cipher_bytes(master_key, nonce, len(ciphertext))
|
||||
plaintext = _xor_bytes(ciphertext, keystream)
|
||||
plaintext = _provider_key_fernet(_derive_provider_api_key_key()).decrypt(token.encode("ascii"))
|
||||
except (InvalidToken, ValueError, UnicodeEncodeError) as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed") from error
|
||||
try:
|
||||
return plaintext.decode("utf-8").strip()
|
||||
except UnicodeDecodeError as error:
|
||||
|
||||
289
backend/app/services/authentication.py
Normal file
289
backend/app/services/authentication.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Authentication services for user credential validation and session issuance."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import Settings, get_settings
|
||||
from app.db.base import SessionLocal
|
||||
from app.models.auth import AppUser, AuthSession, UserRole
|
||||
|
||||
|
||||
PASSWORD_HASH_SCHEME = "pbkdf2_sha256"
|
||||
DEFAULT_AUTH_FALLBACK_SECRET = "dcm-session-secret"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IssuedSession:
|
||||
"""Represents one newly issued bearer session token and expiration timestamp."""
|
||||
|
||||
token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def normalize_username(username: str) -> str:
|
||||
"""Normalizes usernames to a stable lowercase identity key."""
|
||||
|
||||
return username.strip().lower()
|
||||
|
||||
|
||||
def _urlsafe_b64encode_no_padding(data: bytes) -> str:
|
||||
"""Encodes bytes to compact URL-safe base64 without padding."""
|
||||
|
||||
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def _urlsafe_b64decode_no_padding(data: str) -> bytes:
|
||||
"""Decodes URL-safe base64 values that may omit trailing padding characters."""
|
||||
|
||||
padded = data + "=" * (-len(data) % 4)
|
||||
return base64.urlsafe_b64decode(padded.encode("ascii"))
|
||||
|
||||
|
||||
def _password_iterations(settings: Settings) -> int:
|
||||
"""Returns PBKDF2 iteration count clamped to a secure operational range."""
|
||||
|
||||
return max(200_000, min(1_200_000, int(settings.auth_password_pbkdf2_iterations)))
|
||||
|
||||
|
||||
def hash_password(password: str, settings: Settings | None = None) -> str:
|
||||
"""Derives and formats a PBKDF2-SHA256 password hash for persisted user credentials."""
|
||||
|
||||
resolved_settings = settings or get_settings()
|
||||
normalized_password = password.strip()
|
||||
if not normalized_password:
|
||||
raise ValueError("Password must not be empty")
|
||||
|
||||
iterations = _password_iterations(resolved_settings)
|
||||
salt = secrets.token_bytes(16)
|
||||
derived = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
normalized_password.encode("utf-8"),
|
||||
salt,
|
||||
iterations,
|
||||
dklen=32,
|
||||
)
|
||||
return (
|
||||
f"{PASSWORD_HASH_SCHEME}$"
|
||||
f"{iterations}$"
|
||||
f"{_urlsafe_b64encode_no_padding(salt)}$"
|
||||
f"{_urlsafe_b64encode_no_padding(derived)}"
|
||||
)
|
||||
|
||||
|
||||
def verify_password(password: str, stored_hash: str, settings: Settings | None = None) -> bool:
|
||||
"""Verifies one plaintext password against persisted PBKDF2-SHA256 hash material."""
|
||||
|
||||
resolved_settings = settings or get_settings()
|
||||
normalized_password = password.strip()
|
||||
if not normalized_password:
|
||||
return False
|
||||
|
||||
parts = stored_hash.strip().split("$")
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
scheme, iterations_text, salt_text, digest_text = parts
|
||||
if scheme != PASSWORD_HASH_SCHEME:
|
||||
return False
|
||||
try:
|
||||
iterations = int(iterations_text)
|
||||
except ValueError:
|
||||
return False
|
||||
if iterations < 200_000 or iterations > 2_000_000:
|
||||
return False
|
||||
try:
|
||||
salt = _urlsafe_b64decode_no_padding(salt_text)
|
||||
expected_digest = _urlsafe_b64decode_no_padding(digest_text)
|
||||
except (binascii.Error, ValueError):
|
||||
return False
|
||||
derived_digest = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
normalized_password.encode("utf-8"),
|
||||
salt,
|
||||
iterations,
|
||||
dklen=len(expected_digest),
|
||||
)
|
||||
if not hmac.compare_digest(expected_digest, derived_digest):
|
||||
return False
|
||||
return iterations >= _password_iterations(resolved_settings)
|
||||
|
||||
|
||||
def _auth_session_secret(settings: Settings) -> bytes:
|
||||
"""Resolves a stable secret used to hash issued bearer session tokens."""
|
||||
|
||||
candidate = settings.auth_session_pepper.strip() or settings.app_settings_encryption_key.strip()
|
||||
if not candidate:
|
||||
candidate = DEFAULT_AUTH_FALLBACK_SECRET
|
||||
return hashlib.sha256(candidate.encode("utf-8")).digest()
|
||||
|
||||
|
||||
def _hash_session_token(token: str, settings: Settings | None = None) -> str:
|
||||
"""Derives a deterministic SHA256 token hash guarded by secret pepper material."""
|
||||
|
||||
resolved_settings = settings or get_settings()
|
||||
secret = _auth_session_secret(resolved_settings)
|
||||
digest = hmac.new(secret, token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
return digest
|
||||
|
||||
|
||||
def _new_session_token(settings: Settings) -> str:
|
||||
"""Creates a random URL-safe bearer token for one API session."""
|
||||
|
||||
token_bytes = max(24, min(128, int(settings.auth_session_token_bytes)))
|
||||
return secrets.token_urlsafe(token_bytes)
|
||||
|
||||
|
||||
def _resolve_optional_user_credentials(username: str, password: str) -> tuple[str, str] | None:
|
||||
"""Returns optional user credentials only when both username and password are configured."""
|
||||
|
||||
normalized_username = normalize_username(username)
|
||||
normalized_password = password.strip()
|
||||
if not normalized_username and not normalized_password:
|
||||
return None
|
||||
if not normalized_username or not normalized_password:
|
||||
raise ValueError("Optional bootstrap user requires both username and password")
|
||||
return normalized_username, normalized_password
|
||||
|
||||
|
||||
def _upsert_bootstrap_user(session: Session, *, username: str, password: str, role: UserRole) -> AppUser:
|
||||
"""Creates or updates one bootstrap account with deterministic role assignment."""
|
||||
|
||||
existing = session.execute(select(AppUser).where(AppUser.username == username)).scalar_one_or_none()
|
||||
password_hash = hash_password(password)
|
||||
if existing is None:
|
||||
user = AppUser(
|
||||
username=username,
|
||||
password_hash=password_hash,
|
||||
role=role,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
return user
|
||||
|
||||
existing.password_hash = password_hash
|
||||
existing.role = role
|
||||
existing.is_active = True
|
||||
return existing
|
||||
|
||||
|
||||
def ensure_bootstrap_users() -> None:
|
||||
"""Creates or refreshes bootstrap user accounts from runtime environment credentials."""
|
||||
|
||||
settings = get_settings()
|
||||
admin_username = normalize_username(settings.auth_bootstrap_admin_username)
|
||||
admin_password = settings.auth_bootstrap_admin_password.strip()
|
||||
if not admin_username:
|
||||
raise RuntimeError("AUTH_BOOTSTRAP_ADMIN_USERNAME must not be empty")
|
||||
if not admin_password:
|
||||
raise RuntimeError("AUTH_BOOTSTRAP_ADMIN_PASSWORD must not be empty")
|
||||
|
||||
optional_user_credentials = _resolve_optional_user_credentials(
|
||||
username=settings.auth_bootstrap_user_username,
|
||||
password=settings.auth_bootstrap_user_password,
|
||||
)
|
||||
|
||||
with SessionLocal() as session:
|
||||
_upsert_bootstrap_user(
|
||||
session,
|
||||
username=admin_username,
|
||||
password=admin_password,
|
||||
role=UserRole.ADMIN,
|
||||
)
|
||||
if optional_user_credentials is not None:
|
||||
user_username, user_password = optional_user_credentials
|
||||
if user_username == admin_username:
|
||||
raise RuntimeError("AUTH_BOOTSTRAP_USER_USERNAME must differ from admin username")
|
||||
_upsert_bootstrap_user(
|
||||
session,
|
||||
username=user_username,
|
||||
password=user_password,
|
||||
role=UserRole.USER,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def authenticate_user(session: Session, *, username: str, password: str) -> AppUser | None:
|
||||
"""Authenticates one username/password pair and returns active account on success."""
|
||||
|
||||
normalized_username = normalize_username(username)
|
||||
if not normalized_username:
|
||||
return None
|
||||
user = session.execute(select(AppUser).where(AppUser.username == normalized_username)).scalar_one_or_none()
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
def issue_user_session(
|
||||
session: Session,
|
||||
*,
|
||||
user: AppUser,
|
||||
user_agent: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> IssuedSession:
|
||||
"""Issues one new bearer token session for a validated user account."""
|
||||
|
||||
settings = get_settings()
|
||||
now = datetime.now(UTC)
|
||||
ttl_minutes = max(5, min(7 * 24 * 60, int(settings.auth_session_ttl_minutes)))
|
||||
expires_at = now + timedelta(minutes=ttl_minutes)
|
||||
token = _new_session_token(settings)
|
||||
token_hash = _hash_session_token(token, settings)
|
||||
|
||||
session.execute(
|
||||
delete(AuthSession).where(
|
||||
AuthSession.user_id == user.id,
|
||||
AuthSession.expires_at <= now,
|
||||
)
|
||||
)
|
||||
session_entry = AuthSession(
|
||||
user_id=user.id,
|
||||
token_hash=token_hash,
|
||||
expires_at=expires_at,
|
||||
user_agent=(user_agent or "").strip()[:512] or None,
|
||||
ip_address=(ip_address or "").strip()[:64] or None,
|
||||
)
|
||||
session.add(session_entry)
|
||||
return IssuedSession(token=token, expires_at=expires_at)
|
||||
|
||||
|
||||
def resolve_auth_session(session: Session, *, token: str) -> AuthSession | None:
|
||||
"""Resolves one non-revoked and non-expired session from a bearer token value."""
|
||||
|
||||
normalized = token.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
token_hash = _hash_session_token(normalized)
|
||||
now = datetime.now(UTC)
|
||||
session_entry = session.execute(
|
||||
select(AuthSession).where(
|
||||
AuthSession.token_hash == token_hash,
|
||||
AuthSession.revoked_at.is_(None),
|
||||
AuthSession.expires_at > now,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if session_entry is None or session_entry.user is None:
|
||||
return None
|
||||
if not session_entry.user.is_active:
|
||||
return None
|
||||
return session_entry
|
||||
|
||||
|
||||
def revoke_auth_session(session: Session, *, session_id: uuid.UUID) -> bool:
|
||||
"""Revokes one active session by identifier and returns whether a change was applied."""
|
||||
|
||||
existing = session.execute(select(AuthSession).where(AuthSession.id == session_id)).scalar_one_or_none()
|
||||
if existing is None or existing.revoked_at is not None:
|
||||
return False
|
||||
existing.revoked_at = datetime.now(UTC)
|
||||
return True
|
||||
@@ -6,10 +6,13 @@ from uuid import UUID
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.document import Document
|
||||
from app.models.processing_log import ProcessingLogEntry
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
MAX_STAGE_LENGTH = 64
|
||||
MAX_EVENT_LENGTH = 256
|
||||
MAX_LEVEL_LENGTH = 16
|
||||
@@ -37,9 +40,49 @@ def _trim(value: str | None, max_length: int) -> str | None:
|
||||
|
||||
|
||||
def _safe_payload(payload_json: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Ensures payload values are persisted as dictionaries."""
|
||||
"""Normalizes payload persistence mode using metadata-only defaults for sensitive content."""
|
||||
|
||||
return payload_json if isinstance(payload_json, dict) else {}
|
||||
if not isinstance(payload_json, dict):
|
||||
return {}
|
||||
if settings.processing_log_store_payload_text:
|
||||
return payload_json
|
||||
return _metadata_only_payload(payload_json)
|
||||
|
||||
|
||||
def _metadata_only_payload(payload_json: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Converts payload content into metadata descriptors without persisting raw text values."""
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
for index, (raw_key, raw_value) in enumerate(payload_json.items()):
|
||||
if index >= 80:
|
||||
break
|
||||
key = str(raw_key)
|
||||
metadata[key] = _metadata_only_payload_value(raw_value)
|
||||
return metadata
|
||||
|
||||
|
||||
def _metadata_only_payload_value(value: Any) -> Any:
|
||||
"""Converts one payload value into non-sensitive metadata representation."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return _metadata_only_payload(value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = list(value)
|
||||
return {
|
||||
"item_count": len(items),
|
||||
"items_preview": [_metadata_only_payload_value(item) for item in items[:20]],
|
||||
}
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
return {
|
||||
"text_chars": len(normalized),
|
||||
"text_omitted": bool(normalized),
|
||||
}
|
||||
if isinstance(value, bytes):
|
||||
return {"binary_bytes": len(value)}
|
||||
if isinstance(value, (int, float, bool)) or value is None:
|
||||
return value
|
||||
return {"value_type": type(value).__name__}
|
||||
|
||||
|
||||
def set_processing_log_autocommit(session: Session, enabled: bool) -> None:
|
||||
@@ -82,8 +125,8 @@ def log_processing_event(
|
||||
document_filename=_trim(resolved_document_filename, MAX_DOCUMENT_FILENAME_LENGTH),
|
||||
provider_id=_trim(provider_id, MAX_PROVIDER_LENGTH),
|
||||
model_name=_trim(model_name, MAX_MODEL_LENGTH),
|
||||
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH),
|
||||
response_text=_trim(response_text, MAX_RESPONSE_LENGTH),
|
||||
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH) if settings.processing_log_store_model_io_text else None,
|
||||
response_text=_trim(response_text, MAX_RESPONSE_LENGTH) if settings.processing_log_store_model_io_text else None,
|
||||
payload_json=_safe_payload(payload_json),
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
42
backend/app/services/rate_limiter.py
Normal file
42
backend/app/services/rate_limiter.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Redis-backed fixed-window rate limiter helpers for sensitive API operations."""
|
||||
|
||||
import time
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from app.worker.queue import get_redis
|
||||
|
||||
|
||||
def _rate_limit_key(*, scope: str, subject: str, window_id: int) -> str:
|
||||
"""Builds a stable Redis key for one scope, subject, and fixed time window."""
|
||||
|
||||
return f"dcm:rate-limit:{scope}:{subject}:{window_id}"
|
||||
|
||||
|
||||
def increment_rate_limit(
|
||||
*,
|
||||
scope: str,
|
||||
subject: str,
|
||||
limit: int,
|
||||
window_seconds: int = 60,
|
||||
) -> tuple[int, int]:
|
||||
"""Increments one rate bucket and returns current count with configured limit."""
|
||||
|
||||
bounded_limit = max(0, int(limit))
|
||||
if bounded_limit == 0:
|
||||
return (0, 0)
|
||||
|
||||
bounded_window = max(1, int(window_seconds))
|
||||
current_window = int(time.time() // bounded_window)
|
||||
key = _rate_limit_key(scope=scope, subject=subject, window_id=current_window)
|
||||
|
||||
redis_client = get_redis()
|
||||
try:
|
||||
pipeline = redis_client.pipeline(transaction=True)
|
||||
pipeline.incr(key, 1)
|
||||
pipeline.expire(key, bounded_window + 5)
|
||||
count_value, _ = pipeline.execute()
|
||||
except RedisError as error:
|
||||
raise RuntimeError("Rate limiter backend unavailable") from error
|
||||
|
||||
return (int(count_value), bounded_limit)
|
||||
26
backend/app/worker/run_worker.py
Normal file
26
backend/app/worker/run_worker.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Worker entrypoint that enforces Redis URL security checks before queue consumption."""
|
||||
|
||||
from redis import Redis
|
||||
from rq import Worker
|
||||
|
||||
from app.core.config import get_settings, validate_redis_url_security
|
||||
|
||||
|
||||
def _build_worker_connection() -> Redis:
|
||||
"""Builds validated Redis connection used by RQ worker runtime."""
|
||||
|
||||
settings = get_settings()
|
||||
secure_redis_url = validate_redis_url_security(settings.redis_url)
|
||||
return Redis.from_url(secure_redis_url)
|
||||
|
||||
|
||||
def run_worker() -> None:
|
||||
"""Runs the RQ worker loop for the configured DCM processing queue."""
|
||||
|
||||
connection = _build_worker_connection()
|
||||
worker = Worker(["dcm"], connection=connection)
|
||||
worker.work()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_worker()
|
||||
@@ -143,6 +143,7 @@ def _create_archive_member_document(
|
||||
size_bytes=len(member_data),
|
||||
logical_path=parent.logical_path,
|
||||
tags=list(parent.tags),
|
||||
owner_user_id=parent.owner_user_id,
|
||||
metadata_json={
|
||||
"origin": "archive",
|
||||
"parent": str(parent.id),
|
||||
|
||||
@@ -16,3 +16,4 @@ orjson==3.11.3
|
||||
openai==1.107.2
|
||||
typesense==1.1.1
|
||||
tiktoken==0.11.0
|
||||
cryptography==46.0.1
|
||||
|
||||
@@ -272,10 +272,10 @@ if "app.services.routing_pipeline" not in sys.modules:
|
||||
sys.modules["app.services.routing_pipeline"] = routing_pipeline_stub
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import HTTPAuthorizationCredentials
|
||||
|
||||
from app.api.auth import AuthRole, get_request_role, require_admin
|
||||
from app.api.auth import AuthContext, require_admin
|
||||
from app.core import config as config_module
|
||||
from app.models.auth import UserRole
|
||||
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
|
||||
from app.schemas.processing_logs import ProcessingLogEntryResponse
|
||||
from app.services import extractor as extractor_module
|
||||
@@ -298,52 +298,34 @@ def _security_settings(
|
||||
|
||||
|
||||
class AuthDependencyTests(unittest.TestCase):
|
||||
"""Verifies token authentication and admin authorization behavior."""
|
||||
|
||||
def test_get_request_role_accepts_admin_token(self) -> None:
|
||||
"""Admin token resolves admin role."""
|
||||
|
||||
settings = SimpleNamespace(
|
||||
admin_api_token="admin-token",
|
||||
user_api_token="user-token",
|
||||
allow_development_anonymous_user_access=False,
|
||||
app_env="production",
|
||||
)
|
||||
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="admin-token")
|
||||
role = get_request_role(credentials=credentials, settings=settings)
|
||||
self.assertEqual(role, AuthRole.ADMIN)
|
||||
|
||||
def test_get_request_role_rejects_missing_credentials(self) -> None:
|
||||
"""Missing bearer credentials return 401."""
|
||||
|
||||
settings = SimpleNamespace(
|
||||
admin_api_token="admin-token",
|
||||
user_api_token="user-token",
|
||||
allow_development_anonymous_user_access=False,
|
||||
app_env="production",
|
||||
)
|
||||
with self.assertRaises(HTTPException) as context:
|
||||
get_request_role(credentials=None, settings=settings)
|
||||
self.assertEqual(context.exception.status_code, 401)
|
||||
|
||||
def test_get_request_role_allows_tokenless_user_access_in_development(self) -> None:
|
||||
"""Development mode can allow tokenless user role for compatibility."""
|
||||
|
||||
settings = SimpleNamespace(
|
||||
admin_api_token="admin-token",
|
||||
user_api_token="user-token",
|
||||
allow_development_anonymous_user_access=True,
|
||||
app_env="development",
|
||||
)
|
||||
role = get_request_role(credentials=None, settings=settings)
|
||||
self.assertEqual(role, AuthRole.USER)
|
||||
"""Verifies role-based admin authorization behavior."""
|
||||
|
||||
def test_require_admin_rejects_user_role(self) -> None:
|
||||
"""User role cannot access admin-only endpoints."""
|
||||
|
||||
with self.assertRaises(HTTPException) as context:
|
||||
require_admin(role=AuthRole.USER)
|
||||
self.assertEqual(context.exception.status_code, 403)
|
||||
auth_context = AuthContext(
|
||||
user_id=uuid.uuid4(),
|
||||
username="user",
|
||||
role=UserRole.USER,
|
||||
session_id=uuid.uuid4(),
|
||||
expires_at=datetime.now(UTC),
|
||||
)
|
||||
with self.assertRaises(HTTPException) as raised:
|
||||
require_admin(context=auth_context)
|
||||
self.assertEqual(raised.exception.status_code, 403)
|
||||
|
||||
def test_require_admin_accepts_admin_role(self) -> None:
|
||||
"""Admin role is accepted for admin-only endpoints."""
|
||||
|
||||
auth_context = AuthContext(
|
||||
user_id=uuid.uuid4(),
|
||||
username="admin",
|
||||
role=UserRole.ADMIN,
|
||||
session_id=uuid.uuid4(),
|
||||
expires_at=datetime.now(UTC),
|
||||
)
|
||||
resolved = require_admin(context=auth_context)
|
||||
self.assertEqual(resolved.role, UserRole.ADMIN)
|
||||
|
||||
|
||||
class ProviderBaseUrlValidationTests(unittest.TestCase):
|
||||
@@ -559,6 +541,7 @@ class ArchiveLineagePropagationTests(unittest.TestCase):
|
||||
source_relative_path="uploads/root.zip",
|
||||
logical_path="Inbox",
|
||||
tags=["finance"],
|
||||
owner_user_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
with (
|
||||
@@ -578,6 +561,7 @@ class ArchiveLineagePropagationTests(unittest.TestCase):
|
||||
self.assertEqual(child.metadata_json.get(worker_tasks_module.ARCHIVE_ROOT_ID_METADATA_KEY), str(parent_id))
|
||||
self.assertEqual(child.metadata_json.get(worker_tasks_module.ARCHIVE_DEPTH_METADATA_KEY), 1)
|
||||
self.assertTrue(child.is_archive_member)
|
||||
self.assertEqual(child.owner_user_id, parent.owner_user_id)
|
||||
|
||||
def test_resolve_archive_lineage_prefers_existing_metadata(self) -> None:
|
||||
"""Existing archive lineage metadata is reused without traversing parent relationships."""
|
||||
|
||||
Reference in New Issue
Block a user