Initial commit

This commit is contained in:
2026-02-21 09:44:18 -03:00
commit 5dfc2cbd85
65 changed files with 11989 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Domain services package for storage, extraction, and classification logic."""

View File

@@ -0,0 +1,885 @@
"""Persistent single-user application settings service backed by host-mounted storage."""
import json
import re
from pathlib import Path
from typing import Any
from app.core.config import get_settings
settings = get_settings()
TASK_OCR_HANDWRITING = "ocr_handwriting"
TASK_SUMMARY_GENERATION = "summary_generation"
TASK_ROUTING_CLASSIFICATION = "routing_classification"
HANDWRITING_STYLE_SETTINGS_KEY = "handwriting_style_clustering"
PREDEFINED_PATHS_SETTINGS_KEY = "predefined_paths"
PREDEFINED_TAGS_SETTINGS_KEY = "predefined_tags"
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL = "ts/clip-vit-b-p32"
DEFAULT_OCR_PROMPT = (
"You are an expert at reading messy handwritten notes, including hard-to-read writing.\n"
"Task: transcribe the handwriting as exactly as possible.\n\n"
"Rules:\n"
"- Output ONLY the transcription in German, no commentary.\n"
"- Preserve original line breaks where they clearly exist.\n"
"- Do NOT translate or correct grammar or spelling.\n"
"- If a word or character is unclear, wrap your best guess in [[? ... ?]].\n"
"- If something is unreadable, write [[?unleserlich?]] in its place."
)
DEFAULT_SUMMARY_PROMPT = (
"You summarize documents for indexing and routing.\n"
"Return concise markdown with key entities, purpose, and document category hints.\n"
"Do not invent facts and do not include any explanation outside the summary."
)
DEFAULT_ROUTING_PROMPT = (
"You classify one document into an existing logical path and tags.\n"
"Prefer existing paths and tags when possible.\n"
"If the evidence is weak, keep chosen_path as null and use suggestions instead.\n"
"Return JSON only with this exact shape:\n"
"{\n"
" \"chosen_path\": string | null,\n"
" \"chosen_tags\": string[],\n"
" \"suggested_new_paths\": string[],\n"
" \"suggested_new_tags\": string[],\n"
" \"confidence\": number\n"
"}\n"
"Confidence must be between 0 and 1."
)
def _default_settings() -> dict[str, Any]:
"""Builds default settings including providers and model task bindings."""
return {
"upload_defaults": {
"logical_path": "Inbox",
"tags": [],
},
"display": {
"cards_per_page": 12,
"log_typing_animation_enabled": True,
},
PREDEFINED_PATHS_SETTINGS_KEY: [],
PREDEFINED_TAGS_SETTINGS_KEY: [],
HANDWRITING_STYLE_SETTINGS_KEY: {
"enabled": True,
"embed_model": DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
"neighbor_limit": 8,
"match_min_similarity": 0.86,
"bootstrap_match_min_similarity": 0.89,
"bootstrap_sample_size": 3,
"image_max_side": 1024,
},
"providers": [
{
"id": "openai-default",
"label": "OpenAI Default",
"provider_type": "openai_compatible",
"base_url": settings.default_openai_base_url,
"timeout_seconds": settings.default_openai_timeout_seconds,
"api_key": settings.default_openai_api_key,
}
],
"tasks": {
TASK_OCR_HANDWRITING: {
"enabled": settings.default_openai_handwriting_enabled,
"provider_id": "openai-default",
"model": settings.default_openai_model,
"prompt": DEFAULT_OCR_PROMPT,
},
TASK_SUMMARY_GENERATION: {
"enabled": True,
"provider_id": "openai-default",
"model": settings.default_summary_model,
"prompt": DEFAULT_SUMMARY_PROMPT,
"max_input_tokens": 8000,
},
TASK_ROUTING_CLASSIFICATION: {
"enabled": True,
"provider_id": "openai-default",
"model": settings.default_routing_model,
"prompt": DEFAULT_ROUTING_PROMPT,
"neighbor_count": 8,
"neighbor_min_similarity": 0.84,
"auto_apply_confidence_threshold": 0.78,
"auto_apply_neighbor_similarity_threshold": 0.55,
"neighbor_path_override_enabled": True,
"neighbor_path_override_min_similarity": 0.86,
"neighbor_path_override_min_gap": 0.04,
"neighbor_path_override_max_confidence": 0.9,
},
},
}
def _settings_path() -> Path:
"""Returns the absolute path of the persisted settings file."""
return settings.storage_root / "settings.json"
def _clamp_timeout(value: int) -> int:
"""Clamps timeout values to a safe and practical range."""
return max(5, min(180, value))
def _clamp_input_tokens(value: int) -> int:
"""Clamps per-request summary input token budget values to practical bounds."""
return max(512, min(64000, value))
def _clamp_neighbor_count(value: int) -> int:
"""Clamps nearest-neighbor lookup count for routing classification."""
return max(1, min(40, value))
def _clamp_cards_per_page(value: int) -> int:
"""Clamps dashboard cards-per-page display setting to practical bounds."""
return max(1, min(200, value))
def _clamp_predefined_entries_limit(value: int) -> int:
"""Clamps maximum count for predefined tag/path catalog entries."""
return max(1, min(2000, value))
def _clamp_handwriting_style_neighbor_limit(value: int) -> int:
"""Clamps handwriting-style nearest-neighbor count used for style matching."""
return max(1, min(32, value))
def _clamp_handwriting_style_sample_size(value: int) -> int:
"""Clamps handwriting-style bootstrap sample size used for stricter matching."""
return max(1, min(30, value))
def _clamp_handwriting_style_image_max_side(value: int) -> int:
"""Clamps handwriting-style image normalization max-side pixel size."""
return max(256, min(4096, value))
def _clamp_probability(value: float, fallback: float) -> float:
"""Clamps probability-like numbers to the range [0, 1]."""
try:
parsed = float(value)
except (TypeError, ValueError):
return fallback
return max(0.0, min(1.0, parsed))
def _safe_int(value: Any, fallback: int) -> int:
"""Safely converts arbitrary values to integers with fallback handling."""
try:
return int(value)
except (TypeError, ValueError):
return fallback
def _normalize_provider_id(value: str | None, fallback: str) -> str:
"""Normalizes provider identifiers into stable lowercase slug values."""
candidate = (value or "").strip().lower()
candidate = re.sub(r"[^a-z0-9_-]+", "-", candidate).strip("-")
return candidate or fallback
def _mask_api_key(value: str) -> str:
"""Masks a secret API key while retaining enough characters for identification."""
if not value:
return ""
if len(value) <= 6:
return "*" * len(value)
return f"{value[:4]}...{value[-2:]}"
def _normalize_provider(
payload: dict[str, Any],
fallback_id: str,
fallback_values: dict[str, Any],
) -> dict[str, Any]:
"""Normalizes one provider payload to a stable shape with bounds and defaults."""
defaults = _default_settings()["providers"][0]
provider_id = _normalize_provider_id(str(payload.get("id", fallback_id)), fallback_id)
provider_type = str(payload.get("provider_type", fallback_values.get("provider_type", defaults["provider_type"]))).strip()
if provider_type != "openai_compatible":
provider_type = "openai_compatible"
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
api_key = str(api_key_value).strip() if api_key_value is not None else ""
return {
"id": provider_id,
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
"provider_type": provider_type,
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
or defaults["base_url"],
"timeout_seconds": _clamp_timeout(
_safe_int(
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),
defaults["timeout_seconds"],
)
),
"api_key": api_key,
}
def _normalize_ocr_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes OCR task settings while enforcing valid provider references."""
defaults = _default_settings()["tasks"][TASK_OCR_HANDWRITING]
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
if provider_id not in provider_ids:
provider_id = provider_ids[0]
return {
"enabled": bool(payload.get("enabled", defaults["enabled"])),
"provider_id": provider_id,
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
}
def _normalize_summary_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes summary task settings while enforcing valid provider references."""
defaults = _default_settings()["tasks"][TASK_SUMMARY_GENERATION]
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
if provider_id not in provider_ids:
provider_id = provider_ids[0]
raw_max_tokens = payload.get("max_input_tokens")
if raw_max_tokens is None:
legacy_chars = _safe_int(payload.get("max_source_chars", 0), 0)
if legacy_chars > 0:
raw_max_tokens = max(512, legacy_chars // 4)
else:
raw_max_tokens = defaults["max_input_tokens"]
return {
"enabled": bool(payload.get("enabled", defaults["enabled"])),
"provider_id": provider_id,
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
"max_input_tokens": _clamp_input_tokens(
_safe_int(raw_max_tokens, defaults["max_input_tokens"])
),
}
def _normalize_routing_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes routing task settings while enforcing valid provider references."""
defaults = _default_settings()["tasks"][TASK_ROUTING_CLASSIFICATION]
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
if provider_id not in provider_ids:
provider_id = provider_ids[0]
return {
"enabled": bool(payload.get("enabled", defaults["enabled"])),
"provider_id": provider_id,
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
"neighbor_count": _clamp_neighbor_count(
_safe_int(payload.get("neighbor_count", defaults["neighbor_count"]), defaults["neighbor_count"])
),
"neighbor_min_similarity": _clamp_probability(
payload.get("neighbor_min_similarity", defaults["neighbor_min_similarity"]),
defaults["neighbor_min_similarity"],
),
"auto_apply_confidence_threshold": _clamp_probability(
payload.get("auto_apply_confidence_threshold", defaults["auto_apply_confidence_threshold"]),
defaults["auto_apply_confidence_threshold"],
),
"auto_apply_neighbor_similarity_threshold": _clamp_probability(
payload.get(
"auto_apply_neighbor_similarity_threshold",
defaults["auto_apply_neighbor_similarity_threshold"],
),
defaults["auto_apply_neighbor_similarity_threshold"],
),
"neighbor_path_override_enabled": bool(
payload.get("neighbor_path_override_enabled", defaults["neighbor_path_override_enabled"])
),
"neighbor_path_override_min_similarity": _clamp_probability(
payload.get(
"neighbor_path_override_min_similarity",
defaults["neighbor_path_override_min_similarity"],
),
defaults["neighbor_path_override_min_similarity"],
),
"neighbor_path_override_min_gap": _clamp_probability(
payload.get("neighbor_path_override_min_gap", defaults["neighbor_path_override_min_gap"]),
defaults["neighbor_path_override_min_gap"],
),
"neighbor_path_override_max_confidence": _clamp_probability(
payload.get(
"neighbor_path_override_max_confidence",
defaults["neighbor_path_override_max_confidence"],
),
defaults["neighbor_path_override_max_confidence"],
),
}
def _normalize_tasks(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes task settings map for OCR, summarization, and routing tasks."""
if not isinstance(payload, dict):
payload = {}
return {
TASK_OCR_HANDWRITING: _normalize_ocr_task(payload.get(TASK_OCR_HANDWRITING, {}), provider_ids),
TASK_SUMMARY_GENERATION: _normalize_summary_task(payload.get(TASK_SUMMARY_GENERATION, {}), provider_ids),
TASK_ROUTING_CLASSIFICATION: _normalize_routing_task(payload.get(TASK_ROUTING_CLASSIFICATION, {}), provider_ids),
}
def _normalize_upload_defaults(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
"""Normalizes upload default destination path and tags."""
if not isinstance(payload, dict):
payload = {}
default_path = str(defaults.get("logical_path", "Inbox")).strip() or "Inbox"
raw_path = str(payload.get("logical_path", default_path)).strip()
logical_path = raw_path or default_path
raw_tags = payload.get("tags", defaults.get("tags", []))
tags: list[str] = []
seen_lowered: set[str] = set()
if isinstance(raw_tags, list):
for raw_tag in raw_tags:
normalized = str(raw_tag).strip()
if not normalized:
continue
lowered = normalized.lower()
if lowered in seen_lowered:
continue
seen_lowered.add(lowered)
tags.append(normalized)
if len(tags) >= 50:
break
return {
"logical_path": logical_path,
"tags": tags,
}
def _normalize_display_settings(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
"""Normalizes display settings used by the document dashboard UI."""
if not isinstance(payload, dict):
payload = {}
default_cards_per_page = _safe_int(defaults.get("cards_per_page", 12), 12)
cards_per_page = _clamp_cards_per_page(
_safe_int(payload.get("cards_per_page", default_cards_per_page), default_cards_per_page)
)
return {
"cards_per_page": cards_per_page,
"log_typing_animation_enabled": bool(
payload.get("log_typing_animation_enabled", defaults.get("log_typing_animation_enabled", True))
),
}
def _normalize_predefined_paths(
payload: Any,
existing_items: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Normalizes predefined path entries and enforces irreversible global-sharing flag."""
existing_map: dict[str, dict[str, Any]] = {}
if isinstance(existing_items, list):
for item in existing_items:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip().strip("/")
if not value:
continue
existing_map[value.lower()] = {
"value": value,
"global_shared": bool(item.get("global_shared", False)),
}
if not isinstance(payload, list):
return list(existing_map.values())
normalized: list[dict[str, Any]] = []
seen: set[str] = set()
limit = _clamp_predefined_entries_limit(len(payload))
for item in payload:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip().strip("/")
if not value:
continue
lowered = value.lower()
if lowered in seen:
continue
seen.add(lowered)
existing = existing_map.get(lowered)
requested_global = bool(item.get("global_shared", False))
global_shared = bool(existing.get("global_shared", False) if existing else False) or requested_global
normalized.append(
{
"value": value,
"global_shared": global_shared,
}
)
if len(normalized) >= limit:
break
return normalized
def _normalize_predefined_tags(
payload: Any,
existing_items: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Normalizes predefined tag entries and enforces irreversible global-sharing flag."""
existing_map: dict[str, dict[str, Any]] = {}
if isinstance(existing_items, list):
for item in existing_items:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip()
if not value:
continue
existing_map[value.lower()] = {
"value": value,
"global_shared": bool(item.get("global_shared", False)),
}
if not isinstance(payload, list):
return list(existing_map.values())
normalized: list[dict[str, Any]] = []
seen: set[str] = set()
limit = _clamp_predefined_entries_limit(len(payload))
for item in payload:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip()
if not value:
continue
lowered = value.lower()
if lowered in seen:
continue
seen.add(lowered)
existing = existing_map.get(lowered)
requested_global = bool(item.get("global_shared", False))
global_shared = bool(existing.get("global_shared", False) if existing else False) or requested_global
normalized.append(
{
"value": value,
"global_shared": global_shared,
}
)
if len(normalized) >= limit:
break
return normalized
def _normalize_handwriting_style_settings(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
"""Normalizes handwriting-style clustering settings exposed in the settings UI."""
if not isinstance(payload, dict):
payload = {}
default_enabled = bool(defaults.get("enabled", True))
default_embed_model = str(defaults.get("embed_model", DEFAULT_HANDWRITING_STYLE_EMBED_MODEL)).strip()
default_neighbor_limit = _safe_int(defaults.get("neighbor_limit", 8), 8)
default_match_min = _clamp_probability(defaults.get("match_min_similarity", 0.86), 0.86)
default_bootstrap_match_min = _clamp_probability(defaults.get("bootstrap_match_min_similarity", 0.89), 0.89)
default_bootstrap_sample_size = _safe_int(defaults.get("bootstrap_sample_size", 3), 3)
default_image_max_side = _safe_int(defaults.get("image_max_side", 1024), 1024)
return {
"enabled": bool(payload.get("enabled", default_enabled)),
"embed_model": str(payload.get("embed_model", default_embed_model)).strip() or default_embed_model,
"neighbor_limit": _clamp_handwriting_style_neighbor_limit(
_safe_int(payload.get("neighbor_limit", default_neighbor_limit), default_neighbor_limit)
),
"match_min_similarity": _clamp_probability(
payload.get("match_min_similarity", default_match_min),
default_match_min,
),
"bootstrap_match_min_similarity": _clamp_probability(
payload.get("bootstrap_match_min_similarity", default_bootstrap_match_min),
default_bootstrap_match_min,
),
"bootstrap_sample_size": _clamp_handwriting_style_sample_size(
_safe_int(payload.get("bootstrap_sample_size", default_bootstrap_sample_size), default_bootstrap_sample_size)
),
"image_max_side": _clamp_handwriting_style_image_max_side(
_safe_int(payload.get("image_max_side", default_image_max_side), default_image_max_side)
),
}
def _sanitize_settings(payload: dict[str, Any]) -> dict[str, Any]:
"""Sanitizes all persisted settings into a stable normalized structure."""
if not isinstance(payload, dict):
payload = {}
defaults = _default_settings()
providers_payload = payload.get("providers")
normalized_providers: list[dict[str, Any]] = []
seen_provider_ids: set[str] = set()
if isinstance(providers_payload, list):
for index, provider_payload in enumerate(providers_payload):
if not isinstance(provider_payload, dict):
continue
fallback = defaults["providers"][0]
candidate = _normalize_provider(provider_payload, fallback_id=f"provider-{index + 1}", fallback_values=fallback)
if candidate["id"] in seen_provider_ids:
continue
seen_provider_ids.add(candidate["id"])
normalized_providers.append(candidate)
if not normalized_providers:
normalized_providers = [dict(defaults["providers"][0])]
provider_ids = [provider["id"] for provider in normalized_providers]
tasks_payload = payload.get("tasks", {})
normalized_tasks = _normalize_tasks(tasks_payload, provider_ids)
upload_defaults = _normalize_upload_defaults(payload.get("upload_defaults", {}), defaults["upload_defaults"])
display_settings = _normalize_display_settings(payload.get("display", {}), defaults["display"])
predefined_paths = _normalize_predefined_paths(
payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
)
predefined_tags = _normalize_predefined_tags(
payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
)
handwriting_style_settings = _normalize_handwriting_style_settings(
payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
defaults[HANDWRITING_STYLE_SETTINGS_KEY],
)
return {
"upload_defaults": upload_defaults,
"display": display_settings,
PREDEFINED_PATHS_SETTINGS_KEY: predefined_paths,
PREDEFINED_TAGS_SETTINGS_KEY: predefined_tags,
HANDWRITING_STYLE_SETTINGS_KEY: handwriting_style_settings,
"providers": normalized_providers,
"tasks": normalized_tasks,
}
def ensure_app_settings() -> None:
"""Creates a settings file with defaults when no persisted settings are present."""
path = _settings_path()
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
return
defaults = _sanitize_settings(_default_settings())
path.write_text(json.dumps(defaults, indent=2), encoding="utf-8")
def _read_raw_settings() -> dict[str, Any]:
"""Reads persisted settings from disk and returns normalized values."""
ensure_app_settings()
path = _settings_path()
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
payload = {}
return _sanitize_settings(payload)
def _write_settings(payload: dict[str, Any]) -> None:
"""Persists sanitized settings payload to host-mounted storage."""
path = _settings_path()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def read_app_settings() -> dict[str, Any]:
"""Reads settings and returns a sanitized view safe for API responses."""
payload = _read_raw_settings()
providers_response: list[dict[str, Any]] = []
for provider in payload["providers"]:
api_key = str(provider.get("api_key", ""))
providers_response.append(
{
"id": provider["id"],
"label": provider["label"],
"provider_type": provider["provider_type"],
"base_url": provider["base_url"],
"timeout_seconds": int(provider["timeout_seconds"]),
"api_key_set": bool(api_key),
"api_key_masked": _mask_api_key(api_key),
}
)
return {
"upload_defaults": payload.get("upload_defaults", {"logical_path": "Inbox", "tags": []}),
"display": payload.get("display", {"cards_per_page": 12, "log_typing_animation_enabled": True}),
PREDEFINED_PATHS_SETTINGS_KEY: payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
PREDEFINED_TAGS_SETTINGS_KEY: payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
HANDWRITING_STYLE_SETTINGS_KEY: payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
"providers": providers_response,
"tasks": payload["tasks"],
}
def reset_app_settings() -> dict[str, Any]:
"""Resets persisted application settings to sanitized repository defaults."""
defaults = _sanitize_settings(_default_settings())
_write_settings(defaults)
return read_app_settings()
def read_task_runtime_settings(task_name: str) -> dict[str, Any]:
"""Returns runtime task settings and resolved provider including secret values."""
payload = _read_raw_settings()
tasks = payload["tasks"]
if task_name not in tasks:
raise KeyError(f"Unknown task settings key: {task_name}")
task = dict(tasks[task_name])
provider_map = {provider["id"]: provider for provider in payload["providers"]}
provider = provider_map.get(task.get("provider_id"))
if provider is None:
provider = payload["providers"][0]
task["provider_id"] = provider["id"]
return {
"task": task,
"provider": dict(provider),
}
def update_app_settings(
providers: list[dict[str, Any]] | None = None,
tasks: dict[str, dict[str, Any]] | None = None,
upload_defaults: dict[str, Any] | None = None,
display: dict[str, Any] | None = None,
handwriting_style: dict[str, Any] | None = None,
predefined_paths: list[dict[str, Any]] | None = None,
predefined_tags: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Updates app settings, persists them, and returns API-safe values."""
current_payload = _read_raw_settings()
next_payload: dict[str, Any] = {
"upload_defaults": dict(current_payload.get("upload_defaults", {"logical_path": "Inbox", "tags": []})),
"display": dict(current_payload.get("display", {"cards_per_page": 12, "log_typing_animation_enabled": True})),
PREDEFINED_PATHS_SETTINGS_KEY: list(current_payload.get(PREDEFINED_PATHS_SETTINGS_KEY, [])),
PREDEFINED_TAGS_SETTINGS_KEY: list(current_payload.get(PREDEFINED_TAGS_SETTINGS_KEY, [])),
HANDWRITING_STYLE_SETTINGS_KEY: dict(
current_payload.get(HANDWRITING_STYLE_SETTINGS_KEY, _default_settings()[HANDWRITING_STYLE_SETTINGS_KEY])
),
"providers": list(current_payload["providers"]),
"tasks": dict(current_payload["tasks"]),
}
if providers is not None:
existing_provider_map = {provider["id"]: provider for provider in current_payload["providers"]}
next_providers: list[dict[str, Any]] = []
for index, provider_payload in enumerate(providers):
if not isinstance(provider_payload, dict):
continue
provider_id = _normalize_provider_id(
str(provider_payload.get("id", "")),
fallback=f"provider-{index + 1}",
)
existing_provider = existing_provider_map.get(provider_id, {})
merged_payload = dict(provider_payload)
merged_payload["id"] = provider_id
if bool(provider_payload.get("clear_api_key", False)):
merged_payload["api_key"] = ""
elif "api_key" in provider_payload and provider_payload.get("api_key") is not None:
merged_payload["api_key"] = str(provider_payload.get("api_key")).strip()
else:
merged_payload["api_key"] = str(existing_provider.get("api_key", ""))
normalized_provider = _normalize_provider(
merged_payload,
fallback_id=provider_id,
fallback_values=existing_provider,
)
next_providers.append(normalized_provider)
if next_providers:
next_payload["providers"] = next_providers
if tasks is not None:
merged_tasks = dict(current_payload["tasks"])
for task_name, task_update in tasks.items():
if task_name not in merged_tasks or not isinstance(task_update, dict):
continue
existing_task = dict(merged_tasks[task_name])
for key, value in task_update.items():
if value is None:
continue
existing_task[key] = value
merged_tasks[task_name] = existing_task
next_payload["tasks"] = merged_tasks
if upload_defaults is not None and isinstance(upload_defaults, dict):
next_upload_defaults = dict(next_payload.get("upload_defaults", {}))
for key in ("logical_path", "tags"):
if key in upload_defaults:
next_upload_defaults[key] = upload_defaults[key]
next_payload["upload_defaults"] = next_upload_defaults
if display is not None and isinstance(display, dict):
next_display = dict(next_payload.get("display", {}))
if "cards_per_page" in display:
next_display["cards_per_page"] = display["cards_per_page"]
if "log_typing_animation_enabled" in display:
next_display["log_typing_animation_enabled"] = bool(display["log_typing_animation_enabled"])
next_payload["display"] = next_display
if handwriting_style is not None and isinstance(handwriting_style, dict):
next_handwriting_style = dict(next_payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}))
for key in (
"enabled",
"embed_model",
"neighbor_limit",
"match_min_similarity",
"bootstrap_match_min_similarity",
"bootstrap_sample_size",
"image_max_side",
):
if key in handwriting_style:
next_handwriting_style[key] = handwriting_style[key]
next_payload[HANDWRITING_STYLE_SETTINGS_KEY] = next_handwriting_style
if predefined_paths is not None:
next_payload[PREDEFINED_PATHS_SETTINGS_KEY] = _normalize_predefined_paths(
predefined_paths,
existing_items=next_payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
)
if predefined_tags is not None:
next_payload[PREDEFINED_TAGS_SETTINGS_KEY] = _normalize_predefined_tags(
predefined_tags,
existing_items=next_payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
)
sanitized = _sanitize_settings(next_payload)
_write_settings(sanitized)
return read_app_settings()
def read_handwriting_provider_settings() -> dict[str, Any]:
"""Returns OCR settings in legacy shape for the handwriting transcription service."""
runtime = read_task_runtime_settings(TASK_OCR_HANDWRITING)
provider = runtime["provider"]
task = runtime["task"]
return {
"provider": provider["provider_type"],
"enabled": bool(task.get("enabled", True)),
"openai_base_url": str(provider.get("base_url", settings.default_openai_base_url)),
"openai_model": str(task.get("model", settings.default_openai_model)),
"openai_timeout_seconds": int(provider.get("timeout_seconds", settings.default_openai_timeout_seconds)),
"openai_api_key": str(provider.get("api_key", "")),
"prompt": str(task.get("prompt", DEFAULT_OCR_PROMPT)),
"provider_id": str(provider.get("id", "openai-default")),
}
def read_handwriting_style_settings() -> dict[str, Any]:
"""Returns handwriting-style clustering settings for Typesense style assignment logic."""
payload = _read_raw_settings()
defaults = _default_settings()[HANDWRITING_STYLE_SETTINGS_KEY]
return _normalize_handwriting_style_settings(
payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
defaults,
)
def read_predefined_paths_settings() -> list[dict[str, Any]]:
"""Returns normalized predefined logical path catalog entries."""
payload = _read_raw_settings()
return _normalize_predefined_paths(
payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
)
def read_predefined_tags_settings() -> list[dict[str, Any]]:
"""Returns normalized predefined tag catalog entries."""
payload = _read_raw_settings()
return _normalize_predefined_tags(
payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
)
def update_handwriting_settings(
enabled: bool | None = None,
openai_base_url: str | None = None,
openai_model: str | None = None,
openai_timeout_seconds: int | None = None,
openai_api_key: str | None = None,
clear_openai_api_key: bool = False,
) -> dict[str, Any]:
"""Updates OCR task and bound provider values using the legacy handwriting API contract."""
runtime = read_task_runtime_settings(TASK_OCR_HANDWRITING)
provider = runtime["provider"]
provider_update: dict[str, Any] = {
"id": provider["id"],
"label": provider["label"],
"provider_type": provider["provider_type"],
"base_url": openai_base_url if openai_base_url is not None else provider["base_url"],
"timeout_seconds": openai_timeout_seconds if openai_timeout_seconds is not None else provider["timeout_seconds"],
}
if clear_openai_api_key:
provider_update["clear_api_key"] = True
elif openai_api_key is not None:
provider_update["api_key"] = openai_api_key
tasks_update: dict[str, dict[str, Any]] = {TASK_OCR_HANDWRITING: {}}
if enabled is not None:
tasks_update[TASK_OCR_HANDWRITING]["enabled"] = enabled
if openai_model is not None:
tasks_update[TASK_OCR_HANDWRITING]["model"] = openai_model
return update_app_settings(
providers=[provider_update],
tasks=tasks_update,
)

