Harden auth, redaction, upload size checks, and compose token requirements

This commit is contained in:
2026-02-21 13:48:55 -03:00
parent 5792586a90
commit 3cbad053cc
21 changed files with 1168 additions and 85 deletions

View File

@@ -5,12 +5,16 @@ import re
from pathlib import Path
from typing import Any
from app.core.config import get_settings
from app.core.config import get_settings, normalize_and_validate_provider_base_url
settings = get_settings()
class AppSettingsValidationError(ValueError):
"""Raised when user-provided settings values fail security or contract validation."""
TASK_OCR_HANDWRITING = "ocr_handwriting"
TASK_SUMMARY_GENERATION = "summary_generation"
TASK_ROUTING_CLASSIFICATION = "routing_classification"
@@ -156,13 +160,13 @@ def _clamp_cards_per_page(value: int) -> int:
def _clamp_processing_log_document_sessions(value: int) -> int:
"""Clamps the number of recent document log sessions kept during cleanup."""
return max(0, min(20, value))
return max(0, min(settings.processing_log_max_document_sessions, value))
def _clamp_processing_log_unbound_entries(value: int) -> int:
"""Clamps retained unbound processing log events kept during cleanup."""
return max(0, min(400, value))
return max(0, min(settings.processing_log_max_unbound_entries, value))
def _clamp_predefined_entries_limit(value: int) -> int:
@@ -242,12 +246,19 @@ def _normalize_provider(
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
api_key = str(api_key_value).strip() if api_key_value is not None else ""
raw_base_url = str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
if not raw_base_url:
raw_base_url = str(defaults["base_url"]).strip()
try:
normalized_base_url = normalize_and_validate_provider_base_url(raw_base_url)
except ValueError as error:
raise AppSettingsValidationError(str(error)) from error
return {
"id": provider_id,
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
"provider_type": provider_type,
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
or defaults["base_url"],
"base_url": normalized_base_url,
"timeout_seconds": _clamp_timeout(
_safe_int(
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),

View File

@@ -300,16 +300,39 @@ def extract_text_content(filename: str, data: bytes, mime_type: str) -> Extracti
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
"""Extracts processable members from zip archives with configurable depth limits."""
"""Extracts processable ZIP members within configured decompression safety budgets."""
members: list[ArchiveMember] = []
if depth > settings.max_zip_depth:
return members
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
member_data = archive.read(info.filename)
members.append(ArchiveMember(name=info.filename, data=member_data))
total_uncompressed_bytes = 0
try:
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
if info.file_size <= 0:
continue
if info.file_size > settings.max_zip_member_uncompressed_bytes:
continue
if total_uncompressed_bytes + info.file_size > settings.max_zip_total_uncompressed_bytes:
continue
compressed_size = max(1, int(info.compress_size))
compression_ratio = float(info.file_size) / float(compressed_size)
if compression_ratio > settings.max_zip_compression_ratio:
continue
with archive.open(info, mode="r") as archive_member:
member_data = archive_member.read(settings.max_zip_member_uncompressed_bytes + 1)
if len(member_data) > settings.max_zip_member_uncompressed_bytes:
continue
if total_uncompressed_bytes + len(member_data) > settings.max_zip_total_uncompressed_bytes:
continue
total_uncompressed_bytes += len(member_data)
members.append(ArchiveMember(name=info.filename, data=member_data))
except zipfile.BadZipFile:
return []
return members

View File

@@ -2,10 +2,10 @@
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse, urlunparse
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from app.core.config import normalize_and_validate_provider_base_url
from app.services.app_settings import read_task_runtime_settings
@@ -36,18 +36,9 @@ class ModelTaskRuntime:
def _normalize_base_url(raw_value: str) -> str:
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
"""Normalizes provider base URL and enforces SSRF protections before outbound calls."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
return "https://api.openai.com/v1"
parsed = urlparse(trimmed)
path = parsed.path or ""
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
return urlunparse(parsed._replace(path=path))
return normalize_and_validate_provider_base_url(raw_value, resolve_dns=True)
def _should_fallback_to_chat(error: Exception) -> bool:
@@ -137,11 +128,16 @@ def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
if provider_type != "openai_compatible":
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
try:
normalized_base_url = _normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1")))
except ValueError as error:
raise ModelTaskError(f"invalid_provider_base_url:{error}") from error
return ModelTaskRuntime(
task_name=task_name,
provider_id=str(provider_payload.get("id", "")),
provider_type=provider_type,
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
base_url=normalized_base_url,
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
model=str(task_payload.get("model", "")).strip(),