Fix authenticated media flows and upload preflight handling

This commit is contained in:
2026-02-21 15:53:02 -03:00
parent 1cb6bfee58
commit c3f34b38b4
12 changed files with 619 additions and 35 deletions

View File

@@ -1,6 +1,8 @@
"""FastAPI entrypoint for the DMS backend service."""
from fastapi import FastAPI, Request
from typing import Awaitable, Callable
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@@ -14,6 +16,18 @@ from app.services.typesense_index import ensure_typesense_collection
settings = get_settings()
UPLOAD_ENDPOINT_PATH = "/api/v1/documents/upload"
UPLOAD_ENDPOINT_METHOD = "POST"
def _is_upload_size_guard_target(request: Request) -> bool:
"""Returns whether upload request-size enforcement applies to this request.
Upload-size validation is intentionally scoped to the upload POST endpoint so CORS
preflight OPTIONS requests can pass through CORSMiddleware.
"""
return request.method.upper() == UPLOAD_ENDPOINT_METHOD and request.url.path == UPLOAD_ENDPOINT_PATH
def create_app() -> FastAPI:
@@ -30,10 +44,13 @@ def create_app() -> FastAPI:
app.include_router(api_router, prefix="/api/v1")
@app.middleware("http")
async def enforce_upload_request_size(request: Request, call_next):
"""Rejects upload requests without deterministic length or exceeding configured limits."""
async def enforce_upload_request_size(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Rejects only POST upload bodies without deterministic length or with oversized request totals."""
if request.url.path.endswith("/api/v1/documents/upload"):
if _is_upload_size_guard_target(request):
content_length = request.headers.get("content-length", "").strip()
if not content_length:
return JSONResponse(

View File

@@ -0,0 +1,270 @@
"""Regression tests for upload request-size middleware scope and preflight handling."""
from __future__ import annotations
import importlib
import sys
import unittest
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any, Awaitable, Callable
BACKEND_ROOT = Path(__file__).resolve().parents[1]
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
def _install_main_import_stubs() -> dict[str, ModuleType | None]:
"""Installs lightweight module stubs required for importing app.main in isolation."""
previous_modules: dict[str, ModuleType | None] = {
name: sys.modules.get(name)
for name in [
"fastapi",
"fastapi.middleware",
"fastapi.middleware.cors",
"fastapi.responses",
"app.api.router",
"app.core.config",
"app.db.base",
"app.services.app_settings",
"app.services.handwriting_style",
"app.services.storage",
"app.services.typesense_index",
]
}
fastapi_stub = ModuleType("fastapi")
class _Response:
"""Minimal response base class for middleware typing compatibility."""
class _FastAPI:
"""Captures middleware registration behavior used by app.main tests."""
def __init__(self, *_args: object, **_kwargs: object) -> None:
self.http_middlewares: list[Any] = []
def add_middleware(self, *_args: object, **_kwargs: object) -> None:
"""Accepts middleware registrations without side effects."""
def include_router(self, *_args: object, **_kwargs: object) -> None:
"""Accepts router registration without side effects."""
def middleware(
self,
middleware_type: str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Registers request middleware functions for later invocation in tests."""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
if middleware_type == "http":
self.http_middlewares.append(func)
return func
return decorator
def on_event(
self,
*_args: object,
**_kwargs: object,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Returns no-op startup and shutdown decorators."""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
return func
return decorator
fastapi_stub.FastAPI = _FastAPI
fastapi_stub.Request = object
fastapi_stub.Response = _Response
sys.modules["fastapi"] = fastapi_stub
fastapi_middleware_stub = ModuleType("fastapi.middleware")
sys.modules["fastapi.middleware"] = fastapi_middleware_stub
fastapi_middleware_cors_stub = ModuleType("fastapi.middleware.cors")
class _CORSMiddleware:
"""Placeholder CORS middleware class accepted by FastAPI.add_middleware."""
fastapi_middleware_cors_stub.CORSMiddleware = _CORSMiddleware
sys.modules["fastapi.middleware.cors"] = fastapi_middleware_cors_stub
fastapi_responses_stub = ModuleType("fastapi.responses")
class _JSONResponse:
"""Simple JSONResponse stand-in exposing status code and payload fields."""
def __init__(self, *, status_code: int, content: dict[str, Any]) -> None:
self.status_code = status_code
self.content = content
fastapi_responses_stub.JSONResponse = _JSONResponse
sys.modules["fastapi.responses"] = fastapi_responses_stub
api_router_stub = ModuleType("app.api.router")
api_router_stub.api_router = object()
sys.modules["app.api.router"] = api_router_stub
config_stub = ModuleType("app.core.config")
def get_settings() -> SimpleNamespace:
"""Returns minimal settings consumed by app.main during test import."""
return SimpleNamespace(
cors_origins=["http://localhost:5173"],
max_upload_request_size_bytes=1024,
)
config_stub.get_settings = get_settings
sys.modules["app.core.config"] = config_stub
db_base_stub = ModuleType("app.db.base")
def init_db() -> None:
"""No-op database initializer for middleware scope tests."""
db_base_stub.init_db = init_db
sys.modules["app.db.base"] = db_base_stub
app_settings_stub = ModuleType("app.services.app_settings")
def ensure_app_settings() -> None:
"""No-op settings initializer for middleware scope tests."""
app_settings_stub.ensure_app_settings = ensure_app_settings
sys.modules["app.services.app_settings"] = app_settings_stub
handwriting_style_stub = ModuleType("app.services.handwriting_style")
def ensure_handwriting_style_collection() -> None:
"""No-op handwriting collection initializer for middleware scope tests."""
handwriting_style_stub.ensure_handwriting_style_collection = ensure_handwriting_style_collection
sys.modules["app.services.handwriting_style"] = handwriting_style_stub
storage_stub = ModuleType("app.services.storage")
def ensure_storage() -> None:
"""No-op storage initializer for middleware scope tests."""
storage_stub.ensure_storage = ensure_storage
sys.modules["app.services.storage"] = storage_stub
typesense_stub = ModuleType("app.services.typesense_index")
def ensure_typesense_collection() -> None:
"""No-op Typesense collection initializer for middleware scope tests."""
typesense_stub.ensure_typesense_collection = ensure_typesense_collection
sys.modules["app.services.typesense_index"] = typesense_stub
return previous_modules
def _restore_main_import_stubs(previous_modules: dict[str, ModuleType | None]) -> None:
"""Restores module table entries captured before installing app.main test stubs."""
for module_name, previous in previous_modules.items():
if previous is None:
sys.modules.pop(module_name, None)
else:
sys.modules[module_name] = previous
class UploadRequestSizeMiddlewareTests(unittest.IsolatedAsyncioTestCase):
"""Verifies upload request-size middleware ignores preflight and guards only upload POST."""
@classmethod
def setUpClass(cls) -> None:
"""Installs import stubs and imports app.main once for middleware extraction."""
cls._previous_modules = _install_main_import_stubs()
cls.main_module = importlib.import_module("app.main")
@classmethod
def tearDownClass(cls) -> None:
"""Removes imported module and restores pre-existing module table entries."""
sys.modules.pop("app.main", None)
_restore_main_import_stubs(cls._previous_modules)
def _http_middleware(
self,
) -> Callable[[object, Callable[[object], Awaitable[object]]], Awaitable[object]]:
"""Returns the registered HTTP middleware callable from the stubbed FastAPI app."""
return self.main_module.app.http_middlewares[0]
async def test_options_preflight_skips_upload_content_length_guard(self) -> None:
"""OPTIONS preflight requests for upload endpoint continue without Content-Length enforcement."""
request = SimpleNamespace(
method="OPTIONS",
url=SimpleNamespace(path="/api/v1/documents/upload"),
headers={},
)
expected_response = object()
call_next_count = 0
async def call_next(_request: object) -> object:
nonlocal call_next_count
call_next_count += 1
return expected_response
response = await self._http_middleware()(request, call_next)
self.assertIs(response, expected_response)
self.assertEqual(call_next_count, 1)
async def test_post_upload_without_content_length_is_rejected(self) -> None:
"""Upload POST requests remain blocked when Content-Length is absent."""
request = SimpleNamespace(
method="POST",
url=SimpleNamespace(path="/api/v1/documents/upload"),
headers={},
)
call_next_count = 0
async def call_next(_request: object) -> object:
nonlocal call_next_count
call_next_count += 1
return object()
response = await self._http_middleware()(request, call_next)
self.assertEqual(response.status_code, 411)
self.assertEqual(
response.content,
{"detail": "Content-Length header is required for document uploads"},
)
self.assertEqual(call_next_count, 0)
async def test_post_non_upload_path_skips_upload_content_length_guard(self) -> None:
"""Content-Length enforcement does not run for non-upload POST requests."""
request = SimpleNamespace(
method="POST",
url=SimpleNamespace(path="/api/v1/documents"),
headers={},
)
expected_response = object()
call_next_count = 0
async def call_next(_request: object) -> object:
nonlocal call_next_count
call_next_count += 1
return expected_response
response = await self._http_middleware()(request, call_next)
self.assertIs(response, expected_response)
self.assertEqual(call_next_count, 1)
if __name__ == "__main__":
unittest.main()