Harden auth, redaction, upload size checks, and compose token requirements

This commit is contained in:
2026-02-21 13:48:55 -03:00
parent 5792586a90
commit 3cbad053cc
21 changed files with 1168 additions and 85 deletions

87
backend/app/api/auth.py Normal file
View File

@@ -0,0 +1,87 @@
"""Token-based authentication and authorization dependencies for privileged API routes."""
import hmac
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.core.config import Settings, get_settings
bearer_auth = HTTPBearer(auto_error=False)
class AuthRole:
"""Declares supported authorization roles for privileged API operations."""
ADMIN = "admin"
USER = "user"
def _raise_unauthorized() -> None:
"""Raises an HTTP 401 response with bearer authentication challenge headers."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing API token",
headers={"WWW-Authenticate": "Bearer"},
)
def _configured_admin_token(settings: Settings) -> str:
"""Returns required admin token or raises configuration error when unset."""
token = settings.admin_api_token.strip()
if token:
return token
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Admin API token is not configured",
)
def _resolve_token_role(token: str, settings: Settings) -> str:
"""Resolves role from a bearer token using constant-time comparisons."""
admin_token = _configured_admin_token(settings)
if hmac.compare_digest(token, admin_token):
return AuthRole.ADMIN
user_token = settings.user_api_token.strip()
if user_token and hmac.compare_digest(token, user_token):
return AuthRole.USER
_raise_unauthorized()
def get_request_role(
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(bearer_auth)],
settings: Annotated[Settings, Depends(get_settings)],
) -> str:
"""Authenticates request token and returns its authorization role."""
if credentials is None:
_raise_unauthorized()
token = credentials.credentials.strip()
if not token:
_raise_unauthorized()
return _resolve_token_role(token=token, settings=settings)
def require_user_or_admin(role: Annotated[str, Depends(get_request_role)]) -> str:
"""Requires a valid user or admin token and returns resolved role."""
return role
def require_admin(role: Annotated[str, Depends(get_request_role)]) -> str:
"""Requires admin role and rejects requests authenticated as regular users."""
if role != AuthRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin token required",
)
return role

View File

@@ -1,7 +1,8 @@
"""API router registration for all HTTP route modules."""
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from app.api.auth import require_admin, require_user_or_admin
from app.api.routes_documents import router as documents_router
from app.api.routes_health import router as health_router
from app.api.routes_processing_logs import router as processing_logs_router
@@ -11,7 +12,27 @@ from app.api.routes_settings import router as settings_router
api_router = APIRouter()
api_router.include_router(health_router)
api_router.include_router(documents_router, prefix="/documents", tags=["documents"])
api_router.include_router(processing_logs_router, prefix="/processing/logs", tags=["processing-logs"])
api_router.include_router(search_router, prefix="/search", tags=["search"])
api_router.include_router(settings_router, prefix="/settings", tags=["settings"])
api_router.include_router(
documents_router,
prefix="/documents",
tags=["documents"],
dependencies=[Depends(require_user_or_admin)],
)
api_router.include_router(
processing_logs_router,
prefix="/processing/logs",
tags=["processing-logs"],
dependencies=[Depends(require_admin)],
)
api_router.include_router(
search_router,
prefix="/search",
tags=["search"],
dependencies=[Depends(require_user_or_admin)],
)
api_router.include_router(
settings_router,
prefix="/settings",
tags=["settings"],
dependencies=[Depends(require_admin)],
)

View File

@@ -1,4 +1,4 @@
"""Document CRUD, lifecycle, metadata, file access, and content export endpoints."""
"""Authenticated document CRUD, lifecycle, metadata, file access, and content export endpoints."""
import io
import re
@@ -14,7 +14,7 @@ from fastapi.responses import FileResponse, Response, StreamingResponse
from sqlalchemy import or_, func, select
from sqlalchemy.orm import Session
from app.services.app_settings import read_predefined_paths_settings, read_predefined_tags_settings
from app.core.config import get_settings
from app.db.base import get_session
from app.models.document import Document, DocumentStatus
from app.schemas.documents import (
@@ -26,6 +26,7 @@ from app.schemas.documents import (
UploadConflict,
UploadResponse,
)
from app.services.app_settings import read_predefined_paths_settings, read_predefined_tags_settings
from app.services.extractor import sniff_mime
from app.services.handwriting_style import delete_many_handwriting_style_documents
from app.services.processing_logs import log_processing_event, set_processing_log_autocommit
@@ -35,6 +36,7 @@ from app.worker.queue import get_processing_queue
router = APIRouter()
settings = get_settings()
def _parse_csv(value: str | None) -> list[str]:
@@ -227,6 +229,33 @@ def _build_document_list_statement(
return statement
def _enforce_upload_shape(files: list[UploadFile]) -> None:
"""Validates upload request shape against configured file-count bounds."""
if not files:
raise HTTPException(status_code=400, detail="Upload request must include at least one file")
if len(files) > settings.max_upload_files_per_request:
raise HTTPException(
status_code=413,
detail=(
"Upload request exceeds file count limit "
f"({len(files)} > {settings.max_upload_files_per_request})"
),
)
async def _read_upload_bytes(file: UploadFile, max_bytes: int) -> bytes:
"""Reads one upload file while enforcing per-file byte limits."""
data = await file.read(max_bytes + 1)
if len(data) > max_bytes:
raise HTTPException(
status_code=413,
detail=f"File '{file.filename or 'upload'}' exceeds per-file limit of {max_bytes} bytes",
)
return data
def _collect_document_tree(session: Session, root_document_id: UUID) -> list[tuple[int, Document]]:
"""Collects a document and all descendants for recursive permanent deletion."""
@@ -472,18 +501,29 @@ async def upload_documents(
) -> UploadResponse:
"""Uploads files, records metadata, and enqueues asynchronous extraction tasks."""
_enforce_upload_shape(files)
set_processing_log_autocommit(session, True)
normalized_tags = _normalize_tags(tags)
queue = get_processing_queue()
uploaded: list[DocumentResponse] = []
conflicts: list[UploadConflict] = []
total_request_bytes = 0
indexed_relative_paths = relative_paths or []
prepared_uploads: list[dict[str, object]] = []
for idx, file in enumerate(files):
filename = file.filename or f"uploaded_{idx}"
data = await file.read()
data = await _read_upload_bytes(file, settings.max_upload_file_size_bytes)
total_request_bytes += len(data)
if total_request_bytes > settings.max_upload_request_size_bytes:
raise HTTPException(
status_code=413,
detail=(
"Upload request exceeds total size limit "
f"({total_request_bytes} > {settings.max_upload_request_size_bytes} bytes)"
),
)
sha256 = compute_sha256(data)
source_relative_path = indexed_relative_paths[idx] if idx < len(indexed_relative_paths) else filename
extension = Path(filename).suffix.lower()

View File

@@ -1,10 +1,11 @@
"""Read-only API endpoints for processing pipeline event logs."""
"""Admin-only API endpoints for processing pipeline event logs."""
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.db.base import get_session
from app.schemas.processing_logs import ProcessingLogEntryResponse, ProcessingLogListResponse
from app.services.app_settings import read_processing_log_retention_settings
@@ -17,12 +18,13 @@ from app.services.processing_logs import (
router = APIRouter()
settings = get_settings()
@router.get("", response_model=ProcessingLogListResponse)
def get_processing_logs(
offset: int = Query(default=0, ge=0),
limit: int = Query(default=120, ge=1, le=400),
limit: int = Query(default=120, ge=1, le=settings.processing_log_max_unbound_entries),
document_id: UUID | None = Query(default=None),
session: Session = Depends(get_session),
) -> ProcessingLogListResponse:
@@ -43,8 +45,8 @@ def get_processing_logs(
@router.post("/trim")
def trim_processing_logs(
keep_document_sessions: int | None = Query(default=None, ge=0, le=20),
keep_unbound_entries: int | None = Query(default=None, ge=0, le=400),
keep_document_sessions: int | None = Query(default=None, ge=0, le=settings.processing_log_max_document_sessions),
keep_unbound_entries: int | None = Query(default=None, ge=0, le=settings.processing_log_max_unbound_entries),
session: Session = Depends(get_session),
) -> dict[str, int]:
"""Deletes old processing logs using query values or persisted retention defaults."""
@@ -61,10 +63,19 @@ def trim_processing_logs(
else int(retention_defaults.get("keep_unbound_entries", 80))
)
capped_keep_document_sessions = min(
settings.processing_log_max_document_sessions,
max(0, int(resolved_keep_document_sessions)),
)
capped_keep_unbound_entries = min(
settings.processing_log_max_unbound_entries,
max(0, int(resolved_keep_unbound_entries)),
)
result = cleanup_processing_logs(
session=session,
keep_document_sessions=resolved_keep_document_sessions,
keep_unbound_entries=resolved_keep_unbound_entries,
keep_document_sessions=capped_keep_document_sessions,
keep_unbound_entries=capped_keep_unbound_entries,
)
session.commit()
return result

View File

@@ -1,6 +1,6 @@
"""API routes for managing persistent single-user application settings."""
"""Admin-only API routes for managing persistent single-user application settings."""
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from app.schemas.settings import (
AppSettingsUpdateRequest,
@@ -18,6 +18,7 @@ from app.schemas.settings import (
UploadDefaultsResponse,
)
from app.services.app_settings import (
AppSettingsValidationError,
TASK_OCR_HANDWRITING,
TASK_ROUTING_CLASSIFICATION,
TASK_SUMMARY_GENERATION,
@@ -179,16 +180,19 @@ def set_app_settings(payload: AppSettingsUpdateRequest) -> AppSettingsResponse:
if payload.predefined_tags is not None:
predefined_tags_payload = [item.model_dump(exclude_none=True) for item in payload.predefined_tags]
updated = update_app_settings(
providers=providers_payload,
tasks=tasks_payload,
upload_defaults=upload_defaults_payload,
display=display_payload,
processing_log_retention=processing_log_retention_payload,
handwriting_style=handwriting_style_payload,
predefined_paths=predefined_paths_payload,
predefined_tags=predefined_tags_payload,
)
try:
updated = update_app_settings(
providers=providers_payload,
tasks=tasks_payload,
upload_defaults=upload_defaults_payload,
display=display_payload,
processing_log_retention=processing_log_retention_payload,
handwriting_style=handwriting_style_payload,
predefined_paths=predefined_paths_payload,
predefined_tags=predefined_tags_payload,
)
except AppSettingsValidationError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return _build_response(updated)
@@ -203,14 +207,17 @@ def reset_settings_to_defaults() -> AppSettingsResponse:
def set_handwriting_settings(payload: HandwritingSettingsUpdateRequest) -> AppSettingsResponse:
"""Updates handwriting transcription settings and returns the resulting configuration."""
updated = update_handwriting_settings(
enabled=payload.enabled,
openai_base_url=payload.openai_base_url,
openai_model=payload.openai_model,
openai_timeout_seconds=payload.openai_timeout_seconds,
openai_api_key=payload.openai_api_key,
clear_openai_api_key=payload.clear_openai_api_key,
)
try:
updated = update_handwriting_settings(
enabled=payload.enabled,
openai_base_url=payload.openai_base_url,
openai_model=payload.openai_model,
openai_timeout_seconds=payload.openai_timeout_seconds,
openai_api_key=payload.openai_api_key,
clear_openai_api_key=payload.clear_openai_api_key,
)
except AppSettingsValidationError as error:
raise HTTPException(status_code=400, detail=str(error)) from error
return _build_response(updated)

View File

@@ -1,7 +1,10 @@
"""Application settings and environment configuration."""
from functools import lru_cache
import ipaddress
from pathlib import Path
import socket
from urllib.parse import urlparse, urlunparse
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -18,9 +21,24 @@ class Settings(BaseSettings):
redis_url: str = "redis://redis:6379/0"
storage_root: Path = Path("/data/storage")
upload_chunk_size: int = 4 * 1024 * 1024
max_upload_files_per_request: int = 50
max_upload_file_size_bytes: int = 25 * 1024 * 1024
max_upload_request_size_bytes: int = 100 * 1024 * 1024
max_zip_members: int = 250
max_zip_depth: int = 2
max_zip_member_uncompressed_bytes: int = 25 * 1024 * 1024
max_zip_total_uncompressed_bytes: int = 150 * 1024 * 1024
max_zip_compression_ratio: float = 120.0
max_text_length: int = 500_000
admin_api_token: str = ""
user_api_token: str = ""
provider_base_url_allowlist: list[str] = Field(default_factory=lambda: ["api.openai.com"])
provider_base_url_allow_http: bool = False
provider_base_url_allow_private_network: bool = False
processing_log_max_document_sessions: int = 20
processing_log_max_unbound_entries: int = 400
processing_log_max_payload_chars: int = 4096
processing_log_max_text_chars: int = 12000
default_openai_base_url: str = "https://api.openai.com/v1"
default_openai_model: str = "gpt-4.1-mini"
default_openai_timeout_seconds: int = 45
@@ -39,6 +57,187 @@ class Settings(BaseSettings):
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173", "http://localhost:3000"])
LOCAL_HOSTNAME_SUFFIXES = (".local", ".internal", ".home.arpa")
def _normalize_allowlist(allowlist: object) -> tuple[str, ...]:
"""Normalizes host allowlist entries to lowercase DNS labels."""
if not isinstance(allowlist, (list, tuple, set)):
return ()
normalized = {
candidate.strip().lower().rstrip(".")
for candidate in allowlist
if isinstance(candidate, str) and candidate.strip()
}
return tuple(sorted(normalized))
def _host_matches_allowlist(hostname: str, allowlist: tuple[str, ...]) -> bool:
"""Returns whether a hostname is included by an exact or subdomain allowlist rule."""
if not allowlist:
return False
candidate = hostname.lower().rstrip(".")
for allowed_host in allowlist:
if candidate == allowed_host or candidate.endswith(f".{allowed_host}"):
return True
return False
def _is_private_or_special_ip(value: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Returns whether an IP belongs to private, loopback, link-local, or reserved ranges."""
return (
value.is_private
or value.is_loopback
or value.is_link_local
or value.is_multicast
or value.is_reserved
or value.is_unspecified
)
def _validate_resolved_host_ips(hostname: str, port: int, allow_private_network: bool) -> None:
"""Resolves hostnames and rejects private or special addresses when private network access is disabled."""
try:
addresses = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as error:
raise ValueError(f"Provider base URL host cannot be resolved: {hostname}") from error
resolved_ips: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
for entry in addresses:
sockaddr = entry[4]
if not sockaddr:
continue
ip_text = sockaddr[0]
try:
resolved_ips.add(ipaddress.ip_address(ip_text))
except ValueError:
continue
if not resolved_ips:
raise ValueError(f"Provider base URL host resolved without usable IP addresses: {hostname}")
if allow_private_network:
return
blocked = [ip for ip in resolved_ips if _is_private_or_special_ip(ip)]
if blocked:
blocked_text = ", ".join(str(ip) for ip in blocked)
raise ValueError(f"Provider base URL resolves to private or special IP addresses: {blocked_text}")
def _normalize_and_validate_provider_base_url(
raw_value: str,
allowlist: tuple[str, ...],
allow_http: bool,
allow_private_network: bool,
resolve_dns: bool,
) -> str:
"""Normalizes and validates provider base URLs with SSRF-safe scheme and host checks."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
raise ValueError("Provider base URL must not be empty")
parsed = urlparse(trimmed)
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"}:
raise ValueError("Provider base URL must use http or https")
if scheme == "http" and not allow_http:
raise ValueError("Provider base URL must use https")
if parsed.query or parsed.fragment:
raise ValueError("Provider base URL must not include query strings or fragments")
if parsed.username or parsed.password:
raise ValueError("Provider base URL must not include embedded credentials")
hostname = (parsed.hostname or "").lower().rstrip(".")
if not hostname:
raise ValueError("Provider base URL must include a hostname")
if allowlist and not _host_matches_allowlist(hostname, allowlist):
allowed_hosts = ", ".join(allowlist)
raise ValueError(f"Provider base URL host is not in allowlist: {hostname}. Allowed hosts: {allowed_hosts}")
if hostname == "localhost" or hostname.endswith(LOCAL_HOSTNAME_SUFFIXES):
if not allow_private_network:
raise ValueError("Provider base URL must not target local or internal hostnames")
try:
ip_host = ipaddress.ip_address(hostname)
except ValueError:
ip_host = None
if ip_host is not None:
if not allow_private_network and _is_private_or_special_ip(ip_host):
raise ValueError("Provider base URL must not target private or special IP addresses")
elif resolve_dns:
resolved_port = parsed.port
if resolved_port is None:
resolved_port = 443 if scheme == "https" else 80
_validate_resolved_host_ips(
hostname=hostname,
port=resolved_port,
allow_private_network=allow_private_network,
)
path = (parsed.path or "").rstrip("/")
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
normalized_hostname = hostname
if ":" in normalized_hostname and not normalized_hostname.startswith("["):
normalized_hostname = f"[{normalized_hostname}]"
netloc = f"{normalized_hostname}:{parsed.port}" if parsed.port is not None else normalized_hostname
return urlunparse((scheme, netloc, path, "", "", ""))
@lru_cache(maxsize=256)
def _normalize_and_validate_provider_base_url_cached(
raw_value: str,
allowlist: tuple[str, ...],
allow_http: bool,
allow_private_network: bool,
) -> str:
"""Caches provider URL validation results for non-DNS-resolved checks."""
return _normalize_and_validate_provider_base_url(
raw_value=raw_value,
allowlist=allowlist,
allow_http=allow_http,
allow_private_network=allow_private_network,
resolve_dns=False,
)
def normalize_and_validate_provider_base_url(raw_value: str, *, resolve_dns: bool = False) -> str:
"""Validates and normalizes provider base URL values using configured SSRF protections."""
settings = get_settings()
allowlist = _normalize_allowlist(settings.provider_base_url_allowlist)
allow_http = settings.provider_base_url_allow_http if isinstance(settings.provider_base_url_allow_http, bool) else False
allow_private_network = (
settings.provider_base_url_allow_private_network
if isinstance(settings.provider_base_url_allow_private_network, bool)
else False
)
if resolve_dns:
return _normalize_and_validate_provider_base_url(
raw_value=raw_value,
allowlist=allowlist,
allow_http=allow_http,
allow_private_network=allow_private_network,
resolve_dns=True,
)
return _normalize_and_validate_provider_base_url_cached(
raw_value=raw_value,
allowlist=allowlist,
allow_http=allow_http,
allow_private_network=allow_private_network,
)
@lru_cache(maxsize=1)
def get_settings() -> Settings:
"""Returns a cached settings object for dependency injection and service access."""

View File

@@ -1,7 +1,8 @@
"""FastAPI entrypoint for the DMS backend service."""
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api.router import api_router
from app.core.config import get_settings
@@ -28,6 +29,35 @@ def create_app() -> FastAPI:
)
app.include_router(api_router, prefix="/api/v1")
@app.middleware("http")
async def enforce_upload_request_size(request: Request, call_next):
"""Rejects upload requests without deterministic length or exceeding configured limits."""
if request.url.path.endswith("/api/v1/documents/upload"):
content_length = request.headers.get("content-length", "").strip()
if not content_length:
return JSONResponse(
status_code=411,
content={"detail": "Content-Length header is required for document uploads"},
)
try:
content_length_value = int(content_length)
except ValueError:
return JSONResponse(status_code=400, content={"detail": "Invalid Content-Length header"})
if content_length_value <= 0:
return JSONResponse(status_code=400, content={"detail": "Content-Length must be a positive integer"})
if content_length_value > settings.max_upload_request_size_bytes:
return JSONResponse(
status_code=413,
content={
"detail": (
"Upload request exceeds total size limit "
f"({content_length_value} > {settings.max_upload_request_size_bytes} bytes)"
)
},
)
return await call_next(request)
@app.on_event("startup")
def startup_event() -> None:
"""Initializes storage directories and database schema on service startup."""

View File

@@ -2,14 +2,118 @@
import uuid
from datetime import UTC, datetime
import re
from typing import Any
from sqlalchemy import BigInteger, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, validates
from app.core.config import get_settings
from app.db.base import Base
settings = get_settings()
SENSITIVE_KEY_MARKERS = (
"api_key",
"apikey",
"authorization",
"bearer",
"token",
"secret",
"password",
"credential",
"cookie",
)
SENSITIVE_TEXT_PATTERNS = (
re.compile(r"(?i)\bauthorization\b\s*[:=]\s*bearer\s+[a-z0-9._~+/\-]+=*"),
re.compile(r"(?i)\bbearer\s+[a-z0-9._~+/\-]+=*"),
re.compile(r"\b[a-z0-9_-]{8,}\.[a-z0-9_-]{8,}\.[a-z0-9_-]{8,}\b", flags=re.IGNORECASE),
re.compile(r"(?i)\bsk-[a-z0-9]{16,}\b"),
re.compile(r"(?i)\b(api[_-]?key|token|secret|password)\b\s*[:=]\s*['\"]?[^\s,'\";]+['\"]?"),
)
REDACTED_TEXT = "[REDACTED]"
MAX_PAYLOAD_KEYS = 80
MAX_PAYLOAD_LIST_ITEMS = 80
def _truncate(value: str, limit: int) -> str:
"""Truncates long log fields to configured bounds with stable suffix marker."""
normalized = value.strip()
if len(normalized) <= limit:
return normalized
return normalized[: max(0, limit - 3)] + "..."
def _is_sensitive_key(key: str) -> bool:
"""Returns whether a payload key likely contains sensitive credential data."""
normalized = key.strip().lower()
return any(marker in normalized for marker in SENSITIVE_KEY_MARKERS)
def _redact_sensitive_text(value: str) -> str:
"""Redacts token-like segments from log text while retaining non-sensitive context."""
redacted = value
for pattern in SENSITIVE_TEXT_PATTERNS:
redacted = pattern.sub(lambda _: REDACTED_TEXT, redacted)
return redacted
def sanitize_processing_log_payload_value(value: Any, *, parent_key: str | None = None) -> Any:
"""Sanitizes payload structures by redacting sensitive fields and bounding size."""
if parent_key and _is_sensitive_key(parent_key):
return REDACTED_TEXT
if isinstance(value, dict):
sanitized: dict[str, Any] = {}
for index, (raw_key, raw_value) in enumerate(value.items()):
if index >= MAX_PAYLOAD_KEYS:
break
key = str(raw_key)
sanitized[key] = sanitize_processing_log_payload_value(raw_value, parent_key=key)
return sanitized
if isinstance(value, list):
return [
sanitize_processing_log_payload_value(item, parent_key=parent_key)
for item in value[:MAX_PAYLOAD_LIST_ITEMS]
]
if isinstance(value, tuple):
return [
sanitize_processing_log_payload_value(item, parent_key=parent_key)
for item in list(value)[:MAX_PAYLOAD_LIST_ITEMS]
]
if isinstance(value, str):
redacted = _redact_sensitive_text(value)
return _truncate(redacted, settings.processing_log_max_payload_chars)
if isinstance(value, (int, float, bool)) or value is None:
return value
as_text = _truncate(str(value), settings.processing_log_max_payload_chars)
return _redact_sensitive_text(as_text)
def sanitize_processing_log_text(value: str | None) -> str | None:
"""Sanitizes prompt and response fields by redacting credentials and clamping length."""
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
redacted = _redact_sensitive_text(normalized)
return _truncate(redacted, settings.processing_log_max_text_chars)
class ProcessingLogEntry(Base):
"""Stores a timestamped processing event with optional model prompt and response text."""
@@ -31,3 +135,17 @@ class ProcessingLogEntry(Base):
prompt_text: Mapped[str | None] = mapped_column(Text, nullable=True)
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
payload_json: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict)
@validates("prompt_text", "response_text")
def _validate_text_fields(self, key: str, value: str | None) -> str | None:
"""Redacts and bounds free-text log fields before persistence."""
return sanitize_processing_log_text(value)
@validates("payload_json")
def _validate_payload_json(self, key: str, value: dict[str, Any] | None) -> dict[str, Any]:
"""Redacts and bounds structured payload fields before persistence."""
if not isinstance(value, dict):
return {}
return sanitize_processing_log_payload_value(value)

View File

@@ -1,13 +1,16 @@
"""Pydantic schemas for processing pipeline log API payloads."""
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
class ProcessingLogEntryResponse(BaseModel):
"""Represents one persisted processing log event returned by API endpoints."""
"""Represents one persisted processing log event with already-redacted sensitive fields."""
id: int
created_at: datetime
@@ -20,7 +23,26 @@ class ProcessingLogEntryResponse(BaseModel):
model_name: str | None
prompt_text: str | None
response_text: str | None
payload_json: dict
payload_json: dict[str, Any]
@field_validator("prompt_text", "response_text", mode="before")
@classmethod
def _sanitize_text_fields(cls, value: Any) -> str | None:
"""Ensures log text fields are redacted in API responses."""
if value is None:
return None
return sanitize_processing_log_text(str(value))
@field_validator("payload_json", mode="before")
@classmethod
def _sanitize_payload_field(cls, value: Any) -> dict[str, Any]:
"""Ensures payload fields are redacted in API responses."""
if not isinstance(value, dict):
return {}
sanitized = sanitize_processing_log_payload_value(value)
return sanitized if isinstance(sanitized, dict) else {}
class Config:
"""Enables ORM object parsing for SQLAlchemy model instances."""

View File

@@ -5,12 +5,16 @@ import re
from pathlib import Path
from typing import Any
from app.core.config import get_settings
from app.core.config import get_settings, normalize_and_validate_provider_base_url
settings = get_settings()
class AppSettingsValidationError(ValueError):
"""Raised when user-provided settings values fail security or contract validation."""
TASK_OCR_HANDWRITING = "ocr_handwriting"
TASK_SUMMARY_GENERATION = "summary_generation"
TASK_ROUTING_CLASSIFICATION = "routing_classification"
@@ -156,13 +160,13 @@ def _clamp_cards_per_page(value: int) -> int:
def _clamp_processing_log_document_sessions(value: int) -> int:
"""Clamps the number of recent document log sessions kept during cleanup."""
return max(0, min(20, value))
return max(0, min(settings.processing_log_max_document_sessions, value))
def _clamp_processing_log_unbound_entries(value: int) -> int:
"""Clamps retained unbound processing log events kept during cleanup."""
return max(0, min(400, value))
return max(0, min(settings.processing_log_max_unbound_entries, value))
def _clamp_predefined_entries_limit(value: int) -> int:
@@ -242,12 +246,19 @@ def _normalize_provider(
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
api_key = str(api_key_value).strip() if api_key_value is not None else ""
raw_base_url = str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
if not raw_base_url:
raw_base_url = str(defaults["base_url"]).strip()
try:
normalized_base_url = normalize_and_validate_provider_base_url(raw_base_url)
except ValueError as error:
raise AppSettingsValidationError(str(error)) from error
return {
"id": provider_id,
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
"provider_type": provider_type,
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
or defaults["base_url"],
"base_url": normalized_base_url,
"timeout_seconds": _clamp_timeout(
_safe_int(
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),

View File

@@ -300,16 +300,39 @@ def extract_text_content(filename: str, data: bytes, mime_type: str) -> Extracti
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
"""Extracts processable members from zip archives with configurable depth limits."""
"""Extracts processable ZIP members within configured decompression safety budgets."""
members: list[ArchiveMember] = []
if depth > settings.max_zip_depth:
return members
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
member_data = archive.read(info.filename)
members.append(ArchiveMember(name=info.filename, data=member_data))
total_uncompressed_bytes = 0
try:
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
if info.file_size <= 0:
continue
if info.file_size > settings.max_zip_member_uncompressed_bytes:
continue
if total_uncompressed_bytes + info.file_size > settings.max_zip_total_uncompressed_bytes:
continue
compressed_size = max(1, int(info.compress_size))
compression_ratio = float(info.file_size) / float(compressed_size)
if compression_ratio > settings.max_zip_compression_ratio:
continue
with archive.open(info, mode="r") as archive_member:
member_data = archive_member.read(settings.max_zip_member_uncompressed_bytes + 1)
if len(member_data) > settings.max_zip_member_uncompressed_bytes:
continue
if total_uncompressed_bytes + len(member_data) > settings.max_zip_total_uncompressed_bytes:
continue
total_uncompressed_bytes += len(member_data)
members.append(ArchiveMember(name=info.filename, data=member_data))
except zipfile.BadZipFile:
return []
return members

View File

@@ -2,10 +2,10 @@
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse, urlunparse
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from app.core.config import normalize_and_validate_provider_base_url
from app.services.app_settings import read_task_runtime_settings
@@ -36,18 +36,9 @@ class ModelTaskRuntime:
def _normalize_base_url(raw_value: str) -> str:
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
"""Normalizes provider base URL and enforces SSRF protections before outbound calls."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
return "https://api.openai.com/v1"
parsed = urlparse(trimmed)
path = parsed.path or ""
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
return urlunparse(parsed._replace(path=path))
return normalize_and_validate_provider_base_url(raw_value, resolve_dns=True)
def _should_fallback_to_chat(error: Exception) -> bool:
@@ -137,11 +128,16 @@ def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
if provider_type != "openai_compatible":
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
try:
normalized_base_url = _normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1")))
except ValueError as error:
raise ModelTaskError(f"invalid_provider_base_url:{error}") from error
return ModelTaskRuntime(
task_name=task_name,
provider_id=str(provider_payload.get("id", "")),
provider_type=provider_type,
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
base_url=normalized_base_url,
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
model=str(task_payload.get("model", "")).strip(),