Files
ledgerdock/backend/tests/test_security_controls.py

1162 lines
46 KiB
Python

"""Unit coverage for API auth, SSRF validation, and processing-log redaction controls."""
from __future__ import annotations
from datetime import UTC, datetime
import io
import socket
import sys
import uuid
from pathlib import Path
from types import ModuleType, SimpleNamespace
import unittest
from unittest.mock import patch
import zipfile
BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if "pydantic_settings" not in sys.modules:
pydantic_settings_stub = ModuleType("pydantic_settings")
class _BaseSettings:
"""Minimal BaseSettings replacement for dependency-light unit test execution."""
def __init__(self, **kwargs: object) -> None:
for key, value in kwargs.items():
setattr(self, key, value)
def _settings_config_dict(**kwargs: object) -> dict[str, object]:
"""Returns configuration values using dict semantics expected by settings module."""
return kwargs
pydantic_settings_stub.BaseSettings = _BaseSettings
pydantic_settings_stub.SettingsConfigDict = _settings_config_dict
sys.modules["pydantic_settings"] = pydantic_settings_stub
if "fastapi" not in sys.modules:
fastapi_stub = ModuleType("fastapi")
class _APIRouter:
"""Minimal APIRouter stand-in supporting decorator registration."""
def __init__(self, *args: object, **kwargs: object) -> None:
self.args = args
self.kwargs = kwargs
def post(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
"""Returns no-op decorator for POST route declarations."""
def decorator(func): # type: ignore[no-untyped-def]
return func
return decorator
def get(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
"""Returns no-op decorator for GET route declarations."""
def decorator(func): # type: ignore[no-untyped-def]
return func
return decorator
class _Request:
"""Minimal request placeholder for route function import compatibility."""
class _HTTPException(Exception):
"""Minimal HTTPException compatible with route dependency tests."""
def __init__(self, status_code: int, detail: str, headers: dict[str, str] | None = None) -> None:
super().__init__(detail)
self.status_code = status_code
self.detail = detail
self.headers = headers or {}
class _Status:
"""Minimal status namespace for auth unit tests."""
HTTP_401_UNAUTHORIZED = 401
HTTP_403_FORBIDDEN = 403
HTTP_429_TOO_MANY_REQUESTS = 429
HTTP_503_SERVICE_UNAVAILABLE = 503
def _depends(dependency): # type: ignore[no-untyped-def]
"""Returns provided dependency unchanged for unit testing."""
return dependency
fastapi_stub.APIRouter = _APIRouter
fastapi_stub.Depends = _depends
fastapi_stub.HTTPException = _HTTPException
fastapi_stub.Request = _Request
fastapi_stub.status = _Status()
sys.modules["fastapi"] = fastapi_stub
if "fastapi.security" not in sys.modules:
fastapi_security_stub = ModuleType("fastapi.security")
class _HTTPAuthorizationCredentials:
"""Minimal bearer credential object used by auth dependency tests."""
def __init__(self, *, scheme: str, credentials: str) -> None:
self.scheme = scheme
self.credentials = credentials
class _HTTPBearer:
"""Minimal HTTPBearer stand-in for dependency construction."""
def __init__(self, auto_error: bool = True) -> None:
self.auto_error = auto_error
fastapi_security_stub.HTTPAuthorizationCredentials = _HTTPAuthorizationCredentials
fastapi_security_stub.HTTPBearer = _HTTPBearer
sys.modules["fastapi.security"] = fastapi_security_stub
if "magic" not in sys.modules:
magic_stub = ModuleType("magic")
def _from_buffer(_data: bytes, mime: bool = True) -> str:
"""Returns deterministic fallback MIME values for extractor import stubs."""
return "application/octet-stream" if mime else ""
magic_stub.from_buffer = _from_buffer
sys.modules["magic"] = magic_stub
if "docx" not in sys.modules:
docx_stub = ModuleType("docx")
class _DocxDocument:
"""Minimal docx document stub for extractor import compatibility."""
def __init__(self, *_args: object, **_kwargs: object) -> None:
self.paragraphs: list[SimpleNamespace] = []
docx_stub.Document = _DocxDocument
sys.modules["docx"] = docx_stub
if "openpyxl" not in sys.modules:
openpyxl_stub = ModuleType("openpyxl")
class _Workbook:
"""Minimal workbook stub for extractor import compatibility."""
worksheets: list[SimpleNamespace] = []
def _load_workbook(*_args: object, **_kwargs: object) -> _Workbook:
"""Returns deterministic workbook placeholder for extractor import stubs."""
return _Workbook()
openpyxl_stub.load_workbook = _load_workbook
sys.modules["openpyxl"] = openpyxl_stub
if "PIL" not in sys.modules:
pil_stub = ModuleType("PIL")
class _Image:
"""Minimal PIL.Image replacement for extractor and handwriting import stubs."""
class Resampling:
"""Minimal enum-like namespace used by handwriting image resize path."""
LANCZOS = 1
@staticmethod
def open(*_args: object, **_kwargs: object) -> "_Image":
"""Raises for unsupported image operations in dependency-light tests."""
raise RuntimeError("Image.open is not available in stub")
class _ImageOps:
"""Minimal PIL.ImageOps replacement for import compatibility."""
@staticmethod
def exif_transpose(image: object) -> object:
"""Returns original image object unchanged in dependency-light tests."""
return image
pil_stub.Image = _Image
pil_stub.ImageOps = _ImageOps
sys.modules["PIL"] = pil_stub
if "pypdf" not in sys.modules:
pypdf_stub = ModuleType("pypdf")
class _PdfReader:
"""Minimal PdfReader replacement for extractor import compatibility."""
def __init__(self, *_args: object, **_kwargs: object) -> None:
self.pages: list[SimpleNamespace] = []
pypdf_stub.PdfReader = _PdfReader
sys.modules["pypdf"] = pypdf_stub
if "pymupdf" not in sys.modules:
pymupdf_stub = ModuleType("pymupdf")
class _Matrix:
"""Minimal matrix placeholder for extractor import compatibility."""
def __init__(self, *_args: object, **_kwargs: object) -> None:
pass
def _open(*_args: object, **_kwargs: object) -> object:
"""Raises when preview rendering is invoked in dependency-light tests."""
raise RuntimeError("pymupdf is not available in stub")
pymupdf_stub.Matrix = _Matrix
pymupdf_stub.open = _open
sys.modules["pymupdf"] = pymupdf_stub
if "app.services.handwriting" not in sys.modules:
handwriting_stub = ModuleType("app.services.handwriting")
class _HandwritingError(Exception):
"""Minimal base error class for extractor import compatibility."""
class _HandwritingNotConfiguredError(_HandwritingError):
"""Minimal not-configured error class for extractor import compatibility."""
class _HandwritingTimeoutError(_HandwritingError):
"""Minimal timeout error class for extractor import compatibility."""
def _classify_image_text_bytes(*_args: object, **_kwargs: object) -> SimpleNamespace:
"""Returns deterministic image text classification fallback."""
return SimpleNamespace(label="unknown", confidence=0.0, provider="stub", model="stub")
def _transcribe_handwriting_bytes(*_args: object, **_kwargs: object) -> SimpleNamespace:
"""Returns deterministic handwriting transcription fallback."""
return SimpleNamespace(text="", uncertainties=[], provider="stub", model="stub")
handwriting_stub.IMAGE_TEXT_TYPE_NO_TEXT = "no_text"
handwriting_stub.IMAGE_TEXT_TYPE_UNKNOWN = "unknown"
handwriting_stub.IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
handwriting_stub.HandwritingTranscriptionError = _HandwritingError
handwriting_stub.HandwritingTranscriptionNotConfiguredError = _HandwritingNotConfiguredError
handwriting_stub.HandwritingTranscriptionTimeoutError = _HandwritingTimeoutError
handwriting_stub.classify_image_text_bytes = _classify_image_text_bytes
handwriting_stub.transcribe_handwriting_bytes = _transcribe_handwriting_bytes
sys.modules["app.services.handwriting"] = handwriting_stub
if "app.services.handwriting_style" not in sys.modules:
handwriting_style_stub = ModuleType("app.services.handwriting_style")
def _assign_handwriting_style(*_args: object, **_kwargs: object) -> SimpleNamespace:
"""Returns deterministic style assignment payload for worker import compatibility."""
return SimpleNamespace(
style_cluster_id="cluster-1",
matched_existing=False,
similarity=0.0,
vector_distance=0.0,
compared_neighbors=0,
match_min_similarity=0.0,
bootstrap_match_min_similarity=0.0,
)
def _delete_handwriting_style_document(*_args: object, **_kwargs: object) -> None:
"""No-op style document delete stub for worker import compatibility."""
return None
handwriting_style_stub.assign_handwriting_style = _assign_handwriting_style
handwriting_style_stub.delete_handwriting_style_document = _delete_handwriting_style_document
sys.modules["app.services.handwriting_style"] = handwriting_style_stub
if "app.services.routing_pipeline" not in sys.modules:
routing_pipeline_stub = ModuleType("app.services.routing_pipeline")
def _apply_routing_decision(*_args: object, **_kwargs: object) -> None:
"""No-op routing application stub for worker import compatibility."""
return None
def _classify_document_routing(*_args: object, **_kwargs: object) -> dict[str, object]:
"""Returns deterministic routing decision payload for worker import compatibility."""
return {"chosen_path": None, "chosen_tags": []}
def _summarize_document(*_args: object, **_kwargs: object) -> str:
"""Returns deterministic summary text for worker import compatibility."""
return "summary"
def _upsert_semantic_index(*_args: object, **_kwargs: object) -> None:
"""No-op semantic index update stub for worker import compatibility."""
return None
routing_pipeline_stub.apply_routing_decision = _apply_routing_decision
routing_pipeline_stub.classify_document_routing = _classify_document_routing
routing_pipeline_stub.summarize_document = _summarize_document
routing_pipeline_stub.upsert_semantic_index = _upsert_semantic_index
sys.modules["app.services.routing_pipeline"] = routing_pipeline_stub
from fastapi import HTTPException
from app.api.auth import AuthContext, require_admin
from app.api import routes_auth as auth_routes_module
from app.core import config as config_module
from app.models.auth import UserRole
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
from app.schemas.processing_logs import ProcessingLogEntryResponse
from app.services import auth_login_throttle as auth_login_throttle_module
from app.services import extractor as extractor_module
from app.worker import tasks as worker_tasks_module
def _security_settings(
*,
allowlist: list[str] | None = None,
allow_http: bool = False,
allow_private_network: bool = False,
) -> SimpleNamespace:
"""Builds lightweight settings object for provider URL validation tests."""
return SimpleNamespace(
provider_base_url_allowlist=allowlist if allowlist is not None else ["api.openai.com"],
provider_base_url_allow_http=allow_http,
provider_base_url_allow_private_network=allow_private_network,
)
class AuthDependencyTests(unittest.TestCase):
"""Verifies role-based admin authorization behavior."""
def test_require_admin_rejects_user_role(self) -> None:
"""User role cannot access admin-only endpoints."""
auth_context = AuthContext(
user_id=uuid.uuid4(),
username="user",
role=UserRole.USER,
session_id=uuid.uuid4(),
expires_at=datetime.now(UTC),
)
with self.assertRaises(HTTPException) as raised:
require_admin(context=auth_context)
self.assertEqual(raised.exception.status_code, 403)
def test_require_admin_accepts_admin_role(self) -> None:
"""Admin role is accepted for admin-only endpoints."""
auth_context = AuthContext(
user_id=uuid.uuid4(),
username="admin",
role=UserRole.ADMIN,
session_id=uuid.uuid4(),
expires_at=datetime.now(UTC),
)
resolved = require_admin(context=auth_context)
self.assertEqual(resolved.role, UserRole.ADMIN)
class _FakeRedisPipeline:
"""Provides deterministic Redis pipeline behavior for login throttle tests."""
def __init__(self, redis_client: "_FakeRedis") -> None:
self._redis_client = redis_client
self._operations: list[tuple[str, tuple[object, ...]]] = []
def incr(self, key: str, amount: int) -> "_FakeRedisPipeline":
"""Queues one counter increment operation for pipeline execution."""
self._operations.append(("incr", (key, amount)))
return self
def expire(self, key: str, ttl_seconds: int) -> "_FakeRedisPipeline":
"""Queues one key expiration operation for pipeline execution."""
self._operations.append(("expire", (key, ttl_seconds)))
return self
def execute(self) -> list[object]:
"""Executes queued operations in order and returns Redis-like result values."""
results: list[object] = []
for operation, arguments in self._operations:
if operation == "incr":
key, amount = arguments
previous = int(self._redis_client.values.get(str(key), 0))
updated = previous + int(amount)
self._redis_client.values[str(key)] = updated
results.append(updated)
elif operation == "expire":
key, ttl_seconds = arguments
self._redis_client.ttl_seconds[str(key)] = int(ttl_seconds)
results.append(True)
return results
class _FakeRedis:
"""In-memory Redis replacement with TTL behavior needed by throttle tests."""
def __init__(self) -> None:
self.values: dict[str, object] = {}
self.ttl_seconds: dict[str, int] = {}
def pipeline(self, transaction: bool = True) -> _FakeRedisPipeline:
"""Creates a fake transaction pipeline for grouped increment operations."""
_ = transaction
return _FakeRedisPipeline(self)
def set(self, key: str, value: str, ex: int | None = None) -> bool:
"""Stores key values and optional TTL metadata used by lockout keys."""
self.values[key] = value
if ex is not None:
self.ttl_seconds[key] = int(ex)
return True
def ttl(self, key: str) -> int:
"""Returns TTL for existing keys or Redis-compatible missing-key indicator."""
return int(self.ttl_seconds.get(key, -2))
def delete(self, *keys: str) -> int:
"""Deletes keys and returns number of removed entries."""
removed_count = 0
for key in keys:
if key in self.values:
self.values.pop(key, None)
removed_count += 1
self.ttl_seconds.pop(key, None)
return removed_count
def _login_throttle_settings(
*,
failure_limit: int = 2,
failure_window_seconds: int = 60,
lockout_base_seconds: int = 10,
lockout_max_seconds: int = 40,
) -> SimpleNamespace:
"""Builds deterministic login-throttle settings for service-level unit coverage."""
return SimpleNamespace(
auth_login_failure_limit=failure_limit,
auth_login_failure_window_seconds=failure_window_seconds,
auth_login_lockout_base_seconds=lockout_base_seconds,
auth_login_lockout_max_seconds=lockout_max_seconds,
)
class AuthLoginThrottleServiceTests(unittest.TestCase):
"""Verifies login throttle lockout progression, cap behavior, and clear semantics."""
def test_failed_attempts_trigger_lockout_after_limit(self) -> None:
"""Failed attempts beyond configured limit activate login lockouts."""
fake_redis = _FakeRedis()
with (
patch.object(
auth_login_throttle_module,
"get_settings",
return_value=_login_throttle_settings(failure_limit=2, lockout_base_seconds=12),
),
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
):
self.assertEqual(
auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="203.0.113.10",
),
0,
)
self.assertEqual(
auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="203.0.113.10",
),
0,
)
lockout_seconds = auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="203.0.113.10",
)
status = auth_login_throttle_module.check_login_throttle(
username="admin",
ip_address="203.0.113.10",
)
self.assertEqual(lockout_seconds, 12)
self.assertTrue(status.is_throttled)
self.assertEqual(status.retry_after_seconds, 12)
def test_lockout_duration_escalates_and_respects_max_cap(self) -> None:
"""Repeated failures after threshold double lockout duration up to configured maximum."""
fake_redis = _FakeRedis()
with (
patch.object(
auth_login_throttle_module,
"get_settings",
return_value=_login_throttle_settings(
failure_limit=1,
lockout_base_seconds=10,
lockout_max_seconds=25,
),
),
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
):
first_lockout = auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="198.51.100.15",
)
second_lockout = auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="198.51.100.15",
)
third_lockout = auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="198.51.100.15",
)
fourth_lockout = auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="198.51.100.15",
)
self.assertEqual(first_lockout, 0)
self.assertEqual(second_lockout, 10)
self.assertEqual(third_lockout, 20)
self.assertEqual(fourth_lockout, 25)
def test_clear_login_throttle_removes_active_lockout_state(self) -> None:
"""Successful login clears active lockout keys for username and IP subjects."""
fake_redis = _FakeRedis()
with (
patch.object(
auth_login_throttle_module,
"get_settings",
return_value=_login_throttle_settings(
failure_limit=1,
lockout_base_seconds=15,
lockout_max_seconds=30,
),
),
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
):
auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="192.0.2.20",
)
auth_login_throttle_module.record_failed_login_attempt(
username="admin",
ip_address="192.0.2.20",
)
throttled_before_clear = auth_login_throttle_module.check_login_throttle(
username="admin",
ip_address="192.0.2.20",
)
auth_login_throttle_module.clear_login_throttle(
username="admin",
ip_address="192.0.2.20",
)
throttled_after_clear = auth_login_throttle_module.check_login_throttle(
username="admin",
ip_address="192.0.2.20",
)
self.assertTrue(throttled_before_clear.is_throttled)
self.assertFalse(throttled_after_clear.is_throttled)
self.assertEqual(throttled_after_clear.retry_after_seconds, 0)
def test_backend_errors_raise_runtime_error(self) -> None:
"""Redis backend failures are surfaced as RuntimeError for caller fail-closed handling."""
class _BrokenRedis:
"""Raises RedisError for all Redis interactions used by login throttle service."""
def ttl(self, _key: str) -> int:
raise auth_login_throttle_module.RedisError("redis unavailable")
with patch.object(auth_login_throttle_module, "get_redis", return_value=_BrokenRedis()):
with self.assertRaises(RuntimeError):
auth_login_throttle_module.check_login_throttle(
username="admin",
ip_address="203.0.113.88",
)
class AuthLoginRouteThrottleTests(unittest.TestCase):
"""Verifies `/auth/login` route throttle responses and success-flow state clearing."""
class _SessionStub:
"""Tracks commit calls for route-level login tests without database dependencies."""
def __init__(self) -> None:
self.commit_count = 0
def commit(self) -> None:
"""Records one commit invocation."""
self.commit_count += 1
@staticmethod
def _request_stub(ip_address: str = "203.0.113.2", user_agent: str = "unit-test") -> SimpleNamespace:
"""Builds request-like object containing client host and user-agent header fields."""
return SimpleNamespace(
client=SimpleNamespace(host=ip_address),
headers={"user-agent": user_agent},
)
def test_login_rejects_when_precheck_reports_active_throttle(self) -> None:
"""Pre-auth throttle checks return a stable 429 response without credential lookup."""
payload = auth_routes_module.AuthLoginRequest(username="admin", password="bad-password")
session = self._SessionStub()
throttled = auth_login_throttle_module.LoginThrottleStatus(
is_throttled=True,
retry_after_seconds=21,
)
with (
patch.object(auth_routes_module, "check_login_throttle", return_value=throttled),
patch.object(auth_routes_module, "authenticate_user") as authenticate_mock,
):
with self.assertRaises(HTTPException) as raised:
auth_routes_module.login(
payload=payload,
request=self._request_stub(),
session=session,
)
self.assertEqual(raised.exception.status_code, 429)
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_THROTTLED_DETAIL)
self.assertEqual(raised.exception.headers.get("Retry-After"), "21")
authenticate_mock.assert_not_called()
self.assertEqual(session.commit_count, 0)
def test_login_returns_throttle_response_when_failure_crosses_limit(self) -> None:
"""Failed credentials return stable 429 response once lockout threshold is crossed."""
payload = auth_routes_module.AuthLoginRequest(username="admin", password="bad-password")
session = self._SessionStub()
with (
patch.object(
auth_routes_module,
"check_login_throttle",
return_value=auth_login_throttle_module.LoginThrottleStatus(
is_throttled=False,
retry_after_seconds=0,
),
),
patch.object(auth_routes_module, "authenticate_user", return_value=None),
patch.object(auth_routes_module, "record_failed_login_attempt", return_value=30),
):
with self.assertRaises(HTTPException) as raised:
auth_routes_module.login(
payload=payload,
request=self._request_stub(),
session=session,
)
self.assertEqual(raised.exception.status_code, 429)
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_THROTTLED_DETAIL)
self.assertEqual(raised.exception.headers.get("Retry-After"), "30")
self.assertEqual(session.commit_count, 0)
def test_login_clears_throttle_state_after_successful_authentication(self) -> None:
"""Successful login clears throttle state and commits issued session token."""
payload = auth_routes_module.AuthLoginRequest(username="admin", password="correct-password")
session = self._SessionStub()
fake_user = SimpleNamespace(
id=uuid.uuid4(),
username="admin",
role=UserRole.ADMIN,
)
fake_session = SimpleNamespace(
token="session-token",
expires_at=datetime.now(UTC),
)
with (
patch.object(
auth_routes_module,
"check_login_throttle",
return_value=auth_login_throttle_module.LoginThrottleStatus(
is_throttled=False,
retry_after_seconds=0,
),
),
patch.object(auth_routes_module, "authenticate_user", return_value=fake_user),
patch.object(auth_routes_module, "clear_login_throttle") as clear_mock,
patch.object(auth_routes_module, "issue_user_session", return_value=fake_session),
):
response = auth_routes_module.login(
payload=payload,
request=self._request_stub(),
session=session,
)
self.assertEqual(response.access_token, "session-token")
self.assertEqual(response.user.username, "admin")
clear_mock.assert_called_once()
self.assertEqual(session.commit_count, 1)
def test_login_returns_503_when_throttle_backend_is_unavailable(self) -> None:
"""Throttle backend errors fail closed with a deterministic 503 login response."""
payload = auth_routes_module.AuthLoginRequest(username="admin", password="password")
session = self._SessionStub()
with patch.object(auth_routes_module, "check_login_throttle", side_effect=RuntimeError("redis down")):
with self.assertRaises(HTTPException) as raised:
auth_routes_module.login(
payload=payload,
request=self._request_stub(),
session=session,
)
self.assertEqual(raised.exception.status_code, 503)
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_RATE_LIMITER_UNAVAILABLE_DETAIL)
self.assertEqual(session.commit_count, 0)
class ProviderBaseUrlValidationTests(unittest.TestCase):
"""Verifies allowlist, scheme, and private-network SSRF protections."""
def setUp(self) -> None:
"""Clears URL validation cache to keep tests independent."""
config_module._normalize_and_validate_provider_base_url_cached.cache_clear()
def test_validation_accepts_allowlisted_https_url(self) -> None:
"""Allowlisted HTTPS URLs are normalized with /v1 suffix."""
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])):
normalized = config_module.normalize_and_validate_provider_base_url("https://api.openai.com")
self.assertEqual(normalized, "https://api.openai.com/v1")
def test_validation_rejects_non_allowlisted_host(self) -> None:
"""Hosts outside configured allowlist are rejected."""
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])):
with self.assertRaises(ValueError):
config_module.normalize_and_validate_provider_base_url("https://example.org/v1")
def test_validation_rejects_private_ip_literal(self) -> None:
"""Private and loopback IP literals are blocked."""
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=[])):
with self.assertRaises(ValueError):
config_module.normalize_and_validate_provider_base_url("https://127.0.0.1/v1")
def test_validation_rejects_private_ip_after_dns_resolution(self) -> None:
"""DNS rebind protection blocks public hostnames resolving to private addresses."""
mocked_dns_response = [
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", ("127.0.0.1", 443)),
]
with (
patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])),
patch.object(config_module.socket, "getaddrinfo", return_value=mocked_dns_response),
):
with self.assertRaises(ValueError):
config_module.normalize_and_validate_provider_base_url(
"https://api.openai.com/v1",
resolve_dns=True,
)
def test_resolve_dns_validation_revalidates_each_call(self) -> None:
"""DNS-resolved validation is not cached and re-checks host resolution each call."""
mocked_dns_response = [
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", ("8.8.8.8", 443)),
]
with (
patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])),
patch.object(config_module.socket, "getaddrinfo", return_value=mocked_dns_response) as getaddrinfo_mock,
):
first = config_module.normalize_and_validate_provider_base_url(
"https://api.openai.com/v1",
resolve_dns=True,
)
second = config_module.normalize_and_validate_provider_base_url(
"https://api.openai.com/v1",
resolve_dns=True,
)
self.assertEqual(first, "https://api.openai.com/v1")
self.assertEqual(second, "https://api.openai.com/v1")
self.assertEqual(getaddrinfo_mock.call_count, 2)
class RedisQueueSecurityTests(unittest.TestCase):
"""Verifies Redis URL security policy behavior for compatibility and strict environments."""
def test_auto_mode_allows_insecure_redis_url_in_development(self) -> None:
"""Development mode stays backward-compatible with local unauthenticated redis URLs."""
normalized = config_module.validate_redis_url_security(
"redis://redis:6379/0",
app_env="development",
security_mode="auto",
tls_mode="auto",
)
self.assertEqual(normalized, "redis://redis:6379/0")
def test_auto_mode_rejects_missing_auth_in_production(self) -> None:
"""Production auto mode fails closed when Redis URL omits authentication."""
with self.assertRaises(ValueError):
config_module.validate_redis_url_security(
"rediss://redis:6379/0",
app_env="production",
security_mode="auto",
tls_mode="auto",
)
def test_auto_mode_rejects_plaintext_redis_in_production(self) -> None:
"""Production auto mode requires TLS transport for Redis URLs."""
with self.assertRaises(ValueError):
config_module.validate_redis_url_security(
"redis://:password@redis:6379/0",
app_env="production",
security_mode="auto",
tls_mode="auto",
)
def test_strict_mode_enforces_auth_and_tls_outside_production(self) -> None:
"""Strict mode enforces production-grade Redis controls in all environments."""
with self.assertRaises(ValueError):
config_module.validate_redis_url_security(
"redis://redis:6379/0",
app_env="development",
security_mode="strict",
tls_mode="auto",
)
normalized = config_module.validate_redis_url_security(
"rediss://:password@redis:6379/0",
app_env="development",
security_mode="strict",
tls_mode="auto",
)
self.assertEqual(normalized, "rediss://:password@redis:6379/0")
def test_compat_mode_allows_insecure_redis_in_production_for_safe_migration(self) -> None:
"""Compatibility mode keeps legacy production Redis URLs usable during migration windows."""
normalized = config_module.validate_redis_url_security(
"redis://redis:6379/0",
app_env="production",
security_mode="compat",
tls_mode="allow_insecure",
)
self.assertEqual(normalized, "redis://redis:6379/0")
class PreviewMimeSafetyTests(unittest.TestCase):
"""Verifies inline preview MIME safety checks for uploaded document responses."""
def test_preview_blocks_script_capable_html_and_svg_types(self) -> None:
"""HTML and SVG MIME types are rejected for inline preview responses."""
self.assertFalse(config_module.is_inline_preview_mime_type_safe("text/html"))
self.assertFalse(config_module.is_inline_preview_mime_type_safe("image/svg+xml"))
def test_preview_allows_pdf_and_safe_image_types(self) -> None:
"""PDF and raster image MIME types stay eligible for inline preview responses."""
self.assertTrue(config_module.is_inline_preview_mime_type_safe("application/pdf"))
self.assertTrue(config_module.is_inline_preview_mime_type_safe("image/png"))
def _build_zip_bytes(entries: dict[str, bytes]) -> bytes:
"""Builds in-memory ZIP bytes for archive extraction guardrail tests."""
output = io.BytesIO()
with zipfile.ZipFile(output, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
for filename, payload in entries.items():
archive.writestr(filename, payload)
return output.getvalue()
class ArchiveExtractionGuardrailTests(unittest.TestCase):
"""Verifies depth-aware archive extraction and per-call member cap enforcement."""
def test_extract_archive_members_rejects_depth_at_configured_limit(self) -> None:
"""Archive member extraction is disabled at or beyond configured depth ceiling."""
archive_bytes = _build_zip_bytes({"sample.txt": b"sample"})
patched_settings = SimpleNamespace(
max_zip_depth=2,
max_zip_members=250,
max_zip_member_uncompressed_bytes=25 * 1024 * 1024,
max_zip_total_uncompressed_bytes=150 * 1024 * 1024,
max_zip_compression_ratio=120.0,
)
with patch.object(extractor_module, "settings", patched_settings):
members = extractor_module.extract_archive_members(archive_bytes, depth=2)
self.assertEqual(members, [])
def test_extract_archive_members_respects_member_cap_argument(self) -> None:
"""Archive extraction truncates results when caller-provided member cap is lower than archive size."""
archive_bytes = _build_zip_bytes(
{
"one.txt": b"1",
"two.txt": b"2",
"three.txt": b"3",
}
)
patched_settings = SimpleNamespace(
max_zip_depth=3,
max_zip_members=250,
max_zip_member_uncompressed_bytes=25 * 1024 * 1024,
max_zip_total_uncompressed_bytes=150 * 1024 * 1024,
max_zip_compression_ratio=120.0,
)
with patch.object(extractor_module, "settings", patched_settings):
members = extractor_module.extract_archive_members(archive_bytes, depth=0, max_members=1)
self.assertEqual(len(members), 1)
class ArchiveLineagePropagationTests(unittest.TestCase):
"""Verifies archive lineage metadata propagation helpers used by worker descendant queueing."""
def test_create_archive_member_document_persists_lineage_metadata(self) -> None:
"""Child archive documents include root id and incremented depth metadata."""
parent_id = uuid.uuid4()
parent = SimpleNamespace(
id=parent_id,
source_relative_path="uploads/root.zip",
logical_path="Inbox",
tags=["finance"],
owner_user_id=uuid.uuid4(),
)
with (
patch.object(worker_tasks_module, "store_bytes", return_value="stored/path/child.zip"),
patch.object(worker_tasks_module, "compute_sha256", return_value="deadbeef"),
):
child = worker_tasks_module._create_archive_member_document(
parent=parent,
member_name="nested/child.zip",
member_data=b"zip-bytes",
mime_type="application/zip",
archive_root_document_id=parent_id,
archive_depth=1,
)
self.assertEqual(child.parent_document_id, parent_id)
self.assertEqual(child.metadata_json.get(worker_tasks_module.ARCHIVE_ROOT_ID_METADATA_KEY), str(parent_id))
self.assertEqual(child.metadata_json.get(worker_tasks_module.ARCHIVE_DEPTH_METADATA_KEY), 1)
self.assertTrue(child.is_archive_member)
self.assertEqual(child.owner_user_id, parent.owner_user_id)
def test_resolve_archive_lineage_prefers_existing_metadata(self) -> None:
"""Existing archive lineage metadata is reused without traversing parent relationships."""
root_id = uuid.uuid4()
document = SimpleNamespace(
id=uuid.uuid4(),
metadata_json={
worker_tasks_module.ARCHIVE_ROOT_ID_METADATA_KEY: str(root_id),
worker_tasks_module.ARCHIVE_DEPTH_METADATA_KEY: 3,
},
is_archive_member=True,
parent_document_id=uuid.uuid4(),
)
class _SessionShouldNotBeUsed:
"""Fails test if lineage helper performs unnecessary parent traversals."""
def execute(self, _statement: object) -> object:
raise AssertionError("session.execute should not be called when metadata is present")
resolved_root, resolved_depth = worker_tasks_module._resolve_archive_lineage(
session=_SessionShouldNotBeUsed(),
document=document,
)
self.assertEqual(resolved_root, root_id)
self.assertEqual(resolved_depth, 3)
def test_resolve_archive_lineage_walks_parent_chain_when_metadata_missing(self) -> None:
"""Lineage fallback traverses parent references to recover root id and depth."""
root_id = uuid.uuid4()
parent_id = uuid.uuid4()
root_document = SimpleNamespace(id=root_id, parent_document_id=None)
parent_document = SimpleNamespace(id=parent_id, parent_document_id=root_id)
document = SimpleNamespace(
id=uuid.uuid4(),
metadata_json={},
is_archive_member=True,
parent_document_id=parent_id,
)
class _ScalarResult:
"""Wraps scalar ORM results for deterministic worker helper tests."""
def __init__(self, value: object) -> None:
self._value = value
def scalar_one_or_none(self) -> object:
return self._value
class _SequenceSession:
"""Returns predetermined parent rows in traversal order."""
def __init__(self, values: list[object]) -> None:
self._values = values
def execute(self, _statement: object) -> _ScalarResult:
next_value = self._values.pop(0) if self._values else None
return _ScalarResult(next_value)
resolved_root, resolved_depth = worker_tasks_module._resolve_archive_lineage(
session=_SequenceSession([parent_document, root_document]),
document=document,
)
self.assertEqual(resolved_root, root_id)
self.assertEqual(resolved_depth, 2)
class ProcessingLogRedactionTests(unittest.TestCase):
"""Verifies sensitive processing-log values are redacted for persistence and responses."""
def test_payload_redacts_sensitive_keys(self) -> None:
"""Sensitive payload keys are replaced with redaction marker."""
sanitized = sanitize_processing_log_payload_value(
{
"api_key": "secret-value",
"nested": {
"authorization": "Bearer sample-token",
},
}
)
self.assertEqual(sanitized["api_key"], "[REDACTED]")
self.assertEqual(sanitized["nested"]["authorization"], "[REDACTED]")
def test_text_redaction_removes_bearer_and_jwt_values(self) -> None:
"""Bearer and JWT token substrings are fully removed from log text."""
bearer_token = "super-secret-token-123"
jwt_token = (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4ifQ."
"signaturevalue123456789"
)
sanitized = sanitize_processing_log_text(
f"Authorization: Bearer {bearer_token}\nraw_jwt={jwt_token}"
)
self.assertIsNotNone(sanitized)
sanitized_text = sanitized or ""
self.assertIn("[REDACTED]", sanitized_text)
self.assertNotIn(bearer_token, sanitized_text)
self.assertNotIn(jwt_token, sanitized_text)
def test_text_redaction_removes_json_formatted_secret_values(self) -> None:
"""JSON-formatted quoted secrets are fully removed from redacted log text."""
api_key_secret = "json-api-key-secret"
token_secret = "json-token-secret"
authorization_secret = "json-auth-secret"
bearer_secret = "json-bearer-secret"
json_text = (
"{"
f"\"api_key\":\"{api_key_secret}\","
f"\"token\":\"{token_secret}\","
f"\"authorization\":\"Bearer {authorization_secret}\","
f"\"bearer\":\"{bearer_secret}\""
"}"
)
sanitized = sanitize_processing_log_text(json_text)
self.assertIsNotNone(sanitized)
sanitized_text = sanitized or ""
self.assertIn("[REDACTED]", sanitized_text)
self.assertNotIn(api_key_secret, sanitized_text)
self.assertNotIn(token_secret, sanitized_text)
self.assertNotIn(authorization_secret, sanitized_text)
self.assertNotIn(bearer_secret, sanitized_text)
def test_response_schema_applies_redaction_to_existing_entries(self) -> None:
"""API schema validators redact sensitive fields from legacy stored rows."""
bearer_token = "abc123token"
jwt_token = (
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4ifQ."
"signaturevalue123456789"
)
response = ProcessingLogEntryResponse.model_validate(
{
"id": 1,
"created_at": datetime.now(UTC),
"level": "info",
"stage": "summary",
"event": "response",
"document_id": None,
"document_filename": "sample.txt",
"provider_id": "provider",
"model_name": "model",
"prompt_text": f"Authorization: Bearer {bearer_token}",
"response_text": f"token={jwt_token}",
"payload_json": {"password": "secret", "trace_id": "trace-1"},
}
)
self.assertEqual(response.payload_json["password"], "[REDACTED]")
self.assertIn("[REDACTED]", response.prompt_text or "")
self.assertIn("[REDACTED]", response.response_text or "")
self.assertNotIn(bearer_token, response.prompt_text or "")
self.assertNotIn(jwt_token, response.response_text or "")
def test_response_schema_redacts_json_formatted_secret_values(self) -> None:
"""Response schema redacts quoted JSON secret forms from legacy text fields."""
api_key_secret = "legacy-json-api-key"
token_secret = "legacy-json-token"
authorization_secret = "legacy-json-auth"
bearer_secret = "legacy-json-bearer"
prompt_text = (
"{"
f"\"api_key\":\"{api_key_secret}\","
f"\"token\":\"{token_secret}\""
"}"
)
response_text = (
"{"
f"\"authorization\":\"Bearer {authorization_secret}\","
f"\"bearer\":\"{bearer_secret}\""
"}"
)
response = ProcessingLogEntryResponse.model_validate(
{
"id": 2,
"created_at": datetime.now(UTC),
"level": "info",
"stage": "summary",
"event": "response",
"document_id": None,
"document_filename": "sample-json.txt",
"provider_id": "provider",
"model_name": "model",
"prompt_text": prompt_text,
"response_text": response_text,
"payload_json": {"trace_id": "trace-2"},
}
)
self.assertIn("[REDACTED]", response.prompt_text or "")
self.assertIn("[REDACTED]", response.response_text or "")
self.assertNotIn(api_key_secret, response.prompt_text or "")
self.assertNotIn(token_secret, response.prompt_text or "")
self.assertNotIn(authorization_secret, response.response_text or "")
self.assertNotIn(bearer_secret, response.response_text or "")
if __name__ == "__main__":
unittest.main()