244 lines
11 KiB
Python
244 lines
11 KiB
Python
"""API routes for managing persistent single-user application settings."""
|
|
|
|
from fastapi import APIRouter
|
|
|
|
from app.schemas.settings import (
|
|
AppSettingsUpdateRequest,
|
|
AppSettingsResponse,
|
|
DisplaySettingsResponse,
|
|
HandwritingSettingsResponse,
|
|
HandwritingStyleSettingsResponse,
|
|
HandwritingSettingsUpdateRequest,
|
|
OcrTaskSettingsResponse,
|
|
ProcessingLogRetentionSettingsResponse,
|
|
ProviderSettingsResponse,
|
|
RoutingTaskSettingsResponse,
|
|
SummaryTaskSettingsResponse,
|
|
TaskSettingsResponse,
|
|
UploadDefaultsResponse,
|
|
)
|
|
from app.services.app_settings import (
|
|
TASK_OCR_HANDWRITING,
|
|
TASK_ROUTING_CLASSIFICATION,
|
|
TASK_SUMMARY_GENERATION,
|
|
read_app_settings,
|
|
reset_app_settings,
|
|
update_app_settings,
|
|
update_handwriting_settings,
|
|
)
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _build_response(payload: dict) -> AppSettingsResponse:
|
|
"""Converts internal settings dictionaries into API response models."""
|
|
|
|
upload_defaults_payload = payload.get("upload_defaults", {})
|
|
display_payload = payload.get("display", {})
|
|
processing_log_retention_payload = payload.get("processing_log_retention", {})
|
|
providers_payload = payload.get("providers", [])
|
|
tasks_payload = payload.get("tasks", {})
|
|
handwriting_style_payload = payload.get("handwriting_style_clustering", {})
|
|
ocr_payload = tasks_payload.get(TASK_OCR_HANDWRITING, {})
|
|
summary_payload = tasks_payload.get(TASK_SUMMARY_GENERATION, {})
|
|
routing_payload = tasks_payload.get(TASK_ROUTING_CLASSIFICATION, {})
|
|
|
|
return AppSettingsResponse(
|
|
upload_defaults=UploadDefaultsResponse(
|
|
logical_path=str(upload_defaults_payload.get("logical_path", "Inbox")),
|
|
tags=[
|
|
str(tag).strip()
|
|
for tag in upload_defaults_payload.get("tags", [])
|
|
if isinstance(tag, str) and tag.strip()
|
|
],
|
|
),
|
|
display=DisplaySettingsResponse(
|
|
cards_per_page=int(display_payload.get("cards_per_page", 12)),
|
|
log_typing_animation_enabled=bool(display_payload.get("log_typing_animation_enabled", True)),
|
|
),
|
|
processing_log_retention=ProcessingLogRetentionSettingsResponse(
|
|
keep_document_sessions=int(processing_log_retention_payload.get("keep_document_sessions", 2)),
|
|
keep_unbound_entries=int(processing_log_retention_payload.get("keep_unbound_entries", 80)),
|
|
),
|
|
handwriting_style_clustering=HandwritingStyleSettingsResponse(
|
|
enabled=bool(handwriting_style_payload.get("enabled", True)),
|
|
embed_model=str(handwriting_style_payload.get("embed_model", "ts/clip-vit-b-p32")),
|
|
neighbor_limit=int(handwriting_style_payload.get("neighbor_limit", 8)),
|
|
match_min_similarity=float(handwriting_style_payload.get("match_min_similarity", 0.86)),
|
|
bootstrap_match_min_similarity=float(
|
|
handwriting_style_payload.get("bootstrap_match_min_similarity", 0.89)
|
|
),
|
|
bootstrap_sample_size=int(handwriting_style_payload.get("bootstrap_sample_size", 3)),
|
|
image_max_side=int(handwriting_style_payload.get("image_max_side", 1024)),
|
|
),
|
|
predefined_paths=[
|
|
{
|
|
"value": str(item.get("value", "")).strip(),
|
|
"global_shared": bool(item.get("global_shared", False)),
|
|
}
|
|
for item in payload.get("predefined_paths", [])
|
|
if isinstance(item, dict) and str(item.get("value", "")).strip()
|
|
],
|
|
predefined_tags=[
|
|
{
|
|
"value": str(item.get("value", "")).strip(),
|
|
"global_shared": bool(item.get("global_shared", False)),
|
|
}
|
|
for item in payload.get("predefined_tags", [])
|
|
if isinstance(item, dict) and str(item.get("value", "")).strip()
|
|
],
|
|
providers=[
|
|
ProviderSettingsResponse(
|
|
id=str(provider.get("id", "")),
|
|
label=str(provider.get("label", "")),
|
|
provider_type=str(provider.get("provider_type", "openai_compatible")),
|
|
base_url=str(provider.get("base_url", "https://api.openai.com/v1")),
|
|
timeout_seconds=int(provider.get("timeout_seconds", 45)),
|
|
api_key_set=bool(provider.get("api_key_set", False)),
|
|
api_key_masked=str(provider.get("api_key_masked", "")),
|
|
)
|
|
for provider in providers_payload
|
|
],
|
|
tasks=TaskSettingsResponse(
|
|
ocr_handwriting=OcrTaskSettingsResponse(
|
|
enabled=bool(ocr_payload.get("enabled", True)),
|
|
provider_id=str(ocr_payload.get("provider_id", "openai-default")),
|
|
model=str(ocr_payload.get("model", "gpt-4.1-mini")),
|
|
prompt=str(ocr_payload.get("prompt", "")),
|
|
),
|
|
summary_generation=SummaryTaskSettingsResponse(
|
|
enabled=bool(summary_payload.get("enabled", True)),
|
|
provider_id=str(summary_payload.get("provider_id", "openai-default")),
|
|
model=str(summary_payload.get("model", "gpt-4.1-mini")),
|
|
prompt=str(summary_payload.get("prompt", "")),
|
|
max_input_tokens=int(summary_payload.get("max_input_tokens", 8000)),
|
|
),
|
|
routing_classification=RoutingTaskSettingsResponse(
|
|
enabled=bool(routing_payload.get("enabled", True)),
|
|
provider_id=str(routing_payload.get("provider_id", "openai-default")),
|
|
model=str(routing_payload.get("model", "gpt-4.1-mini")),
|
|
prompt=str(routing_payload.get("prompt", "")),
|
|
neighbor_count=int(routing_payload.get("neighbor_count", 8)),
|
|
neighbor_min_similarity=float(routing_payload.get("neighbor_min_similarity", 0.84)),
|
|
auto_apply_confidence_threshold=float(routing_payload.get("auto_apply_confidence_threshold", 0.78)),
|
|
auto_apply_neighbor_similarity_threshold=float(
|
|
routing_payload.get("auto_apply_neighbor_similarity_threshold", 0.55)
|
|
),
|
|
neighbor_path_override_enabled=bool(routing_payload.get("neighbor_path_override_enabled", True)),
|
|
neighbor_path_override_min_similarity=float(
|
|
routing_payload.get("neighbor_path_override_min_similarity", 0.86)
|
|
),
|
|
neighbor_path_override_min_gap=float(routing_payload.get("neighbor_path_override_min_gap", 0.04)),
|
|
neighbor_path_override_max_confidence=float(
|
|
routing_payload.get("neighbor_path_override_max_confidence", 0.9)
|
|
),
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
@router.get("", response_model=AppSettingsResponse)
|
|
def get_app_settings() -> AppSettingsResponse:
|
|
"""Returns persisted provider and per-task settings configuration."""
|
|
|
|
return _build_response(read_app_settings())
|
|
|
|
|
|
@router.patch("", response_model=AppSettingsResponse)
|
|
def set_app_settings(payload: AppSettingsUpdateRequest) -> AppSettingsResponse:
|
|
"""Updates providers and task settings and returns resulting persisted configuration."""
|
|
|
|
providers_payload = None
|
|
if payload.providers is not None:
|
|
providers_payload = [provider.model_dump() for provider in payload.providers]
|
|
|
|
tasks_payload = None
|
|
if payload.tasks is not None:
|
|
tasks_payload = payload.tasks.model_dump(exclude_none=True)
|
|
|
|
upload_defaults_payload = None
|
|
if payload.upload_defaults is not None:
|
|
upload_defaults_payload = payload.upload_defaults.model_dump(exclude_none=True)
|
|
|
|
display_payload = None
|
|
if payload.display is not None:
|
|
display_payload = payload.display.model_dump(exclude_none=True)
|
|
|
|
processing_log_retention_payload = None
|
|
if payload.processing_log_retention is not None:
|
|
processing_log_retention_payload = payload.processing_log_retention.model_dump(exclude_none=True)
|
|
|
|
handwriting_style_payload = None
|
|
if payload.handwriting_style_clustering is not None:
|
|
handwriting_style_payload = payload.handwriting_style_clustering.model_dump(exclude_none=True)
|
|
predefined_paths_payload = None
|
|
if payload.predefined_paths is not None:
|
|
predefined_paths_payload = [item.model_dump(exclude_none=True) for item in payload.predefined_paths]
|
|
predefined_tags_payload = None
|
|
if payload.predefined_tags is not None:
|
|
predefined_tags_payload = [item.model_dump(exclude_none=True) for item in payload.predefined_tags]
|
|
|
|
updated = update_app_settings(
|
|
providers=providers_payload,
|
|
tasks=tasks_payload,
|
|
upload_defaults=upload_defaults_payload,
|
|
display=display_payload,
|
|
processing_log_retention=processing_log_retention_payload,
|
|
handwriting_style=handwriting_style_payload,
|
|
predefined_paths=predefined_paths_payload,
|
|
predefined_tags=predefined_tags_payload,
|
|
)
|
|
return _build_response(updated)
|
|
|
|
|
|
@router.post("/reset", response_model=AppSettingsResponse)
|
|
def reset_settings_to_defaults() -> AppSettingsResponse:
|
|
"""Resets all persisted settings to default providers and task bindings."""
|
|
|
|
return _build_response(reset_app_settings())
|
|
|
|
|
|
@router.patch("/handwriting", response_model=AppSettingsResponse)
|
|
def set_handwriting_settings(payload: HandwritingSettingsUpdateRequest) -> AppSettingsResponse:
|
|
"""Updates handwriting transcription settings and returns the resulting configuration."""
|
|
|
|
updated = update_handwriting_settings(
|
|
enabled=payload.enabled,
|
|
openai_base_url=payload.openai_base_url,
|
|
openai_model=payload.openai_model,
|
|
openai_timeout_seconds=payload.openai_timeout_seconds,
|
|
openai_api_key=payload.openai_api_key,
|
|
clear_openai_api_key=payload.clear_openai_api_key,
|
|
)
|
|
return _build_response(updated)
|
|
|
|
|
|
@router.get("/handwriting", response_model=HandwritingSettingsResponse)
|
|
def get_handwriting_settings() -> HandwritingSettingsResponse:
|
|
"""Returns legacy handwriting response shape for compatibility with older clients."""
|
|
|
|
payload = _build_response(read_app_settings())
|
|
fallback_provider = ProviderSettingsResponse(
|
|
id="openai-default",
|
|
label="OpenAI Default",
|
|
provider_type="openai_compatible",
|
|
base_url="https://api.openai.com/v1",
|
|
timeout_seconds=45,
|
|
api_key_set=False,
|
|
api_key_masked="",
|
|
)
|
|
ocr = payload.tasks.ocr_handwriting
|
|
provider = next((item for item in payload.providers if item.id == ocr.provider_id), None)
|
|
if provider is None:
|
|
provider = payload.providers[0] if payload.providers else fallback_provider
|
|
return HandwritingSettingsResponse(
|
|
provider=provider.provider_type,
|
|
enabled=ocr.enabled,
|
|
openai_base_url=provider.base_url,
|
|
openai_model=ocr.model,
|
|
openai_timeout_seconds=provider.timeout_seconds,
|
|
openai_api_key_set=provider.api_key_set,
|
|
openai_api_key_masked=provider.api_key_masked,
|
|
)
|