Harden auth and security controls with session auth and docs

This commit is contained in:
2026-03-01 15:29:09 -03:00
parent 7a19f22f41
commit 0242e061c2
36 changed files with 1794 additions and 505 deletions

View File

@@ -11,6 +11,14 @@ import secrets
from pathlib import Path
from typing import Any
try:
from cryptography.fernet import Fernet, InvalidToken
except Exception: # pragma: no cover - dependency failures are surfaced at runtime usage.
Fernet = None # type: ignore[assignment]
class InvalidToken(Exception):
"""Fallback InvalidToken type used when cryptography dependency import fails."""
from app.core.config import get_settings, normalize_and_validate_provider_base_url
@@ -63,12 +71,13 @@ DEFAULT_ROUTING_PROMPT = (
"Confidence must be between 0 and 1."
)
PROVIDER_API_KEY_CIPHERTEXT_PREFIX = "enc-v1"
PROVIDER_API_KEY_CIPHERTEXT_PREFIX = "enc-v2"
PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX = "enc-v1"
PROVIDER_API_KEY_KEYFILE_NAME = ".settings-api-key"
PROVIDER_API_KEY_STREAM_CONTEXT = b"dcm-provider-api-key-stream"
PROVIDER_API_KEY_AUTH_CONTEXT = b"dcm-provider-api-key-auth"
PROVIDER_API_KEY_NONCE_BYTES = 16
PROVIDER_API_KEY_TAG_BYTES = 32
PROVIDER_API_KEY_LEGACY_STREAM_CONTEXT = b"dcm-provider-api-key-stream"
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT = b"dcm-provider-api-key-auth"
PROVIDER_API_KEY_LEGACY_NONCE_BYTES = 16
PROVIDER_API_KEY_LEGACY_TAG_BYTES = 32
def _settings_api_key_path() -> Path:
@@ -128,14 +137,14 @@ def _derive_provider_api_key_key() -> bytes:
return generated
def _xor_bytes(left: bytes, right: bytes) -> bytes:
"""Applies byte-wise XOR for equal-length byte sequences."""
def _legacy_xor_bytes(left: bytes, right: bytes) -> bytes:
"""Applies byte-wise XOR for equal-length byte sequences used by legacy ciphertext migration."""
return bytes(first ^ second for first, second in zip(left, right))
def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) -> bytes:
"""Derives deterministic stream bytes from HMAC-SHA256 blocks for payload masking."""
def _legacy_derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) -> bytes:
"""Derives legacy deterministic stream bytes from HMAC-SHA256 blocks for migration reads."""
stream = bytearray()
counter = 0
@@ -143,7 +152,7 @@ def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) ->
counter_bytes = counter.to_bytes(4, "big")
block = hmac.new(
master_key,
PROVIDER_API_KEY_STREAM_CONTEXT + nonce + counter_bytes,
PROVIDER_API_KEY_LEGACY_STREAM_CONTEXT + nonce + counter_bytes,
hashlib.sha256,
).digest()
stream.extend(block)
@@ -151,6 +160,33 @@ def _derive_stream_cipher_bytes(master_key: bytes, nonce: bytes, length: int) ->
return bytes(stream[:length])
def _provider_key_fernet(master_key: bytes) -> Fernet:
"""Builds Fernet instance from 32-byte symmetric key material."""
if Fernet is None:
raise AppSettingsValidationError("cryptography dependency is not available")
fernet_key = base64.urlsafe_b64encode(master_key[:32])
return Fernet(fernet_key)
def _encrypt_provider_api_key_fallback(value: str) -> str:
"""Encrypts provider keys with legacy HMAC stream construction when cryptography is unavailable."""
plaintext = value.encode("utf-8")
master_key = _derive_provider_api_key_key()
nonce = secrets.token_bytes(PROVIDER_API_KEY_LEGACY_NONCE_BYTES)
keystream = _legacy_derive_stream_cipher_bytes(master_key, nonce, len(plaintext))
ciphertext = _legacy_xor_bytes(plaintext, keystream)
tag = hmac.new(
master_key,
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT + nonce + ciphertext,
hashlib.sha256,
).digest()
payload = nonce + ciphertext + tag
encoded = _urlsafe_b64encode_no_padding(payload)
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{encoded}"
def _encrypt_provider_api_key(value: str) -> str:
"""Encrypts one provider API key for at-rest JSON persistence."""
@@ -158,19 +194,52 @@ def _encrypt_provider_api_key(value: str) -> str:
if not normalized:
return ""
plaintext = normalized.encode("utf-8")
if Fernet is None:
return _encrypt_provider_api_key_fallback(normalized)
master_key = _derive_provider_api_key_key()
nonce = secrets.token_bytes(PROVIDER_API_KEY_NONCE_BYTES)
keystream = _derive_stream_cipher_bytes(master_key, nonce, len(plaintext))
ciphertext = _xor_bytes(plaintext, keystream)
tag = hmac.new(
token = _provider_key_fernet(master_key).encrypt(normalized.encode("utf-8")).decode("ascii")
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{token}"
def _decrypt_provider_api_key_legacy_payload(encoded_payload: str) -> str:
"""Decrypts legacy stream-cipher payload bytes used for migration and fallback reads."""
if not encoded_payload:
raise AppSettingsValidationError("Provider API key ciphertext is missing payload bytes")
try:
payload = _urlsafe_b64decode_no_padding(encoded_payload)
except (binascii.Error, ValueError) as error:
raise AppSettingsValidationError("Provider API key ciphertext is not valid base64") from error
minimum_length = PROVIDER_API_KEY_LEGACY_NONCE_BYTES + PROVIDER_API_KEY_LEGACY_TAG_BYTES
if len(payload) < minimum_length:
raise AppSettingsValidationError("Provider API key ciphertext payload is truncated")
nonce = payload[:PROVIDER_API_KEY_LEGACY_NONCE_BYTES]
ciphertext = payload[PROVIDER_API_KEY_LEGACY_NONCE_BYTES:-PROVIDER_API_KEY_LEGACY_TAG_BYTES]
received_tag = payload[-PROVIDER_API_KEY_LEGACY_TAG_BYTES:]
master_key = _derive_provider_api_key_key()
expected_tag = hmac.new(
master_key,
PROVIDER_API_KEY_AUTH_CONTEXT + nonce + ciphertext,
PROVIDER_API_KEY_LEGACY_AUTH_CONTEXT + nonce + ciphertext,
hashlib.sha256,
).digest()
payload = nonce + ciphertext + tag
encoded = _urlsafe_b64encode_no_padding(payload)
return f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:{encoded}"
if not hmac.compare_digest(received_tag, expected_tag):
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed")
keystream = _legacy_derive_stream_cipher_bytes(master_key, nonce, len(ciphertext))
plaintext = _legacy_xor_bytes(ciphertext, keystream)
try:
return plaintext.decode("utf-8").strip()
except UnicodeDecodeError as error:
raise AppSettingsValidationError("Provider API key ciphertext is not valid UTF-8") from error
def _decrypt_provider_api_key_legacy(value: str) -> str:
"""Decrypts legacy `enc-v1` payloads to support non-breaking key migration."""
encoded_payload = value.split(":", 1)[1]
return _decrypt_provider_api_key_legacy_payload(encoded_payload)
def _decrypt_provider_api_key(value: str) -> str:
@@ -179,35 +248,23 @@ def _decrypt_provider_api_key(value: str) -> str:
normalized = value.strip()
if not normalized:
return ""
if not normalized.startswith(f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:"):
if not normalized.startswith(f"{PROVIDER_API_KEY_CIPHERTEXT_PREFIX}:") and not normalized.startswith(
f"{PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX}:"
):
return normalized
encoded_payload = normalized.split(":", 1)[1]
if not encoded_payload:
if normalized.startswith(f"{PROVIDER_API_KEY_LEGACY_CIPHERTEXT_PREFIX}:"):
return _decrypt_provider_api_key_legacy(normalized)
token = normalized.split(":", 1)[1].strip()
if not token:
raise AppSettingsValidationError("Provider API key ciphertext is missing payload bytes")
if Fernet is None:
return _decrypt_provider_api_key_legacy_payload(token)
try:
payload = _urlsafe_b64decode_no_padding(encoded_payload)
except (binascii.Error, ValueError) as error:
raise AppSettingsValidationError("Provider API key ciphertext is not valid base64") from error
minimum_length = PROVIDER_API_KEY_NONCE_BYTES + PROVIDER_API_KEY_TAG_BYTES
if len(payload) < minimum_length:
raise AppSettingsValidationError("Provider API key ciphertext payload is truncated")
nonce = payload[:PROVIDER_API_KEY_NONCE_BYTES]
ciphertext = payload[PROVIDER_API_KEY_NONCE_BYTES:-PROVIDER_API_KEY_TAG_BYTES]
received_tag = payload[-PROVIDER_API_KEY_TAG_BYTES:]
master_key = _derive_provider_api_key_key()
expected_tag = hmac.new(
master_key,
PROVIDER_API_KEY_AUTH_CONTEXT + nonce + ciphertext,
hashlib.sha256,
).digest()
if not hmac.compare_digest(received_tag, expected_tag):
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed")
keystream = _derive_stream_cipher_bytes(master_key, nonce, len(ciphertext))
plaintext = _xor_bytes(ciphertext, keystream)
plaintext = _provider_key_fernet(_derive_provider_api_key_key()).decrypt(token.encode("ascii"))
except (InvalidToken, ValueError, UnicodeEncodeError) as error:
raise AppSettingsValidationError("Provider API key ciphertext integrity check failed") from error
try:
return plaintext.decode("utf-8").strip()
except UnicodeDecodeError as error:

View File

@@ -0,0 +1,289 @@
"""Authentication services for user credential validation and session issuance."""
import base64
import binascii
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
import hashlib
import hmac
import secrets
import uuid
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.core.config import Settings, get_settings
from app.db.base import SessionLocal
from app.models.auth import AppUser, AuthSession, UserRole
PASSWORD_HASH_SCHEME = "pbkdf2_sha256"
DEFAULT_AUTH_FALLBACK_SECRET = "dcm-session-secret"
@dataclass(frozen=True)
class IssuedSession:
"""Represents one newly issued bearer session token and expiration timestamp."""
token: str
expires_at: datetime
def normalize_username(username: str) -> str:
"""Normalizes usernames to a stable lowercase identity key."""
return username.strip().lower()
def _urlsafe_b64encode_no_padding(data: bytes) -> str:
"""Encodes bytes to compact URL-safe base64 without padding."""
return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=")
def _urlsafe_b64decode_no_padding(data: str) -> bytes:
"""Decodes URL-safe base64 values that may omit trailing padding characters."""
padded = data + "=" * (-len(data) % 4)
return base64.urlsafe_b64decode(padded.encode("ascii"))
def _password_iterations(settings: Settings) -> int:
"""Returns PBKDF2 iteration count clamped to a secure operational range."""
return max(200_000, min(1_200_000, int(settings.auth_password_pbkdf2_iterations)))
def hash_password(password: str, settings: Settings | None = None) -> str:
"""Derives and formats a PBKDF2-SHA256 password hash for persisted user credentials."""
resolved_settings = settings or get_settings()
normalized_password = password.strip()
if not normalized_password:
raise ValueError("Password must not be empty")
iterations = _password_iterations(resolved_settings)
salt = secrets.token_bytes(16)
derived = hashlib.pbkdf2_hmac(
"sha256",
normalized_password.encode("utf-8"),
salt,
iterations,
dklen=32,
)
return (
f"{PASSWORD_HASH_SCHEME}$"
f"{iterations}$"
f"{_urlsafe_b64encode_no_padding(salt)}$"
f"{_urlsafe_b64encode_no_padding(derived)}"
)
def verify_password(password: str, stored_hash: str, settings: Settings | None = None) -> bool:
"""Verifies one plaintext password against persisted PBKDF2-SHA256 hash material."""
resolved_settings = settings or get_settings()
normalized_password = password.strip()
if not normalized_password:
return False
parts = stored_hash.strip().split("$")
if len(parts) != 4:
return False
scheme, iterations_text, salt_text, digest_text = parts
if scheme != PASSWORD_HASH_SCHEME:
return False
try:
iterations = int(iterations_text)
except ValueError:
return False
if iterations < 200_000 or iterations > 2_000_000:
return False
try:
salt = _urlsafe_b64decode_no_padding(salt_text)
expected_digest = _urlsafe_b64decode_no_padding(digest_text)
except (binascii.Error, ValueError):
return False
derived_digest = hashlib.pbkdf2_hmac(
"sha256",
normalized_password.encode("utf-8"),
salt,
iterations,
dklen=len(expected_digest),
)
if not hmac.compare_digest(expected_digest, derived_digest):
return False
return iterations >= _password_iterations(resolved_settings)
def _auth_session_secret(settings: Settings) -> bytes:
"""Resolves a stable secret used to hash issued bearer session tokens."""
candidate = settings.auth_session_pepper.strip() or settings.app_settings_encryption_key.strip()
if not candidate:
candidate = DEFAULT_AUTH_FALLBACK_SECRET
return hashlib.sha256(candidate.encode("utf-8")).digest()
def _hash_session_token(token: str, settings: Settings | None = None) -> str:
"""Derives a deterministic SHA256 token hash guarded by secret pepper material."""
resolved_settings = settings or get_settings()
secret = _auth_session_secret(resolved_settings)
digest = hmac.new(secret, token.encode("utf-8"), hashlib.sha256).hexdigest()
return digest
def _new_session_token(settings: Settings) -> str:
"""Creates a random URL-safe bearer token for one API session."""
token_bytes = max(24, min(128, int(settings.auth_session_token_bytes)))
return secrets.token_urlsafe(token_bytes)
def _resolve_optional_user_credentials(username: str, password: str) -> tuple[str, str] | None:
"""Returns optional user credentials only when both username and password are configured."""
normalized_username = normalize_username(username)
normalized_password = password.strip()
if not normalized_username and not normalized_password:
return None
if not normalized_username or not normalized_password:
raise ValueError("Optional bootstrap user requires both username and password")
return normalized_username, normalized_password
def _upsert_bootstrap_user(session: Session, *, username: str, password: str, role: UserRole) -> AppUser:
"""Creates or updates one bootstrap account with deterministic role assignment."""
existing = session.execute(select(AppUser).where(AppUser.username == username)).scalar_one_or_none()
password_hash = hash_password(password)
if existing is None:
user = AppUser(
username=username,
password_hash=password_hash,
role=role,
is_active=True,
)
session.add(user)
return user
existing.password_hash = password_hash
existing.role = role
existing.is_active = True
return existing
def ensure_bootstrap_users() -> None:
"""Creates or refreshes bootstrap user accounts from runtime environment credentials."""
settings = get_settings()
admin_username = normalize_username(settings.auth_bootstrap_admin_username)
admin_password = settings.auth_bootstrap_admin_password.strip()
if not admin_username:
raise RuntimeError("AUTH_BOOTSTRAP_ADMIN_USERNAME must not be empty")
if not admin_password:
raise RuntimeError("AUTH_BOOTSTRAP_ADMIN_PASSWORD must not be empty")
optional_user_credentials = _resolve_optional_user_credentials(
username=settings.auth_bootstrap_user_username,
password=settings.auth_bootstrap_user_password,
)
with SessionLocal() as session:
_upsert_bootstrap_user(
session,
username=admin_username,
password=admin_password,
role=UserRole.ADMIN,
)
if optional_user_credentials is not None:
user_username, user_password = optional_user_credentials
if user_username == admin_username:
raise RuntimeError("AUTH_BOOTSTRAP_USER_USERNAME must differ from admin username")
_upsert_bootstrap_user(
session,
username=user_username,
password=user_password,
role=UserRole.USER,
)
session.commit()
def authenticate_user(session: Session, *, username: str, password: str) -> AppUser | None:
"""Authenticates one username/password pair and returns active account on success."""
normalized_username = normalize_username(username)
if not normalized_username:
return None
user = session.execute(select(AppUser).where(AppUser.username == normalized_username)).scalar_one_or_none()
if user is None or not user.is_active:
return None
if not verify_password(password, user.password_hash):
return None
return user
def issue_user_session(
session: Session,
*,
user: AppUser,
user_agent: str | None = None,
ip_address: str | None = None,
) -> IssuedSession:
"""Issues one new bearer token session for a validated user account."""
settings = get_settings()
now = datetime.now(UTC)
ttl_minutes = max(5, min(7 * 24 * 60, int(settings.auth_session_ttl_minutes)))
expires_at = now + timedelta(minutes=ttl_minutes)
token = _new_session_token(settings)
token_hash = _hash_session_token(token, settings)
session.execute(
delete(AuthSession).where(
AuthSession.user_id == user.id,
AuthSession.expires_at <= now,
)
)
session_entry = AuthSession(
user_id=user.id,
token_hash=token_hash,
expires_at=expires_at,
user_agent=(user_agent or "").strip()[:512] or None,
ip_address=(ip_address or "").strip()[:64] or None,
)
session.add(session_entry)
return IssuedSession(token=token, expires_at=expires_at)
def resolve_auth_session(session: Session, *, token: str) -> AuthSession | None:
"""Resolves one non-revoked and non-expired session from a bearer token value."""
normalized = token.strip()
if not normalized:
return None
token_hash = _hash_session_token(normalized)
now = datetime.now(UTC)
session_entry = session.execute(
select(AuthSession).where(
AuthSession.token_hash == token_hash,
AuthSession.revoked_at.is_(None),
AuthSession.expires_at > now,
)
).scalar_one_or_none()
if session_entry is None or session_entry.user is None:
return None
if not session_entry.user.is_active:
return None
return session_entry
def revoke_auth_session(session: Session, *, session_id: uuid.UUID) -> bool:
"""Revokes one active session by identifier and returns whether a change was applied."""
existing = session.execute(select(AuthSession).where(AuthSession.id == session_id)).scalar_one_or_none()
if existing is None or existing.revoked_at is not None:
return False
existing.revoked_at = datetime.now(UTC)
return True

View File

@@ -6,10 +6,13 @@ from uuid import UUID
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.models.document import Document
from app.models.processing_log import ProcessingLogEntry
settings = get_settings()
MAX_STAGE_LENGTH = 64
MAX_EVENT_LENGTH = 256
MAX_LEVEL_LENGTH = 16
@@ -37,9 +40,49 @@ def _trim(value: str | None, max_length: int) -> str | None:
def _safe_payload(payload_json: dict[str, Any] | None) -> dict[str, Any]:
"""Ensures payload values are persisted as dictionaries."""
"""Normalizes payload persistence mode using metadata-only defaults for sensitive content."""
return payload_json if isinstance(payload_json, dict) else {}
if not isinstance(payload_json, dict):
return {}
if settings.processing_log_store_payload_text:
return payload_json
return _metadata_only_payload(payload_json)
def _metadata_only_payload(payload_json: dict[str, Any]) -> dict[str, Any]:
"""Converts payload content into metadata descriptors without persisting raw text values."""
metadata: dict[str, Any] = {}
for index, (raw_key, raw_value) in enumerate(payload_json.items()):
if index >= 80:
break
key = str(raw_key)
metadata[key] = _metadata_only_payload_value(raw_value)
return metadata
def _metadata_only_payload_value(value: Any) -> Any:
"""Converts one payload value into non-sensitive metadata representation."""
if isinstance(value, dict):
return _metadata_only_payload(value)
if isinstance(value, (list, tuple)):
items = list(value)
return {
"item_count": len(items),
"items_preview": [_metadata_only_payload_value(item) for item in items[:20]],
}
if isinstance(value, str):
normalized = value.strip()
return {
"text_chars": len(normalized),
"text_omitted": bool(normalized),
}
if isinstance(value, bytes):
return {"binary_bytes": len(value)}
if isinstance(value, (int, float, bool)) or value is None:
return value
return {"value_type": type(value).__name__}
def set_processing_log_autocommit(session: Session, enabled: bool) -> None:
@@ -82,8 +125,8 @@ def log_processing_event(
document_filename=_trim(resolved_document_filename, MAX_DOCUMENT_FILENAME_LENGTH),
provider_id=_trim(provider_id, MAX_PROVIDER_LENGTH),
model_name=_trim(model_name, MAX_MODEL_LENGTH),
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH),
response_text=_trim(response_text, MAX_RESPONSE_LENGTH),
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH) if settings.processing_log_store_model_io_text else None,
response_text=_trim(response_text, MAX_RESPONSE_LENGTH) if settings.processing_log_store_model_io_text else None,
payload_json=_safe_payload(payload_json),
)
session.add(entry)

View File

@@ -0,0 +1,42 @@
"""Redis-backed fixed-window rate limiter helpers for sensitive API operations."""
import time
from redis.exceptions import RedisError
from app.worker.queue import get_redis
def _rate_limit_key(*, scope: str, subject: str, window_id: int) -> str:
"""Builds a stable Redis key for one scope, subject, and fixed time window."""
return f"dcm:rate-limit:{scope}:{subject}:{window_id}"
def increment_rate_limit(
*,
scope: str,
subject: str,
limit: int,
window_seconds: int = 60,
) -> tuple[int, int]:
"""Increments one rate bucket and returns current count with configured limit."""
bounded_limit = max(0, int(limit))
if bounded_limit == 0:
return (0, 0)
bounded_window = max(1, int(window_seconds))
current_window = int(time.time() // bounded_window)
key = _rate_limit_key(scope=scope, subject=subject, window_id=current_window)
redis_client = get_redis()
try:
pipeline = redis_client.pipeline(transaction=True)
pipeline.incr(key, 1)
pipeline.expire(key, bounded_window + 5)
count_value, _ = pipeline.execute()
except RedisError as error:
raise RuntimeError("Rate limiter backend unavailable") from error
return (int(count_value), bounded_limit)