Harden auth and security controls with session auth and docs
This commit is contained in:
@@ -11,6 +11,14 @@ import secrets
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
except Exception: # pragma: no cover - dependency failures are surfaced at runtime usage.
|
||||
Fernet = None # type: ignore[assignment]
|
||||
|
||||
class InvalidToken(Exception):
|
||||
"""Fallback InvalidToken type used when cryptography dependency import fails."""
|
||||
|
||||
from app.core.config import get_settings, normalize_and_validate_provider_base_url
|
||||
|
||||
|
||||
@@ -63,12 +71,13 @@ DEFAULT_ROUTING_PROMPT = (
|
||||
"Confidence must be between 0 and 1."
|
||||
)
|
||||
|
||||
PROVIDER_API_KEY_CIPHERTEXT_PREFIX = "enc-v1"
|
||||
PROVIDER_API_KEY_CIPHERTEXT_PREFIX = "enc-v2"
|
||||
PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX = "enc-v1"
|
||||
PROVIDER_API_KEY_KEYFILE_NAME = ".settings-api-key"
|
||||
PROVIDER_API_KEY_STREAM_CONTEXT = b"dcm-provider-api-key-stream"
|
||||
PROVIDER_API_KEY_AUTH_CONTEXT = b"dcm-provider-api-key-auth"
|
||||
PROVIDER_API_KEY_NONCE_BYTES = 16
|
||||
PROVIDER_API_KEY_TAG_BYTES = 32
|
||||
PROVIDER_API_KEY_LEGACY_STREAM_CONTEXT = b"dcm-provider-api-key-stream"
|
||||
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT = b"dcm-provider-api-key-auth"
|
||||
PROVIDER_API_KEY_LEGACY_NONCE_BYTES = 16
|
||||
PROVIDER_API_KEY_LEGACY_TAG_BYTES = 32
|
||||
|
||||
|
||||
def _settings_api_key_path() -> Path:
|
||||
@@ -128,14 +137,14 @@ def _derive_provider_api_key_key() -> bytes:
|
||||
return generated
|
||||
|
||||
|
||||
def _xor_bytes(left: bytes, right: bytes) -> bytes:
|
||||
"""Applies byte-wise XOR for equal-length byte sequences."""
|
||||
def _legacy_xor_bytes(left: bytes, right: bytes) -> bytes:
|
||||
"""Applies byte-wise XOR for equal-length byte sequences used by legacy ciphertext migration."""
|
||||
|
||||
return bytes(first ^ second for first, second in zip(left, right))
|
||||
|
||||
|
||||
def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) -> bytes:
|
||||
"""Derives deterministic stream bytes from HMAC-SHA256 blocks for payload masking."""
|
||||
def _legacy_derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) -> bytes:
|
||||
"""Derives legacy deterministic stream bytes from HMAC-SHA256 blocks for migration reads."""
|
||||
|
||||
stream = bytearray()
|
||||
counter = 0
|
||||
@@ -143,7 +152,7 @@ def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) ->
|
||||
counter_bytes = counter.to_bytes(4, "big")
|
||||
block = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_STREAM_CONTEXT + nonce + counter_bytes,
|
||||
PROVIDER_API_KEY_LEGACY_STREAM_CONTEXT + nonce + counter_bytes,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
stream.extend(block)
|
||||
@@ -151,6 +160,33 @@ def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) ->
|
||||
return bytes(stream[:length])
|
||||
|
||||
|
||||
def _provider_key_fernet(master_key: bytes) -> Fernet:
|
||||
"""Builds Fernet instance from 32-byte symmetric key material."""
|
||||
|
||||
if Fernet is None:
|
||||
raise AppSettingsValidationError("cryptography dependency is not available")
|
||||
fernet_key = base64.urlsafe_b64encode(master_key[:32])
|
||||
return Fernet(fernet_key)
|
||||
|
||||
|
||||
def _encrypt_provider_api_key_fallback(value: str) -> str:
|
||||
"""Encrypts provider keys with legacy HMAC stream construction when cryptography is unavailable."""
|
||||
|
||||
plaintext = value.encode("utf-8")
|
||||
master_key = _derive_provider_api_key_key()
|
||||
nonce = secrets.token_bytes(PROVIDER_API_KEY_LEGACY_NONCE_BYTES)
|
||||
keystream = _legacy_derive_stream_cipher_bytes(master_key, nonce, len(plaintext))
|
||||
ciphertext = _legacy_xor_bytes(plaintext, keystream)
|
||||
tag = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
payload = nonce + ciphertext + tag
|
||||
encoded = _urlsafe_b64encode_no_padding(payload)
|
||||
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{encoded}"
|
||||
|
||||
|
||||
def _encrypt_provider_api_key(value: str) -> str:
|
||||
"""Encrypts one provider API key for at-rest JSON persistence."""
|
||||
|
||||
@@ -158,19 +194,52 @@ def _encrypt_provider_api_key(value: str) -> str:
|
||||
if not normalized:
|
||||
return ""
|
||||
|
||||
plaintext = normalized.encode("utf-8")
|
||||
if Fernet is None:
|
||||
return _encrypt_provider_api_key_fallback(normalized)
|
||||
master_key = _derive_provider_api_key_key()
|
||||
nonce = secrets.token_bytes(PROVIDER_API_KEY_NONCE_BYTES)
|
||||
keystream = _derive_stream_cipher_bytes(master_key, nonce, len(plaintext))
|
||||
ciphertext = _xor_bytes(plaintext, keystream)
|
||||
tag = hmac.new(
|
||||
token = _provider_key_fernet(master_key).encrypt(normalized.encode("utf-8")).decode("ascii")
|
||||
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{token}"
|
||||
|
||||
|
||||
def _decrypt_provider_api_key_legacy_payload(encoded_payload: str) -> str:
|
||||
"""Decrypts legacy stream-cipher payload bytes used for migration and fallback reads."""
|
||||
|
||||
if not encoded_payload:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is missing payload bytes")
|
||||
try:
|
||||
payload = _urlsafe_b64decode_no_padding(encoded_payload)
|
||||
except (binascii.Error, ValueError) as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is not valid base64") from error
|
||||
|
||||
minimum_length = PROVIDER_API_KEY_LEGACY_NONCE_BYTES + PROVIDER_API_KEY_LEGACY_TAG_BYTES
|
||||
if len(payload) < minimum_length:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext payload is truncated")
|
||||
|
||||
nonce = payload[:PROVIDER_API_KEY_LEGACY_NONCE_BYTES]
|
||||
ciphertext = payload[PROVIDER_API_KEY_LEGACY_NONCE_BYTES:-PROVIDER_API_KEY_LEGACY_TAG_BYTES]
|
||||
received_tag = payload[-PROVIDER_API_KEY_LEGACY_TAG_BYTES:]
|
||||
master_key = _derive_provider_api_key_key()
|
||||
expected_tag = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
payload = nonce + ciphertext + tag
|
||||
encoded = _urlsafe_b64encode_no_padding(payload)
|
||||
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{encoded}"
|
||||
if not hmac.compare_digest(received_tag, expected_tag):
|
||||
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed")
|
||||
|
||||
keystream = _legacy_derive_stream_cipher_bytes(master_key, nonce, len(ciphertext))
|
||||
plaintext = _legacy_xor_bytes(ciphertext, keystream)
|
||||
try:
|
||||
return plaintext.decode("utf-8").strip()
|
||||
except UnicodeDecodeError as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is not valid UTF-8") from error
|
||||
|
||||
|
||||
def _decrypt_provider_api_key_legacy(value: str) -> str:
|
||||
"""Decrypts legacy `enc-v1` payloads to support non-breaking key migration."""
|
||||
|
||||
encoded_payload = value.split(":", 1)[1]
|
||||
return _decrypt_provider_api_key_legacy_payload(encoded_payload)
|
||||
|
||||
|
||||
def _decrypt_provider_api_key(value: str) -> str:
|
||||
@@ -179,35 +248,23 @@ def _decrypt_provider_api_key(value: str) -> str:
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return ""
|
||||
if not normalized.startswith(f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:"):
|
||||
if not normalized.startswith(f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:") and not normalized.startswith(
|
||||
f"{PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX}:"
|
||||
):
|
||||
return normalized
|
||||
|
||||
encoded_payload = normalized.split(":", 1)[1]
|
||||
if not encoded_payload:
|
||||
if normalized.startswith(f"{PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX}:"):
|
||||
return _decrypt_provider_api_key_legacy(normalized)
|
||||
|
||||
token = normalized.split(":", 1)[1].strip()
|
||||
if not token:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is missing payload bytes")
|
||||
if Fernet is None:
|
||||
return _decrypt_provider_api_key_legacy_payload(token)
|
||||
try:
|
||||
payload = _urlsafe_b64decode_no_padding(encoded_payload)
|
||||
except (binascii.Error, ValueError) as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext is not valid base64") from error
|
||||
|
||||
minimum_length = PROVIDER_API_KEY_NONCE_BYTES + PROVIDER_API_KEY_TAG_BYTES
|
||||
if len(payload) < minimum_length:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext payload is truncated")
|
||||
|
||||
nonce = payload[:PROVIDER_API_KEY_NONCE_BYTES]
|
||||
ciphertext = payload[PROVIDER_API_KEY_NONCE_BYTES:-PROVIDER_API_KEY_TAG_BYTES]
|
||||
received_tag = payload[-PROVIDER_API_KEY_TAG_BYTES:]
|
||||
master_key = _derive_provider_api_key_key()
|
||||
expected_tag = hmac.new(
|
||||
master_key,
|
||||
PROVIDER_API_KEY_AUTH_CONTEXT + nonce + ciphertext,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
if not hmac.compare_digest(received_tag, expected_tag):
|
||||
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed")
|
||||
|
||||
keystream = _derive_stream_cipher_bytes(master_key, nonce, len(ciphertext))
|
||||
plaintext = _xor_bytes(ciphertext, keystream)
|
||||
plaintext = _provider_key_fernet(_derive_provider_api_key_key()).decrypt(token.encode("ascii"))
|
||||
except (InvalidToken, ValueError, UnicodeEncodeError) as error:
|
||||
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed") from error
|
||||
try:
|
||||
return plaintext.decode("utf-8").strip()
|
||||
except UnicodeDecodeError as error:
|
||||
|
||||
289
backend/app/services/authentication.py
Normal file
289
backend/app/services/authentication.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Authentication services for user credential validation and session issuance."""
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import Settings, get_settings
|
||||
from app.db.base import SessionLocal
|
||||
from app.models.auth import AppUser, AuthSession, UserRole
|
||||
|
||||
|
||||
PASSWORD_HASH_SCHEME = "pbkdf2_sha256"
|
||||
DEFAULT_AUTH_FALLBACK_SECRET = "dcm-session-secret"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IssuedSession:
|
||||
"""Represents one newly issued bearer session token and expiration timestamp."""
|
||||
|
||||
token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def normalize_username(username: str) -> str:
|
||||
"""Normalizes usernames to a stable lowercase identity key."""
|
||||
|
||||
return username.strip().lower()
|
||||
|
||||
|
||||
def _urlsafe_b64encode_no_padding(data: bytes) -> str:
|
||||
"""Encodes bytes to compact URL-safe base64 without padding."""
|
||||
|
||||
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def _urlsafe_b64decode_no_padding(data: str) -> bytes:
|
||||
"""Decodes URL-safe base64 values that may omit trailing padding characters."""
|
||||
|
||||
padded = data + "=" * (-len(data) % 4)
|
||||
return base64.urlsafe_b64decode(padded.encode("ascii"))
|
||||
|
||||
|
||||
def _password_iterations(settings: Settings) -> int:
|
||||
"""Returns PBKDF2 iteration count clamped to a secure operational range."""
|
||||
|
||||
return max(200_000, min(1_200_000, int(settings.auth_password_pbkdf2_iterations)))
|
||||
|
||||
|
||||
def hash_password(password: str, settings: Settings | None = None) -> str:
|
||||
"""Derives and formats a PBKDF2-SHA256 password hash for persisted user credentials."""
|
||||
|
||||
resolved_settings = settings or get_settings()
|
||||
normalized_password = password.strip()
|
||||
if not normalized_password:
|
||||
raise ValueError("Password must not be empty")
|
||||
|
||||
iterations = _password_iterations(resolved_settings)
|
||||
salt = secrets.token_bytes(16)
|
||||
derived = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
normalized_password.encode("utf-8"),
|
||||
salt,
|
||||
iterations,
|
||||
dklen=32,
|
||||
)
|
||||
return (
|
||||
f"{PASSWORD_HASH_SCHEME}$"
|
||||
f"{iterations}$"
|
||||
f"{_urlsafe_b64encode_no_padding(salt)}$"
|
||||
f"{_urlsafe_b64encode_no_padding(derived)}"
|
||||
)
|
||||
|
||||
|
||||
def verify_password(password: str, stored_hash: str, settings: Settings | None = None) -> bool:
|
||||
"""Verifies one plaintext password against persisted PBKDF2-SHA256 hash material."""
|
||||
|
||||
resolved_settings = settings or get_settings()
|
||||
normalized_password = password.strip()
|
||||
if not normalized_password:
|
||||
return False
|
||||
|
||||
parts = stored_hash.strip().split("$")
|
||||
if len(parts) != 4:
|
||||
return False
|
||||
scheme, iterations_text, salt_text, digest_text = parts
|
||||
if scheme != PASSWORD_HASH_SCHEME:
|
||||
return False
|
||||
try:
|
||||
iterations = int(iterations_text)
|
||||
except ValueError:
|
||||
return False
|
||||
if iterations < 200_000 or iterations > 2_000_000:
|
||||
return False
|
||||
try:
|
||||
salt = _urlsafe_b64decode_no_padding(salt_text)
|
||||
expected_digest = _urlsafe_b64decode_no_padding(digest_text)
|
||||
except (binascii.Error, ValueError):
|
||||
return False
|
||||
derived_digest = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
normalized_password.encode("utf-8"),
|
||||
salt,
|
||||
iterations,
|
||||
dklen=len(expected_digest),
|
||||
)
|
||||
if not hmac.compare_digest(expected_digest, derived_digest):
|
||||
return False
|
||||
return iterations >= _password_iterations(resolved_settings)
|
||||
|
||||
|
||||
def _auth_session_secret(settings: Settings) -> bytes:
|
||||
"""Resolves a stable secret used to hash issued bearer session tokens."""
|
||||
|
||||
candidate = settings.auth_session_pepper.strip() or settings.app_settings_encryption_key.strip()
|
||||
if not candidate:
|
||||
candidate = DEFAULT_AUTH_FALLBACK_SECRET
|
||||
return hashlib.sha256(candidate.encode("utf-8")).digest()
|
||||
|
||||
|
||||
def _hash_session_token(token: str, settings: Settings | None = None) -> str:
|
||||
"""Derives a deterministic SHA256 token hash guarded by secret pepper material."""
|
||||
|
||||
resolved_settings = settings or get_settings()
|
||||
secret = _auth_session_secret(resolved_settings)
|
||||
digest = hmac.new(secret, token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
return digest
|
||||
|
||||
|
||||
def _new_session_token(settings: Settings) -> str:
|
||||
"""Creates a random URL-safe bearer token for one API session."""
|
||||
|
||||
token_bytes = max(24, min(128, int(settings.auth_session_token_bytes)))
|
||||
return secrets.token_urlsafe(token_bytes)
|
||||
|
||||
|
||||
def _resolve_optional_user_credentials(username: str, password: str) -> tuple[str, str] | None:
|
||||
"""Returns optional user credentials only when both username and password are configured."""
|
||||
|
||||
normalized_username = normalize_username(username)
|
||||
normalized_password = password.strip()
|
||||
if not normalized_username and not normalized_password:
|
||||
return None
|
||||
if not normalized_username or not normalized_password:
|
||||
raise ValueError("Optional bootstrap user requires both username and password")
|
||||
return normalized_username, normalized_password
|
||||
|
||||
|
||||
def _upsert_bootstrap_user(session: Session, *, username: str, password: str, role: UserRole) -> AppUser:
|
||||
"""Creates or updates one bootstrap account with deterministic role assignment."""
|
||||
|
||||
existing = session.execute(select(AppUser).where(AppUser.username == username)).scalar_one_or_none()
|
||||
password_hash = hash_password(password)
|
||||
if existing is None:
|
||||
user = AppUser(
|
||||
username=username,
|
||||
password_hash=password_hash,
|
||||
role=role,
|
||||
is_active=True,
|
||||
)
|
||||
session.add(user)
|
||||
return user
|
||||
|
||||
existing.password_hash = password_hash
|
||||
existing.role = role
|
||||
existing.is_active = True
|
||||
return existing
|
||||
|
||||
|
||||
def ensure_bootstrap_users() -> None:
|
||||
"""Creates or refreshes bootstrap user accounts from runtime environment credentials."""
|
||||
|
||||
settings = get_settings()
|
||||
admin_username = normalize_username(settings.auth_bootstrap_admin_username)
|
||||
admin_password = settings.auth_bootstrap_admin_password.strip()
|
||||
if not admin_username:
|
||||
raise RuntimeError("AUTH_BOOTSTRAP_ADMIN_USERNAME must not be empty")
|
||||
if not admin_password:
|
||||
raise RuntimeError("AUTH_BOOTSTRAP_ADMIN_PASSWORD must not be empty")
|
||||
|
||||
optional_user_credentials = _resolve_optional_user_credentials(
|
||||
username=settings.auth_bootstrap_user_username,
|
||||
password=settings.auth_bootstrap_user_password,
|
||||
)
|
||||
|
||||
with SessionLocal() as session:
|
||||
_upsert_bootstrap_user(
|
||||
session,
|
||||
username=admin_username,
|
||||
password=admin_password,
|
||||
role=UserRole.ADMIN,
|
||||
)
|
||||
if optional_user_credentials is not None:
|
||||
user_username, user_password = optional_user_credentials
|
||||
if user_username == admin_username:
|
||||
raise RuntimeError("AUTH_BOOTSTRAP_USER_USERNAME must differ from admin username")
|
||||
_upsert_bootstrap_user(
|
||||
session,
|
||||
username=user_username,
|
||||
password=user_password,
|
||||
role=UserRole.USER,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
def authenticate_user(session: Session, *, username: str, password: str) -> AppUser | None:
|
||||
"""Authenticates one username/password pair and returns active account on success."""
|
||||
|
||||
normalized_username = normalize_username(username)
|
||||
if not normalized_username:
|
||||
return None
|
||||
user = session.execute(select(AppUser).where(AppUser.username == normalized_username)).scalar_one_or_none()
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
if not verify_password(password, user.password_hash):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
def issue_user_session(
|
||||
session: Session,
|
||||
*,
|
||||
user: AppUser,
|
||||
user_agent: str | None = None,
|
||||
ip_address: str | None = None,
|
||||
) -> IssuedSession:
|
||||
"""Issues one new bearer token session for a validated user account."""
|
||||
|
||||
settings = get_settings()
|
||||
now = datetime.now(UTC)
|
||||
ttl_minutes = max(5, min(7 * 24 * 60, int(settings.auth_session_ttl_minutes)))
|
||||
expires_at = now + timedelta(minutes=ttl_minutes)
|
||||
token = _new_session_token(settings)
|
||||
token_hash = _hash_session_token(token, settings)
|
||||
|
||||
session.execute(
|
||||
delete(AuthSession).where(
|
||||
AuthSession.user_id == user.id,
|
||||
AuthSession.expires_at <= now,
|
||||
)
|
||||
)
|
||||
session_entry = AuthSession(
|
||||
user_id=user.id,
|
||||
token_hash=token_hash,
|
||||
expires_at=expires_at,
|
||||
user_agent=(user_agent or "").strip()[:512] or None,
|
||||
ip_address=(ip_address or "").strip()[:64] or None,
|
||||
)
|
||||
session.add(session_entry)
|
||||
return IssuedSession(token=token, expires_at=expires_at)
|
||||
|
||||
|
||||
def resolve_auth_session(session: Session, *, token: str) -> AuthSession | None:
|
||||
"""Resolves one non-revoked and non-expired session from a bearer token value."""
|
||||
|
||||
normalized = token.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
token_hash = _hash_session_token(normalized)
|
||||
now = datetime.now(UTC)
|
||||
session_entry = session.execute(
|
||||
select(AuthSession).where(
|
||||
AuthSession.token_hash == token_hash,
|
||||
AuthSession.revoked_at.is_(None),
|
||||
AuthSession.expires_at > now,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if session_entry is None or session_entry.user is None:
|
||||
return None
|
||||
if not session_entry.user.is_active:
|
||||
return None
|
||||
return session_entry
|
||||
|
||||
|
||||
def revoke_auth_session(session: Session, *, session_id: uuid.UUID) -> bool:
|
||||
"""Revokes one active session by identifier and returns whether a change was applied."""
|
||||
|
||||
existing = session.execute(select(AuthSession).where(AuthSession.id == session_id)).scalar_one_or_none()
|
||||
if existing is None or existing.revoked_at is not None:
|
||||
return False
|
||||
existing.revoked_at = datetime.now(UTC)
|
||||
return True
|
||||
@@ -6,10 +6,13 @@ from uuid import UUID
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.document import Document
|
||||
from app.models.processing_log import ProcessingLogEntry
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
MAX_STAGE_LENGTH = 64
|
||||
MAX_EVENT_LENGTH = 256
|
||||
MAX_LEVEL_LENGTH = 16
|
||||
@@ -37,9 +40,49 @@ def _trim(value: str | None, max_length: int) -> str | None:
|
||||
|
||||
|
||||
def _safe_payload(payload_json: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Ensures payload values are persisted as dictionaries."""
|
||||
"""Normalizes payload persistence mode using metadata-only defaults for sensitive content."""
|
||||
|
||||
return payload_json if isinstance(payload_json, dict) else {}
|
||||
if not isinstance(payload_json, dict):
|
||||
return {}
|
||||
if settings.processing_log_store_payload_text:
|
||||
return payload_json
|
||||
return _metadata_only_payload(payload_json)
|
||||
|
||||
|
||||
def _metadata_only_payload(payload_json: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Converts payload content into metadata descriptors without persisting raw text values."""
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
for index, (raw_key, raw_value) in enumerate(payload_json.items()):
|
||||
if index >= 80:
|
||||
break
|
||||
key = str(raw_key)
|
||||
metadata[key] = _metadata_only_payload_value(raw_value)
|
||||
return metadata
|
||||
|
||||
|
||||
def _metadata_only_payload_value(value: Any) -> Any:
|
||||
"""Converts one payload value into non-sensitive metadata representation."""
|
||||
|
||||
if isinstance(value, dict):
|
||||
return _metadata_only_payload(value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = list(value)
|
||||
return {
|
||||
"item_count": len(items),
|
||||
"items_preview": [_metadata_only_payload_value(item) for item in items[:20]],
|
||||
}
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
return {
|
||||
"text_chars": len(normalized),
|
||||
"text_omitted": bool(normalized),
|
||||
}
|
||||
if isinstance(value, bytes):
|
||||
return {"binary_bytes": len(value)}
|
||||
if isinstance(value, (int, float, bool)) or value is None:
|
||||
return value
|
||||
return {"value_type": type(value).__name__}
|
||||
|
||||
|
||||
def set_processing_log_autocommit(session: Session, enabled: bool) -> None:
|
||||
@@ -82,8 +125,8 @@ def log_processing_event(
|
||||
document_filename=_trim(resolved_document_filename, MAX_DOCUMENT_FILENAME_LENGTH),
|
||||
provider_id=_trim(provider_id, MAX_PROVIDER_LENGTH),
|
||||
model_name=_trim(model_name, MAX_MODEL_LENGTH),
|
||||
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH),
|
||||
response_text=_trim(response_text, MAX_RESPONSE_LENGTH),
|
||||
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH) if settings.processing_log_store_model_io_text else None,
|
||||
response_text=_trim(response_text, MAX_RESPONSE_LENGTH) if settings.processing_log_store_model_io_text else None,
|
||||
payload_json=_safe_payload(payload_json),
|
||||
)
|
||||
session.add(entry)
|
||||
|
||||
42
backend/app/services/rate_limiter.py
Normal file
42
backend/app/services/rate_limiter.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Redis-backed fixed-window rate limiter helpers for sensitive API operations."""
|
||||
|
||||
import time
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from app.worker.queue import get_redis
|
||||
|
||||
|
||||
def _rate_limit_key(*, scope: str, subject: str, window_id: int) -> str:
|
||||
"""Builds a stable Redis key for one scope, subject, and fixed time window."""
|
||||
|
||||
return f"dcm:rate-limit:{scope}:{subject}:{window_id}"
|
||||
|
||||
|
||||
def increment_rate_limit(
|
||||
*,
|
||||
scope: str,
|
||||
subject: str,
|
||||
limit: int,
|
||||
window_seconds: int = 60,
|
||||
) -> tuple[int, int]:
|
||||
"""Increments one rate bucket and returns current count with configured limit."""
|
||||
|
||||
bounded_limit = max(0, int(limit))
|
||||
if bounded_limit == 0:
|
||||
return (0, 0)
|
||||
|
||||
bounded_window = max(1, int(window_seconds))
|
||||
current_window = int(time.time() // bounded_window)
|
||||
key = _rate_limit_key(scope=scope, subject=subject, window_id=current_window)
|
||||
|
||||
redis_client = get_redis()
|
||||
try:
|
||||
pipeline = redis_client.pipeline(transaction=True)
|
||||
pipeline.incr(key, 1)
|
||||
pipeline.expire(key, bounded_window + 5)
|
||||
count_value, _ = pipeline.execute()
|
||||
except RedisError as error:
|
||||
raise RuntimeError("Rate limiter backend unavailable") from error
|
||||
|
||||
return (int(count_value), bounded_limit)
|
||||
Reference in New Issue
Block a user