View File

@@ -0,0 +1,315 @@
"""Document extraction service for text indexing, previews, and archive fan-out."""
import io
import re
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
import magic
from docx import Document as DocxDocument
from openpyxl import load_workbook
from PIL import Image, ImageOps
from pypdf import PdfReader
import pymupdf
from app.core.config import get_settings
from app.services.handwriting import (
IMAGE_TEXT_TYPE_NO_TEXT,
IMAGE_TEXT_TYPE_UNKNOWN,
HandwritingTranscriptionError,
HandwritingTranscriptionNotConfiguredError,
HandwritingTranscriptionTimeoutError,
classify_image_text_bytes,
transcribe_handwriting_bytes,
)
settings = get_settings()
IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
".gif",
".webp",
".heic",
}
SUPPORTED_TEXT_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".svg",
".pdf",
".docx",
".xlsx",
*IMAGE_EXTENSIONS,
}
@dataclass
class ExtractionResult:
"""Represents output generated during extraction for a single file."""
text: str
preview_bytes: bytes | None
preview_suffix: str | None
status: str
metadata_json: dict[str, object] = field(default_factory=dict)
@dataclass
class ArchiveMember:
"""Represents an extracted file entry from an archive."""
name: str
data: bytes
def sniff_mime(data: bytes) -> str:
"""Detects MIME type using libmagic for robust format handling."""
return magic.from_buffer(data, mime=True) or "application/octet-stream"
def is_supported_for_extraction(extension: str, mime_type: str) -> bool:
"""Determines if a file should be text-processed for indexing and classification."""
return extension in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/")
def _normalize_text(text: str) -> str:
"""Normalizes extracted text by removing repeated form separators and controls."""
cleaned = text.replace("\r", "\n").replace("\x00", "")
lines: list[str] = []
for line in cleaned.split("\n"):
stripped = line.strip()
if stripped and re.fullmatch(r"[.\-_*=~\s]{4,}", stripped):
continue
lines.append(line)
normalized = "\n".join(lines)
normalized = re.sub(r"\n{3,}", "\n\n", normalized)
return normalized.strip()
def _extract_pdf_text(data: bytes) -> str:
"""Extracts text from PDF bytes using pypdf page parsing."""
reader = PdfReader(io.BytesIO(data))
pages: list[str] = []
for page in reader.pages:
pages.append(page.extract_text() or "")
return _normalize_text("\n".join(pages))
def _extract_pdf_preview(data: bytes) -> tuple[bytes | None, str | None]:
"""Creates a JPEG thumbnail preview from the first PDF page."""
try:
document = pymupdf.open(stream=data, filetype="pdf")
except Exception:
return None, None
try:
if document.page_count < 1:
return None, None
page = document.load_page(0)
pixmap = page.get_pixmap(matrix=pymupdf.Matrix(1.5, 1.5), alpha=False)
return pixmap.tobytes("jpeg"), ".jpg"
except Exception:
return None, None
finally:
document.close()
def _extract_docx_text(data: bytes) -> str:
"""Extracts paragraph text from DOCX content."""
document = DocxDocument(io.BytesIO(data))
return _normalize_text("\n".join(paragraph.text for paragraph in document.paragraphs if paragraph.text))
def _extract_xlsx_text(data: bytes) -> str:
"""Extracts cell text from XLSX workbook sheets for indexing."""
workbook = load_workbook(io.BytesIO(data), data_only=True, read_only=True)
chunks: list[str] = []
for sheet in workbook.worksheets:
chunks.append(sheet.title)
row_count = 0
for row in sheet.iter_rows(min_row=1, max_row=200):
row_values = [str(cell.value) for cell in row if cell.value is not None]
if row_values:
chunks.append(" ".join(row_values))
row_count += 1
if row_count >= 200:
break
return _normalize_text("\n".join(chunks))
def _build_image_preview(data: bytes) -> tuple[bytes | None, str | None]:
"""Builds a JPEG preview thumbnail for image files."""
try:
with Image.open(io.BytesIO(data)) as image:
preview = ImageOps.exif_transpose(image).convert("RGB")
preview.thumbnail((600, 600))
output = io.BytesIO()
preview.save(output, format="JPEG", optimize=True, quality=82)
return output.getvalue(), ".jpg"
except Exception:
return None, None
def _extract_handwriting_text(data: bytes, mime_type: str) -> ExtractionResult:
"""Extracts text from image bytes and records handwriting-vs-printed classification metadata."""
preview_bytes, preview_suffix = _build_image_preview(data)
metadata_json: dict[str, object] = {}
try:
text_type = classify_image_text_bytes(data, mime_type=mime_type)
metadata_json = {
"image_text_type": text_type.label,
"image_text_type_confidence": text_type.confidence,
"image_text_type_provider": text_type.provider,
"image_text_type_model": text_type.model,
}
except HandwritingTranscriptionNotConfiguredError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="unsupported",
metadata_json={"transcription_error": str(error), "image_text_type": IMAGE_TEXT_TYPE_UNKNOWN},
)
except HandwritingTranscriptionTimeoutError as error:
metadata_json = {
"image_text_type": IMAGE_TEXT_TYPE_UNKNOWN,
"image_text_type_error": str(error),
}
except HandwritingTranscriptionError as error:
metadata_json = {
"image_text_type": IMAGE_TEXT_TYPE_UNKNOWN,
"image_text_type_error": str(error),
}
if metadata_json.get("image_text_type") == IMAGE_TEXT_TYPE_NO_TEXT:
metadata_json["transcription_skipped"] = "no_text_detected"
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="processed",
metadata_json=metadata_json,
)
try:
transcription = transcribe_handwriting_bytes(data, mime_type=mime_type)
transcription_metadata: dict[str, object] = {
"transcription_provider": transcription.provider,
"transcription_model": transcription.model,
"transcription_uncertainties": transcription.uncertainties,
}
return ExtractionResult(
text=_normalize_text(transcription.text),
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="processed",
metadata_json={**metadata_json, **transcription_metadata},
)
except HandwritingTranscriptionNotConfiguredError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="unsupported",
metadata_json={**metadata_json, "transcription_error": str(error)},
)
except HandwritingTranscriptionTimeoutError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="error",
metadata_json={**metadata_json, "transcription_error": str(error)},
)
except HandwritingTranscriptionError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="error",
metadata_json={**metadata_json, "transcription_error": str(error)},
)
def extract_text_content(filename: str, data: bytes, mime_type: str) -> ExtractionResult:
"""Extracts text and optional preview bytes for supported file types."""
extension = Path(filename).suffix.lower()
text = ""
preview_bytes: bytes | None = None
preview_suffix: str | None = None
try:
if extension == ".pdf":
text = _extract_pdf_text(data)
preview_bytes, preview_suffix = _extract_pdf_preview(data)
elif extension in {".txt", ".md", ".csv", ".json", ".xml", ".svg"} or mime_type.startswith("text/"):
text = _normalize_text(data.decode("utf-8", errors="ignore"))
elif extension == ".docx":
text = _extract_docx_text(data)
elif extension == ".xlsx":
text = _extract_xlsx_text(data)
elif extension in IMAGE_EXTENSIONS:
return _extract_handwriting_text(data=data, mime_type=mime_type)
else:
return ExtractionResult(
text="",
preview_bytes=None,
preview_suffix=None,
status="unsupported",
metadata_json={"reason": "unsupported_format"},
)
except Exception as error:
return ExtractionResult(
text="",
preview_bytes=None,
preview_suffix=None,
status="error",
metadata_json={"reason": "extraction_exception", "error": str(error)},
)
return ExtractionResult(
text=text[: settings.max_text_length],
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="processed",
metadata_json={},
)
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
"""Extracts processable members from zip archives with configurable depth limits."""
members: list[ArchiveMember] = []
if depth > settings.max_zip_depth:
return members
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
member_data = archive.read(info.filename)
members.append(ArchiveMember(name=info.filename, data=member_data))
return members

