Harden auth, redaction, upload size checks, and compose token requirements

This commit is contained in:
2026-02-21 13:48:55 -03:00
parent 5792586a90
commit 3cbad053cc
21 changed files with 1168 additions and 85 deletions

View File

@@ -2,6 +2,17 @@ APP_ENV=development
DATABASE_URL=postgresql+psycopg://dcm:dcm@db:5432/dcm
REDIS_URL=redis://redis:6379/0
STORAGE_ROOT=/data/storage
ADMIN_API_TOKEN=replace-with-random-admin-token
USER_API_TOKEN=replace-with-random-user-token
MAX_UPLOAD_FILES_PER_REQUEST=50
MAX_UPLOAD_FILE_SIZE_BYTES=26214400
MAX_UPLOAD_REQUEST_SIZE_BYTES=104857600
MAX_ZIP_MEMBER_UNCOMPRESSED_BYTES=26214400
MAX_ZIP_TOTAL_UNCOMPRESSED_BYTES=157286400
MAX_ZIP_COMPRESSION_RATIO=120
PROVIDER_BASE_URL_ALLOWLIST=["api.openai.com"]
PROVIDER_BASE_URL_ALLOW_HTTP=false
PROVIDER_BASE_URL_ALLOW_PRIVATE_NETWORK=false
DEFAULT_OPENAI_BASE_URL=https://api.openai.com/v1
DEFAULT_OPENAI_MODEL=gpt-4.1-mini
DEFAULT_OPENAI_TIMEOUT_SECONDS=45

View File

@@ -12,6 +12,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
COPY requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt
COPY app /app/app
RUN addgroup --system appgroup && adduser --system --ingroup appgroup --uid 10001 appuser
RUN mkdir -p /data/storage && chown -R appuser:appgroup /app /data
COPY --chown=appuser:appgroup app /app/app
USER appuser
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

87
backend/app/api/auth.py Normal file
View File

