356 lines
13 KiB
Python
356 lines
13 KiB
Python
"""Application settings and environment configuration."""
|
|
|
|
from functools import lru_cache
|
|
import ipaddress
|
|
from pathlib import Path
|
|
import socket
|
|
from urllib.parse import urlparse, urlunparse
|
|
|
|
from pydantic import Field
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Defines runtime configuration values loaded from environment variables."""
|
|
|
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
|
|
|
app_name: str = "dcm-dms"
|
|
app_env: str = "development"
|
|
database_url: str = "postgresql+psycopg://dcm:dcm@db:5432/dcm"
|
|
redis_url: str = "redis://redis:6379/0"
|
|
redis_security_mode: str = "auto"
|
|
redis_tls_mode: str = "auto"
|
|
allow_development_anonymous_user_access: bool = True
|
|
storage_root: Path = Path("/data/storage")
|
|
upload_chunk_size: int = 4 * 1024 * 1024
|
|
max_upload_files_per_request: int = 50
|
|
max_upload_file_size_bytes: int = 25 * 1024 * 1024
|
|
max_upload_request_size_bytes: int = 100 * 1024 * 1024
|
|
max_zip_members: int = 250
|
|
max_zip_depth: int = 2
|
|
max_zip_descendants_per_root: int = 1000
|
|
max_zip_member_uncompressed_bytes: int = 25 * 1024 * 1024
|
|
max_zip_total_uncompressed_bytes: int = 150 * 1024 * 1024
|
|
max_zip_compression_ratio: float = 120.0
|
|
max_text_length: int = 500_000
|
|
admin_api_token: str = ""
|
|
user_api_token: str = ""
|
|
provider_base_url_allowlist: list[str] = Field(default_factory=lambda: ["api.openai.com"])
|
|
provider_base_url_allow_http: bool = False
|
|
provider_base_url_allow_private_network: bool = False
|
|
processing_log_max_document_sessions: int = 20
|
|
processing_log_max_unbound_entries: int = 400
|
|
processing_log_max_payload_chars: int = 4096
|
|
processing_log_max_text_chars: int = 12000
|
|
default_openai_base_url: str = "https://api.openai.com/v1"
|
|
default_openai_model: str = "gpt-4.1-mini"
|
|
default_openai_timeout_seconds: int = 45
|
|
default_openai_handwriting_enabled: bool = True
|
|
default_openai_api_key: str = ""
|
|
app_settings_encryption_key: str = ""
|
|
default_summary_model: str = "gpt-4.1-mini"
|
|
default_routing_model: str = "gpt-4.1-mini"
|
|
typesense_protocol: str = "http"
|
|
typesense_host: str = "typesense"
|
|
typesense_port: int = 8108
|
|
typesense_api_key: str = ""
|
|
typesense_collection_name: str = "documents"
|
|
typesense_timeout_seconds: int = 120
|
|
typesense_num_retries: int = 0
|
|
public_base_url: str = "http://localhost:8000"
|
|
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173", "http://localhost:3000"])
|
|
|
|
|
|
LOCAL_HOSTNAME_SUFFIXES = (".local", ".internal", ".home.arpa")
|
|
SCRIPT_CAPABLE_INLINE_MIME_TYPES = frozenset(
|
|
{
|
|
"application/ecmascript",
|
|
"application/javascript",
|
|
"application/x-javascript",
|
|
"application/xhtml+xml",
|
|
"image/svg+xml",
|
|
"text/ecmascript",
|
|
"text/html",
|
|
"text/javascript",
|
|
}
|
|
)
|
|
SCRIPT_CAPABLE_XML_MIME_TYPES = frozenset({"application/xml", "text/xml"})
|
|
REDIS_SECURITY_MODES = frozenset({"auto", "strict", "compat"})
|
|
REDIS_TLS_MODES = frozenset({"auto", "required", "allow_insecure"})
|
|
|
|
|
|
def _is_production_environment(app_env: str) -> bool:
|
|
"""Returns whether the runtime environment should enforce production-only security gates."""
|
|
|
|
normalized = app_env.strip().lower()
|
|
return normalized in {"production", "prod"}
|
|
|
|
|
|
def _normalize_redis_security_mode(raw_mode: str) -> str:
|
|
"""Normalizes Redis security mode values into one supported mode."""
|
|
|
|
normalized = raw_mode.strip().lower()
|
|
if normalized not in REDIS_SECURITY_MODES:
|
|
return "auto"
|
|
return normalized
|
|
|
|
|
|
def _normalize_redis_tls_mode(raw_mode: str) -> str:
|
|
"""Normalizes Redis TLS mode values into one supported mode."""
|
|
|
|
normalized = raw_mode.strip().lower()
|
|
if normalized not in REDIS_TLS_MODES:
|
|
return "auto"
|
|
return normalized
|
|
|
|
|
|
def validate_redis_url_security(
|
|
redis_url: str,
|
|
*,
|
|
app_env: str | None = None,
|
|
security_mode: str | None = None,
|
|
tls_mode: str | None = None,
|
|
) -> str:
|
|
"""Validates Redis URL security posture with production fail-closed defaults."""
|
|
|
|
settings = get_settings()
|
|
resolved_app_env = app_env if app_env is not None else settings.app_env
|
|
resolved_security_mode = (
|
|
_normalize_redis_security_mode(security_mode)
|
|
if security_mode is not None
|
|
else _normalize_redis_security_mode(settings.redis_security_mode)
|
|
)
|
|
resolved_tls_mode = (
|
|
_normalize_redis_tls_mode(tls_mode)
|
|
if tls_mode is not None
|
|
else _normalize_redis_tls_mode(settings.redis_tls_mode)
|
|
)
|
|
|
|
candidate = redis_url.strip()
|
|
if not candidate:
|
|
raise ValueError("Redis URL must not be empty")
|
|
|
|
parsed = urlparse(candidate)
|
|
scheme = parsed.scheme.lower()
|
|
if scheme not in {"redis", "rediss"}:
|
|
raise ValueError("Redis URL must use redis:// or rediss://")
|
|
if not parsed.hostname:
|
|
raise ValueError("Redis URL must include a hostname")
|
|
|
|
strict_security = (
|
|
resolved_security_mode == "strict"
|
|
or (resolved_security_mode == "auto" and _is_production_environment(resolved_app_env))
|
|
)
|
|
require_tls = (
|
|
resolved_tls_mode == "required"
|
|
or (resolved_tls_mode == "auto" and strict_security)
|
|
)
|
|
has_password = bool(parsed.password and parsed.password.strip())
|
|
uses_tls = scheme == "rediss"
|
|
|
|
if strict_security and not has_password:
|
|
raise ValueError("Redis URL must include authentication when security mode is strict")
|
|
if require_tls and not uses_tls:
|
|
raise ValueError("Redis URL must use rediss:// when TLS is required")
|
|
|
|
return candidate
|
|
|
|
|
|
def is_inline_preview_mime_type_safe(mime_type: str) -> bool:
|
|
"""Returns whether a MIME type is safe to serve inline from untrusted document uploads."""
|
|
|
|
normalized = mime_type.split(";", 1)[0].strip().lower() if mime_type else ""
|
|
if not normalized:
|
|
return False
|
|
if normalized in SCRIPT_CAPABLE_INLINE_MIME_TYPES:
|
|
return False
|
|
if normalized in SCRIPT_CAPABLE_XML_MIME_TYPES or normalized.endswith("+xml"):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _normalize_allowlist(allowlist: object) -> tuple[str, ...]:
|
|
"""Normalizes host allowlist entries to lowercase DNS labels."""
|
|
|
|
if not isinstance(allowlist, (list, tuple, set)):
|
|
return ()
|
|
normalized = {
|
|
candidate.strip().lower().rstrip(".")
|
|
for candidate in allowlist
|
|
if isinstance(candidate, str) and candidate.strip()
|
|
}
|
|
return tuple(sorted(normalized))
|
|
|
|
|
|
def _host_matches_allowlist(hostname: str, allowlist: tuple[str, ...]) -> bool:
|
|
"""Returns whether a hostname is included by an exact or subdomain allowlist rule."""
|
|
|
|
if not allowlist:
|
|
return False
|
|
candidate = hostname.lower().rstrip(".")
|
|
for allowed_host in allowlist:
|
|
if candidate == allowed_host or candidate.endswith(f".{allowed_host}"):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_private_or_special_ip(value: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
|
"""Returns whether an IP belongs to private, loopback, link-local, or reserved ranges."""
|
|
|
|
return (
|
|
value.is_private
|
|
or value.is_loopback
|
|
or value.is_link_local
|
|
or value.is_multicast
|
|
or value.is_reserved
|
|
or value.is_unspecified
|
|
)
|
|
|
|
|
|
def _validate_resolved_host_ips(hostname: str, port: int, allow_private_network: bool) -> None:
|
|
"""Resolves hostnames and rejects private or special addresses when private network access is disabled."""
|
|
|
|
try:
|
|
addresses = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
|
|
except socket.gaierror as error:
|
|
raise ValueError(f"Provider base URL host cannot be resolved: {hostname}") from error
|
|
|
|
resolved_ips: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
|
|
for entry in addresses:
|
|
sockaddr = entry[4]
|
|
if not sockaddr:
|
|
continue
|
|
ip_text = sockaddr[0]
|
|
try:
|
|
resolved_ips.add(ipaddress.ip_address(ip_text))
|
|
except ValueError:
|
|
continue
|
|
|
|
if not resolved_ips:
|
|
raise ValueError(f"Provider base URL host resolved without usable IP addresses: {hostname}")
|
|
|
|
if allow_private_network:
|
|
return
|
|
|
|
blocked = [ip for ip in resolved_ips if _is_private_or_special_ip(ip)]
|
|
if blocked:
|
|
blocked_text = ", ".join(str(ip) for ip in blocked)
|
|
raise ValueError(f"Provider base URL resolves to private or special IP addresses: {blocked_text}")
|
|
|
|
|
|
def _normalize_and_validate_provider_base_url(
|
|
raw_value: str,
|
|
allowlist: tuple[str, ...],
|
|
allow_http: bool,
|
|
allow_private_network: bool,
|
|
resolve_dns: bool,
|
|
) -> str:
|
|
"""Normalizes and validates provider base URLs with SSRF-safe scheme and host checks."""
|
|
|
|
trimmed = raw_value.strip().rstrip("/")
|
|
if not trimmed:
|
|
raise ValueError("Provider base URL must not be empty")
|
|
|
|
parsed = urlparse(trimmed)
|
|
scheme = parsed.scheme.lower()
|
|
if scheme not in {"http", "https"}:
|
|
raise ValueError("Provider base URL must use http or https")
|
|
if scheme == "http" and not allow_http:
|
|
raise ValueError("Provider base URL must use https")
|
|
if parsed.query or parsed.fragment:
|
|
raise ValueError("Provider base URL must not include query strings or fragments")
|
|
if parsed.username or parsed.password:
|
|
raise ValueError("Provider base URL must not include embedded credentials")
|
|
|
|
hostname = (parsed.hostname or "").lower().rstrip(".")
|
|
if not hostname:
|
|
raise ValueError("Provider base URL must include a hostname")
|
|
if allowlist and not _host_matches_allowlist(hostname, allowlist):
|
|
allowed_hosts = ", ".join(allowlist)
|
|
raise ValueError(f"Provider base URL host is not in allowlist: {hostname}. Allowed hosts: {allowed_hosts}")
|
|
|
|
if hostname == "localhost" or hostname.endswith(LOCAL_HOSTNAME_SUFFIXES):
|
|
if not allow_private_network:
|
|
raise ValueError("Provider base URL must not target local or internal hostnames")
|
|
|
|
try:
|
|
ip_host = ipaddress.ip_address(hostname)
|
|
except ValueError:
|
|
ip_host = None
|
|
|
|
if ip_host is not None:
|
|
if not allow_private_network and _is_private_or_special_ip(ip_host):
|
|
raise ValueError("Provider base URL must not target private or special IP addresses")
|
|
elif resolve_dns:
|
|
resolved_port = parsed.port
|
|
if resolved_port is None:
|
|
resolved_port = 443 if scheme == "https" else 80
|
|
_validate_resolved_host_ips(
|
|
hostname=hostname,
|
|
port=resolved_port,
|
|
allow_private_network=allow_private_network,
|
|
)
|
|
|
|
path = (parsed.path or "").rstrip("/")
|
|
if not path.endswith("/v1"):
|
|
path = f"{path}/v1" if path else "/v1"
|
|
|
|
normalized_hostname = hostname
|
|
if ":" in normalized_hostname and not normalized_hostname.startswith("["):
|
|
normalized_hostname = f"[{normalized_hostname}]"
|
|
netloc = f"{normalized_hostname}:{parsed.port}" if parsed.port is not None else normalized_hostname
|
|
return urlunparse((scheme, netloc, path, "", "", ""))
|
|
|
|
|
|
@lru_cache(maxsize=256)
|
|
def _normalize_and_validate_provider_base_url_cached(
|
|
raw_value: str,
|
|
allowlist: tuple[str, ...],
|
|
allow_http: bool,
|
|
allow_private_network: bool,
|
|
) -> str:
|
|
"""Caches provider URL validation results for non-DNS-resolved checks."""
|
|
|
|
return _normalize_and_validate_provider_base_url(
|
|
raw_value=raw_value,
|
|
allowlist=allowlist,
|
|
allow_http=allow_http,
|
|
allow_private_network=allow_private_network,
|
|
resolve_dns=False,
|
|
)
|
|
|
|
|
|
def normalize_and_validate_provider_base_url(raw_value: str, *, resolve_dns: bool = False) -> str:
|
|
"""Validates and normalizes provider base URL values using configured SSRF protections."""
|
|
|
|
settings = get_settings()
|
|
allowlist = _normalize_allowlist(settings.provider_base_url_allowlist)
|
|
allow_http = settings.provider_base_url_allow_http if isinstance(settings.provider_base_url_allow_http, bool) else False
|
|
allow_private_network = (
|
|
settings.provider_base_url_allow_private_network
|
|
if isinstance(settings.provider_base_url_allow_private_network, bool)
|
|
else False
|
|
)
|
|
if resolve_dns:
|
|
return _normalize_and_validate_provider_base_url(
|
|
raw_value=raw_value,
|
|
allowlist=allowlist,
|
|
allow_http=allow_http,
|
|
allow_private_network=allow_private_network,
|
|
resolve_dns=True,
|
|
)
|
|
return _normalize_and_validate_provider_base_url_cached(
|
|
raw_value=raw_value,
|
|
allowlist=allowlist,
|
|
allow_http=allow_http,
|
|
allow_private_network=allow_private_network,
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_settings() -> Settings:
|
|
"""Returns a cached settings object for dependency injection and service access."""
|
|
|
|
return Settings()
|