View File

@@ -0,0 +1,477 @@
"""Handwriting transcription service using OpenAI-compatible vision models."""
import base64
import io
import json
import re
from dataclasses import dataclass
from typing import Any
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from PIL import Image, ImageOps
from app.services.app_settings import DEFAULT_OCR_PROMPT, read_handwriting_provider_settings
MAX_IMAGE_SIDE = 2000
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
IMAGE_TEXT_TYPE_PRINTED = "printed_text"
IMAGE_TEXT_TYPE_NO_TEXT = "no_text"
IMAGE_TEXT_TYPE_UNKNOWN = "unknown"
IMAGE_TEXT_CLASSIFICATION_PROMPT = (
"Classify the text content in this image.\n"
"Choose exactly one label from: handwriting, printed_text, no_text.\n"
"Definitions:\n"
"- handwriting: text exists and most readable text is handwritten.\n"
"- printed_text: text exists and most readable text is machine printed or typed.\n"
"- no_text: no readable text is present.\n"
"Return strict JSON only with shape:\n"
"{\n"
' "label": "handwriting|printed_text|no_text",\n'
' "confidence": number\n'
"}\n"
"Confidence must be between 0 and 1."
)
class HandwritingTranscriptionError(Exception):
"""Raised when handwriting transcription fails for a non-timeout reason."""
class HandwritingTranscriptionTimeoutError(HandwritingTranscriptionError):
"""Raised when handwriting transcription exceeds the configured timeout."""
class HandwritingTranscriptionNotConfiguredError(HandwritingTranscriptionError):
"""Raised when handwriting transcription is disabled or missing credentials."""
@dataclass
class HandwritingTranscription:
"""Represents transcription output and uncertainty markers."""
text: str
uncertainties: list[str]
provider: str
model: str
@dataclass
class ImageTextClassification:
"""Represents model classification of image text modality for one image."""
label: str
confidence: float
provider: str
model: str
def _extract_uncertainties(text: str) -> list[str]:
"""Extracts uncertainty markers from transcription output."""
matches = re.findall(r"\[\[\?(.*?)\?\]\]", text)
return [match.strip() for match in matches if match.strip()]
def _coerce_json_object(payload: str) -> dict[str, Any]:
"""Parses and extracts a JSON object from raw model output text."""
text = payload.strip()
if not text:
return {}
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
if fenced:
try:
parsed = json.loads(fenced.group(1))
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
first_brace = text.find("{")
last_brace = text.rfind("}")
if first_brace >= 0 and last_brace > first_brace:
candidate = text[first_brace : last_brace + 1]
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {}
return {}
def _clamp_probability(value: Any, fallback: float = 0.0) -> float:
"""Clamps confidence-like values to the inclusive [0, 1] range."""
try:
parsed = float(value)
except (TypeError, ValueError):
return fallback
return max(0.0, min(1.0, parsed))
def _normalize_image_text_type(label: str) -> str:
"""Normalizes classifier labels into one supported canonical image text type."""
normalized = label.strip().lower().replace("-", "_").replace(" ", "_")
if normalized in {IMAGE_TEXT_TYPE_HANDWRITING, "handwritten", "handwritten_text"}:
return IMAGE_TEXT_TYPE_HANDWRITING
if normalized in {IMAGE_TEXT_TYPE_PRINTED, "printed", "typed", "machine_text"}:
return IMAGE_TEXT_TYPE_PRINTED
if normalized in {IMAGE_TEXT_TYPE_NO_TEXT, "no-text", "none", "no readable text"}:
return IMAGE_TEXT_TYPE_NO_TEXT
return IMAGE_TEXT_TYPE_UNKNOWN
def _normalize_image_bytes(image_data: bytes) -> tuple[bytes, str]:
"""Applies EXIF rotation and scales large images down for efficient transcription."""
with Image.open(io.BytesIO(image_data)) as image:
rotated = ImageOps.exif_transpose(image)
prepared = rotated.convert("RGB")
long_side = max(prepared.width, prepared.height)
if long_side > MAX_IMAGE_SIDE:
scale = MAX_IMAGE_SIDE / long_side
resized_width = max(1, int(prepared.width * scale))
resized_height = max(1, int(prepared.height * scale))
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
output = io.BytesIO()
prepared.save(output, format="JPEG", quality=90, optimize=True)
return output.getvalue(), "image/jpeg"
def _create_client(provider_settings: dict[str, Any]) -> OpenAI:
"""Creates an OpenAI client configured for compatible endpoints and timeouts."""
api_key = str(provider_settings.get("openai_api_key", "")).strip() or "no-key-required"
return OpenAI(
api_key=api_key,
base_url=str(provider_settings["openai_base_url"]),
timeout=int(provider_settings["openai_timeout_seconds"]),
)
def _extract_text_from_response(response: Any) -> str:
"""Extracts plain text from responses API output objects."""
output_text = getattr(response, "output_text", None)
if isinstance(output_text, str) and output_text.strip():
return output_text.strip()
output_items = getattr(response, "output", None)
if not isinstance(output_items, list):
return ""
texts: list[str] = []
for item in output_items:
item_data = item.model_dump() if hasattr(item, "model_dump") else item
if not isinstance(item_data, dict):
continue
item_type = item_data.get("type")
if item_type == "output_text":
text = str(item_data.get("text", "")).strip()
if text:
texts.append(text)
if item_type == "message":
for content in item_data.get("content", []) or []:
if not isinstance(content, dict):
continue
if content.get("type") in {"output_text", "text"}:
text = str(content.get("text", "")).strip()
if text:
texts.append(text)
return "\n".join(texts).strip()
def _transcribe_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Transcribes handwriting using the responses API."""
response = client.responses.create(
model=model,
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": prompt,
},
{
"type": "input_image",
"image_url": image_data_url,
"detail": "high",
},
],
}
],
)
return _extract_text_from_response(response)
def _transcribe_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Transcribes handwriting using chat completions for endpoint compatibility."""
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": image_data_url,
"detail": "high",
},
},
],
}
],
)
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if isinstance(message_content, list):
text_parts: list[str] = []
for part in message_content:
if isinstance(part, dict):
text = str(part.get("text", "")).strip()
if text:
text_parts.append(text)
return "\n".join(text_parts).strip()
return ""
def _classify_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Classifies image text modality using the responses API."""
response = client.responses.create(
model=model,
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": prompt,
},
{
"type": "input_image",
"image_url": image_data_url,
"detail": "high",
},
],
}
],
)
return _extract_text_from_response(response)
def _classify_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Classifies image text modality using chat completions for compatibility."""
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": image_data_url,
"detail": "high",
},
},
],
}
],
)
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if isinstance(message_content, list):
text_parts: list[str] = []
for part in message_content:
if isinstance(part, dict):
text = str(part.get("text", "")).strip()
if text:
text_parts.append(text)
return "\n".join(text_parts).strip()
return ""
def _classify_image_text_data_url(image_data_url: str) -> ImageTextClassification:
"""Classifies an image as handwriting, printed text, or no text."""
provider_settings = read_handwriting_provider_settings()
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
if not bool(provider_settings.get("enabled", True)):
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
client = _create_client(provider_settings)
try:
output_text = _classify_with_responses(
client=client,
model=model,
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
image_data_url=image_data_url,
)
if not output_text:
output_text = _classify_with_chat(
client=client,
model=model,
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
image_data_url=image_data_url,
)
except APITimeoutError as error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
except (APIConnectionError, APIError):
try:
output_text = _classify_with_chat(
client=client,
model=model,
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
image_data_url=image_data_url,
)
except APITimeoutError as timeout_error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
except Exception as fallback_error:
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
except Exception as error:
raise HandwritingTranscriptionError(str(error)) from error
parsed = _coerce_json_object(output_text)
if not parsed:
raise HandwritingTranscriptionError("image_text_classification_parse_failed")
label = _normalize_image_text_type(str(parsed.get("label", "")))
confidence = _clamp_probability(parsed.get("confidence", 0.0), fallback=0.0)
return ImageTextClassification(
label=label,
confidence=confidence,
provider="openai",
model=model,
)
def _transcribe_image_data_url(image_data_url: str) -> HandwritingTranscription:
"""Transcribes a handwriting image data URL with configured OpenAI provider settings."""
provider_settings = read_handwriting_provider_settings()
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
if not bool(provider_settings.get("enabled", True)):
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
prompt = str(provider_settings.get("prompt", DEFAULT_OCR_PROMPT)).strip() or DEFAULT_OCR_PROMPT
client = _create_client(provider_settings)
try:
text = _transcribe_with_responses(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
if not text:
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
except APITimeoutError as error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
except (APIConnectionError, APIError) as error:
try:
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
except APITimeoutError as timeout_error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
except Exception as fallback_error:
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
except Exception as error:
raise HandwritingTranscriptionError(str(error)) from error
final_text = text.strip()
return HandwritingTranscription(
text=final_text,
uncertainties=_extract_uncertainties(final_text),
provider="openai",
model=model,
)
def transcribe_handwriting_base64(image_base64: str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
"""Transcribes handwriting from a base64 payload without data URL prefix."""
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
return _transcribe_image_data_url(image_data_url)
def transcribe_handwriting_url(image_url: str) -> HandwritingTranscription:
"""Transcribes handwriting from a direct image URL."""
return _transcribe_image_data_url(image_url)
def transcribe_handwriting_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> HandwritingTranscription:
"""Transcribes handwriting from raw image bytes after normalization."""
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
encoded = base64.b64encode(normalized_bytes).decode("ascii")
return transcribe_handwriting_base64(encoded, mime_type=normalized_mime)
def classify_image_text_base64(image_base64: str, mime_type: str = "image/jpeg") -> ImageTextClassification:
"""Classifies image text type from a base64 payload without data URL prefix."""
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
return _classify_image_text_data_url(image_data_url)
def classify_image_text_url(image_url: str) -> ImageTextClassification:
"""Classifies image text type from a direct image URL."""
return _classify_image_text_data_url(image_url)
def classify_image_text_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> ImageTextClassification:
"""Classifies image text type from raw image bytes after normalization."""
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
encoded = base64.b64encode(normalized_bytes).decode("ascii")
return classify_image_text_base64(encoded, mime_type=normalized_mime)
def transcribe_handwriting(image: bytes | str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
"""Transcribes handwriting from bytes, base64 text, or URL input."""
if isinstance(image, bytes):
return transcribe_handwriting_bytes(image, mime_type=mime_type)
stripped = image.strip()
if stripped.startswith("http://") or stripped.startswith("https://"):
return transcribe_handwriting_url(stripped)
return transcribe_handwriting_base64(stripped, mime_type=mime_type)

View File

@@ -0,0 +1,435 @@
"""Handwriting-style clustering and style-scoped path composition for image documents."""
import base64
import io
import re
from dataclasses import dataclass
from typing import Any
from PIL import Image, ImageOps
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.models.document import Document, DocumentStatus
from app.services.app_settings import (
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
read_handwriting_style_settings,
)
from app.services.typesense_index import get_typesense_client
settings = get_settings()
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
HANDWRITING_STYLE_COLLECTION_SUFFIX = "_handwriting_styles"
HANDWRITING_STYLE_EMBED_MODEL = DEFAULT_HANDWRITING_STYLE_EMBED_MODEL
HANDWRITING_STYLE_MATCH_MIN_SIMILARITY = 0.86
HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY = 0.89
HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE = 3
HANDWRITING_STYLE_NEIGHBOR_LIMIT = 8
HANDWRITING_STYLE_IMAGE_MAX_SIDE = 1024
HANDWRITING_STYLE_ID_PREFIX = "hw_style_"
HANDWRITING_STYLE_ID_PATTERN = re.compile(r"^hw_style_(\d+)$")
@dataclass
class HandwritingStyleNeighbor:
"""Represents one nearest handwriting-style neighbor returned from Typesense."""
document_id: str
style_cluster_id: str
vector_distance: float
similarity: float
@dataclass
class HandwritingStyleAssignment:
"""Represents the chosen handwriting-style cluster assignment for one document."""
style_cluster_id: str
matched_existing: bool
similarity: float
vector_distance: float
compared_neighbors: int
match_min_similarity: float
bootstrap_match_min_similarity: float
def _style_collection_name() -> str:
"""Builds the dedicated Typesense collection name used for handwriting-style vectors."""
return f"{settings.typesense_collection_name}{HANDWRITING_STYLE_COLLECTION_SUFFIX}"
def _style_collection() -> Any:
"""Returns the Typesense collection handle for handwriting-style indexing."""
client = get_typesense_client()
return client.collections[_style_collection_name()]
def _distance_to_similarity(vector_distance: float) -> float:
"""Converts Typesense vector distance into conservative similarity in [0, 1]."""
return max(0.0, min(1.0, 1.0 - (vector_distance / 2.0)))
def _encode_style_image_base64(image_data: bytes, image_max_side: int) -> str:
"""Normalizes and downsizes image bytes and returns a base64-encoded JPEG payload."""
with Image.open(io.BytesIO(image_data)) as image:
prepared = ImageOps.exif_transpose(image).convert("RGB")
longest_side = max(prepared.width, prepared.height)
if longest_side > image_max_side:
scale = image_max_side / longest_side
resized_width = max(1, int(prepared.width * scale))
resized_height = max(1, int(prepared.height * scale))
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
output = io.BytesIO()
prepared.save(output, format="JPEG", quality=86, optimize=True)
return base64.b64encode(output.getvalue()).decode("ascii")
def ensure_handwriting_style_collection() -> None:
"""Creates the handwriting-style Typesense collection when it is not present."""
runtime_settings = read_handwriting_style_settings()
embed_model = str(runtime_settings.get("embed_model", HANDWRITING_STYLE_EMBED_MODEL)).strip() or HANDWRITING_STYLE_EMBED_MODEL
collection = _style_collection()
should_recreate_collection = False
try:
existing_schema = collection.retrieve()
if isinstance(existing_schema, dict):
existing_fields = existing_schema.get("fields", [])
if isinstance(existing_fields, list):
for field in existing_fields:
if not isinstance(field, dict):
continue
if str(field.get("name", "")).strip() != "embedding":
continue
embed_config = field.get("embed", {})
model_config = embed_config.get("model_config", {}) if isinstance(embed_config, dict) else {}
existing_model = str(model_config.get("model_name", "")).strip()
if existing_model and existing_model != embed_model:
should_recreate_collection = True
break
if not should_recreate_collection:
return
except Exception as error:
message = str(error).lower()
if "404" not in message and "not found" not in message:
raise
client = get_typesense_client()
if should_recreate_collection:
client.collections[_style_collection_name()].delete()
schema = {
"name": _style_collection_name(),
"fields": [
{
"name": "style_cluster_id",
"type": "string",
"facet": True,
},
{
"name": "image_text_type",
"type": "string",
"facet": True,
},
{
"name": "created_at",
"type": "int64",
},
{
"name": "image",
"type": "image",
"store": False,
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": ["image"],
"model_config": {
"model_name": embed_model,
},
},
},
],
"default_sorting_field": "created_at",
}
client.collections.create(schema)
def _search_style_neighbors(
image_base64: str,
limit: int,
exclude_document_id: str | None = None,
) -> list[HandwritingStyleNeighbor]:
"""Returns nearest handwriting-style neighbors for one encoded image payload."""
ensure_handwriting_style_collection()
client = get_typesense_client()
filter_clauses = [f"image_text_type:={IMAGE_TEXT_TYPE_HANDWRITING}"]
if exclude_document_id:
filter_clauses.append(f"id:!={exclude_document_id}")
search_payload = {
"q": "*",
"query_by": "embedding",
"vector_query": f"embedding:([], image:{image_base64}, k:{max(1, limit)})",
"exclude_fields": "embedding,image",
"per_page": max(1, limit),
"filter_by": " && ".join(filter_clauses),
}
response = client.multi_search.perform(
{
"searches": [
{
"collection": _style_collection_name(),
**search_payload,
}
]
},
{},
)
results = response.get("results", []) if isinstance(response, dict) else []
first_result = results[0] if isinstance(results, list) and len(results) > 0 else {}
hits = first_result.get("hits", []) if isinstance(first_result, dict) else []
neighbors: list[HandwritingStyleNeighbor] = []
for hit in hits:
if not isinstance(hit, dict):
continue
document = hit.get("document")
if not isinstance(document, dict):
continue
document_id = str(document.get("id", "")).strip()
style_cluster_id = str(document.get("style_cluster_id", "")).strip()
if not document_id or not style_cluster_id:
continue
try:
vector_distance = float(hit.get("vector_distance", 2.0))
except (TypeError, ValueError):
vector_distance = 2.0
neighbors.append(
HandwritingStyleNeighbor(
document_id=document_id,
style_cluster_id=style_cluster_id,
vector_distance=vector_distance,
similarity=_distance_to_similarity(vector_distance),
)
)
if len(neighbors) >= limit:
break
return neighbors
def _next_style_cluster_id(session: Session) -> str:
"""Allocates the next stable handwriting-style folder identifier."""
existing_ids = session.execute(
select(Document.handwriting_style_id).where(Document.handwriting_style_id.is_not(None))
).scalars().all()
max_value = 0
for existing_id in existing_ids:
candidate = str(existing_id).strip()
match = HANDWRITING_STYLE_ID_PATTERN.fullmatch(candidate)
if not match:
continue
numeric_part = int(match.group(1))
max_value = max(max_value, numeric_part)
return f"{HANDWRITING_STYLE_ID_PREFIX}{max_value + 1}"
def _style_cluster_sample_size(session: Session, style_cluster_id: str) -> int:
"""Returns the number of indexed documents currently assigned to one style cluster."""
return int(
session.execute(
select(func.count())
.select_from(Document)
.where(Document.handwriting_style_id == style_cluster_id)
.where(Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING)
).scalar_one()
)
def assign_handwriting_style(
session: Session,
document: Document,
image_data: bytes,
) -> HandwritingStyleAssignment:
"""Assigns a document to an existing handwriting-style cluster or creates a new one."""
runtime_settings = read_handwriting_style_settings()
image_max_side = int(runtime_settings.get("image_max_side", HANDWRITING_STYLE_IMAGE_MAX_SIDE))
neighbor_limit = int(runtime_settings.get("neighbor_limit", HANDWRITING_STYLE_NEIGHBOR_LIMIT))
match_min_similarity = float(runtime_settings.get("match_min_similarity", HANDWRITING_STYLE_MATCH_MIN_SIMILARITY))
bootstrap_match_min_similarity = float(
runtime_settings.get("bootstrap_match_min_similarity", HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY)
)
bootstrap_sample_size = int(runtime_settings.get("bootstrap_sample_size", HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE))
image_base64 = _encode_style_image_base64(image_data, image_max_side=image_max_side)
neighbors = _search_style_neighbors(
image_base64=image_base64,
limit=neighbor_limit,
exclude_document_id=str(document.id),
)
best_neighbor = neighbors[0] if neighbors else None
similarity = best_neighbor.similarity if best_neighbor else 0.0
vector_distance = best_neighbor.vector_distance if best_neighbor else 2.0
cluster_sample_size = 0
if best_neighbor:
cluster_sample_size = _style_cluster_sample_size(
session=session,
style_cluster_id=best_neighbor.style_cluster_id,
)
required_similarity = (
bootstrap_match_min_similarity
if cluster_sample_size < bootstrap_sample_size
else match_min_similarity
)
should_match_existing = (
best_neighbor is not None and similarity >= required_similarity
)
if should_match_existing and best_neighbor:
style_cluster_id = best_neighbor.style_cluster_id
matched_existing = True
else:
existing_style_cluster_id = (document.handwriting_style_id or "").strip()
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(existing_style_cluster_id):
style_cluster_id = existing_style_cluster_id
else:
style_cluster_id = _next_style_cluster_id(session=session)
matched_existing = False
ensure_handwriting_style_collection()
collection = _style_collection()
payload = {
"id": str(document.id),
"style_cluster_id": style_cluster_id,
"image_text_type": IMAGE_TEXT_TYPE_HANDWRITING,
"created_at": int(document.created_at.timestamp()),
"image": image_base64,
}
collection.documents.upsert(payload)
return HandwritingStyleAssignment(
style_cluster_id=style_cluster_id,
matched_existing=matched_existing,
similarity=similarity,
vector_distance=vector_distance,
compared_neighbors=len(neighbors),
match_min_similarity=match_min_similarity,
bootstrap_match_min_similarity=bootstrap_match_min_similarity,
)
def delete_handwriting_style_document(document_id: str) -> None:
"""Deletes one document id from the handwriting-style Typesense collection."""
collection = _style_collection()
try:
collection.documents[document_id].delete()
except Exception as error:
message = str(error).lower()
if "404" in message or "not found" in message:
return
raise
def delete_many_handwriting_style_documents(document_ids: list[str]) -> None:
"""Deletes many document ids from the handwriting-style Typesense collection."""
for document_id in document_ids:
delete_handwriting_style_document(document_id)
def apply_handwriting_style_path(style_cluster_id: str | None, path_value: str | None) -> str | None:
"""Composes style-prefixed logical paths while preventing duplicate prefix nesting."""
if path_value is None:
return None
normalized_path = path_value.strip().strip("/")
if not normalized_path:
return None
normalized_style = (style_cluster_id or "").strip().strip("/")
if not normalized_style:
return normalized_path
segments = [segment for segment in normalized_path.split("/") if segment]
while segments and HANDWRITING_STYLE_ID_PATTERN.fullmatch(segments[0]):
segments.pop(0)
if segments and segments[0].strip().lower() == normalized_style.lower():
segments.pop(0)
if len(segments) == 0:
return normalized_style
sanitized_path = "/".join(segments)
return f"{normalized_style}/{sanitized_path}"
def resolve_handwriting_style_path_prefix(
session: Session,
style_cluster_id: str | None,
*,
exclude_document_id: str | None = None,
) -> str | None:
"""Resolves a stable path prefix for one style cluster, preferring known non-style root segments."""
normalized_style = (style_cluster_id or "").strip()
if not normalized_style:
return None
statement = select(Document.logical_path).where(
Document.handwriting_style_id == normalized_style,
Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING,
Document.status != DocumentStatus.TRASHED,
)
if exclude_document_id:
statement = statement.where(Document.id != exclude_document_id)
rows = session.execute(statement).scalars().all()
segment_counts: dict[str, int] = {}
segment_labels: dict[str, str] = {}
for raw_path in rows:
if not isinstance(raw_path, str):
continue
segments = [segment.strip() for segment in raw_path.split("/") if segment.strip()]
if not segments:
continue
first_segment = segments[0]
lowered = first_segment.lower()
if lowered == "inbox":
continue
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(first_segment):
continue
segment_counts[lowered] = segment_counts.get(lowered, 0) + 1
if lowered not in segment_labels:
segment_labels[lowered] = first_segment
if not segment_counts:
return normalized_style
winner = sorted(
segment_counts.items(),
key=lambda item: (-item[1], item[0]),
)[0][0]
return segment_labels.get(winner, normalized_style)

View File

@@ -0,0 +1,227 @@
"""Model runtime utilities for provider-bound LLM task execution."""
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse, urlunparse
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from app.services.app_settings import read_task_runtime_settings
class ModelTaskError(Exception):
"""Raised when a model task request fails."""
class ModelTaskTimeoutError(ModelTaskError):
"""Raised when a model task request times out."""
class ModelTaskDisabledError(ModelTaskError):
"""Raised when a model task is disabled in settings."""
@dataclass
class ModelTaskRuntime:
"""Resolved runtime configuration for one task and provider."""
task_name: str
provider_id: str
provider_type: str
base_url: str
timeout_seconds: int
api_key: str
model: str
prompt: str
def _normalize_base_url(raw_value: str) -> str:
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
return "https://api.openai.com/v1"
parsed = urlparse(trimmed)
path = parsed.path or ""
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
return urlunparse(parsed._replace(path=path))
def _should_fallback_to_chat(error: Exception) -> bool:
"""Determines whether a responses API failure should fallback to chat completions."""
status_code = getattr(error, "status_code", None)
if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}:
return True
message = str(error).lower()
fallback_markers = (
"404",
"not found",
"unknown endpoint",
"unsupported",
"invalid url",
"responses",
)
return any(marker in message for marker in fallback_markers)
def _extract_text_from_response(response: Any) -> str:
"""Extracts plain text from Responses API outputs."""
output_text = getattr(response, "output_text", None)
if isinstance(output_text, str) and output_text.strip():
return output_text.strip()
output_items = getattr(response, "output", None)
if not isinstance(output_items, list):
return ""
chunks: list[str] = []
for item in output_items:
item_data = item.model_dump() if hasattr(item, "model_dump") else item
if not isinstance(item_data, dict):
continue
item_type = item_data.get("type")
if item_type == "output_text":
text = str(item_data.get("text", "")).strip()
if text:
chunks.append(text)
if item_type == "message":
for content in item_data.get("content", []) or []:
if not isinstance(content, dict):
continue
if content.get("type") in {"output_text", "text"}:
text = str(content.get("text", "")).strip()
if text:
chunks.append(text)
return "\n".join(chunks).strip()
def _extract_text_from_chat_response(response: Any) -> str:
"""Extracts text from Chat Completions API outputs."""
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if not isinstance(message_content, list):
return ""
chunks: list[str] = []
for content in message_content:
if not isinstance(content, dict):
continue
text = str(content.get("text", "")).strip()
if text:
chunks.append(text)
return "\n".join(chunks).strip()
def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
"""Resolves one task runtime including provider endpoint, model, and prompt."""
runtime_payload = read_task_runtime_settings(task_name)
task_payload = runtime_payload["task"]
provider_payload = runtime_payload["provider"]
if not bool(task_payload.get("enabled", True)):
raise ModelTaskDisabledError(f"task_disabled:{task_name}")
provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
return ModelTaskRuntime(
task_name=task_name,
provider_id=str(provider_payload.get("id", "")),
provider_type=provider_type,
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
model=str(task_payload.get("model", "")).strip(),
prompt=str(task_payload.get("prompt", "")).strip(),
)
def _create_client(runtime: ModelTaskRuntime) -> OpenAI:
"""Builds an OpenAI SDK client for OpenAI-compatible provider endpoints."""
return OpenAI(
api_key=runtime.api_key,
base_url=runtime.base_url,
timeout=runtime.timeout_seconds,
)
def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str:
"""Runs a text-only task against the configured provider and returns plain output text."""
runtime = resolve_task_runtime(task_name)
client = _create_client(runtime)
prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt
try:
response = client.responses.create(
model=runtime.model,
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": prompt,
}
],
},
{
"role": "user",
"content": [
{
"type": "input_text",
"text": user_text,
}
],
},
],
)
text = _extract_text_from_response(response)
if text:
return text
except APITimeoutError as error:
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
except APIConnectionError as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except APIError as error:
if not _should_fallback_to_chat(error):
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except Exception as error:
if not _should_fallback_to_chat(error):
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
try:
fallback = client.chat.completions.create(
model=runtime.model,
messages=[
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": user_text,
},
],
)
return _extract_text_from_chat_response(fallback)
except APITimeoutError as error:
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
except (APIConnectionError, APIError) as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except Exception as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error

