"""Regression tests for upload request-size middleware scope and preflight handling.""" from __future__ import annotations import importlib import sys import unittest from pathlib import Path from types import ModuleType, SimpleNamespace from typing import Any, Awaitable, Callable BACKEND_ROOT = Path(__file__).resolve().parents[1] if str(BACKEND_ROOT) not in sys.path: sys.path.insert(0, str(BACKEND_ROOT)) def _install_main_import_stubs() -> dict[str, ModuleType | None]: """Installs lightweight module stubs required for importing app.main in isolation.""" previous_modules: dict[str, ModuleType | None] = { name: sys.modules.get(name) for name in [ "fastapi", "fastapi.middleware", "fastapi.middleware.cors", "fastapi.responses", "app.api.router", "app.core.config", "app.db.base", "app.services.app_settings", "app.services.handwriting_style", "app.services.storage", "app.services.typesense_index", ] } fastapi_stub = ModuleType("fastapi") class _Response: """Minimal response base class for middleware typing compatibility.""" class _FastAPI: """Captures middleware registration behavior used by app.main tests.""" def __init__(self, *_args: object, **_kwargs: object) -> None: self.http_middlewares: list[Any] = [] def add_middleware(self, *_args: object, **_kwargs: object) -> None: """Accepts middleware registrations without side effects.""" def include_router(self, *_args: object, **_kwargs: object) -> None: """Accepts router registration without side effects.""" def middleware( self, middleware_type: str, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Registers request middleware functions for later invocation in tests.""" def decorator(func: Callable[..., Any]) -> Callable[..., Any]: if middleware_type == "http": self.http_middlewares.append(func) return func return decorator def on_event( self, *_args: object, **_kwargs: object, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Returns no-op startup and shutdown decorators.""" def decorator(func: Callable[..., Any]) -> Callable[..., Any]: return func return decorator fastapi_stub.FastAPI = _FastAPI fastapi_stub.Request = object fastapi_stub.Response = _Response sys.modules["fastapi"] = fastapi_stub fastapi_middleware_stub = ModuleType("fastapi.middleware") sys.modules["fastapi.middleware"] = fastapi_middleware_stub fastapi_middleware_cors_stub = ModuleType("fastapi.middleware.cors") class _CORSMiddleware: """Placeholder CORS middleware class accepted by FastAPI.add_middleware.""" fastapi_middleware_cors_stub.CORSMiddleware = _CORSMiddleware sys.modules["fastapi.middleware.cors"] = fastapi_middleware_cors_stub fastapi_responses_stub = ModuleType("fastapi.responses") class _JSONResponse: """Simple JSONResponse stand-in exposing status code and payload fields.""" def __init__(self, *, status_code: int, content: dict[str, Any]) -> None: self.status_code = status_code self.content = content fastapi_responses_stub.JSONResponse = _JSONResponse sys.modules["fastapi.responses"] = fastapi_responses_stub api_router_stub = ModuleType("app.api.router") api_router_stub.api_router = object() sys.modules["app.api.router"] = api_router_stub config_stub = ModuleType("app.core.config") def get_settings() -> SimpleNamespace: """Returns minimal settings consumed by app.main during test import.""" return SimpleNamespace( app_env="development", cors_origins=["http://localhost:5173"], max_upload_request_size_bytes=1024, ) config_stub.get_settings = get_settings sys.modules["app.core.config"] = config_stub db_base_stub = ModuleType("app.db.base") def init_db() -> None: """No-op database initializer for middleware scope tests.""" db_base_stub.init_db = init_db sys.modules["app.db.base"] = db_base_stub app_settings_stub = ModuleType("app.services.app_settings") def ensure_app_settings() -> None: """No-op settings initializer for middleware scope tests.""" app_settings_stub.ensure_app_settings = ensure_app_settings sys.modules["app.services.app_settings"] = app_settings_stub handwriting_style_stub = ModuleType("app.services.handwriting_style") def ensure_handwriting_style_collection() -> None: """No-op handwriting collection initializer for middleware scope tests.""" handwriting_style_stub.ensure_handwriting_style_collection = ensure_handwriting_style_collection sys.modules["app.services.handwriting_style"] = handwriting_style_stub storage_stub = ModuleType("app.services.storage") def ensure_storage() -> None: """No-op storage initializer for middleware scope tests.""" storage_stub.ensure_storage = ensure_storage sys.modules["app.services.storage"] = storage_stub typesense_stub = ModuleType("app.services.typesense_index") def ensure_typesense_collection() -> None: """No-op Typesense collection initializer for middleware scope tests.""" typesense_stub.ensure_typesense_collection = ensure_typesense_collection sys.modules["app.services.typesense_index"] = typesense_stub return previous_modules def _restore_main_import_stubs(previous_modules: dict[str, ModuleType | None]) -> None: """Restores module table entries captured before installing app.main test stubs.""" for module_name, previous in previous_modules.items(): if previous is None: sys.modules.pop(module_name, None) else: sys.modules[module_name] = previous class UploadRequestSizeMiddlewareTests(unittest.IsolatedAsyncioTestCase): """Verifies upload request-size middleware ignores preflight and guards only upload POST.""" @classmethod def setUpClass(cls) -> None: """Installs import stubs and imports app.main once for middleware extraction.""" cls._previous_modules = _install_main_import_stubs() cls.main_module = importlib.import_module("app.main") @classmethod def tearDownClass(cls) -> None: """Removes imported module and restores pre-existing module table entries.""" sys.modules.pop("app.main", None) _restore_main_import_stubs(cls._previous_modules) def _http_middleware( self, ) -> Callable[[object, Callable[[object], Awaitable[object]]], Awaitable[object]]: """Returns the registered HTTP middleware callable from the stubbed FastAPI app.""" return self.main_module.app.http_middlewares[0] async def test_options_preflight_skips_upload_content_length_guard(self) -> None: """OPTIONS preflight requests for upload endpoint continue without Content-Length enforcement.""" request = SimpleNamespace( method="OPTIONS", url=SimpleNamespace(path="/api/v1/documents/upload"), headers={}, ) expected_response = object() call_next_count = 0 async def call_next(_request: object) -> object: nonlocal call_next_count call_next_count += 1 return expected_response response = await self._http_middleware()(request, call_next) self.assertIs(response, expected_response) self.assertEqual(call_next_count, 1) async def test_post_upload_without_content_length_is_rejected(self) -> None: """Upload POST requests remain blocked when Content-Length is absent.""" request = SimpleNamespace( method="POST", url=SimpleNamespace(path="/api/v1/documents/upload"), headers={}, ) call_next_count = 0 async def call_next(_request: object) -> object: nonlocal call_next_count call_next_count += 1 return object() response = await self._http_middleware()(request, call_next) self.assertEqual(response.status_code, 411) self.assertEqual( response.content, {"detail": "Content-Length header is required for document uploads"}, ) self.assertEqual(call_next_count, 0) async def test_post_non_upload_path_skips_upload_content_length_guard(self) -> None: """Content-Length enforcement does not run for non-upload POST requests.""" request = SimpleNamespace( method="POST", url=SimpleNamespace(path="/api/v1/documents"), headers={}, ) expected_response = object() call_next_count = 0 async def call_next(_request: object) -> object: nonlocal call_next_count call_next_count += 1 return expected_response response = await self._http_middleware()(request, call_next) self.assertIs(response, expected_response) self.assertEqual(call_next_count, 1) if __name__ == "__main__": unittest.main()