Harden auth, redaction, upload size checks, and compose token requirements
This commit is contained in:
@@ -5,12 +5,16 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.core.config import get_settings, normalize_and_validate_provider_base_url
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class AppSettingsValidationError(ValueError):
|
||||
"""Raised when user-provided settings values fail security or contract validation."""
|
||||
|
||||
|
||||
TASK_OCR_HANDWRITING = "ocr_handwriting"
|
||||
TASK_SUMMARY_GENERATION = "summary_generation"
|
||||
TASK_ROUTING_CLASSIFICATION = "routing_classification"
|
||||
@@ -156,13 +160,13 @@ def _clamp_cards_per_page(value: int) -> int:
|
||||
def _clamp_processing_log_document_sessions(value: int) -> int:
|
||||
"""Clamps the number of recent document log sessions kept during cleanup."""
|
||||
|
||||
return max(0, min(20, value))
|
||||
return max(0, min(settings.processing_log_max_document_sessions, value))
|
||||
|
||||
|
||||
def _clamp_processing_log_unbound_entries(value: int) -> int:
|
||||
"""Clamps retained unbound processing log events kept during cleanup."""
|
||||
|
||||
return max(0, min(400, value))
|
||||
return max(0, min(settings.processing_log_max_unbound_entries, value))
|
||||
|
||||
|
||||
def _clamp_predefined_entries_limit(value: int) -> int:
|
||||
@@ -242,12 +246,19 @@ def _normalize_provider(
|
||||
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
|
||||
api_key = str(api_key_value).strip() if api_key_value is not None else ""
|
||||
|
||||
raw_base_url = str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
|
||||
if not raw_base_url:
|
||||
raw_base_url = str(defaults["base_url"]).strip()
|
||||
try:
|
||||
normalized_base_url = normalize_and_validate_provider_base_url(raw_base_url)
|
||||
except ValueError as error:
|
||||
raise AppSettingsValidationError(str(error)) from error
|
||||
|
||||
return {
|
||||
"id": provider_id,
|
||||
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
|
||||
"provider_type": provider_type,
|
||||
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
|
||||
or defaults["base_url"],
|
||||
"base_url": normalized_base_url,
|
||||
"timeout_seconds": _clamp_timeout(
|
||||
_safe_int(
|
||||
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),
|
||||
|
||||
@@ -300,16 +300,39 @@ def extract_text_content(filename: str, data: bytes, mime_type: str) -> Extracti
|
||||
|
||||
|
||||
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
|
||||
"""Extracts processable members from zip archives with configurable depth limits."""
|
||||
"""Extracts processable ZIP members within configured decompression safety budgets."""
|
||||
|
||||
members: list[ArchiveMember] = []
|
||||
if depth > settings.max_zip_depth:
|
||||
return members
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(data)) as archive:
|
||||
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
|
||||
for info in infos:
|
||||
member_data = archive.read(info.filename)
|
||||
members.append(ArchiveMember(name=info.filename, data=member_data))
|
||||
total_uncompressed_bytes = 0
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(data)) as archive:
|
||||
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
|
||||
for info in infos:
|
||||
if info.file_size <= 0:
|
||||
continue
|
||||
if info.file_size > settings.max_zip_member_uncompressed_bytes:
|
||||
continue
|
||||
if total_uncompressed_bytes + info.file_size > settings.max_zip_total_uncompressed_bytes:
|
||||
continue
|
||||
|
||||
compressed_size = max(1, int(info.compress_size))
|
||||
compression_ratio = float(info.file_size) / float(compressed_size)
|
||||
if compression_ratio > settings.max_zip_compression_ratio:
|
||||
continue
|
||||
|
||||
with archive.open(info, mode="r") as archive_member:
|
||||
member_data = archive_member.read(settings.max_zip_member_uncompressed_bytes + 1)
|
||||
if len(member_data) > settings.max_zip_member_uncompressed_bytes:
|
||||
continue
|
||||
if total_uncompressed_bytes + len(member_data) > settings.max_zip_total_uncompressed_bytes:
|
||||
continue
|
||||
|
||||
total_uncompressed_bytes += len(member_data)
|
||||
members.append(ArchiveMember(name=info.filename, data=member_data))
|
||||
except zipfile.BadZipFile:
|
||||
return []
|
||||
|
||||
return members
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
||||
|
||||
from app.core.config import normalize_and_validate_provider_base_url
|
||||
from app.services.app_settings import read_task_runtime_settings
|
||||
|
||||
|
||||
@@ -36,18 +36,9 @@ class ModelTaskRuntime:
|
||||
|
||||
|
||||
def _normalize_base_url(raw_value: str) -> str:
|
||||
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
|
||||
"""Normalizes provider base URL and enforces SSRF protections before outbound calls."""
|
||||
|
||||
trimmed = raw_value.strip().rstrip("/")
|
||||
if not trimmed:
|
||||
return "https://api.openai.com/v1"
|
||||
|
||||
parsed = urlparse(trimmed)
|
||||
path = parsed.path or ""
|
||||
if not path.endswith("/v1"):
|
||||
path = f"{path}/v1" if path else "/v1"
|
||||
|
||||
return urlunparse(parsed._replace(path=path))
|
||||
return normalize_and_validate_provider_base_url(raw_value, resolve_dns=True)
|
||||
|
||||
|
||||
def _should_fallback_to_chat(error: Exception) -> bool:
|
||||
@@ -137,11 +128,16 @@ def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
|
||||
if provider_type != "openai_compatible":
|
||||
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
|
||||
|
||||
try:
|
||||
normalized_base_url = _normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1")))
|
||||
except ValueError as error:
|
||||
raise ModelTaskError(f"invalid_provider_base_url:{error}") from error
|
||||
|
||||
return ModelTaskRuntime(
|
||||
task_name=task_name,
|
||||
provider_id=str(provider_payload.get("id", "")),
|
||||
provider_type=provider_type,
|
||||
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
|
||||
base_url=normalized_base_url,
|
||||
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
|
||||
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
|
||||
model=str(task_payload.get("model", "")).strip(),
|
||||
|
||||
Reference in New Issue
Block a user