View File

@@ -0,0 +1,192 @@
"""Persistence helpers for writing and querying processing pipeline log events."""
from typing import Any
from uuid import UUID
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from app.models.document import Document
from app.models.processing_log import ProcessingLogEntry
MAX_STAGE_LENGTH = 64
MAX_EVENT_LENGTH = 256
MAX_LEVEL_LENGTH = 16
MAX_PROVIDER_LENGTH = 128
MAX_MODEL_LENGTH = 256
MAX_DOCUMENT_FILENAME_LENGTH = 512
MAX_PROMPT_LENGTH = 200000
MAX_RESPONSE_LENGTH = 200000
DEFAULT_KEEP_DOCUMENT_SESSIONS = 2
DEFAULT_KEEP_UNBOUND_ENTRIES = 80
PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY = "processing_log_autocommit"
def _trim(value: str | None, max_length: int) -> str | None:
"""Normalizes and truncates text values for safe log persistence."""
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) <= max_length:
return normalized
return normalized[: max_length - 3] + "..."
def _safe_payload(payload_json: dict[str, Any] | None) -> dict[str, Any]:
"""Ensures payload values are persisted as dictionaries."""
return payload_json if isinstance(payload_json, dict) else {}
def set_processing_log_autocommit(session: Session, enabled: bool) -> None:
"""Toggles per-session immediate commit behavior for processing log events."""
session.info[PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY] = bool(enabled)
def is_processing_log_autocommit_enabled(session: Session) -> bool:
"""Returns whether processing logs are committed immediately for the current session."""
return bool(session.info.get(PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY, False))
def log_processing_event(
session: Session,
stage: str,
event: str,
*,
level: str = "info",
document: Document | None = None,
document_id: UUID | None = None,
document_filename: str | None = None,
provider_id: str | None = None,
model_name: str | None = None,
prompt_text: str | None = None,
response_text: str | None = None,
payload_json: dict[str, Any] | None = None,
) -> None:
"""Persists one processing log entry linked to an optional document context."""
resolved_document_id = document.id if document is not None else document_id
resolved_document_filename = document.original_filename if document is not None else document_filename
entry = ProcessingLogEntry(
level=_trim(level, MAX_LEVEL_LENGTH) or "info",
stage=_trim(stage, MAX_STAGE_LENGTH) or "pipeline",
event=_trim(event, MAX_EVENT_LENGTH) or "event",
document_id=resolved_document_id,
document_filename=_trim(resolved_document_filename, MAX_DOCUMENT_FILENAME_LENGTH),
provider_id=_trim(provider_id, MAX_PROVIDER_LENGTH),
model_name=_trim(model_name, MAX_MODEL_LENGTH),
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH),
response_text=_trim(response_text, MAX_RESPONSE_LENGTH),
payload_json=_safe_payload(payload_json),
)
session.add(entry)
if is_processing_log_autocommit_enabled(session):
session.commit()
def count_processing_logs(session: Session, document_id: UUID | None = None) -> int:
"""Counts persisted processing logs, optionally restricted to one document."""
statement = select(func.count()).select_from(ProcessingLogEntry)
if document_id is not None:
statement = statement.where(ProcessingLogEntry.document_id == document_id)
return int(session.execute(statement).scalar_one())
def list_processing_logs(
session: Session,
*,
limit: int,
offset: int,
document_id: UUID | None = None,
) -> list[ProcessingLogEntry]:
"""Lists processing logs ordered by newest-first with optional document filter."""
statement = select(ProcessingLogEntry)
if document_id is not None:
statement = statement.where(ProcessingLogEntry.document_id == document_id)
statement = statement.order_by(ProcessingLogEntry.created_at.desc(), ProcessingLogEntry.id.desc()).offset(offset).limit(limit)
return session.execute(statement).scalars().all()
def cleanup_processing_logs(
session: Session,
*,
keep_document_sessions: int = DEFAULT_KEEP_DOCUMENT_SESSIONS,
keep_unbound_entries: int = DEFAULT_KEEP_UNBOUND_ENTRIES,
) -> dict[str, int]:
"""Deletes old log entries while keeping recent document sessions and unbound events."""
normalized_keep_sessions = max(0, keep_document_sessions)
normalized_keep_unbound = max(0, keep_unbound_entries)
deleted_document_entries = 0
deleted_unbound_entries = 0
recent_document_rows = session.execute(
select(
ProcessingLogEntry.document_id,
func.max(ProcessingLogEntry.created_at).label("last_seen"),
)
.where(ProcessingLogEntry.document_id.is_not(None))
.group_by(ProcessingLogEntry.document_id)
.order_by(func.max(ProcessingLogEntry.created_at).desc())
.limit(normalized_keep_sessions)
).all()
keep_document_ids = [row[0] for row in recent_document_rows if row[0] is not None]
if keep_document_ids:
deleted_document_entries = int(
session.execute(
delete(ProcessingLogEntry).where(
ProcessingLogEntry.document_id.is_not(None),
ProcessingLogEntry.document_id.notin_(keep_document_ids),
)
).rowcount
or 0
)
else:
deleted_document_entries = int(
session.execute(delete(ProcessingLogEntry).where(ProcessingLogEntry.document_id.is_not(None))).rowcount or 0
)
keep_unbound_rows = session.execute(
select(ProcessingLogEntry.id)
.where(ProcessingLogEntry.document_id.is_(None))
.order_by(ProcessingLogEntry.created_at.desc(), ProcessingLogEntry.id.desc())
.limit(normalized_keep_unbound)
).all()
keep_unbound_ids = [row[0] for row in keep_unbound_rows]
if keep_unbound_ids:
deleted_unbound_entries = int(
session.execute(
delete(ProcessingLogEntry).where(
ProcessingLogEntry.document_id.is_(None),
ProcessingLogEntry.id.notin_(keep_unbound_ids),
)
).rowcount
or 0
)
else:
deleted_unbound_entries = int(
session.execute(delete(ProcessingLogEntry).where(ProcessingLogEntry.document_id.is_(None))).rowcount or 0
)
return {
"deleted_document_entries": deleted_document_entries,
"deleted_unbound_entries": deleted_unbound_entries,
}
def clear_processing_logs(session: Session) -> dict[str, int]:
"""Deletes all persisted processing log entries and returns deletion count."""
deleted_entries = int(session.execute(delete(ProcessingLogEntry)).rowcount or 0)
return {"deleted_entries": deleted_entries}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
"""File storage utilities for persistence, retrieval, and checksum calculation."""
import hashlib
import uuid
from datetime import UTC, datetime
from pathlib import Path
from app.core.config import get_settings
settings = get_settings()
def ensure_storage() -> None:
"""Ensures required storage directories exist at service startup."""
for relative in ["originals", "derived/previews", "tmp"]:
(settings.storage_root / relative).mkdir(parents=True, exist_ok=True)
def compute_sha256(data: bytes) -> str:
"""Computes a SHA-256 hex digest for raw file bytes."""
return hashlib.sha256(data).hexdigest()
def store_bytes(filename: str, data: bytes) -> str:
"""Stores file content under a unique path and returns its storage-relative location."""
stamp = datetime.now(UTC).strftime("%Y/%m/%d")
safe_ext = Path(filename).suffix.lower()
target_dir = settings.storage_root / "originals" / stamp
target_dir.mkdir(parents=True, exist_ok=True)
target_name = f"{uuid.uuid4()}{safe_ext}"
target_path = target_dir / target_name
target_path.write_bytes(data)
return str(target_path.relative_to(settings.storage_root))
def read_bytes(relative_path: str) -> bytes:
"""Reads and returns bytes from a storage-relative path."""
return (settings.storage_root / relative_path).read_bytes()
def absolute_path(relative_path: str) -> Path:
"""Returns the absolute filesystem path for a storage-relative location."""
return settings.storage_root / relative_path
def write_preview(document_id: str, data: bytes, suffix: str = ".jpg") -> str:
"""Writes preview bytes and returns the preview path relative to storage root."""
target_dir = settings.storage_root / "derived" / "previews"
target_dir.mkdir(parents=True, exist_ok=True)
target_path = target_dir / f"{document_id}{suffix}"
target_path.write_bytes(data)
return str(target_path.relative_to(settings.storage_root))

