1520 lines
60 KiB
Python
1520 lines
60 KiB
Python
"""Unit coverage for API auth, SSRF validation, and processing-log redaction controls."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import UTC, datetime
|
|
import io
|
|
import socket
|
|
import sys
|
|
import uuid
|
|
from pathlib import Path
|
|
from types import ModuleType, SimpleNamespace
|
|
import unittest
|
|
from unittest.mock import patch
|
|
import zipfile
|
|
|
|
|
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(BACKEND_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(BACKEND_ROOT))
|
|
|
|
if "pydantic_settings" not in sys.modules:
|
|
pydantic_settings_stub = ModuleType("pydantic_settings")
|
|
|
|
class _BaseSettings:
|
|
"""Minimal BaseSettings replacement for dependency-light unit test execution."""
|
|
|
|
def __init__(self, **kwargs: object) -> None:
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
def _settings_config_dict(**kwargs: object) -> dict[str, object]:
|
|
"""Returns configuration values using dict semantics expected by settings module."""
|
|
|
|
return kwargs
|
|
|
|
pydantic_settings_stub.BaseSettings = _BaseSettings
|
|
pydantic_settings_stub.SettingsConfigDict = _settings_config_dict
|
|
sys.modules["pydantic_settings"] = pydantic_settings_stub
|
|
|
|
if "fastapi" not in sys.modules:
|
|
fastapi_stub = ModuleType("fastapi")
|
|
|
|
class _APIRouter:
|
|
"""Minimal APIRouter stand-in supporting decorator registration."""
|
|
|
|
def __init__(self, *args: object, **kwargs: object) -> None:
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def post(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
|
|
"""Returns no-op decorator for POST route declarations."""
|
|
|
|
def decorator(func): # type: ignore[no-untyped-def]
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def get(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
|
|
"""Returns no-op decorator for GET route declarations."""
|
|
|
|
def decorator(func): # type: ignore[no-untyped-def]
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def patch(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
|
|
"""Returns no-op decorator for PATCH route declarations."""
|
|
|
|
def decorator(func): # type: ignore[no-untyped-def]
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def delete(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
|
|
"""Returns no-op decorator for DELETE route declarations."""
|
|
|
|
def decorator(func): # type: ignore[no-untyped-def]
|
|
return func
|
|
|
|
return decorator
|
|
|
|
class _Request:
|
|
"""Minimal request placeholder for route function import compatibility."""
|
|
|
|
class _HTTPException(Exception):
|
|
"""Minimal HTTPException compatible with route dependency tests."""
|
|
|
|
def __init__(self, status_code: int, detail: str, headers: dict[str, str] | None = None) -> None:
|
|
super().__init__(detail)
|
|
self.status_code = status_code
|
|
self.detail = detail
|
|
self.headers = headers or {}
|
|
|
|
class _Status:
|
|
"""Minimal status namespace for auth unit tests."""
|
|
|
|
HTTP_401_UNAUTHORIZED = 401
|
|
HTTP_403_FORBIDDEN = 403
|
|
HTTP_429_TOO_MANY_REQUESTS = 429
|
|
HTTP_503_SERVICE_UNAVAILABLE = 503
|
|
|
|
def _depends(dependency): # type: ignore[no-untyped-def]
|
|
"""Returns provided dependency unchanged for unit testing."""
|
|
|
|
return dependency
|
|
|
|
def _query(default=None, **_kwargs): # type: ignore[no-untyped-def]
|
|
"""Returns FastAPI-like query defaults for dependency-light route imports."""
|
|
|
|
return default
|
|
|
|
def _file(default=None, **_kwargs): # type: ignore[no-untyped-def]
|
|
"""Returns FastAPI-like file defaults for dependency-light route imports."""
|
|
|
|
return default
|
|
|
|
def _form(default=None, **_kwargs): # type: ignore[no-untyped-def]
|
|
"""Returns FastAPI-like form defaults for dependency-light route imports."""
|
|
|
|
return default
|
|
|
|
class _UploadFile:
|
|
"""Minimal UploadFile placeholder for route import compatibility."""
|
|
|
|
fastapi_stub.APIRouter = _APIRouter
|
|
fastapi_stub.Depends = _depends
|
|
fastapi_stub.File = _file
|
|
fastapi_stub.Form = _form
|
|
fastapi_stub.HTTPException = _HTTPException
|
|
fastapi_stub.Query = _query
|
|
fastapi_stub.Request = _Request
|
|
fastapi_stub.UploadFile = _UploadFile
|
|
fastapi_stub.status = _Status()
|
|
sys.modules["fastapi"] = fastapi_stub
|
|
|
|
if "fastapi.responses" not in sys.modules:
|
|
fastapi_responses_stub = ModuleType("fastapi.responses")
|
|
|
|
class _Response:
|
|
"""Minimal response placeholder for route import compatibility."""
|
|
|
|
class _FileResponse(_Response):
|
|
"""Minimal file response placeholder for route import compatibility."""
|
|
|
|
class _StreamingResponse(_Response):
|
|
"""Minimal streaming response placeholder for route import compatibility."""
|
|
|
|
fastapi_responses_stub.Response = _Response
|
|
fastapi_responses_stub.FileResponse = _FileResponse
|
|
fastapi_responses_stub.StreamingResponse = _StreamingResponse
|
|
sys.modules["fastapi.responses"] = fastapi_responses_stub
|
|
|
|
if "fastapi.security" not in sys.modules:
|
|
fastapi_security_stub = ModuleType("fastapi.security")
|
|
|
|
class _HTTPAuthorizationCredentials:
|
|
"""Minimal bearer credential object used by auth dependency tests."""
|
|
|
|
def __init__(self, *, scheme: str, credentials: str) -> None:
|
|
self.scheme = scheme
|
|
self.credentials = credentials
|
|
|
|
class _HTTPBearer:
|
|
"""Minimal HTTPBearer stand-in for dependency construction."""
|
|
|
|
def __init__(self, auto_error: bool = True) -> None:
|
|
self.auto_error = auto_error
|
|
|
|
fastapi_security_stub.HTTPAuthorizationCredentials = _HTTPAuthorizationCredentials
|
|
fastapi_security_stub.HTTPBearer = _HTTPBearer
|
|
sys.modules["fastapi.security"] = fastapi_security_stub
|
|
|
|
if "magic" not in sys.modules:
|
|
magic_stub = ModuleType("magic")
|
|
|
|
def _from_buffer(_data: bytes, mime: bool = True) -> str:
|
|
"""Returns deterministic fallback MIME values for extractor import stubs."""
|
|
|
|
return "application/octet-stream" if mime else ""
|
|
|
|
magic_stub.from_buffer = _from_buffer
|
|
sys.modules["magic"] = magic_stub
|
|
|
|
if "docx" not in sys.modules:
|
|
docx_stub = ModuleType("docx")
|
|
|
|
class _DocxDocument:
|
|
"""Minimal docx document stub for extractor import compatibility."""
|
|
|
|
def __init__(self, *_args: object, **_kwargs: object) -> None:
|
|
self.paragraphs: list[SimpleNamespace] = []
|
|
|
|
docx_stub.Document = _DocxDocument
|
|
sys.modules["docx"] = docx_stub
|
|
|
|
if "openpyxl" not in sys.modules:
|
|
openpyxl_stub = ModuleType("openpyxl")
|
|
|
|
class _Workbook:
|
|
"""Minimal workbook stub for extractor import compatibility."""
|
|
|
|
worksheets: list[SimpleNamespace] = []
|
|
|
|
def _load_workbook(*_args: object, **_kwargs: object) -> _Workbook:
|
|
"""Returns deterministic workbook placeholder for extractor import stubs."""
|
|
|
|
return _Workbook()
|
|
|
|
openpyxl_stub.load_workbook = _load_workbook
|
|
sys.modules["openpyxl"] = openpyxl_stub
|
|
|
|
if "PIL" not in sys.modules:
|
|
pil_stub = ModuleType("PIL")
|
|
|
|
class _Image:
|
|
"""Minimal PIL.Image replacement for extractor and handwriting import stubs."""
|
|
|
|
class Resampling:
|
|
"""Minimal enum-like namespace used by handwriting image resize path."""
|
|
|
|
LANCZOS = 1
|
|
|
|
@staticmethod
|
|
def open(*_args: object, **_kwargs: object) -> "_Image":
|
|
"""Raises for unsupported image operations in dependency-light tests."""
|
|
|
|
raise RuntimeError("Image.open is not available in stub")
|
|
|
|
class _ImageOps:
|
|
"""Minimal PIL.ImageOps replacement for import compatibility."""
|
|
|
|
@staticmethod
|
|
def exif_transpose(image: object) -> object:
|
|
"""Returns original image object unchanged in dependency-light tests."""
|
|
|
|
return image
|
|
|
|
pil_stub.Image = _Image
|
|
pil_stub.ImageOps = _ImageOps
|
|
sys.modules["PIL"] = pil_stub
|
|
|
|
if "pypdf" not in sys.modules:
|
|
pypdf_stub = ModuleType("pypdf")
|
|
|
|
class _PdfReader:
|
|
"""Minimal PdfReader replacement for extractor import compatibility."""
|
|
|
|
def __init__(self, *_args: object, **_kwargs: object) -> None:
|
|
self.pages: list[SimpleNamespace] = []
|
|
|
|
pypdf_stub.PdfReader = _PdfReader
|
|
sys.modules["pypdf"] = pypdf_stub
|
|
|
|
if "pymupdf" not in sys.modules:
|
|
pymupdf_stub = ModuleType("pymupdf")
|
|
|
|
class _Matrix:
|
|
"""Minimal matrix placeholder for extractor import compatibility."""
|
|
|
|
def __init__(self, *_args: object, **_kwargs: object) -> None:
|
|
pass
|
|
|
|
def _open(*_args: object, **_kwargs: object) -> object:
|
|
"""Raises when preview rendering is invoked in dependency-light tests."""
|
|
|
|
raise RuntimeError("pymupdf is not available in stub")
|
|
|
|
pymupdf_stub.Matrix = _Matrix
|
|
pymupdf_stub.open = _open
|
|
sys.modules["pymupdf"] = pymupdf_stub
|
|
|
|
if "app.services.handwriting" not in sys.modules:
|
|
handwriting_stub = ModuleType("app.services.handwriting")
|
|
|
|
class _HandwritingError(Exception):
|
|
"""Minimal base error class for extractor import compatibility."""
|
|
|
|
class _HandwritingNotConfiguredError(_HandwritingError):
|
|
"""Minimal not-configured error class for extractor import compatibility."""
|
|
|
|
class _HandwritingTimeoutError(_HandwritingError):
|
|
"""Minimal timeout error class for extractor import compatibility."""
|
|
|
|
def _classify_image_text_bytes(*_args: object, **_kwargs: object) -> SimpleNamespace:
|
|
"""Returns deterministic image text classification fallback."""
|
|
|
|
return SimpleNamespace(label="unknown", confidence=0.0, provider="stub", model="stub")
|
|
|
|
def _transcribe_handwriting_bytes(*_args: object, **_kwargs: object) -> SimpleNamespace:
|
|
"""Returns deterministic handwriting transcription fallback."""
|
|
|
|
return SimpleNamespace(text="", uncertainties=[], provider="stub", model="stub")
|
|
|
|
handwriting_stub.IMAGE_TEXT_TYPE_NO_TEXT = "no_text"
|
|
handwriting_stub.IMAGE_TEXT_TYPE_UNKNOWN = "unknown"
|
|
handwriting_stub.IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
|
|
handwriting_stub.HandwritingTranscriptionError = _HandwritingError
|
|
handwriting_stub.HandwritingTranscriptionNotConfiguredError = _HandwritingNotConfiguredError
|
|
handwriting_stub.HandwritingTranscriptionTimeoutError = _HandwritingTimeoutError
|
|
handwriting_stub.classify_image_text_bytes = _classify_image_text_bytes
|
|
handwriting_stub.transcribe_handwriting_bytes = _transcribe_handwriting_bytes
|
|
sys.modules["app.services.handwriting"] = handwriting_stub
|
|
|
|
if "app.services.handwriting_style" not in sys.modules:
|
|
handwriting_style_stub = ModuleType("app.services.handwriting_style")
|
|
|
|
def _assign_handwriting_style(*_args: object, **_kwargs: object) -> SimpleNamespace:
|
|
"""Returns deterministic style assignment payload for worker import compatibility."""
|
|
|
|
return SimpleNamespace(
|
|
style_cluster_id="cluster-1",
|
|
matched_existing=False,
|
|
similarity=0.0,
|
|
vector_distance=0.0,
|
|
compared_neighbors=0,
|
|
match_min_similarity=0.0,
|
|
bootstrap_match_min_similarity=0.0,
|
|
)
|
|
|
|
def _delete_handwriting_style_document(*_args: object, **_kwargs: object) -> None:
|
|
"""No-op style document delete stub for worker import compatibility."""
|
|
|
|
return None
|
|
|
|
def _delete_many_handwriting_style_documents(*_args: object, **_kwargs: object) -> None:
|
|
"""No-op bulk style document delete stub for route import compatibility."""
|
|
|
|
return None
|
|
|
|
handwriting_style_stub.assign_handwriting_style = _assign_handwriting_style
|
|
handwriting_style_stub.delete_handwriting_style_document = _delete_handwriting_style_document
|
|
handwriting_style_stub.delete_many_handwriting_style_documents = _delete_many_handwriting_style_documents
|
|
sys.modules["app.services.handwriting_style"] = handwriting_style_stub
|
|
|
|
if "app.services.routing_pipeline" not in sys.modules:
|
|
routing_pipeline_stub = ModuleType("app.services.routing_pipeline")
|
|
|
|
def _apply_routing_decision(*_args: object, **_kwargs: object) -> None:
|
|
"""No-op routing application stub for worker import compatibility."""
|
|
|
|
return None
|
|
|
|
def _classify_document_routing(*_args: object, **_kwargs: object) -> dict[str, object]:
|
|
"""Returns deterministic routing decision payload for worker import compatibility."""
|
|
|
|
return {"chosen_path": None, "chosen_tags": []}
|
|
|
|
def _summarize_document(*_args: object, **_kwargs: object) -> str:
|
|
"""Returns deterministic summary text for worker import compatibility."""
|
|
|
|
return "summary"
|
|
|
|
def _upsert_semantic_index(*_args: object, **_kwargs: object) -> None:
|
|
"""No-op semantic index update stub for worker import compatibility."""
|
|
|
|
return None
|
|
|
|
routing_pipeline_stub.apply_routing_decision = _apply_routing_decision
|
|
routing_pipeline_stub.classify_document_routing = _classify_document_routing
|
|
routing_pipeline_stub.summarize_document = _summarize_document
|
|
routing_pipeline_stub.upsert_semantic_index = _upsert_semantic_index
|
|
sys.modules["app.services.routing_pipeline"] = routing_pipeline_stub
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from app.api.auth import AuthContext, require_admin
|
|
from app.api import auth as auth_dependency_module
|
|
from app.api import routes_auth as auth_routes_module
|
|
from app.api import routes_documents as documents_routes_module
|
|
from app.core import config as config_module
|
|
from app.models.auth import UserRole
|
|
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
|
|
from app.schemas.processing_logs import ProcessingLogEntryResponse
|
|
from app.services import auth_login_throttle as auth_login_throttle_module
|
|
from app.services import extractor as extractor_module
|
|
from app.worker import tasks as worker_tasks_module
|
|
|
|
|
|
def _security_settings(
|
|
*,
|
|
allowlist: list[str] | None = None,
|
|
allow_http: bool = False,
|
|
allow_private_network: bool = False,
|
|
) -> SimpleNamespace:
|
|
"""Builds lightweight settings object for provider URL validation tests."""
|
|
|
|
return SimpleNamespace(
|
|
provider_base_url_allowlist=allowlist if allowlist is not None else ["api.openai.com"],
|
|
provider_base_url_allow_http=allow_http,
|
|
provider_base_url_allow_private_network=allow_private_network,
|
|
)
|
|
|
|
|
|
class AuthDependencyTests(unittest.TestCase):
|
|
"""Verifies role-based admin authorization behavior."""
|
|
|
|
def test_require_admin_rejects_user_role(self) -> None:
|
|
"""User role cannot access admin-only endpoints."""
|
|
|
|
auth_context = AuthContext(
|
|
user_id=uuid.uuid4(),
|
|
username="user",
|
|
role=UserRole.USER,
|
|
session_id=uuid.uuid4(),
|
|
expires_at=datetime.now(UTC),
|
|
)
|
|
with self.assertRaises(HTTPException) as raised:
|
|
require_admin(context=auth_context)
|
|
self.assertEqual(raised.exception.status_code, 403)
|
|
|
|
def test_require_admin_accepts_admin_role(self) -> None:
|
|
"""Admin role is accepted for admin-only endpoints."""
|
|
|
|
auth_context = AuthContext(
|
|
user_id=uuid.uuid4(),
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
session_id=uuid.uuid4(),
|
|
expires_at=datetime.now(UTC),
|
|
)
|
|
resolved = require_admin(context=auth_context)
|
|
self.assertEqual(resolved.role, UserRole.ADMIN)
|
|
|
|
def test_csrf_validation_accepts_matching_token_among_duplicate_cookie_values(self) -> None:
|
|
"""PATCH CSRF validation accepts header token matching any duplicate csrf cookie value."""
|
|
|
|
request = SimpleNamespace(
|
|
method="PATCH",
|
|
headers={"cookie": "dcm_session=session-token; dcm_csrf=stale-token; dcm_csrf=fresh-token"},
|
|
)
|
|
resolved_session = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
expires_at=datetime.now(UTC),
|
|
user=SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
),
|
|
)
|
|
with patch.object(auth_dependency_module, "resolve_auth_session", return_value=resolved_session):
|
|
context = auth_dependency_module.get_request_auth_context(
|
|
request=request,
|
|
credentials=None,
|
|
csrf_header="fresh-token",
|
|
csrf_cookie="stale-token",
|
|
session_cookie="session-token",
|
|
session=SimpleNamespace(),
|
|
)
|
|
self.assertEqual(context.username, "admin")
|
|
self.assertEqual(context.role, UserRole.ADMIN)
|
|
|
|
def test_csrf_validation_rejects_when_header_does_not_match_any_cookie_value(self) -> None:
|
|
"""PATCH CSRF validation rejects requests when header token matches no csrf cookie values."""
|
|
|
|
request = SimpleNamespace(
|
|
method="PATCH",
|
|
headers={"cookie": "dcm_session=session-token; dcm_csrf=stale-token; dcm_csrf=fresh-token"},
|
|
)
|
|
resolved_session = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
expires_at=datetime.now(UTC),
|
|
user=SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
),
|
|
)
|
|
with patch.object(auth_dependency_module, "resolve_auth_session", return_value=resolved_session):
|
|
with self.assertRaises(HTTPException) as raised:
|
|
auth_dependency_module.get_request_auth_context(
|
|
request=request,
|
|
credentials=None,
|
|
csrf_header="unknown-token",
|
|
csrf_cookie="stale-token",
|
|
session_cookie="session-token",
|
|
session=SimpleNamespace(),
|
|
)
|
|
self.assertEqual(raised.exception.status_code, 403)
|
|
self.assertEqual(raised.exception.detail, "Invalid CSRF token")
|
|
|
|
def test_cookie_auth_accepts_matching_session_among_duplicate_cookie_values(self) -> None:
|
|
"""Cookie auth accepts the first valid session token among duplicate cookie values."""
|
|
|
|
request = SimpleNamespace(
|
|
method="GET",
|
|
headers={"cookie": "dcm_session=stale-token; dcm_session=fresh-token"},
|
|
)
|
|
resolved_session = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
expires_at=datetime.now(UTC),
|
|
user=SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
),
|
|
)
|
|
with patch.object(
|
|
auth_dependency_module,
|
|
"resolve_auth_session",
|
|
side_effect=[None, resolved_session],
|
|
) as resolve_mock:
|
|
context = auth_dependency_module.get_request_auth_context(
|
|
request=request,
|
|
credentials=None,
|
|
csrf_header=None,
|
|
csrf_cookie=None,
|
|
session_cookie="stale-token",
|
|
session=SimpleNamespace(),
|
|
)
|
|
self.assertEqual(context.username, "admin")
|
|
self.assertEqual(context.role, UserRole.ADMIN)
|
|
self.assertEqual(resolve_mock.call_count, 2)
|
|
|
|
|
|
class DocumentCatalogVisibilityTests(unittest.TestCase):
|
|
"""Verifies predefined tag and path discovery visibility by caller role."""
|
|
|
|
class _ScalarSequence:
|
|
"""Provides SQLAlchemy-like scalar result chaining for route unit tests."""
|
|
|
|
def __init__(self, values: list[object]) -> None:
|
|
self._values = values
|
|
|
|
def scalars(self) -> "DocumentCatalogVisibilityTests._ScalarSequence":
|
|
"""Returns self to emulate `.scalars().all()` call chains."""
|
|
|
|
return self
|
|
|
|
def all(self) -> list[object]:
|
|
"""Returns deterministic sequence values for route helper queries."""
|
|
|
|
return list(self._values)
|
|
|
|
class _SessionStub:
|
|
"""Returns a fixed scalar sequence for route metadata queries."""
|
|
|
|
def __init__(self, values: list[object]) -> None:
|
|
self._values = values
|
|
|
|
def execute(self, _statement: object) -> "DocumentCatalogVisibilityTests._ScalarSequence":
|
|
"""Ignores query details and returns deterministic scalar sequence results."""
|
|
|
|
return DocumentCatalogVisibilityTests._ScalarSequence(self._values)
|
|
|
|
@staticmethod
|
|
def _auth_context(role: UserRole) -> AuthContext:
|
|
"""Builds deterministic auth context fixtures for document discovery tests."""
|
|
|
|
return AuthContext(
|
|
user_id=uuid.uuid4(),
|
|
username=f"{role.value}-user",
|
|
role=role,
|
|
session_id=uuid.uuid4(),
|
|
expires_at=datetime.now(UTC),
|
|
)
|
|
|
|
def test_non_admin_only_receives_global_shared_predefined_tags_and_paths(self) -> None:
|
|
"""User role receives only globally shared predefined values in discovery responses."""
|
|
|
|
session = self._SessionStub(
|
|
values=[
|
|
["owner-tag", ""],
|
|
["owner-only-tag"],
|
|
]
|
|
)
|
|
predefined_tags = [
|
|
{"value": "SharedTag", "global_shared": True},
|
|
{"value": "InternalTag", "global_shared": False},
|
|
{"value": "ImplicitPrivateTag"},
|
|
]
|
|
predefined_paths = [
|
|
{"value": "Shared/Path", "global_shared": True},
|
|
{"value": "Internal/Path", "global_shared": False},
|
|
{"value": "Implicit/Private"},
|
|
]
|
|
|
|
with (
|
|
patch.object(documents_routes_module, "read_predefined_tags_settings", return_value=predefined_tags),
|
|
patch.object(documents_routes_module, "read_predefined_paths_settings", return_value=predefined_paths),
|
|
):
|
|
tags_response = documents_routes_module.list_tags(
|
|
include_trashed=False,
|
|
auth_context=self._auth_context(UserRole.USER),
|
|
session=session,
|
|
)
|
|
paths_response = documents_routes_module.list_paths(
|
|
include_trashed=False,
|
|
auth_context=self._auth_context(UserRole.USER),
|
|
session=self._SessionStub(values=["Owner/Path"]),
|
|
)
|
|
|
|
self.assertEqual(tags_response["tags"], ["SharedTag", "owner-only-tag", "owner-tag"])
|
|
self.assertEqual(paths_response["paths"], ["Owner/Path", "Shared/Path"])
|
|
|
|
def test_admin_receives_full_predefined_tags_and_paths_catalog(self) -> None:
|
|
"""Admin role receives full predefined values regardless of global-sharing scope."""
|
|
|
|
predefined_tags = [
|
|
{"value": "SharedTag", "global_shared": True},
|
|
{"value": "InternalTag", "global_shared": False},
|
|
{"value": "ImplicitPrivateTag"},
|
|
]
|
|
predefined_paths = [
|
|
{"value": "Shared/Path", "global_shared": True},
|
|
{"value": "Internal/Path", "global_shared": False},
|
|
{"value": "Implicit/Private"},
|
|
]
|
|
|
|
with (
|
|
patch.object(documents_routes_module, "read_predefined_tags_settings", return_value=predefined_tags),
|
|
patch.object(documents_routes_module, "read_predefined_paths_settings", return_value=predefined_paths),
|
|
):
|
|
tags_response = documents_routes_module.list_tags(
|
|
include_trashed=False,
|
|
auth_context=self._auth_context(UserRole.ADMIN),
|
|
session=self._SessionStub(values=[["admin-tag"]]),
|
|
)
|
|
paths_response = documents_routes_module.list_paths(
|
|
include_trashed=False,
|
|
auth_context=self._auth_context(UserRole.ADMIN),
|
|
session=self._SessionStub(values=["Admin/Path"]),
|
|
)
|
|
|
|
self.assertEqual(
|
|
tags_response["tags"],
|
|
["ImplicitPrivateTag", "InternalTag", "SharedTag", "admin-tag"],
|
|
)
|
|
self.assertEqual(
|
|
paths_response["paths"],
|
|
["Admin/Path", "Implicit/Private", "Internal/Path", "Shared/Path"],
|
|
)
|
|
|
|
|
|
class _FakeRedisPipeline:
|
|
"""Provides deterministic Redis pipeline behavior for login throttle tests."""
|
|
|
|
def __init__(self, redis_client: "_FakeRedis") -> None:
|
|
self._redis_client = redis_client
|
|
self._operations: list[tuple[str, tuple[object, ...]]] = []
|
|
|
|
def incr(self, key: str, amount: int) -> "_FakeRedisPipeline":
|
|
"""Queues one counter increment operation for pipeline execution."""
|
|
|
|
self._operations.append(("incr", (key, amount)))
|
|
return self
|
|
|
|
def expire(self, key: str, ttl_seconds: int) -> "_FakeRedisPipeline":
|
|
"""Queues one key expiration operation for pipeline execution."""
|
|
|
|
self._operations.append(("expire", (key, ttl_seconds)))
|
|
return self
|
|
|
|
def execute(self) -> list[object]:
|
|
"""Executes queued operations in order and returns Redis-like result values."""
|
|
|
|
results: list[object] = []
|
|
for operation, arguments in self._operations:
|
|
if operation == "incr":
|
|
key, amount = arguments
|
|
previous = int(self._redis_client.values.get(str(key), 0))
|
|
updated = previous + int(amount)
|
|
self._redis_client.values[str(key)] = updated
|
|
results.append(updated)
|
|
elif operation == "expire":
|
|
key, ttl_seconds = arguments
|
|
self._redis_client.ttl_seconds[str(key)] = int(ttl_seconds)
|
|
results.append(True)
|
|
return results
|
|
|
|
|
|
class _FakeRedis:
|
|
"""In-memory Redis replacement with TTL behavior needed by throttle tests."""
|
|
|
|
def __init__(self) -> None:
|
|
self.values: dict[str, object] = {}
|
|
self.ttl_seconds: dict[str, int] = {}
|
|
|
|
def pipeline(self, transaction: bool = True) -> _FakeRedisPipeline:
|
|
"""Creates a fake transaction pipeline for grouped increment operations."""
|
|
|
|
_ = transaction
|
|
return _FakeRedisPipeline(self)
|
|
|
|
def set(self, key: str, value: str, ex: int | None = None) -> bool:
|
|
"""Stores key values and optional TTL metadata used by lockout keys."""
|
|
|
|
self.values[key] = value
|
|
if ex is not None:
|
|
self.ttl_seconds[key] = int(ex)
|
|
return True
|
|
|
|
def ttl(self, key: str) -> int:
|
|
"""Returns TTL for existing keys or Redis-compatible missing-key indicator."""
|
|
|
|
return int(self.ttl_seconds.get(key, -2))
|
|
|
|
def delete(self, *keys: str) -> int:
|
|
"""Deletes keys and returns number of removed entries."""
|
|
|
|
removed_count = 0
|
|
for key in keys:
|
|
if key in self.values:
|
|
self.values.pop(key, None)
|
|
removed_count += 1
|
|
self.ttl_seconds.pop(key, None)
|
|
return removed_count
|
|
|
|
|
|
def _login_throttle_settings(
|
|
*,
|
|
failure_limit: int = 2,
|
|
failure_window_seconds: int = 60,
|
|
lockout_base_seconds: int = 10,
|
|
lockout_max_seconds: int = 40,
|
|
) -> SimpleNamespace:
|
|
"""Builds deterministic login-throttle settings for service-level unit coverage."""
|
|
|
|
return SimpleNamespace(
|
|
auth_login_failure_limit=failure_limit,
|
|
auth_login_failure_window_seconds=failure_window_seconds,
|
|
auth_login_lockout_base_seconds=lockout_base_seconds,
|
|
auth_login_lockout_max_seconds=lockout_max_seconds,
|
|
)
|
|
|
|
|
|
class AuthLoginThrottleServiceTests(unittest.TestCase):
|
|
"""Verifies login throttle lockout progression, cap behavior, and clear semantics."""
|
|
|
|
def test_failed_attempts_trigger_lockout_after_limit(self) -> None:
|
|
"""Failed attempts beyond configured limit activate login lockouts."""
|
|
|
|
fake_redis = _FakeRedis()
|
|
with (
|
|
patch.object(
|
|
auth_login_throttle_module,
|
|
"get_settings",
|
|
return_value=_login_throttle_settings(failure_limit=2, lockout_base_seconds=12),
|
|
),
|
|
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
|
|
):
|
|
self.assertEqual(
|
|
auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="203.0.113.10",
|
|
),
|
|
0,
|
|
)
|
|
self.assertEqual(
|
|
auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="203.0.113.10",
|
|
),
|
|
0,
|
|
)
|
|
lockout_seconds = auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="203.0.113.10",
|
|
)
|
|
status = auth_login_throttle_module.check_login_throttle(
|
|
username="admin",
|
|
ip_address="203.0.113.10",
|
|
)
|
|
|
|
self.assertEqual(lockout_seconds, 12)
|
|
self.assertTrue(status.is_throttled)
|
|
self.assertEqual(status.retry_after_seconds, 12)
|
|
|
|
def test_lockout_duration_escalates_and_respects_max_cap(self) -> None:
|
|
"""Repeated failures after threshold double lockout duration up to configured maximum."""
|
|
|
|
fake_redis = _FakeRedis()
|
|
with (
|
|
patch.object(
|
|
auth_login_throttle_module,
|
|
"get_settings",
|
|
return_value=_login_throttle_settings(
|
|
failure_limit=1,
|
|
lockout_base_seconds=10,
|
|
lockout_max_seconds=25,
|
|
),
|
|
),
|
|
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
|
|
):
|
|
first_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="198.51.100.15",
|
|
)
|
|
second_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="198.51.100.15",
|
|
)
|
|
third_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="198.51.100.15",
|
|
)
|
|
fourth_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="198.51.100.15",
|
|
)
|
|
|
|
self.assertEqual(first_lockout, 0)
|
|
self.assertEqual(second_lockout, 10)
|
|
self.assertEqual(third_lockout, 20)
|
|
self.assertEqual(fourth_lockout, 25)
|
|
|
|
def test_clear_login_throttle_removes_active_lockout_state(self) -> None:
|
|
"""Successful login clears active lockout keys for username and IP subjects."""
|
|
|
|
fake_redis = _FakeRedis()
|
|
with (
|
|
patch.object(
|
|
auth_login_throttle_module,
|
|
"get_settings",
|
|
return_value=_login_throttle_settings(
|
|
failure_limit=1,
|
|
lockout_base_seconds=15,
|
|
lockout_max_seconds=30,
|
|
),
|
|
),
|
|
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
|
|
):
|
|
auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="192.0.2.20",
|
|
)
|
|
auth_login_throttle_module.record_failed_login_attempt(
|
|
username="admin",
|
|
ip_address="192.0.2.20",
|
|
)
|
|
throttled_before_clear = auth_login_throttle_module.check_login_throttle(
|
|
username="admin",
|
|
ip_address="192.0.2.20",
|
|
)
|
|
auth_login_throttle_module.clear_login_throttle(
|
|
username="admin",
|
|
ip_address="192.0.2.20",
|
|
)
|
|
throttled_after_clear = auth_login_throttle_module.check_login_throttle(
|
|
username="admin",
|
|
ip_address="192.0.2.20",
|
|
)
|
|
|
|
self.assertTrue(throttled_before_clear.is_throttled)
|
|
self.assertFalse(throttled_after_clear.is_throttled)
|
|
self.assertEqual(throttled_after_clear.retry_after_seconds, 0)
|
|
|
|
def test_backend_errors_raise_runtime_error(self) -> None:
|
|
"""Redis backend failures are surfaced as RuntimeError for caller fail-closed handling."""
|
|
|
|
class _BrokenRedis:
|
|
"""Raises RedisError for all Redis interactions used by login throttle service."""
|
|
|
|
def ttl(self, _key: str) -> int:
|
|
raise auth_login_throttle_module.RedisError("redis unavailable")
|
|
|
|
with patch.object(auth_login_throttle_module, "get_redis", return_value=_BrokenRedis()):
|
|
with self.assertRaises(RuntimeError):
|
|
auth_login_throttle_module.check_login_throttle(
|
|
username="admin",
|
|
ip_address="203.0.113.88",
|
|
)
|
|
|
|
|
|
class AuthLoginRouteThrottleTests(unittest.TestCase):
|
|
"""Verifies `/auth/login` route throttle responses and success-flow state clearing."""
|
|
|
|
class _SessionStub:
|
|
"""Tracks commit calls for route-level login tests without database dependencies."""
|
|
|
|
def __init__(self) -> None:
|
|
self.commit_count = 0
|
|
|
|
def commit(self) -> None:
|
|
"""Records one commit invocation."""
|
|
|
|
self.commit_count += 1
|
|
|
|
class _ResponseStub:
|
|
"""Captures response cookie calls for direct route invocation tests."""
|
|
|
|
def __init__(self) -> None:
|
|
self.set_cookie_calls: list[tuple[tuple[object, ...], dict[str, object]]] = []
|
|
self.delete_cookie_calls: list[tuple[tuple[object, ...], dict[str, object]]] = []
|
|
|
|
def set_cookie(self, *args: object, **kwargs: object) -> None:
|
|
"""Records one set-cookie call."""
|
|
|
|
self.set_cookie_calls.append((args, kwargs))
|
|
|
|
def delete_cookie(self, *args: object, **kwargs: object) -> None:
|
|
"""Records one delete-cookie call."""
|
|
|
|
self.delete_cookie_calls.append((args, kwargs))
|
|
|
|
@classmethod
|
|
def _response_stub(cls) -> "AuthLoginRouteThrottleTests._ResponseStub":
|
|
"""Builds a minimal response object for direct route invocation."""
|
|
|
|
return cls._ResponseStub()
|
|
|
|
@staticmethod
|
|
def _request_stub(
|
|
ip_address: str = "203.0.113.2",
|
|
user_agent: str = "unit-test",
|
|
origin: str | None = None,
|
|
) -> SimpleNamespace:
|
|
"""Builds request-like object containing client host and user-agent header fields."""
|
|
|
|
headers = {"user-agent": user_agent}
|
|
if origin:
|
|
headers["origin"] = origin
|
|
return SimpleNamespace(
|
|
client=SimpleNamespace(host=ip_address),
|
|
headers=headers,
|
|
url=SimpleNamespace(hostname="api.docs.lan"),
|
|
)
|
|
|
|
def test_login_rejects_when_precheck_reports_active_throttle(self) -> None:
|
|
"""Pre-auth throttle checks return a stable 429 response without credential lookup."""
|
|
|
|
payload = auth_routes_module.AuthLoginRequest(username="admin", password="bad-password")
|
|
session = self._SessionStub()
|
|
throttled = auth_login_throttle_module.LoginThrottleStatus(
|
|
is_throttled=True,
|
|
retry_after_seconds=21,
|
|
)
|
|
with (
|
|
patch.object(auth_routes_module, "check_login_throttle", return_value=throttled),
|
|
patch.object(auth_routes_module, "authenticate_user") as authenticate_mock,
|
|
):
|
|
with self.assertRaises(HTTPException) as raised:
|
|
auth_routes_module.login(
|
|
payload=payload,
|
|
request=self._request_stub(),
|
|
response=self._response_stub(),
|
|
session=session,
|
|
)
|
|
self.assertEqual(raised.exception.status_code, 429)
|
|
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_THROTTLED_DETAIL)
|
|
self.assertEqual(raised.exception.headers.get("Retry-After"), "21")
|
|
authenticate_mock.assert_not_called()
|
|
self.assertEqual(session.commit_count, 0)
|
|
|
|
def test_login_returns_throttle_response_when_failure_crosses_limit(self) -> None:
|
|
"""Failed credentials return stable 429 response once lockout threshold is crossed."""
|
|
|
|
payload = auth_routes_module.AuthLoginRequest(username="admin", password="bad-password")
|
|
session = self._SessionStub()
|
|
with (
|
|
patch.object(
|
|
auth_routes_module,
|
|
"check_login_throttle",
|
|
return_value=auth_login_throttle_module.LoginThrottleStatus(
|
|
is_throttled=False,
|
|
retry_after_seconds=0,
|
|
),
|
|
),
|
|
patch.object(auth_routes_module, "authenticate_user", return_value=None),
|
|
patch.object(auth_routes_module, "record_failed_login_attempt", return_value=30),
|
|
):
|
|
with self.assertRaises(HTTPException) as raised:
|
|
auth_routes_module.login(
|
|
payload=payload,
|
|
request=self._request_stub(),
|
|
response=self._response_stub(),
|
|
session=session,
|
|
)
|
|
self.assertEqual(raised.exception.status_code, 429)
|
|
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_THROTTLED_DETAIL)
|
|
self.assertEqual(raised.exception.headers.get("Retry-After"), "30")
|
|
self.assertEqual(session.commit_count, 0)
|
|
|
|
def test_login_clears_throttle_state_after_successful_authentication(self) -> None:
|
|
"""Successful login clears throttle state and commits issued session token."""
|
|
|
|
payload = auth_routes_module.AuthLoginRequest(username="admin", password="correct-password")
|
|
session = self._SessionStub()
|
|
fake_user = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
)
|
|
fake_session = SimpleNamespace(
|
|
token="session-token",
|
|
expires_at=datetime.now(UTC),
|
|
)
|
|
with (
|
|
patch.object(
|
|
auth_routes_module,
|
|
"check_login_throttle",
|
|
return_value=auth_login_throttle_module.LoginThrottleStatus(
|
|
is_throttled=False,
|
|
retry_after_seconds=0,
|
|
),
|
|
),
|
|
patch.object(auth_routes_module, "authenticate_user", return_value=fake_user),
|
|
patch.object(auth_routes_module, "clear_login_throttle") as clear_mock,
|
|
patch.object(auth_routes_module, "issue_user_session", return_value=fake_session),
|
|
):
|
|
response = auth_routes_module.login(
|
|
payload=payload,
|
|
request=self._request_stub(),
|
|
response=self._response_stub(),
|
|
session=session,
|
|
)
|
|
self.assertEqual(response.access_token, "session-token")
|
|
self.assertEqual(response.user.username, "admin")
|
|
clear_mock.assert_called_once()
|
|
self.assertEqual(session.commit_count, 1)
|
|
|
|
def test_login_returns_503_when_throttle_backend_is_unavailable(self) -> None:
|
|
"""Throttle backend errors fail closed with a deterministic 503 login response."""
|
|
|
|
payload = auth_routes_module.AuthLoginRequest(username="admin", password="password")
|
|
session = self._SessionStub()
|
|
with patch.object(auth_routes_module, "check_login_throttle", side_effect=RuntimeError("redis down")):
|
|
with self.assertRaises(HTTPException) as raised:
|
|
auth_routes_module.login(
|
|
payload=payload,
|
|
request=self._request_stub(),
|
|
response=self._response_stub(),
|
|
session=session,
|
|
)
|
|
self.assertEqual(raised.exception.status_code, 503)
|
|
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_RATE_LIMITER_UNAVAILABLE_DETAIL)
|
|
self.assertEqual(session.commit_count, 0)
|
|
|
|
def test_login_sets_host_only_and_parent_domain_cookie_variants(self) -> None:
|
|
"""Successful login sets a host-only cookie and an optional parent-domain mirror."""
|
|
|
|
payload = auth_routes_module.AuthLoginRequest(username="admin", password="correct-password")
|
|
session = self._SessionStub()
|
|
response_stub = self._response_stub()
|
|
fake_user = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
)
|
|
fake_session = SimpleNamespace(
|
|
token="session-token",
|
|
expires_at=datetime.now(UTC),
|
|
)
|
|
fake_settings = SimpleNamespace(
|
|
auth_cookie_domain="docs.lan",
|
|
auth_cookie_samesite="none",
|
|
public_base_url="https://api.docs.lan",
|
|
)
|
|
with (
|
|
patch.object(
|
|
auth_routes_module,
|
|
"check_login_throttle",
|
|
return_value=auth_login_throttle_module.LoginThrottleStatus(
|
|
is_throttled=False,
|
|
retry_after_seconds=0,
|
|
),
|
|
),
|
|
patch.object(auth_routes_module, "authenticate_user", return_value=fake_user),
|
|
patch.object(auth_routes_module, "clear_login_throttle"),
|
|
patch.object(auth_routes_module, "issue_user_session", return_value=fake_session),
|
|
patch.object(auth_routes_module, "get_settings", return_value=fake_settings),
|
|
patch.object(auth_routes_module.secrets, "token_urlsafe", return_value="csrf-token"),
|
|
):
|
|
auth_routes_module.login(
|
|
payload=payload,
|
|
request=self._request_stub(origin="https://docs.lan"),
|
|
response=response_stub,
|
|
session=session,
|
|
)
|
|
|
|
session_cookie_calls = [call for call in response_stub.set_cookie_calls if call[0][0] == auth_routes_module.SESSION_COOKIE_NAME]
|
|
csrf_cookie_calls = [call for call in response_stub.set_cookie_calls if call[0][0] == auth_routes_module.CSRF_COOKIE_NAME]
|
|
self.assertEqual(len(session_cookie_calls), 2)
|
|
self.assertEqual(len(csrf_cookie_calls), 2)
|
|
self.assertFalse(any("domain" in kwargs and kwargs["domain"] is None for _args, kwargs in session_cookie_calls))
|
|
self.assertIn("domain", session_cookie_calls[1][1])
|
|
self.assertEqual(session_cookie_calls[1][1]["domain"], "docs.lan")
|
|
self.assertEqual(session_cookie_calls[0][1]["samesite"], "lax")
|
|
|
|
|
|
class ProviderBaseUrlValidationTests(unittest.TestCase):
|
|
"""Verifies allowlist, scheme, and private-network SSRF protections."""
|
|
|
|
def setUp(self) -> None:
|
|
"""Clears URL validation cache to keep tests independent."""
|
|
|
|
config_module._normalize_and_validate_provider_base_url_cached.cache_clear()
|
|
|
|
def test_validation_accepts_allowlisted_https_url(self) -> None:
|
|
"""Allowlisted HTTPS URLs are normalized with /v1 suffix."""
|
|
|
|
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])):
|
|
normalized = config_module.normalize_and_validate_provider_base_url("https://api.openai.com")
|
|
self.assertEqual(normalized, "https://api.openai.com/v1")
|
|
|
|
def test_validation_rejects_non_allowlisted_host(self) -> None:
|
|
"""Hosts outside configured allowlist are rejected."""
|
|
|
|
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])):
|
|
with self.assertRaises(ValueError):
|
|
config_module.normalize_and_validate_provider_base_url("https://example.org/v1")
|
|
|
|
def test_validation_rejects_private_ip_literal(self) -> None:
|
|
"""Private and loopback IP literals are blocked."""
|
|
|
|
with patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=[])):
|
|
with self.assertRaises(ValueError):
|
|
config_module.normalize_and_validate_provider_base_url("https://127.0.0.1/v1")
|
|
|
|
def test_validation_rejects_private_ip_after_dns_resolution(self) -> None:
|
|
"""DNS rebind protection blocks public hostnames resolving to private addresses."""
|
|
|
|
mocked_dns_response = [
|
|
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", ("127.0.0.1", 443)),
|
|
]
|
|
with (
|
|
patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])),
|
|
patch.object(config_module.socket, "getaddrinfo", return_value=mocked_dns_response),
|
|
):
|
|
with self.assertRaises(ValueError):
|
|
config_module.normalize_and_validate_provider_base_url(
|
|
"https://api.openai.com/v1",
|
|
resolve_dns=True,
|
|
)
|
|
|
|
def test_resolve_dns_validation_revalidates_each_call(self) -> None:
|
|
"""DNS-resolved validation is not cached and re-checks host resolution each call."""
|
|
|
|
mocked_dns_response = [
|
|
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", ("8.8.8.8", 443)),
|
|
]
|
|
with (
|
|
patch.object(config_module, "get_settings", return_value=_security_settings(allowlist=["api.openai.com"])),
|
|
patch.object(config_module.socket, "getaddrinfo", return_value=mocked_dns_response) as getaddrinfo_mock,
|
|
):
|
|
first = config_module.normalize_and_validate_provider_base_url(
|
|
"https://api.openai.com/v1",
|
|
resolve_dns=True,
|
|
)
|
|
second = config_module.normalize_and_validate_provider_base_url(
|
|
"https://api.openai.com/v1",
|
|
resolve_dns=True,
|
|
)
|
|
self.assertEqual(first, "https://api.openai.com/v1")
|
|
self.assertEqual(second, "https://api.openai.com/v1")
|
|
self.assertEqual(getaddrinfo_mock.call_count, 2)
|
|
|
|
|
|
class RedisQueueSecurityTests(unittest.TestCase):
|
|
"""Verifies Redis URL security policy behavior for compatibility and strict environments."""
|
|
|
|
def test_auto_mode_allows_insecure_redis_url_in_development(self) -> None:
|
|
"""Development mode stays backward-compatible with local unauthenticated redis URLs."""
|
|
|
|
normalized = config_module.validate_redis_url_security(
|
|
"redis://redis:6379/0",
|
|
app_env="development",
|
|
security_mode="auto",
|
|
tls_mode="auto",
|
|
)
|
|
self.assertEqual(normalized, "redis://redis:6379/0")
|
|
|
|
def test_auto_mode_rejects_missing_auth_in_production(self) -> None:
|
|
"""Production auto mode fails closed when Redis URL omits authentication."""
|
|
|
|
with self.assertRaises(ValueError):
|
|
config_module.validate_redis_url_security(
|
|
"rediss://redis:6379/0",
|
|
app_env="production",
|
|
security_mode="auto",
|
|
tls_mode="auto",
|
|
)
|
|
|
|
def test_auto_mode_rejects_plaintext_redis_in_production(self) -> None:
|
|
"""Production auto mode requires TLS transport for Redis URLs."""
|
|
|
|
with self.assertRaises(ValueError):
|
|
config_module.validate_redis_url_security(
|
|
"redis://:password@redis:6379/0",
|
|
app_env="production",
|
|
security_mode="auto",
|
|
tls_mode="auto",
|
|
)
|
|
|
|
def test_strict_mode_enforces_auth_and_tls_outside_production(self) -> None:
|
|
"""Strict mode enforces production-grade Redis controls in all environments."""
|
|
|
|
with self.assertRaises(ValueError):
|
|
config_module.validate_redis_url_security(
|
|
"redis://redis:6379/0",
|
|
app_env="development",
|
|
security_mode="strict",
|
|
tls_mode="auto",
|
|
)
|
|
|
|
normalized = config_module.validate_redis_url_security(
|
|
"rediss://:password@redis:6379/0",
|
|
app_env="development",
|
|
security_mode="strict",
|
|
tls_mode="auto",
|
|
)
|
|
self.assertEqual(normalized, "rediss://:password@redis:6379/0")
|
|
|
|
def test_compat_mode_allows_insecure_redis_in_production_for_safe_migration(self) -> None:
|
|
"""Compatibility mode keeps legacy production Redis URLs usable during migration windows."""
|
|
|
|
normalized = config_module.validate_redis_url_security(
|
|
"redis://redis:6379/0",
|
|
app_env="production",
|
|
security_mode="compat",
|
|
tls_mode="allow_insecure",
|
|
)
|
|
self.assertEqual(normalized, "redis://redis:6379/0")
|
|
|
|
|
|
class PreviewMimeSafetyTests(unittest.TestCase):
|
|
"""Verifies inline preview MIME safety checks for uploaded document responses."""
|
|
|
|
def test_preview_blocks_script_capable_html_and_svg_types(self) -> None:
|
|
"""HTML and SVG MIME types are rejected for inline preview responses."""
|
|
|
|
self.assertFalse(config_module.is_inline_preview_mime_type_safe("text/html"))
|
|
self.assertFalse(config_module.is_inline_preview_mime_type_safe("image/svg+xml"))
|
|
|
|
def test_preview_allows_pdf_and_safe_image_types(self) -> None:
|
|
"""PDF and raster image MIME types stay eligible for inline preview responses."""
|
|
|
|
self.assertTrue(config_module.is_inline_preview_mime_type_safe("application/pdf"))
|
|
self.assertTrue(config_module.is_inline_preview_mime_type_safe("image/png"))
|
|
|
|
|
|
def _build_zip_bytes(entries: dict[str, bytes]) -> bytes:
|
|
"""Builds in-memory ZIP bytes for archive extraction guardrail tests."""
|
|
|
|
output = io.BytesIO()
|
|
with zipfile.ZipFile(output, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
|
|
for filename, payload in entries.items():
|
|
archive.writestr(filename, payload)
|
|
return output.getvalue()
|
|
|
|
|
|
class ArchiveExtractionGuardrailTests(unittest.TestCase):
|
|
"""Verifies depth-aware archive extraction and per-call member cap enforcement."""
|
|
|
|
def test_extract_archive_members_rejects_depth_at_configured_limit(self) -> None:
|
|
"""Archive member extraction is disabled at or beyond configured depth ceiling."""
|
|
|
|
archive_bytes = _build_zip_bytes({"sample.txt": b"sample"})
|
|
patched_settings = SimpleNamespace(
|
|
max_zip_depth=2,
|
|
max_zip_members=250,
|
|
max_zip_member_uncompressed_bytes=25 * 1024 * 1024,
|
|
max_zip_total_uncompressed_bytes=150 * 1024 * 1024,
|
|
max_zip_compression_ratio=120.0,
|
|
)
|
|
with patch.object(extractor_module, "settings", patched_settings):
|
|
members = extractor_module.extract_archive_members(archive_bytes, depth=2)
|
|
self.assertEqual(members, [])
|
|
|
|
def test_extract_archive_members_respects_member_cap_argument(self) -> None:
|
|
"""Archive extraction truncates results when caller-provided member cap is lower than archive size."""
|
|
|
|
archive_bytes = _build_zip_bytes(
|
|
{
|
|
"one.txt": b"1",
|
|
"two.txt": b"2",
|
|
"three.txt": b"3",
|
|
}
|
|
)
|
|
patched_settings = SimpleNamespace(
|
|
max_zip_depth=3,
|
|
max_zip_members=250,
|
|
max_zip_member_uncompressed_bytes=25 * 1024 * 1024,
|
|
max_zip_total_uncompressed_bytes=150 * 1024 * 1024,
|
|
max_zip_compression_ratio=120.0,
|
|
)
|
|
with patch.object(extractor_module, "settings", patched_settings):
|
|
members = extractor_module.extract_archive_members(archive_bytes, depth=0, max_members=1)
|
|
self.assertEqual(len(members), 1)
|
|
|
|
|
|
class ArchiveLineagePropagationTests(unittest.TestCase):
|
|
"""Verifies archive lineage metadata propagation helpers used by worker descendant queueing."""
|
|
|
|
def test_create_archive_member_document_persists_lineage_metadata(self) -> None:
|
|
"""Child archive documents include root id and incremented depth metadata."""
|
|
|
|
parent_id = uuid.uuid4()
|
|
parent = SimpleNamespace(
|
|
id=parent_id,
|
|
source_relative_path="uploads/root.zip",
|
|
logical_path="Inbox",
|
|
tags=["finance"],
|
|
owner_user_id=uuid.uuid4(),
|
|
)
|
|
|
|
with (
|
|
patch.object(worker_tasks_module, "store_bytes", return_value="stored/path/child.zip"),
|
|
patch.object(worker_tasks_module, "compute_sha256", return_value="deadbeef"),
|
|
):
|
|
child = worker_tasks_module._create_archive_member_document(
|
|
parent=parent,
|
|
member_name="nested/child.zip",
|
|
member_data=b"zip-bytes",
|
|
mime_type="application/zip",
|
|
archive_root_document_id=parent_id,
|
|
archive_depth=1,
|
|
)
|
|
|
|
self.assertEqual(child.parent_document_id, parent_id)
|
|
self.assertEqual(child.metadata_json.get(worker_tasks_module.ARCHIVE_ROOT_ID_METADATA_KEY), str(parent_id))
|
|
self.assertEqual(child.metadata_json.get(worker_tasks_module.ARCHIVE_DEPTH_METADATA_KEY), 1)
|
|
self.assertTrue(child.is_archive_member)
|
|
self.assertEqual(child.owner_user_id, parent.owner_user_id)
|
|
|
|
def test_resolve_archive_lineage_prefers_existing_metadata(self) -> None:
|
|
"""Existing archive lineage metadata is reused without traversing parent relationships."""
|
|
|
|
root_id = uuid.uuid4()
|
|
document = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
metadata_json={
|
|
worker_tasks_module.ARCHIVE_ROOT_ID_METADATA_KEY: str(root_id),
|
|
worker_tasks_module.ARCHIVE_DEPTH_METADATA_KEY: 3,
|
|
},
|
|
is_archive_member=True,
|
|
parent_document_id=uuid.uuid4(),
|
|
)
|
|
|
|
class _SessionShouldNotBeUsed:
|
|
"""Fails test if lineage helper performs unnecessary parent traversals."""
|
|
|
|
def execute(self, _statement: object) -> object:
|
|
raise AssertionError("session.execute should not be called when metadata is present")
|
|
|
|
resolved_root, resolved_depth = worker_tasks_module._resolve_archive_lineage(
|
|
session=_SessionShouldNotBeUsed(),
|
|
document=document,
|
|
)
|
|
self.assertEqual(resolved_root, root_id)
|
|
self.assertEqual(resolved_depth, 3)
|
|
|
|
def test_resolve_archive_lineage_walks_parent_chain_when_metadata_missing(self) -> None:
|
|
"""Lineage fallback traverses parent references to recover root id and depth."""
|
|
|
|
root_id = uuid.uuid4()
|
|
parent_id = uuid.uuid4()
|
|
root_document = SimpleNamespace(id=root_id, parent_document_id=None)
|
|
parent_document = SimpleNamespace(id=parent_id, parent_document_id=root_id)
|
|
document = SimpleNamespace(
|
|
id=uuid.uuid4(),
|
|
metadata_json={},
|
|
is_archive_member=True,
|
|
parent_document_id=parent_id,
|
|
)
|
|
|
|
class _ScalarResult:
|
|
"""Wraps scalar ORM results for deterministic worker helper tests."""
|
|
|
|
def __init__(self, value: object) -> None:
|
|
self._value = value
|
|
|
|
def scalar_one_or_none(self) -> object:
|
|
return self._value
|
|
|
|
class _SequenceSession:
|
|
"""Returns predetermined parent rows in traversal order."""
|
|
|
|
def __init__(self, values: list[object]) -> None:
|
|
self._values = values
|
|
|
|
def execute(self, _statement: object) -> _ScalarResult:
|
|
next_value = self._values.pop(0) if self._values else None
|
|
return _ScalarResult(next_value)
|
|
|
|
resolved_root, resolved_depth = worker_tasks_module._resolve_archive_lineage(
|
|
session=_SequenceSession([parent_document, root_document]),
|
|
document=document,
|
|
)
|
|
self.assertEqual(resolved_root, root_id)
|
|
self.assertEqual(resolved_depth, 2)
|
|
|
|
|
|
class ProcessingLogRedactionTests(unittest.TestCase):
|
|
"""Verifies sensitive processing-log values are redacted for persistence and responses."""
|
|
|
|
def test_payload_redacts_sensitive_keys(self) -> None:
|
|
"""Sensitive payload keys are replaced with redaction marker."""
|
|
|
|
sanitized = sanitize_processing_log_payload_value(
|
|
{
|
|
"api_key": "secret-value",
|
|
"nested": {
|
|
"authorization": "Bearer sample-token",
|
|
},
|
|
}
|
|
)
|
|
self.assertEqual(sanitized["api_key"], "[REDACTED]")
|
|
self.assertEqual(sanitized["nested"]["authorization"], "[REDACTED]")
|
|
|
|
def test_text_redaction_removes_bearer_and_jwt_values(self) -> None:
|
|
"""Bearer and JWT token substrings are fully removed from log text."""
|
|
|
|
bearer_token = "super-secret-token-123"
|
|
jwt_token = (
|
|
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
|
|
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4ifQ."
|
|
"signaturevalue123456789"
|
|
)
|
|
sanitized = sanitize_processing_log_text(
|
|
f"Authorization: Bearer {bearer_token}\nraw_jwt={jwt_token}"
|
|
)
|
|
self.assertIsNotNone(sanitized)
|
|
sanitized_text = sanitized or ""
|
|
self.assertIn("[REDACTED]", sanitized_text)
|
|
self.assertNotIn(bearer_token, sanitized_text)
|
|
self.assertNotIn(jwt_token, sanitized_text)
|
|
|
|
def test_text_redaction_removes_json_formatted_secret_values(self) -> None:
|
|
"""JSON-formatted quoted secrets are fully removed from redacted log text."""
|
|
|
|
api_key_secret = "json-api-key-secret"
|
|
token_secret = "json-token-secret"
|
|
authorization_secret = "json-auth-secret"
|
|
bearer_secret = "json-bearer-secret"
|
|
json_text = (
|
|
"{"
|
|
f"\"api_key\":\"{api_key_secret}\","
|
|
f"\"token\":\"{token_secret}\","
|
|
f"\"authorization\":\"Bearer {authorization_secret}\","
|
|
f"\"bearer\":\"{bearer_secret}\""
|
|
"}"
|
|
)
|
|
sanitized = sanitize_processing_log_text(json_text)
|
|
self.assertIsNotNone(sanitized)
|
|
sanitized_text = sanitized or ""
|
|
self.assertIn("[REDACTED]", sanitized_text)
|
|
self.assertNotIn(api_key_secret, sanitized_text)
|
|
self.assertNotIn(token_secret, sanitized_text)
|
|
self.assertNotIn(authorization_secret, sanitized_text)
|
|
self.assertNotIn(bearer_secret, sanitized_text)
|
|
|
|
def test_response_schema_applies_redaction_to_existing_entries(self) -> None:
|
|
"""API schema validators redact sensitive fields from legacy stored rows."""
|
|
|
|
bearer_token = "abc123token"
|
|
jwt_token = (
|
|
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."
|
|
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4ifQ."
|
|
"signaturevalue123456789"
|
|
)
|
|
response = ProcessingLogEntryResponse.model_validate(
|
|
{
|
|
"id": 1,
|
|
"created_at": datetime.now(UTC),
|
|
"level": "info",
|
|
"stage": "summary",
|
|
"event": "response",
|
|
"document_id": None,
|
|
"document_filename": "sample.txt",
|
|
"provider_id": "provider",
|
|
"model_name": "model",
|
|
"prompt_text": f"Authorization: Bearer {bearer_token}",
|
|
"response_text": f"token={jwt_token}",
|
|
"payload_json": {"password": "secret", "trace_id": "trace-1"},
|
|
}
|
|
)
|
|
self.assertEqual(response.payload_json["password"], "[REDACTED]")
|
|
self.assertIn("[REDACTED]", response.prompt_text or "")
|
|
self.assertIn("[REDACTED]", response.response_text or "")
|
|
self.assertNotIn(bearer_token, response.prompt_text or "")
|
|
self.assertNotIn(jwt_token, response.response_text or "")
|
|
|
|
def test_response_schema_redacts_json_formatted_secret_values(self) -> None:
|
|
"""Response schema redacts quoted JSON secret forms from legacy text fields."""
|
|
|
|
api_key_secret = "legacy-json-api-key"
|
|
token_secret = "legacy-json-token"
|
|
authorization_secret = "legacy-json-auth"
|
|
bearer_secret = "legacy-json-bearer"
|
|
prompt_text = (
|
|
"{"
|
|
f"\"api_key\":\"{api_key_secret}\","
|
|
f"\"token\":\"{token_secret}\""
|
|
"}"
|
|
)
|
|
response_text = (
|
|
"{"
|
|
f"\"authorization\":\"Bearer {authorization_secret}\","
|
|
f"\"bearer\":\"{bearer_secret}\""
|
|
"}"
|
|
)
|
|
|
|
response = ProcessingLogEntryResponse.model_validate(
|
|
{
|
|
"id": 2,
|
|
"created_at": datetime.now(UTC),
|
|
"level": "info",
|
|
"stage": "summary",
|
|
"event": "response",
|
|
"document_id": None,
|
|
"document_filename": "sample-json.txt",
|
|
"provider_id": "provider",
|
|
"model_name": "model",
|
|
"prompt_text": prompt_text,
|
|
"response_text": response_text,
|
|
"payload_json": {"trace_id": "trace-2"},
|
|
}
|
|
)
|
|
|
|
self.assertIn("[REDACTED]", response.prompt_text or "")
|
|
self.assertIn("[REDACTED]", response.response_text or "")
|
|
self.assertNotIn(api_key_secret, response.prompt_text or "")
|
|
self.assertNotIn(token_secret, response.prompt_text or "")
|
|
self.assertNotIn(authorization_secret, response.response_text or "")
|
|
self.assertNotIn(bearer_secret, response.response_text or "")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|