Fix authenticated media flows and upload preflight handling
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""FastAPI entrypoint for the DMS backend service."""
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -14,6 +16,18 @@ from app.services.typesense_index import ensure_typesense_collection
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
UPLOAD_ENDPOINT_PATH = "/api/v1/documents/upload"
|
||||
UPLOAD_ENDPOINT_METHOD = "POST"
|
||||
|
||||
|
||||
def _is_upload_size_guard_target(request: Request) -> bool:
|
||||
"""Returns whether upload request-size enforcement applies to this request.
|
||||
|
||||
Upload-size validation is intentionally scoped to the upload POST endpoint so CORS
|
||||
preflight OPTIONS requests can pass through CORSMiddleware.
|
||||
"""
|
||||
|
||||
return request.method.upper() == UPLOAD_ENDPOINT_METHOD and request.url.path == UPLOAD_ENDPOINT_PATH
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -30,10 +44,13 @@ def create_app() -> FastAPI:
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
@app.middleware("http")
|
||||
async def enforce_upload_request_size(request: Request, call_next):
|
||||
"""Rejects upload requests without deterministic length or exceeding configured limits."""
|
||||
async def enforce_upload_request_size(
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Rejects only POST upload bodies without deterministic length or with oversized request totals."""
|
||||
|
||||
if request.url.path.endswith("/api/v1/documents/upload"):
|
||||
if _is_upload_size_guard_target(request):
|
||||
content_length = request.headers.get("content-length", "").strip()
|
||||
if not content_length:
|
||||
return JSONResponse(
|
||||
|
||||
270
backend/tests/test_upload_request_size_middleware.py
Normal file
270
backend/tests/test_upload_request_size_middleware.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Regression tests for upload request-size middleware scope and preflight handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
|
||||
|
||||
def _install_main_import_stubs() -> dict[str, ModuleType | None]:
|
||||
"""Installs lightweight module stubs required for importing app.main in isolation."""
|
||||
|
||||
previous_modules: dict[str, ModuleType | None] = {
|
||||
name: sys.modules.get(name)
|
||||
for name in [
|
||||
"fastapi",
|
||||
"fastapi.middleware",
|
||||
"fastapi.middleware.cors",
|
||||
"fastapi.responses",
|
||||
"app.api.router",
|
||||
"app.core.config",
|
||||
"app.db.base",
|
||||
"app.services.app_settings",
|
||||
"app.services.handwriting_style",
|
||||
"app.services.storage",
|
||||
"app.services.typesense_index",
|
||||
]
|
||||
}
|
||||
|
||||
fastapi_stub = ModuleType("fastapi")
|
||||
|
||||
class _Response:
|
||||
"""Minimal response base class for middleware typing compatibility."""
|
||||
|
||||
class _FastAPI:
|
||||
"""Captures middleware registration behavior used by app.main tests."""
|
||||
|
||||
def __init__(self, *_args: object, **_kwargs: object) -> None:
|
||||
self.http_middlewares: list[Any] = []
|
||||
|
||||
def add_middleware(self, *_args: object, **_kwargs: object) -> None:
|
||||
"""Accepts middleware registrations without side effects."""
|
||||
|
||||
def include_router(self, *_args: object, **_kwargs: object) -> None:
|
||||
"""Accepts router registration without side effects."""
|
||||
|
||||
def middleware(
|
||||
self,
|
||||
middleware_type: str,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Registers request middleware functions for later invocation in tests."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
if middleware_type == "http":
|
||||
self.http_middlewares.append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def on_event(
|
||||
self,
|
||||
*_args: object,
|
||||
**_kwargs: object,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""Returns no-op startup and shutdown decorators."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
fastapi_stub.FastAPI = _FastAPI
|
||||
fastapi_stub.Request = object
|
||||
fastapi_stub.Response = _Response
|
||||
sys.modules["fastapi"] = fastapi_stub
|
||||
|
||||
fastapi_middleware_stub = ModuleType("fastapi.middleware")
|
||||
sys.modules["fastapi.middleware"] = fastapi_middleware_stub
|
||||
|
||||
fastapi_middleware_cors_stub = ModuleType("fastapi.middleware.cors")
|
||||
|
||||
class _CORSMiddleware:
|
||||
"""Placeholder CORS middleware class accepted by FastAPI.add_middleware."""
|
||||
|
||||
fastapi_middleware_cors_stub.CORSMiddleware = _CORSMiddleware
|
||||
sys.modules["fastapi.middleware.cors"] = fastapi_middleware_cors_stub
|
||||
|
||||
fastapi_responses_stub = ModuleType("fastapi.responses")
|
||||
|
||||
class _JSONResponse:
|
||||
"""Simple JSONResponse stand-in exposing status code and payload fields."""
|
||||
|
||||
def __init__(self, *, status_code: int, content: dict[str, Any]) -> None:
|
||||
self.status_code = status_code
|
||||
self.content = content
|
||||
|
||||
fastapi_responses_stub.JSONResponse = _JSONResponse
|
||||
sys.modules["fastapi.responses"] = fastapi_responses_stub
|
||||
|
||||
api_router_stub = ModuleType("app.api.router")
|
||||
api_router_stub.api_router = object()
|
||||
sys.modules["app.api.router"] = api_router_stub
|
||||
|
||||
config_stub = ModuleType("app.core.config")
|
||||
|
||||
def get_settings() -> SimpleNamespace:
|
||||
"""Returns minimal settings consumed by app.main during test import."""
|
||||
|
||||
return SimpleNamespace(
|
||||
cors_origins=["http://localhost:5173"],
|
||||
max_upload_request_size_bytes=1024,
|
||||
)
|
||||
|
||||
config_stub.get_settings = get_settings
|
||||
sys.modules["app.core.config"] = config_stub
|
||||
|
||||
db_base_stub = ModuleType("app.db.base")
|
||||
|
||||
def init_db() -> None:
|
||||
"""No-op database initializer for middleware scope tests."""
|
||||
|
||||
db_base_stub.init_db = init_db
|
||||
sys.modules["app.db.base"] = db_base_stub
|
||||
|
||||
app_settings_stub = ModuleType("app.services.app_settings")
|
||||
|
||||
def ensure_app_settings() -> None:
|
||||
"""No-op settings initializer for middleware scope tests."""
|
||||
|
||||
app_settings_stub.ensure_app_settings = ensure_app_settings
|
||||
sys.modules["app.services.app_settings"] = app_settings_stub
|
||||
|
||||
handwriting_style_stub = ModuleType("app.services.handwriting_style")
|
||||
|
||||
def ensure_handwriting_style_collection() -> None:
|
||||
"""No-op handwriting collection initializer for middleware scope tests."""
|
||||
|
||||
handwriting_style_stub.ensure_handwriting_style_collection = ensure_handwriting_style_collection
|
||||
sys.modules["app.services.handwriting_style"] = handwriting_style_stub
|
||||
|
||||
storage_stub = ModuleType("app.services.storage")
|
||||
|
||||
def ensure_storage() -> None:
|
||||
"""No-op storage initializer for middleware scope tests."""
|
||||
|
||||
storage_stub.ensure_storage = ensure_storage
|
||||
sys.modules["app.services.storage"] = storage_stub
|
||||
|
||||
typesense_stub = ModuleType("app.services.typesense_index")
|
||||
|
||||
def ensure_typesense_collection() -> None:
|
||||
"""No-op Typesense collection initializer for middleware scope tests."""
|
||||
|
||||
typesense_stub.ensure_typesense_collection = ensure_typesense_collection
|
||||
sys.modules["app.services.typesense_index"] = typesense_stub
|
||||
|
||||
return previous_modules
|
||||
|
||||
|
||||
def _restore_main_import_stubs(previous_modules: dict[str, ModuleType | None]) -> None:
|
||||
"""Restores module table entries captured before installing app.main test stubs."""
|
||||
|
||||
for module_name, previous in previous_modules.items():
|
||||
if previous is None:
|
||||
sys.modules.pop(module_name, None)
|
||||
else:
|
||||
sys.modules[module_name] = previous
|
||||
|
||||
|
||||
class UploadRequestSizeMiddlewareTests(unittest.IsolatedAsyncioTestCase):
|
||||
"""Verifies upload request-size middleware ignores preflight and guards only upload POST."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
"""Installs import stubs and imports app.main once for middleware extraction."""
|
||||
|
||||
cls._previous_modules = _install_main_import_stubs()
|
||||
cls.main_module = importlib.import_module("app.main")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
"""Removes imported module and restores pre-existing module table entries."""
|
||||
|
||||
sys.modules.pop("app.main", None)
|
||||
_restore_main_import_stubs(cls._previous_modules)
|
||||
|
||||
def _http_middleware(
|
||||
self,
|
||||
) -> Callable[[object, Callable[[object], Awaitable[object]]], Awaitable[object]]:
|
||||
"""Returns the registered HTTP middleware callable from the stubbed FastAPI app."""
|
||||
|
||||
return self.main_module.app.http_middlewares[0]
|
||||
|
||||
async def test_options_preflight_skips_upload_content_length_guard(self) -> None:
|
||||
"""OPTIONS preflight requests for upload endpoint continue without Content-Length enforcement."""
|
||||
|
||||
request = SimpleNamespace(
|
||||
method="OPTIONS",
|
||||
url=SimpleNamespace(path="/api/v1/documents/upload"),
|
||||
headers={},
|
||||
)
|
||||
expected_response = object()
|
||||
call_next_count = 0
|
||||
|
||||
async def call_next(_request: object) -> object:
|
||||
nonlocal call_next_count
|
||||
call_next_count += 1
|
||||
return expected_response
|
||||
|
||||
response = await self._http_middleware()(request, call_next)
|
||||
|
||||
self.assertIs(response, expected_response)
|
||||
self.assertEqual(call_next_count, 1)
|
||||
|
||||
async def test_post_upload_without_content_length_is_rejected(self) -> None:
|
||||
"""Upload POST requests remain blocked when Content-Length is absent."""
|
||||
|
||||
request = SimpleNamespace(
|
||||
method="POST",
|
||||
url=SimpleNamespace(path="/api/v1/documents/upload"),
|
||||
headers={},
|
||||
)
|
||||
call_next_count = 0
|
||||
|
||||
async def call_next(_request: object) -> object:
|
||||
nonlocal call_next_count
|
||||
call_next_count += 1
|
||||
return object()
|
||||
|
||||
response = await self._http_middleware()(request, call_next)
|
||||
|
||||
self.assertEqual(response.status_code, 411)
|
||||
self.assertEqual(
|
||||
response.content,
|
||||
{"detail": "Content-Length header is required for document uploads"},
|
||||
)
|
||||
self.assertEqual(call_next_count, 0)
|
||||
|
||||
async def test_post_non_upload_path_skips_upload_content_length_guard(self) -> None:
|
||||
"""Content-Length enforcement does not run for non-upload POST requests."""
|
||||
|
||||
request = SimpleNamespace(
|
||||
method="POST",
|
||||
url=SimpleNamespace(path="/api/v1/documents"),
|
||||
headers={},
|
||||
)
|
||||
expected_response = object()
|
||||
call_next_count = 0
|
||||
|
||||
async def call_next(_request: object) -> object:
|
||||
nonlocal call_next_count
|
||||
call_next_count += 1
|
||||
return expected_response
|
||||
|
||||
response = await self._http_middleware()(request, call_next)
|
||||
|
||||
self.assertIs(response, expected_response)
|
||||
self.assertEqual(call_next_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user