View File

@@ -0,0 +1,257 @@
"""Typesense indexing and semantic-neighbor retrieval for document routing."""
from dataclasses import dataclass
from typing import Any
import typesense
from app.core.config import get_settings
from app.models.document import Document, DocumentStatus
settings = get_settings()
MAX_TYPESENSE_QUERY_CHARS = 600
@dataclass
class SimilarDocument:
"""Represents one nearest-neighbor document returned by Typesense semantic search."""
document_id: str
document_name: str
summary_text: str
logical_path: str
tags: list[str]
vector_distance: float
def _build_client() -> typesense.Client:
"""Builds a Typesense API client using configured host and credentials."""
return typesense.Client(
{
"nodes": [
{
"host": settings.typesense_host,
"port": str(settings.typesense_port),
"protocol": settings.typesense_protocol,
}
],
"api_key": settings.typesense_api_key,
"connection_timeout_seconds": settings.typesense_timeout_seconds,
"num_retries": settings.typesense_num_retries,
}
)
_client: typesense.Client | None = None
def get_typesense_client() -> typesense.Client:
"""Returns a cached Typesense client for repeated indexing and search operations."""
global _client
if _client is None:
_client = _build_client()
return _client
def _collection() -> Any:
"""Returns the configured Typesense collection handle."""
client = get_typesense_client()
return client.collections[settings.typesense_collection_name]
def ensure_typesense_collection() -> None:
"""Creates the document semantic collection when it does not already exist."""
collection = _collection()
try:
collection.retrieve()
return
except Exception as error:
message = str(error).lower()
if "404" not in message and "not found" not in message:
raise
schema = {
"name": settings.typesense_collection_name,
"fields": [
{
"name": "document_name",
"type": "string",
},
{
"name": "summary_text",
"type": "string",
},
{
"name": "logical_path",
"type": "string",
"facet": True,
},
{
"name": "tags",
"type": "string[]",
"facet": True,
},
{
"name": "status",
"type": "string",
"facet": True,
},
{
"name": "mime_type",
"type": "string",
"optional": True,
"facet": True,
},
{
"name": "extension",
"type": "string",
"optional": True,
"facet": True,
},
{
"name": "created_at",
"type": "int64",
},
{
"name": "has_labels",
"type": "bool",
"facet": True,
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": [
"document_name",
"summary_text",
],
"model_config": {
"model_name": "ts/e5-small-v2",
"indexing_prefix": "passage:",
"query_prefix": "query:",
},
},
},
],
"default_sorting_field": "created_at",
}
client = get_typesense_client()
client.collections.create(schema)
def _has_labels(document: Document) -> bool:
"""Determines whether a document has usable human-assigned routing metadata."""
if document.logical_path.strip() and document.logical_path.strip().lower() != "inbox":
return True
return len([tag for tag in document.tags if tag.strip()]) > 0
def upsert_document_index(document: Document, summary_text: str) -> None:
"""Upserts one document into Typesense for semantic retrieval and routing examples."""
ensure_typesense_collection()
collection = _collection()
payload = {
"id": str(document.id),
"document_name": document.original_filename,
"summary_text": summary_text[:50000],
"logical_path": document.logical_path,
"tags": [tag for tag in document.tags if tag.strip()][:50],
"status": document.status.value,
"mime_type": document.mime_type,
"extension": document.extension,
"created_at": int(document.created_at.timestamp()),
"has_labels": _has_labels(document) and document.status != DocumentStatus.TRASHED,
}
collection.documents.upsert(payload)
def delete_document_index(document_id: str) -> None:
"""Deletes one document from Typesense by identifier."""
collection = _collection()
try:
collection.documents[document_id].delete()
except Exception as error:
message = str(error).lower()
if "404" in message or "not found" in message:
return
raise
def delete_many_documents_index(document_ids: list[str]) -> None:
"""Deletes many documents from Typesense by identifiers."""
for document_id in document_ids:
delete_document_index(document_id)
def query_similar_documents(summary_text: str, limit: int, exclude_document_id: str | None = None) -> list[SimilarDocument]:
"""Returns semantic nearest neighbors among labeled non-trashed indexed documents."""
ensure_typesense_collection()
collection = _collection()
normalized_query = " ".join(summary_text.strip().split())
query_text = normalized_query[:MAX_TYPESENSE_QUERY_CHARS] if normalized_query else "document"
search_payload = {
"q": query_text,
"query_by": "embedding",
"vector_query": f"embedding:([], k:{max(1, limit)})",
"exclude_fields": "embedding",
"per_page": max(1, limit),
"filter_by": "has_labels:=true && status:!=trashed",
}
try:
response = collection.documents.search(search_payload)
except Exception as error:
message = str(error).lower()
if "query string exceeds max allowed length" not in message:
raise
fallback_payload = dict(search_payload)
fallback_payload["q"] = "document"
response = collection.documents.search(fallback_payload)
hits = response.get("hits", []) if isinstance(response, dict) else []
neighbors: list[SimilarDocument] = []
for hit in hits:
if not isinstance(hit, dict):
continue
document = hit.get("document", {})
if not isinstance(document, dict):
continue
document_id = str(document.get("id", "")).strip()
if not document_id:
continue
if exclude_document_id and document_id == exclude_document_id:
continue
raw_tags = document.get("tags", [])
tags = [str(tag).strip() for tag in raw_tags if str(tag).strip()] if isinstance(raw_tags, list) else []
try:
distance = float(hit.get("vector_distance", 2.0))
except (TypeError, ValueError):
distance = 2.0
neighbors.append(
SimilarDocument(
document_id=document_id,
document_name=str(document.get("document_name", "")).strip(),
summary_text=str(document.get("summary_text", "")).strip(),
logical_path=str(document.get("logical_path", "")).strip(),
tags=tags,
vector_distance=distance,
)
)
if len(neighbors) >= limit:
break
return neighbors