@@ -0,0 +1,87 @@
"""Token-based authentication and authorization dependencies for privileged API routes."""
import hmac
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.core.config import Settings, get_settings
bearer_auth = HTTPBearer(auto_error=False)
class AuthRole:
"""Declares supported authorization roles for privileged API operations."""
ADMIN = "admin"
USER = "user"
def _raise_unauthorized() -> None:
"""Raises an HTTP 401 response with bearer authentication challenge headers."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing API token",
headers={"WWW-Authenticate": "Bearer"},
)
def _configured_admin_token(settings: Settings) -> str:
"""Returns required admin token or raises configuration error when unset."""
token = settings.admin_api_token.strip()
if token:
return token
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Admin API token is not configured",
)
def _resolve_token_role(token: str, settings: Settings) -> str:
"""Resolves role from a bearer token using constant-time comparisons."""
admin_token = _configured_admin_token(settings)
if hmac.compare_digest(token, admin_token):
return AuthRole.ADMIN
user_token = settings.user_api_token.strip()
if user_token and hmac.compare_digest(token, user_token):
return AuthRole.USER
_raise_unauthorized()
def get_request_role(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(bearer_auth)],
settings: Annotated[Settings, Depends(get_settings)],
) -> str:
"""Authenticates request token and returns its authorization role."""
if credentials is None:
_raise_unauthorized()
token = credentials.credentials.strip()
if not token:
_raise_unauthorized()
return _resolve_token_role(token=token, settings=settings)
def require_user_or_admin(role: Annotated[str, Depends(get_request_role)]) -> str:
"""Requires a valid user or admin token and returns resolved role."""
return role
def require_admin(role: Annotated[str, Depends(get_request_role)]) -> str:
"""Requires admin role and rejects requests authenticated as regular users."""
if role != AuthRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin token required",
)
return role

View File

@@ -1,7 +1,8 @@
"""API router registration for all HTTP route modules."""
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from app.api.auth import require_admin, require_user_or_admin
from app.api.routes_documents import router as documents_router
from app.api.routes_health import router as health_router
from app.api.routes_processing_logs import router as processing_logs_router
@@ -11,7 +12,27 @@ from app.api.routes_settings import router as settings_router
api_router = APIRouter()
api_router.include_router(health_router)
api_router.include_router(documents_router, prefix="/documents", tags=["documents"])
api_router.include_router(processing_logs_router, prefix="/processing/logs", tags=["processing-logs"])
api_router.include_router(search_router, prefix="/search", tags=["search"])
api_router.include_router(settings_router, prefix="/settings", tags=["settings"])
api_router.include_router(
documents_router,
prefix="/documents",
tags=["documents"],
dependencies=[Depends(require_user_or_admin)],
)
api_router.include_router(
processing_logs_router,
prefix="/processing/logs",
tags=["processing-logs"],
dependencies=[Depends(require_admin)],
)
api_router.include_router(
search_router,
prefix="/search",
tags=["search"],
dependencies=[Depends(require_user_or_admin)],
)
api_router.include_router(
settings_router,
prefix="/settings",
tags=["settings"],
dependencies=[Depends(require_admin)],
)

View File

@@ -1,4 +1,4 @@
"""Document CRUD, lifecycle, metadata, file access, and content export endpoints."""
"""Authenticated document CRUD, lifecycle, metadata, file access, and content export endpoints."""
import io
import re
@@ -14,7 +14,7 @@ from fastapi.responses import FileResponse, Response, StreamingResponse
from sqlalchemy import or_, func, select
from sqlalchemy.orm import Session
from app.services.app_settings import read_predefined_paths_settings, read_predefined_tags_settings
from app.core.config import get_settings
from app.db.base import get_session
from app.models.document import Document, DocumentStatus
from app.schemas.documents import (
@@ -26,6 +26,7 @@ from app.schemas.documents import (
UploadConflict,
UploadResponse,
)
from app.services.app_settings import read_predefined_paths_settings, read_predefined_tags_settings
from app.services.extractor import sniff_mime
from app.services.handwriting_style import delete_many_handwriting_style_documents
from app.services.processing_logs import log_processing_event, set_processing_log_autocommit
@@ -35,6 +36,7 @@ from app.worker.queue import get_processing_queue
router = APIRouter()
settings = get_settings()
def _parse_csv(value: str | None) -> list[str]:
@@ -227,6 +229,33 @@ def _build_document_list_statement(
return statement
def _enforce_upload_shape(files: list[UploadFile]) -> None:
"""Validates upload request shape against configured file-count bounds."""
if not files:
raise HTTPException(status_code=400, detail="Upload request must include at least one file")
if len(files) > settings.max_upload_files_per_request:
raise HTTPException(
status_code=413,
detail=(
"Upload request exceeds file count limit "
f"({len(files)} > {settings.max_upload_files_per_request})"
),
)
async def _read_upload_bytes(file: UploadFile, max_bytes: int) -> bytes:
"""Reads one upload file while enforcing per-file byte limits."""
data = await file.read(max_bytes + 1)
if len(data) > max_bytes:
raise HTTPException(
status_code=413,
detail=f"File '{file.filename or 'upload'}' exceeds per-file limit of {max_bytes} bytes",
)
return data
def _collect_document_tree(session: Session, root_document_id: UUID) -> list[tuple[int, Document]]:
"""Collects a document and all descendants for recursive permanent deletion."""
@@ -472,18 +501,29 @@ async def upload_documents(
) -> UploadResponse:
"""Uploads files, records metadata, and enqueues asynchronous extraction tasks."""
_enforce_upload_shape(files)
set_processing_log_autocommit(session, True)
normalized_tags = _normalize_tags(tags)
queue = get_processing_queue()
uploaded: list[DocumentResponse] = []
conflicts: list[UploadConflict] = []
total_request_bytes = 0
indexed_relative_paths = relative_paths or []
prepared_uploads: list[dict[str, object]] = []
for idx, file in enumerate(files):
filename = file.filename or f"uploaded_{idx}"
data = await file.read()
data = await _read_upload_bytes(file, settings.max_upload_file_size_bytes)
total_request_bytes += len(data)
if total_request_bytes > settings.max_upload_request_size_bytes:
raise HTTPException(
status_code=413,
detail=(
"Upload request exceeds total size limit "
f"({total_request_bytes} > {settings.max_upload_request_size_bytes} bytes)"
),
)
sha256 = compute_sha256(data)
source_relative_path = indexed_relative_paths[idx] if idx < len(indexed_relative_paths) else filename
extension = Path(filename).suffix.lower()

View File

@@ -1,10 +1,11 @@
"""Read-only API endpoints for processing pipeline event logs."""
"""Admin-only API endpoints for processing pipeline event logs."""
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.db.base import get_session
from app.schemas.processing_logs import ProcessingLogEntryResponse, ProcessingLogListResponse
from app.services.app_settings import read_processing_log_retention_settings
@@ -17,12 +18,13 @@ from app.services.processing_logs import (
router = APIRouter()
settings = get_settings()
@router.get("", response_model=ProcessingLogListResponse)
def get_processing_logs(
offset: int = Query(default=0, ge=0),
limit: int = Query(default=120, ge=1, le=400),
limit: int = Query(default=120, ge=1, le=settings.processing_log_max_unbound_entries),
document_id: UUID | None = Query(default=None),
session: Session = Depends(get_session),
) -> ProcessingLogListResponse:
@@ -43,8 +45,8 @@ def get_processing_logs(
@router.post("/trim")
def trim_processing_logs(
keep_document_sessions: int | None = Query(default=None, ge=0, le=20),
keep_unbound_entries: int | None = Query(default=None, ge=0, le=400),
keep_document_sessions: int | None = Query(default=None, ge=0, le=settings.processing_log_max_document_sessions),
keep_unbound_entries: int | None = Query(default=None, ge=0, le=settings.processing_log_max_unbound_entries),
session: Session = Depends(get_session),
) -> dict[str, int]:
"""Deletes old processing logs using query values or persisted retention defaults."""
@@ -61,10 +63,19 @@ def trim_processing_logs(
else int(retention_defaults.get("keep_unbound_entries", 80))
)
capped_keep_document_sessions = min(
settings.processing_log_max_document_sessions,
max(0, int(resolved_keep_document_sessions)),
)
capped_keep_unbound_entries = min(
settings.processing_log_max_unbound_entries,
max(0, int(resolved_keep_unbound_entries)),
)
result = cleanup_processing_logs(
session=session,
keep_document_sessions=resolved_keep_document_sessions,
keep_unbound_entries=resolved_keep_unbound_entries,
keep_document_sessions=capped_keep_document_sessions,
keep_unbound_entries=capped_keep_unbound_entries,
)
session.commit()
return result

View File

@@ -1,6 +1,6 @@
"""API routes for managing persistent single-user application settings."""
"""Admin-only API routes for managing persistent single-user application settings."""
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from app.schemas.settings import (
AppSettingsUpdateRequest,
@@ -18,6 +18,7 @@ from app.schemas.settings import (
UploadDefaultsResponse,
)
from app.services.app_settings import (
AppSettingsValidationError,
TASK_OCR_HANDWRITING,
TASK_ROUTING_CLASSIFICATION,
TASK_SUMMARY_GENERATION,
@@ -179,16 +180,19 @@ def set_app_settings(payload: AppSettingsUpdateRequest) -> AppSettingsResponse:
if payload.predefined_tags is not None:
predefined_tags_payload = [item.model_dump(exclude_none=True) for item in payload.predefined_tags]
updated = update_app_settings(
providers=providers_payload,
tasks=tasks_payload,
upload_defaults=upload_defaults_payload,
display=display_payload,
processing_log_retention=processing_log_retention_payload,
handwriting_style=handwriting_style_payload,
predefined_paths=predefined_paths_payload,
predefined_tags=predefined_tags_payload,
)
try:
updated = update_app_settings(
providers=providers_payload,
tasks=tasks_payload,
upload_defaults=upload_defaults_payload,
display=display_payload,
processing_log_retention=processing_log_retention_payload,
handwriting_style=handwriting_style_payload,
predefined_paths=predefined_paths_payload,
predefined_tags=predefined_tags_payload,
)
except AppSettingsValidationError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return _build_response(updated)
@@ -203,14 +207,17 @@ def reset_settings_to_defaults() -> AppSettingsResponse:
def set_handwriting_settings(payload: HandwritingSettingsUpdateRequest) -> AppSettingsResponse:
"""Updates handwriting transcription settings and returns the resulting configuration."""
updated = update_handwriting_settings(
enabled=payload.enabled,
openai_base_url=payload.openai_base_url,
openai_model=payload.openai_model,
openai_timeout_seconds=payload.openai_timeout_seconds,
openai_api_key=payload.openai_api_key,
clear_openai_api_key=payload.clear_openai_api_key,
)
try:
updated = update_handwriting_settings(
enabled=payload.enabled,
openai_base_url=payload.openai_base_url,
openai_model=payload.openai_model,
openai_timeout_seconds=payload.openai_timeout_seconds,
openai_api_key=payload.openai_api_key,
clear_openai_api_key=payload.clear_openai_api_key,
)
except AppSettingsValidationError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return _build_response(updated)

View File

@@ -1,7 +1,10 @@
"""Application settings and environment configuration."""
from functools import lru_cache
import ipaddress
from pathlib import Path
import socket
from urllib.parse import urlparse, urlunparse
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -18,9 +21,24 @@ class Settings(BaseSettings):
redis_url: str = "redis://redis:6379/0"
storage_root: Path = Path("/data/storage")
upload_chunk_size: int = 4 * 1024 * 1024
max_upload_files_per_request: int = 50
max_upload_file_size_bytes: int = 25 * 1024 * 1024
max_upload_request_size_bytes: int = 100 * 1024 * 1024
max_zip_members: int = 250
max_zip_depth: int = 2
max_zip_member_uncompressed_bytes: int = 25 * 1024 * 1024
max_zip_total_uncompressed_bytes: int = 150 * 1024 * 1024
max_zip_compression_ratio: float = 120.0
max_text_length: int = 500_000
admin_api_token: str = ""
user_api_token: str = ""
provider_base_url_allowlist: list[str] = Field(default_factory=lambda: ["api.openai.com"])
provider_base_url_allow_http: bool = False
provider_base_url_allow_private_network: bool = False
processing_log_max_document_sessions: int = 20
processing_log_max_unbound_entries: int = 400
processing_log_max_payload_chars: int = 4096
processing_log_max_text_chars: int = 12000
default_openai_base_url: str = "https://api.openai.com/v1"
default_openai_model: str = "gpt-4.1-mini"
default_openai_timeout_seconds: int = 45
@@ -39,6 +57,187 @@ class Settings(BaseSettings):
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173", "http://localhost:3000"])
LOCAL_HOSTNAME_SUFFIXES = (".local", ".internal", ".home.arpa")
def _normalize_allowlist(allowlist: object) -> tuple[str, ...]:
"""Normalizes host allowlist entries to lowercase DNS labels."""
if not isinstance(allowlist, (list, tuple, set)):
return ()
normalized = {
candidate.strip().lower().rstrip(".")
for candidate in allowlist
if isinstance(candidate, str) and candidate.strip()
}
return tuple(sorted(normalized))
def _host_matches_allowlist(hostname: str, allowlist: tuple[str, ...]) -> bool:
"""Returns whether a hostname is included by an exact or subdomain allowlist rule."""
if not allowlist:
return False
candidate = hostname.lower().rstrip(".")
for allowed_host in allowlist:
if candidate == allowed_host or candidate.endswith(f".{allowed_host}"):
return True
return False
def _is_private_or_special_ip(value: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Returns whether an IP belongs to private, loopback, link-local, or reserved ranges."""
return (
value.is_private
or value.is_loopback
or value.is_link_local
or value.is_multicast
or value.is_reserved
or value.is_unspecified
)
def _validate_resolved_host_ips(hostname: str, port: int, allow_private_network: bool) -> None:
"""Resolves hostnames and rejects private or special addresses when private network access is disabled."""
try:
addresses = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as error:
raise ValueError(f"Provider base URL host cannot be resolved: {hostname}") from error
resolved_ips: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
for entry in addresses:
sockaddr = entry[4]
if not sockaddr:
continue
ip_text = sockaddr[0]
try:
resolved_ips.add(ipaddress.ip_address(ip_text))
except ValueError:
continue
if not resolved_ips:
raise ValueError(f"Provider base URL host resolved without usable IP addresses: {hostname}")
if allow_private_network:
return
blocked = [ip for ip in resolved_ips if _is_private_or_special_ip(ip)]
if blocked:
blocked_text = ", ".join(str(ip) for ip in blocked)
raise ValueError(f"Provider base URL resolves to private or special IP addresses: {blocked_text}")
def _normalize_and_validate_provider_base_url(
raw_value: str,
allowlist: tuple[str, ...],
allow_http: bool,
allow_private_network: bool,
resolve_dns: bool,
) -> str:
"""Normalizes and validates provider base URLs with SSRF-safe scheme and host checks."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
raise ValueError("Provider base URL must not be empty")
parsed = urlparse(trimmed)
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"}:
raise ValueError("Provider base URL must use http or https")
if scheme == "http" and not allow_http:
raise ValueError("Provider base URL must use https")
if parsed.query or parsed.fragment:
raise ValueError("Provider base URL must not include query strings or fragments")
if parsed.username or parsed.password:
raise ValueError("Provider base URL must not include embedded credentials")
hostname = (parsed.hostname or "").lower().rstrip(".")
if not hostname:
raise ValueError("Provider base URL must include a hostname")
if allowlist and not _host_matches_allowlist(hostname, allowlist):
allowed_hosts = ", ".join(allowlist)
raise ValueError(f"Provider base URL host is not in allowlist: {hostname}. Allowed hosts: {allowed_hosts}")
if hostname == "localhost" or hostname.endswith(LOCAL_HOSTNAME_SUFFIXES):
if not allow_private_network:
raise ValueError("Provider base URL must not target local or internal hostnames")
try:
ip_host = ipaddress.ip_address(hostname)
except ValueError:
ip_host = None
if ip_host is not None:
if not allow_private_network and _is_private_or_special_ip(ip_host):
raise ValueError("Provider base URL must not target private or special IP addresses")
elif resolve_dns:
resolved_port = parsed.port
if resolved_port is None:
resolved_port = 443 if scheme == "https" else 80
_validate_resolved_host_ips(
hostname=hostname,
port=resolved_port,
allow_private_network=allow_private_network,
)
path = (parsed.path or "").rstrip("/")
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
normalized_hostname = hostname
if ":" in normalized_hostname and not normalized_hostname.startswith("["):
normalized_hostname = f"[{normalized_hostname}]"
netloc = f"{normalized_hostname}:{parsed.port}" if parsed.port is not None else normalized_hostname
return urlunparse((scheme, netloc, path, "", "", ""))
@lru_cache(maxsize=256)
def _normalize_and_validate_provider_base_url_cached(
raw_value: str,
allowlist: tuple[str, ...],
allow_http: bool,
allow_private_network: bool,
) -> str:
"""Caches provider URL validation results for non-DNS-resolved checks."""
return _normalize_and_validate_provider_base_url(
raw_value=raw_value,
allowlist=allowlist,
allow_http=allow_http,
allow_private_network=allow_private_network,
resolve_dns=False,
)
def normalize_and_validate_provider_base_url(raw_value: str, *, resolve_dns: bool = False) -> str:
"""Validates and normalizes provider base URL values using configured SSRF protections."""
settings = get_settings()
allowlist = _normalize_allowlist(settings.provider_base_url_allowlist)
allow_http = settings.provider_base_url_allow_http if isinstance(settings.provider_base_url_allow_http, bool) else False
allow_private_network = (
settings.provider_base_url_allow_private_network
if isinstance(settings.provider_base_url_allow_private_network, bool)
else False
)
if resolve_dns:
return _normalize_and_validate_provider_base_url(
raw_value=raw_value,
allowlist=allowlist,
allow_http=allow_http,
allow_private_network=allow_private_network,
resolve_dns=True,
)
return _normalize_and_validate_provider_base_url_cached(
raw_value=raw_value,
allowlist=allowlist,
allow_http=allow_http,
allow_private_network=allow_private_network,
)
@lru_cache(maxsize=1)
def get_settings() -> Settings:
"""Returns a cached settings object for dependency injection and service access."""

View File

@@ -1,7 +1,8 @@
"""FastAPI entrypoint for the DMS backend service."""
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.router import api_router
from app.core.config import get_settings
@@ -28,6 +29,35 @@ def create_app() -> FastAPI:
)
app.include_router(api_router, prefix="/api/v1")
@app.middleware("http")
async def enforce_upload_request_size(request: Request, call_next):
"""Rejects upload requests without deterministic length or exceeding configured limits."""
if request.url.path.endswith("/api/v1/documents/upload"):
content_length = request.headers.get("content-length", "").strip()
if not content_length:
return JSONResponse(
status_code=411,
content={"detail": "Content-Length header is required for document uploads"},
)
try:
content_length_value = int(content_length)
except ValueError:
return JSONResponse(status_code=400, content={"detail": "Invalid Content-Length header"})
if content_length_value <= 0:
return JSONResponse(status_code=400, content={"detail": "Content-Length must be a positive integer"})
if content_length_value > settings.max_upload_request_size_bytes:
return JSONResponse(
status_code=413,
content={
"detail": (
"Upload request exceeds total size limit "
f"({content_length_value} > {settings.max_upload_request_size_bytes} bytes)"
)
},
)
return await call_next(request)
@app.on_event("startup")
def startup_event() -> None:
"""Initializes storage directories and database schema on service startup."""

View File

@@ -2,14 +2,118 @@
import uuid
from datetime import UTC, datetime
import re
from typing import Any
from sqlalchemy import BigInteger, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, validates
from app.core.config import get_settings
from app.db.base import Base
settings = get_settings()
SENSITIVE_KEY_MARKERS = (
"api_key",
"apikey",
"authorization",
"bearer",
"token",
"secret",
"password",
"credential",
"cookie",
)
SENSITIVE_TEXT_PATTERNS = (
re.compile(r"(?i)\bauthorization\b\s*[:=]\s*bearer\s+[a-z0-9._~+/\-]+=*"),
re.compile(r"(?i)\bbearer\s+[a-z0-9._~+/\-]+=*"),
re.compile(r"\b[a-z0-9_-]{8,}\.[a-z0-9_-]{8,}\.[a-z0-9_-]{8,}\b", flags=re.IGNORECASE),
re.compile(r"(?i)\bsk-[a-z0-9]{16,}\b"),
re.compile(r"(?i)\b(api[_-]?key|token|secret|password)\b\s*[:=]\s*['\"]?[^\s,'\";]+['\"]?"),
)
REDACTED_TEXT = "[REDACTED]"
MAX_PAYLOAD_KEYS = 80
MAX_PAYLOAD_LIST_ITEMS = 80
def _truncate(value: str, limit: int) -> str:
"""Truncates long log fields to configured bounds with stable suffix marker."""
normalized = value.strip()
if len(normalized) <= limit:
return normalized
return normalized[: max(0, limit - 3)] + "..."
def _is_sensitive_key(key: str) -> bool:
"""Returns whether a payload key likely contains sensitive credential data."""
normalized = key.strip().lower()
return any(marker in normalized for marker in SENSITIVE_KEY_MARKERS)
def _redact_sensitive_text(value: str) -> str:
"""Redacts token-like segments from log text while retaining non-sensitive context."""
redacted = value
for pattern in SENSITIVE_TEXT_PATTERNS:
redacted = pattern.sub(lambda _: REDACTED_TEXT, redacted)
return redacted
def sanitize_processing_log_payload_value(value: Any, *, parent_key: str | None = None) -> Any:
"""Sanitizes payload structures by redacting sensitive fields and bounding size."""
if parent_key and _is_sensitive_key(parent_key):
return REDACTED_TEXT
if isinstance(value, dict):
sanitized: dict[str, Any] = {}
for index, (raw_key, raw_value) in enumerate(value.items()):
if index >= MAX_PAYLOAD_KEYS:
break
key = str(raw_key)
sanitized[key] = sanitize_processing_log_payload_value(raw_value, parent_key=key)
return sanitized
if isinstance(value, list):
return [
sanitize_processing_log_payload_value(item, parent_key=parent_key)
for item in value[:MAX_PAYLOAD_LIST_ITEMS]
]
if isinstance(value, tuple):
return [
sanitize_processing_log_payload_value(item, parent_key=parent_key)
for item in list(value)[:MAX_PAYLOAD_LIST_ITEMS]
]
if isinstance(value, str):
redacted = _redact_sensitive_text(value)
return _truncate(redacted, settings.processing_log_max_payload_chars)
if isinstance(value, (int, float, bool)) or value is None:
return value
as_text = _truncate(str(value), settings.processing_log_max_payload_chars)
return _redact_sensitive_text(as_text)
def sanitize_processing_log_text(value: str | None) -> str | None:
"""Sanitizes prompt and response fields by redacting credentials and clamping length."""
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
redacted = _redact_sensitive_text(normalized)
return _truncate(redacted, settings.processing_log_max_text_chars)
class ProcessingLogEntry(Base):
"""Stores a timestamped processing event with optional model prompt and response text."""
@@ -31,3 +135,17 @@ class ProcessingLogEntry(Base):
prompt_text: Mapped[str | None] = mapped_column(Text, nullable=True)
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
payload_json: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict)
@validates("prompt_text", "response_text")
def _validate_text_fields(self, key: str, value: str | None) -> str | None:
"""Redacts and bounds free-text log fields before persistence."""
return sanitize_processing_log_text(value)
@validates("payload_json")
def _validate_payload_json(self, key: str, value: dict[str, Any] | None) -> dict[str, Any]:
"""Redacts and bounds structured payload fields before persistence."""
if not isinstance(value, dict):
return {}
return sanitize_processing_log_payload_value(value)

View File

@@ -1,13 +1,16 @@
"""Pydantic schemas for processing pipeline log API payloads."""
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
class ProcessingLogEntryResponse(BaseModel):
"""Represents one persisted processing log event returned by API endpoints."""
"""Represents one persisted processing log event with already-redacted sensitive fields."""
id: int
created_at: datetime
@@ -20,7 +23,26 @@ class ProcessingLogEntryResponse(BaseModel):
model_name: str | None
prompt_text: str | None
response_text: str | None
payload_json: dict
payload_json: dict[str, Any]
@field_validator("prompt_text", "response_text", mode="before")
@classmethod
def _sanitize_text_fields(cls, value: Any) -> str | None:
"""Ensures log text fields are redacted in API responses."""
if value is None:
return None
return sanitize_processing_log_text(str(value))
@field_validator("payload_json", mode="before")
@classmethod
def _sanitize_payload_field(cls, value: Any) -> dict[str, Any]:
"""Ensures payload fields are redacted in API responses."""
if not isinstance(value, dict):
return {}
sanitized = sanitize_processing_log_payload_value(value)
return sanitized if isinstance(sanitized, dict) else {}
class Config:
"""Enables ORM object parsing for SQLAlchemy model instances."""

View File

@@ -5,12 +5,16 @@ import re
from pathlib import Path
from typing import Any
from app.core.config import get_settings
from app.core.config import get_settings, normalize_and_validate_provider_base_url
settings = get_settings()
class AppSettingsValidationError(ValueError):
"""Raised when user-provided settings values fail security or contract validation."""
TASK_OCR_HANDWRITING = "ocr_handwriting"
TASK_SUMMARY_GENERATION = "summary_generation"
TASK_ROUTING_CLASSIFICATION = "routing_classification"
@@ -156,13 +160,13 @@ def _clamp_cards_per_page(value: int) -> int:
def _clamp_processing_log_document_sessions(value: int) -> int:
"""Clamps the number of recent document log sessions kept during cleanup."""
return max(0, min(20, value))
return max(0, min(settings.processing_log_max_document_sessions, value))
def _clamp_processing_log_unbound_entries(value: int) -> int:
"""Clamps retained unbound processing log events kept during cleanup."""
return max(0, min(400, value))
return max(0, min(settings.processing_log_max_unbound_entries, value))
def _clamp_predefined_entries_limit(value: int) -> int:
@@ -242,12 +246,19 @@ def _normalize_provider(
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
api_key = str(api_key_value).strip() if api_key_value is not None else ""
raw_base_url = str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
if not raw_base_url:
raw_base_url = str(defaults["base_url"]).strip()
try:
normalized_base_url = normalize_and_validate_provider_base_url(raw_base_url)
except ValueError as error:
raise AppSettingsValidationError(str(error)) from error
return {
"id": provider_id,
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
"provider_type": provider_type,
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
or defaults["base_url"],
"base_url": normalized_base_url,
"timeout_seconds": _clamp_timeout(
_safe_int(
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),

View File

@@ -300,16 +300,39 @@ def extract_text_content(filename: str, data: bytes, mime_type: str) -> Extracti
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
"""Extracts processable members from zip archives with configurable depth limits."""
"""Extracts processable ZIP members within configured decompression safety budgets."""
members: list[ArchiveMember] = []
if depth > settings.max_zip_depth:
return members
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
member_data = archive.read(info.filename)
members.append(ArchiveMember(name=info.filename, data=member_data))
total_uncompressed_bytes = 0
try:
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
if info.file_size <= 0:
continue
if info.file_size > settings.max_zip_member_uncompressed_bytes:
continue
if total_uncompressed_bytes + info.file_size > settings.max_zip_total_uncompressed_bytes:
continue
compressed_size = max(1, int(info.compress_size))
compression_ratio = float(info.file_size) / float(compressed_size)
if compression_ratio > settings.max_zip_compression_ratio:
continue
with archive.open(info, mode="r") as archive_member:
member_data = archive_member.read(settings.max_zip_member_uncompressed_bytes + 1)
if len(member_data) > settings.max_zip_member_uncompressed_bytes:
continue
if total_uncompressed_bytes + len(member_data) > settings.max_zip_total_uncompressed_bytes:
continue
total_uncompressed_bytes += len(member_data)
members.append(ArchiveMember(name=info.filename, data=member_data))
except zipfile.BadZipFile:
return []
return members

View File

@@ -2,10 +2,10 @@
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse, urlunparse
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from app.core.config import normalize_and_validate_provider_base_url
from app.services.app_settings import read_task_runtime_settings
@@ -36,18 +36,9 @@ class ModelTaskRuntime:
def _normalize_base_url(raw_value: str) -> str:
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
"""Normalizes provider base URL and enforces SSRF protections before outbound calls."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
return "https://api.openai.com/v1"
parsed = urlparse(trimmed)
path = parsed.path or ""
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
return urlunparse(parsed._replace(path=path))
return normalize_and_validate_provider_base_url(raw_value, resolve_dns=True)
def _should_fallback_to_chat(error: Exception) -> bool:
@@ -137,11 +128,16 @@ def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
if provider_type != "openai_compatible":
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
try:
normalized_base_url = _normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1")))
except ValueError as error:
raise ModelTaskError(f"invalid_provider_base_url:{error}") from error
return ModelTaskRuntime(
task_name=task_name,
provider_id=str(provider_payload.get("id", "")),
provider_type=provider_type,
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
base_url=normalized_base_url,
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
model=str(task_payload.get("model", "")).strip(),

View File

@@ -0,0 +1,273 @@
"""Unit coverage for API auth, SSRF validation, and processing-log redaction controls."""
from __future__ import annotations
from datetime import UTC, datetime
import socket
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import unittest
from unittest.mock import patch
BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if "pydantic_settings" not in sys.modules:
pydantic_settings_stub = ModuleType("pydantic_settings")
class _BaseSettings:
"""Minimal BaseSettings replacement for dependency-light unit test execution."""
def __init__(self, **kwargs: object) -> None:
for key, value in kwargs.items():
setattr(self, key, value)
def _settings_config_dict(**kwargs: object) -> dict[str, object]:
"""Returns configuration values using dict semantics expected by settings module."""
return kwargs
pydantic_settings_stub.BaseSettings = _BaseSettings
pydantic_settings_stub.SettingsConfigDict = _settings_config_dict
sys.modules["pydantic_settings"] = pydantic_settings_stub
if "fastapi" not in sys.modules:
fastapi_stub = ModuleType("fastapi")
class _HTTPException(Exception):
"""Minimal HTTPException compatible with route dependency tests."""
def __init__(self, status_code: int, detail: str, headers: dict[str, str] | None = None) -> None:
super().__init__(detail)
self.status_code = status_code
self.detail = detail
self.headers = headers or {}
class _Status:
"""Minimal status namespace for auth unit tests."""
HTTP_401_UNAUTHORIZED = 401
HTTP_403_FORBIDDEN = 403
HTTP_503_SERVICE_UNAVAILABLE = 503
def _depends(dependency): # type: ignore[no-untyped-def]
"""Returns provided dependency unchanged for unit testing."""
return dependency
fastapi_stub.Depends = _depends
fastapi_stub.HTTPException = _HTTPException
fastapi_stub.status = _Status()
sys.modules["fastapi"] = fastapi_stub
if "fastapi.security" not in sys.modules:
fastapi_security_stub = ModuleType("fastapi.security")
class _HTTPAuthorizationCredentials:
"""Minimal bearer credential object used by auth dependency tests."""
def __init__(self, *, scheme: str, credentials: str) -> None:
self.scheme = scheme
self.credentials = credentials
class _HTTPBearer:
"""Minimal HTTPBearer stand-in for dependency construction."""
def __init__(self, auto_error: bool = True) -> None:
self.auto_error = auto_error
fastapi_security_stub.HTTPAuthorizationCredentials = _HTTPAuthorizationCredentials
fastapi_security_stub.HTTPBearer = _HTTPBearer
sys.modules["fastapi.security"] = fastapi_security_stub
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from app.api.auth import AuthRole, get_request_role, require_admin
from app.core import config as config_module
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
from app.schemas.processing_logs import ProcessingLogEntryResponse
def _security_settings(
*,
allowlist: list[str] | None = None,
allow_http: bool = False,
allow_private_network: bool = False,
) -> SimpleNamespace:
"""Builds lightweight settings object for provider URL validation tests."""
return SimpleNamespace(
provider_base_url_allowlist=allowlist if allowlist is not None else ["api.openai.com"],
provider_base_url_allow_http=allow_http,
provider_base_url_allow_private_network=allow_private_network,
)
class AuthDependencyTests(unittest.TestCase):
"""Verifies token authentication and admin authorization behavior."""
def test_get_request_role_accepts_admin_token(self) -> None:
"""Admin token resolves admin role."""
settings = SimpleNamespace(admin_api_token="admin-token", user_api_token="user-token")
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="admin-token")
role = get_request_role(credentials=credentials, settings=settings)
self.assertEqual(role, AuthRole.ADMIN)
def test_get_request_role_rejects_missing_credentials(self) -> None:
"""Missing bearer credentials return 401."""
settings = SimpleNamespace(admin_api_token="admin-token", user_api_token="user-token")
with self.assertRaises(HTTPException) as context:
get_request_role(credentials=None, settings=settings)
self.assertEqual(context.exception.status_code, 401)
def test_require_admin_rejects_user_role(self) -> None:
"""User role cannot access admin-only endpoints."""
with self.assertRaises(HTTPException) as context:
require_admin(role=AuthRole.USER)
self.assertEqual(context.exception.status_code, 403)
class ProviderBaseUrlValidationTests(unittest.TestCase):
"""Verifies allowlist, scheme, and private-network SSRF protections."""
def setUp(self) -> None:
"""Clears URL validation cache to keep tests independent."""
config_module._normalize_and_validate_provider_base_url_cached.cache_clear()
def test_validation_accepts_allowlisted_https_url(self) -> None:
"""Allowlisted HTTPS URLs are normalized with /v1 suffix."""
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])):
normalized = config_module.normalize_and_validate_provider_base_url("https://api.openai.com")
self.assertEqual(normalized, "https://api.openai.com/v1")
def test_validation_rejects_non_allowlisted_host(self) -> None:
"""Hosts outside configured allowlist are rejected."""
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])):
with self.assertRaises(ValueError):
config_module.normalize_and_validate_provider_base_url("https://example.org/v1")
def test_validation_rejects_private_ip_literal(self) -> None:
"""Private and loopback IP literals are blocked."""
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=[])):
with self.assertRaises(ValueError):
config_module.normalize_and_validate_provider_base_url("https://127.0.0.1/v1")
def test_validation_rejects_private_ip_after_dns_resolution(self) -> None:
"""DNS rebind protection blocks public hostnames resolving to private addresses."""
mocked_dns_response = [
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", ("127.0.0.1", 443)),
]
with (
patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])),
patch.object(config_module.socket, "getaddrinfo", return_value=mocked_dns_response),
):
with self.assertRaises(ValueError):
config_module.normalize_and_validate_provider_base_url(
"https://api.openai.com/v1",
resolve_dns=True,
)
def test_resolve_dns_validation_revalidates_each_call(self) -> None:
"""DNS-resolved validation is not cached and re-checks host resolution each call."""
mocked_dns_response = [
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", ("8.8.8.8", 443)),
]
with (
patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])),
patch.object(config_module.socket, "getaddrinfo", return_value=mocked_dns_response) as getaddrinfo_mock,
):
first = config_module.normalize_and_validate_provider_base_url(
"https://api.openai.com/v1",
resolve_dns=True,
)
second = config_module.normalize_and_validate_provider_base_url(
"https://api.openai.com/v1",
resolve_dns=True,
)
self.assertEqual(first, "https://api.openai.com/v1")
self.assertEqual(second, "https://api.openai.com/v1")
self.assertEqual(getaddrinfo_mock.call_count, 2)
class ProcessingLogRedactionTests(unittest.TestCase):
"""Verifies sensitive processing-log values are redacted for persistence and responses."""
def test_payload_redacts_sensitive_keys(self) -> None:
"""Sensitive payload keys are replaced with redaction marker."""
sanitized = sanitize_processing_log_payload_value(
{
"api_key": "secret-value",
"nested": {
"authorization": "Bearer sample-token",
},
}
)
self.assertEqual(sanitized["api_key"], "[REDACTED]")
self.assertEqual(sanitized["nested"]["authorization"], "[REDACTED]")
def test_text_redaction_removes_bearer_and_jwt_values(self) -> None:
"""Bearer and JWT token substrings are fully removed from log text."""
bearer_token = "super-secret-token-123"
jwt_token = (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4ifQ."
"signaturevalue123456789"
)
sanitized = sanitize_processing_log_text(
f"Authorization: Bearer {bearer_token}\nraw_jwt={jwt_token}"
)
self.assertIsNotNone(sanitized)
sanitized_text = sanitized or ""
self.assertIn("[REDACTED]", sanitized_text)
self.assertNotIn(bearer_token, sanitized_text)
self.assertNotIn(jwt_token, sanitized_text)
def test_response_schema_applies_redaction_to_existing_entries(self) -> None:
"""API schema validators redact sensitive fields from legacy stored rows."""
bearer_token = "abc123token"
jwt_token = (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4ifQ."
"signaturevalue123456789"
)
response = ProcessingLogEntryResponse.model_validate(
{
"id": 1,
"created_at": datetime.now(UTC),
"level": "info",
"stage": "summary",
"event": "response",
"document_id": None,
"document_filename": "sample.txt",
"provider_id": "provider",
"model_name": "model",
"prompt_text": f"Authorization: Bearer {bearer_token}",
"response_text": f"token={jwt_token}",
"payload_json": {"password": "secret", "trace_id": "trace-1"},
}
)
self.assertEqual(response.payload_json["password"], "[REDACTED]")
self.assertIn("[REDACTED]", response.prompt_text or "")
self.assertIn("[REDACTED]", response.response_text or "")
self.assertNotIn(bearer_token, response.prompt_text or "")
self.assertNotIn(jwt_token, response.response_text or "")
if __name__ == "__main__":
unittest.main()