Harden auth login against brute-force and refresh security docs
This commit is contained in:
@@ -40,6 +40,32 @@ if "pydantic_settings" not in sys.modules:
|
||||
if "fastapi" not in sys.modules:
|
||||
fastapi_stub = ModuleType("fastapi")
|
||||
|
||||
class _APIRouter:
|
||||
"""Minimal APIRouter stand-in supporting decorator registration."""
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def post(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
|
||||
"""Returns no-op decorator for POST route declarations."""
|
||||
|
||||
def decorator(func): # type: ignore[no-untyped-def]
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get(self, *_args: object, **_kwargs: object): # type: ignore[no-untyped-def]
|
||||
"""Returns no-op decorator for GET route declarations."""
|
||||
|
||||
def decorator(func): # type: ignore[no-untyped-def]
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
class _Request:
|
||||
"""Minimal request placeholder for route function import compatibility."""
|
||||
|
||||
class _HTTPException(Exception):
|
||||
"""Minimal HTTPException compatible with route dependency tests."""
|
||||
|
||||
@@ -54,6 +80,7 @@ if "fastapi" not in sys.modules:
|
||||
|
||||
HTTP_401_UNAUTHORIZED = 401
|
||||
HTTP_403_FORBIDDEN = 403
|
||||
HTTP_429_TOO_MANY_REQUESTS = 429
|
||||
HTTP_503_SERVICE_UNAVAILABLE = 503
|
||||
|
||||
def _depends(dependency): # type: ignore[no-untyped-def]
|
||||
@@ -61,8 +88,10 @@ if "fastapi" not in sys.modules:
|
||||
|
||||
return dependency
|
||||
|
||||
fastapi_stub.APIRouter = _APIRouter
|
||||
fastapi_stub.Depends = _depends
|
||||
fastapi_stub.HTTPException = _HTTPException
|
||||
fastapi_stub.Request = _Request
|
||||
fastapi_stub.status = _Status()
|
||||
sys.modules["fastapi"] = fastapi_stub
|
||||
|
||||
@@ -274,10 +303,12 @@ if "app.services.routing_pipeline" not in sys.modules:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.api.auth import AuthContext, require_admin
|
||||
from app.api import routes_auth as auth_routes_module
|
||||
from app.core import config as config_module
|
||||
from app.models.auth import UserRole
|
||||
from app.models.processing_log import sanitize_processing_log_payload_value, sanitize_processing_log_text
|
||||
from app.schemas.processing_logs import ProcessingLogEntryResponse
|
||||
from app.services import auth_login_throttle as auth_login_throttle_module
|
||||
from app.services import extractor as extractor_module
|
||||
from app.worker import tasks as worker_tasks_module
|
||||
|
||||
@@ -328,6 +359,366 @@ class AuthDependencyTests(unittest.TestCase):
|
||||
self.assertEqual(resolved.role, UserRole.ADMIN)
|
||||
|
||||
|
||||
class _FakeRedisPipeline:
|
||||
"""Provides deterministic Redis pipeline behavior for login throttle tests."""
|
||||
|
||||
def __init__(self, redis_client: "_FakeRedis") -> None:
|
||||
self._redis_client = redis_client
|
||||
self._operations: list[tuple[str, tuple[object, ...]]] = []
|
||||
|
||||
def incr(self, key: str, amount: int) -> "_FakeRedisPipeline":
|
||||
"""Queues one counter increment operation for pipeline execution."""
|
||||
|
||||
self._operations.append(("incr", (key, amount)))
|
||||
return self
|
||||
|
||||
def expire(self, key: str, ttl_seconds: int) -> "_FakeRedisPipeline":
|
||||
"""Queues one key expiration operation for pipeline execution."""
|
||||
|
||||
self._operations.append(("expire", (key, ttl_seconds)))
|
||||
return self
|
||||
|
||||
def execute(self) -> list[object]:
|
||||
"""Executes queued operations in order and returns Redis-like result values."""
|
||||
|
||||
results: list[object] = []
|
||||
for operation, arguments in self._operations:
|
||||
if operation == "incr":
|
||||
key, amount = arguments
|
||||
previous = int(self._redis_client.values.get(str(key), 0))
|
||||
updated = previous + int(amount)
|
||||
self._redis_client.values[str(key)] = updated
|
||||
results.append(updated)
|
||||
elif operation == "expire":
|
||||
key, ttl_seconds = arguments
|
||||
self._redis_client.ttl_seconds[str(key)] = int(ttl_seconds)
|
||||
results.append(True)
|
||||
return results
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
"""In-memory Redis replacement with TTL behavior needed by throttle tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.values: dict[str, object] = {}
|
||||
self.ttl_seconds: dict[str, int] = {}
|
||||
|
||||
def pipeline(self, transaction: bool = True) -> _FakeRedisPipeline:
|
||||
"""Creates a fake transaction pipeline for grouped increment operations."""
|
||||
|
||||
_ = transaction
|
||||
return _FakeRedisPipeline(self)
|
||||
|
||||
def set(self, key: str, value: str, ex: int | None = None) -> bool:
|
||||
"""Stores key values and optional TTL metadata used by lockout keys."""
|
||||
|
||||
self.values[key] = value
|
||||
if ex is not None:
|
||||
self.ttl_seconds[key] = int(ex)
|
||||
return True
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Returns TTL for existing keys or Redis-compatible missing-key indicator."""
|
||||
|
||||
return int(self.ttl_seconds.get(key, -2))
|
||||
|
||||
def delete(self, *keys: str) -> int:
|
||||
"""Deletes keys and returns number of removed entries."""
|
||||
|
||||
removed_count = 0
|
||||
for key in keys:
|
||||
if key in self.values:
|
||||
self.values.pop(key, None)
|
||||
removed_count += 1
|
||||
self.ttl_seconds.pop(key, None)
|
||||
return removed_count
|
||||
|
||||
|
||||
def _login_throttle_settings(
|
||||
*,
|
||||
failure_limit: int = 2,
|
||||
failure_window_seconds: int = 60,
|
||||
lockout_base_seconds: int = 10,
|
||||
lockout_max_seconds: int = 40,
|
||||
) -> SimpleNamespace:
|
||||
"""Builds deterministic login-throttle settings for service-level unit coverage."""
|
||||
|
||||
return SimpleNamespace(
|
||||
auth_login_failure_limit=failure_limit,
|
||||
auth_login_failure_window_seconds=failure_window_seconds,
|
||||
auth_login_lockout_base_seconds=lockout_base_seconds,
|
||||
auth_login_lockout_max_seconds=lockout_max_seconds,
|
||||
)
|
||||
|
||||
|
||||
class AuthLoginThrottleServiceTests(unittest.TestCase):
|
||||
"""Verifies login throttle lockout progression, cap behavior, and clear semantics."""
|
||||
|
||||
def test_failed_attempts_trigger_lockout_after_limit(self) -> None:
|
||||
"""Failed attempts beyond configured limit activate login lockouts."""
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
with (
|
||||
patch.object(
|
||||
auth_login_throttle_module,
|
||||
"get_settings",
|
||||
return_value=_login_throttle_settings(failure_limit=2, lockout_base_seconds=12),
|
||||
),
|
||||
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
|
||||
):
|
||||
self.assertEqual(
|
||||
auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="203.0.113.10",
|
||||
),
|
||||
0,
|
||||
)
|
||||
self.assertEqual(
|
||||
auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="203.0.113.10",
|
||||
),
|
||||
0,
|
||||
)
|
||||
lockout_seconds = auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="203.0.113.10",
|
||||
)
|
||||
status = auth_login_throttle_module.check_login_throttle(
|
||||
username="admin",
|
||||
ip_address="203.0.113.10",
|
||||
)
|
||||
|
||||
self.assertEqual(lockout_seconds, 12)
|
||||
self.assertTrue(status.is_throttled)
|
||||
self.assertEqual(status.retry_after_seconds, 12)
|
||||
|
||||
def test_lockout_duration_escalates_and_respects_max_cap(self) -> None:
|
||||
"""Repeated failures after threshold double lockout duration up to configured maximum."""
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
with (
|
||||
patch.object(
|
||||
auth_login_throttle_module,
|
||||
"get_settings",
|
||||
return_value=_login_throttle_settings(
|
||||
failure_limit=1,
|
||||
lockout_base_seconds=10,
|
||||
lockout_max_seconds=25,
|
||||
),
|
||||
),
|
||||
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
|
||||
):
|
||||
first_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="198.51.100.15",
|
||||
)
|
||||
second_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="198.51.100.15",
|
||||
)
|
||||
third_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="198.51.100.15",
|
||||
)
|
||||
fourth_lockout = auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="198.51.100.15",
|
||||
)
|
||||
|
||||
self.assertEqual(first_lockout, 0)
|
||||
self.assertEqual(second_lockout, 10)
|
||||
self.assertEqual(third_lockout, 20)
|
||||
self.assertEqual(fourth_lockout, 25)
|
||||
|
||||
def test_clear_login_throttle_removes_active_lockout_state(self) -> None:
|
||||
"""Successful login clears active lockout keys for username and IP subjects."""
|
||||
|
||||
fake_redis = _FakeRedis()
|
||||
with (
|
||||
patch.object(
|
||||
auth_login_throttle_module,
|
||||
"get_settings",
|
||||
return_value=_login_throttle_settings(
|
||||
failure_limit=1,
|
||||
lockout_base_seconds=15,
|
||||
lockout_max_seconds=30,
|
||||
),
|
||||
),
|
||||
patch.object(auth_login_throttle_module, "get_redis", return_value=fake_redis),
|
||||
):
|
||||
auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="192.0.2.20",
|
||||
)
|
||||
auth_login_throttle_module.record_failed_login_attempt(
|
||||
username="admin",
|
||||
ip_address="192.0.2.20",
|
||||
)
|
||||
throttled_before_clear = auth_login_throttle_module.check_login_throttle(
|
||||
username="admin",
|
||||
ip_address="192.0.2.20",
|
||||
)
|
||||
auth_login_throttle_module.clear_login_throttle(
|
||||
username="admin",
|
||||
ip_address="192.0.2.20",
|
||||
)
|
||||
throttled_after_clear = auth_login_throttle_module.check_login_throttle(
|
||||
username="admin",
|
||||
ip_address="192.0.2.20",
|
||||
)
|
||||
|
||||
self.assertTrue(throttled_before_clear.is_throttled)
|
||||
self.assertFalse(throttled_after_clear.is_throttled)
|
||||
self.assertEqual(throttled_after_clear.retry_after_seconds, 0)
|
||||
|
||||
def test_backend_errors_raise_runtime_error(self) -> None:
|
||||
"""Redis backend failures are surfaced as RuntimeError for caller fail-closed handling."""
|
||||
|
||||
class _BrokenRedis:
|
||||
"""Raises RedisError for all Redis interactions used by login throttle service."""
|
||||
|
||||
def ttl(self, _key: str) -> int:
|
||||
raise auth_login_throttle_module.RedisError("redis unavailable")
|
||||
|
||||
with patch.object(auth_login_throttle_module, "get_redis", return_value=_BrokenRedis()):
|
||||
with self.assertRaises(RuntimeError):
|
||||
auth_login_throttle_module.check_login_throttle(
|
||||
username="admin",
|
||||
ip_address="203.0.113.88",
|
||||
)
|
||||
|
||||
|
||||
class AuthLoginRouteThrottleTests(unittest.TestCase):
|
||||
"""Verifies `/auth/login` route throttle responses and success-flow state clearing."""
|
||||
|
||||
class _SessionStub:
|
||||
"""Tracks commit calls for route-level login tests without database dependencies."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.commit_count = 0
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Records one commit invocation."""
|
||||
|
||||
self.commit_count += 1
|
||||
|
||||
@staticmethod
|
||||
def _request_stub(ip_address: str = "203.0.113.2", user_agent: str = "unit-test") -> SimpleNamespace:
|
||||
"""Builds request-like object containing client host and user-agent header fields."""
|
||||
|
||||
return SimpleNamespace(
|
||||
client=SimpleNamespace(host=ip_address),
|
||||
headers={"user-agent": user_agent},
|
||||
)
|
||||
|
||||
def test_login_rejects_when_precheck_reports_active_throttle(self) -> None:
|
||||
"""Pre-auth throttle checks return a stable 429 response without credential lookup."""
|
||||
|
||||
payload = auth_routes_module.AuthLoginRequest(username="admin", password="bad-password")
|
||||
session = self._SessionStub()
|
||||
throttled = auth_login_throttle_module.LoginThrottleStatus(
|
||||
is_throttled=True,
|
||||
retry_after_seconds=21,
|
||||
)
|
||||
with (
|
||||
patch.object(auth_routes_module, "check_login_throttle", return_value=throttled),
|
||||
patch.object(auth_routes_module, "authenticate_user") as authenticate_mock,
|
||||
):
|
||||
with self.assertRaises(HTTPException) as raised:
|
||||
auth_routes_module.login(
|
||||
payload=payload,
|
||||
request=self._request_stub(),
|
||||
session=session,
|
||||
)
|
||||
self.assertEqual(raised.exception.status_code, 429)
|
||||
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_THROTTLED_DETAIL)
|
||||
self.assertEqual(raised.exception.headers.get("Retry-After"), "21")
|
||||
authenticate_mock.assert_not_called()
|
||||
self.assertEqual(session.commit_count, 0)
|
||||
|
||||
def test_login_returns_throttle_response_when_failure_crosses_limit(self) -> None:
|
||||
"""Failed credentials return stable 429 response once lockout threshold is crossed."""
|
||||
|
||||
payload = auth_routes_module.AuthLoginRequest(username="admin", password="bad-password")
|
||||
session = self._SessionStub()
|
||||
with (
|
||||
patch.object(
|
||||
auth_routes_module,
|
||||
"check_login_throttle",
|
||||
return_value=auth_login_throttle_module.LoginThrottleStatus(
|
||||
is_throttled=False,
|
||||
retry_after_seconds=0,
|
||||
),
|
||||
),
|
||||
patch.object(auth_routes_module, "authenticate_user", return_value=None),
|
||||
patch.object(auth_routes_module, "record_failed_login_attempt", return_value=30),
|
||||
):
|
||||
with self.assertRaises(HTTPException) as raised:
|
||||
auth_routes_module.login(
|
||||
payload=payload,
|
||||
request=self._request_stub(),
|
||||
session=session,
|
||||
)
|
||||
self.assertEqual(raised.exception.status_code, 429)
|
||||
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_THROTTLED_DETAIL)
|
||||
self.assertEqual(raised.exception.headers.get("Retry-After"), "30")
|
||||
self.assertEqual(session.commit_count, 0)
|
||||
|
||||
def test_login_clears_throttle_state_after_successful_authentication(self) -> None:
|
||||
"""Successful login clears throttle state and commits issued session token."""
|
||||
|
||||
payload = auth_routes_module.AuthLoginRequest(username="admin", password="correct-password")
|
||||
session = self._SessionStub()
|
||||
fake_user = SimpleNamespace(
|
||||
id=uuid.uuid4(),
|
||||
username="admin",
|
||||
role=UserRole.ADMIN,
|
||||
)
|
||||
fake_session = SimpleNamespace(
|
||||
token="session-token",
|
||||
expires_at=datetime.now(UTC),
|
||||
)
|
||||
with (
|
||||
patch.object(
|
||||
auth_routes_module,
|
||||
"check_login_throttle",
|
||||
return_value=auth_login_throttle_module.LoginThrottleStatus(
|
||||
is_throttled=False,
|
||||
retry_after_seconds=0,
|
||||
),
|
||||
),
|
||||
patch.object(auth_routes_module, "authenticate_user", return_value=fake_user),
|
||||
patch.object(auth_routes_module, "clear_login_throttle") as clear_mock,
|
||||
patch.object(auth_routes_module, "issue_user_session", return_value=fake_session),
|
||||
):
|
||||
response = auth_routes_module.login(
|
||||
payload=payload,
|
||||
request=self._request_stub(),
|
||||
session=session,
|
||||
)
|
||||
self.assertEqual(response.access_token, "session-token")
|
||||
self.assertEqual(response.user.username, "admin")
|
||||
clear_mock.assert_called_once()
|
||||
self.assertEqual(session.commit_count, 1)
|
||||
|
||||
def test_login_returns_503_when_throttle_backend_is_unavailable(self) -> None:
|
||||
"""Throttle backend errors fail closed with a deterministic 503 login response."""
|
||||
|
||||
payload = auth_routes_module.AuthLoginRequest(username="admin", password="password")
|
||||
session = self._SessionStub()
|
||||
with patch.object(auth_routes_module, "check_login_throttle", side_effect=RuntimeError("redis down")):
|
||||
with self.assertRaises(HTTPException) as raised:
|
||||
auth_routes_module.login(
|
||||
payload=payload,
|
||||
request=self._request_stub(),
|
||||
session=session,
|
||||
)
|
||||
self.assertEqual(raised.exception.status_code, 503)
|
||||
self.assertEqual(raised.exception.detail, auth_routes_module.LOGIN_RATE_LIMITER_UNAVAILABLE_DETAIL)
|
||||
self.assertEqual(session.commit_count, 0)
|
||||
|
||||
|
||||
class ProviderBaseUrlValidationTests(unittest.TestCase):
|
||||
"""Verifies allowlist, scheme, and private-network SSRF protections."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user