Initial commit
This commit is contained in:
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Domain services package for storage, extraction, and classification logic."""
|
||||
885
backend/app/services/app_settings.py
Normal file
885
backend/app/services/app_settings.py
Normal file
@@ -0,0 +1,885 @@
|
||||
"""Persistent single-user application settings service backed by host-mounted storage."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
TASK_OCR_HANDWRITING = "ocr_handwriting"
|
||||
TASK_SUMMARY_GENERATION = "summary_generation"
|
||||
TASK_ROUTING_CLASSIFICATION = "routing_classification"
|
||||
HANDWRITING_STYLE_SETTINGS_KEY = "handwriting_style_clustering"
|
||||
PREDEFINED_PATHS_SETTINGS_KEY = "predefined_paths"
|
||||
PREDEFINED_TAGS_SETTINGS_KEY = "predefined_tags"
|
||||
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL = "ts/clip-vit-b-p32"
|
||||
|
||||
|
||||
DEFAULT_OCR_PROMPT = (
|
||||
"You are an expert at reading messy handwritten notes, including hard-to-read writing.\n"
|
||||
"Task: transcribe the handwriting as exactly as possible.\n\n"
|
||||
"Rules:\n"
|
||||
"- Output ONLY the transcription in German, no commentary.\n"
|
||||
"- Preserve original line breaks where they clearly exist.\n"
|
||||
"- Do NOT translate or correct grammar or spelling.\n"
|
||||
"- If a word or character is unclear, wrap your best guess in [[? ... ?]].\n"
|
||||
"- If something is unreadable, write [[?unleserlich?]] in its place."
|
||||
)
|
||||
|
||||
DEFAULT_SUMMARY_PROMPT = (
|
||||
"You summarize documents for indexing and routing.\n"
|
||||
"Return concise markdown with key entities, purpose, and document category hints.\n"
|
||||
"Do not invent facts and do not include any explanation outside the summary."
|
||||
)
|
||||
|
||||
DEFAULT_ROUTING_PROMPT = (
|
||||
"You classify one document into an existing logical path and tags.\n"
|
||||
"Prefer existing paths and tags when possible.\n"
|
||||
"If the evidence is weak, keep chosen_path as null and use suggestions instead.\n"
|
||||
"Return JSON only with this exact shape:\n"
|
||||
"{\n"
|
||||
" \"chosen_path\": string | null,\n"
|
||||
" \"chosen_tags\": string[],\n"
|
||||
" \"suggested_new_paths\": string[],\n"
|
||||
" \"suggested_new_tags\": string[],\n"
|
||||
" \"confidence\": number\n"
|
||||
"}\n"
|
||||
"Confidence must be between 0 and 1."
|
||||
)
|
||||
|
||||
|
||||
def _default_settings() -> dict[str, Any]:
|
||||
"""Builds default settings including providers and model task bindings."""
|
||||
|
||||
return {
|
||||
"upload_defaults": {
|
||||
"logical_path": "Inbox",
|
||||
"tags": [],
|
||||
},
|
||||
"display": {
|
||||
"cards_per_page": 12,
|
||||
"log_typing_animation_enabled": True,
|
||||
},
|
||||
PREDEFINED_PATHS_SETTINGS_KEY: [],
|
||||
PREDEFINED_TAGS_SETTINGS_KEY: [],
|
||||
HANDWRITING_STYLE_SETTINGS_KEY: {
|
||||
"enabled": True,
|
||||
"embed_model": DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
|
||||
"neighbor_limit": 8,
|
||||
"match_min_similarity": 0.86,
|
||||
"bootstrap_match_min_similarity": 0.89,
|
||||
"bootstrap_sample_size": 3,
|
||||
"image_max_side": 1024,
|
||||
},
|
||||
"providers": [
|
||||
{
|
||||
"id": "openai-default",
|
||||
"label": "OpenAI Default",
|
||||
"provider_type": "openai_compatible",
|
||||
"base_url": settings.default_openai_base_url,
|
||||
"timeout_seconds": settings.default_openai_timeout_seconds,
|
||||
"api_key": settings.default_openai_api_key,
|
||||
}
|
||||
],
|
||||
"tasks": {
|
||||
TASK_OCR_HANDWRITING: {
|
||||
"enabled": settings.default_openai_handwriting_enabled,
|
||||
"provider_id": "openai-default",
|
||||
"model": settings.default_openai_model,
|
||||
"prompt": DEFAULT_OCR_PROMPT,
|
||||
},
|
||||
TASK_SUMMARY_GENERATION: {
|
||||
"enabled": True,
|
||||
"provider_id": "openai-default",
|
||||
"model": settings.default_summary_model,
|
||||
"prompt": DEFAULT_SUMMARY_PROMPT,
|
||||
"max_input_tokens": 8000,
|
||||
},
|
||||
TASK_ROUTING_CLASSIFICATION: {
|
||||
"enabled": True,
|
||||
"provider_id": "openai-default",
|
||||
"model": settings.default_routing_model,
|
||||
"prompt": DEFAULT_ROUTING_PROMPT,
|
||||
"neighbor_count": 8,
|
||||
"neighbor_min_similarity": 0.84,
|
||||
"auto_apply_confidence_threshold": 0.78,
|
||||
"auto_apply_neighbor_similarity_threshold": 0.55,
|
||||
"neighbor_path_override_enabled": True,
|
||||
"neighbor_path_override_min_similarity": 0.86,
|
||||
"neighbor_path_override_min_gap": 0.04,
|
||||
"neighbor_path_override_max_confidence": 0.9,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _settings_path() -> Path:
|
||||
"""Returns the absolute path of the persisted settings file."""
|
||||
|
||||
return settings.storage_root / "settings.json"
|
||||
|
||||
|
||||
def _clamp_timeout(value: int) -> int:
|
||||
"""Clamps timeout values to a safe and practical range."""
|
||||
|
||||
return max(5, min(180, value))
|
||||
|
||||
|
||||
def _clamp_input_tokens(value: int) -> int:
|
||||
"""Clamps per-request summary input token budget values to practical bounds."""
|
||||
|
||||
return max(512, min(64000, value))
|
||||
|
||||
|
||||
def _clamp_neighbor_count(value: int) -> int:
|
||||
"""Clamps nearest-neighbor lookup count for routing classification."""
|
||||
|
||||
return max(1, min(40, value))
|
||||
|
||||
|
||||
def _clamp_cards_per_page(value: int) -> int:
|
||||
"""Clamps dashboard cards-per-page display setting to practical bounds."""
|
||||
|
||||
return max(1, min(200, value))
|
||||
|
||||
|
||||
def _clamp_predefined_entries_limit(value: int) -> int:
|
||||
"""Clamps maximum count for predefined tag/path catalog entries."""
|
||||
|
||||
return max(1, min(2000, value))
|
||||
|
||||
|
||||
def _clamp_handwriting_style_neighbor_limit(value: int) -> int:
|
||||
"""Clamps handwriting-style nearest-neighbor count used for style matching."""
|
||||
|
||||
return max(1, min(32, value))
|
||||
|
||||
|
||||
def _clamp_handwriting_style_sample_size(value: int) -> int:
|
||||
"""Clamps handwriting-style bootstrap sample size used for stricter matching."""
|
||||
|
||||
return max(1, min(30, value))
|
||||
|
||||
|
||||
def _clamp_handwriting_style_image_max_side(value: int) -> int:
|
||||
"""Clamps handwriting-style image normalization max-side pixel size."""
|
||||
|
||||
return max(256, min(4096, value))
|
||||
|
||||
|
||||
def _clamp_probability(value: float, fallback: float) -> float:
|
||||
"""Clamps probability-like numbers to the range [0, 1]."""
|
||||
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
return max(0.0, min(1.0, parsed))
|
||||
|
||||
|
||||
def _safe_int(value: Any, fallback: int) -> int:
|
||||
"""Safely converts arbitrary values to integers with fallback handling."""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
|
||||
|
||||
def _normalize_provider_id(value: str | None, fallback: str) -> str:
|
||||
"""Normalizes provider identifiers into stable lowercase slug values."""
|
||||
|
||||
candidate = (value or "").strip().lower()
|
||||
candidate = re.sub(r"[^a-z0-9_-]+", "-", candidate).strip("-")
|
||||
return candidate or fallback
|
||||
|
||||
|
||||
def _mask_api_key(value: str) -> str:
|
||||
"""Masks a secret API key while retaining enough characters for identification."""
|
||||
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) <= 6:
|
||||
return "*" * len(value)
|
||||
return f"{value[:4]}...{value[-2:]}"
|
||||
|
||||
|
||||
def _normalize_provider(
|
||||
payload: dict[str, Any],
|
||||
fallback_id: str,
|
||||
fallback_values: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Normalizes one provider payload to a stable shape with bounds and defaults."""
|
||||
|
||||
defaults = _default_settings()["providers"][0]
|
||||
provider_id = _normalize_provider_id(str(payload.get("id", fallback_id)), fallback_id)
|
||||
provider_type = str(payload.get("provider_type", fallback_values.get("provider_type", defaults["provider_type"]))).strip()
|
||||
if provider_type != "openai_compatible":
|
||||
provider_type = "openai_compatible"
|
||||
|
||||
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
|
||||
api_key = str(api_key_value).strip() if api_key_value is not None else ""
|
||||
|
||||
return {
|
||||
"id": provider_id,
|
||||
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
|
||||
"provider_type": provider_type,
|
||||
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
|
||||
or defaults["base_url"],
|
||||
"timeout_seconds": _clamp_timeout(
|
||||
_safe_int(
|
||||
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),
|
||||
defaults["timeout_seconds"],
|
||||
)
|
||||
),
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_ocr_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
|
||||
"""Normalizes OCR task settings while enforcing valid provider references."""
|
||||
|
||||
defaults = _default_settings()["tasks"][TASK_OCR_HANDWRITING]
|
||||
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
|
||||
if provider_id not in provider_ids:
|
||||
provider_id = provider_ids[0]
|
||||
|
||||
return {
|
||||
"enabled": bool(payload.get("enabled", defaults["enabled"])),
|
||||
"provider_id": provider_id,
|
||||
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
|
||||
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
|
||||
}
|
||||
|
||||
|
||||
def _normalize_summary_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
|
||||
"""Normalizes summary task settings while enforcing valid provider references."""
|
||||
|
||||
defaults = _default_settings()["tasks"][TASK_SUMMARY_GENERATION]
|
||||
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
|
||||
if provider_id not in provider_ids:
|
||||
provider_id = provider_ids[0]
|
||||
|
||||
raw_max_tokens = payload.get("max_input_tokens")
|
||||
if raw_max_tokens is None:
|
||||
legacy_chars = _safe_int(payload.get("max_source_chars", 0), 0)
|
||||
if legacy_chars > 0:
|
||||
raw_max_tokens = max(512, legacy_chars // 4)
|
||||
else:
|
||||
raw_max_tokens = defaults["max_input_tokens"]
|
||||
|
||||
return {
|
||||
"enabled": bool(payload.get("enabled", defaults["enabled"])),
|
||||
"provider_id": provider_id,
|
||||
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
|
||||
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
|
||||
"max_input_tokens": _clamp_input_tokens(
|
||||
_safe_int(raw_max_tokens, defaults["max_input_tokens"])
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_routing_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
|
||||
"""Normalizes routing task settings while enforcing valid provider references."""
|
||||
|
||||
defaults = _default_settings()["tasks"][TASK_ROUTING_CLASSIFICATION]
|
||||
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
|
||||
if provider_id not in provider_ids:
|
||||
provider_id = provider_ids[0]
|
||||
|
||||
return {
|
||||
"enabled": bool(payload.get("enabled", defaults["enabled"])),
|
||||
"provider_id": provider_id,
|
||||
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
|
||||
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
|
||||
"neighbor_count": _clamp_neighbor_count(
|
||||
_safe_int(payload.get("neighbor_count", defaults["neighbor_count"]), defaults["neighbor_count"])
|
||||
),
|
||||
"neighbor_min_similarity": _clamp_probability(
|
||||
payload.get("neighbor_min_similarity", defaults["neighbor_min_similarity"]),
|
||||
defaults["neighbor_min_similarity"],
|
||||
),
|
||||
"auto_apply_confidence_threshold": _clamp_probability(
|
||||
payload.get("auto_apply_confidence_threshold", defaults["auto_apply_confidence_threshold"]),
|
||||
defaults["auto_apply_confidence_threshold"],
|
||||
),
|
||||
"auto_apply_neighbor_similarity_threshold": _clamp_probability(
|
||||
payload.get(
|
||||
"auto_apply_neighbor_similarity_threshold",
|
||||
defaults["auto_apply_neighbor_similarity_threshold"],
|
||||
),
|
||||
defaults["auto_apply_neighbor_similarity_threshold"],
|
||||
),
|
||||
"neighbor_path_override_enabled": bool(
|
||||
payload.get("neighbor_path_override_enabled", defaults["neighbor_path_override_enabled"])
|
||||
),
|
||||
"neighbor_path_override_min_similarity": _clamp_probability(
|
||||
payload.get(
|
||||
"neighbor_path_override_min_similarity",
|
||||
defaults["neighbor_path_override_min_similarity"],
|
||||
),
|
||||
defaults["neighbor_path_override_min_similarity"],
|
||||
),
|
||||
"neighbor_path_override_min_gap": _clamp_probability(
|
||||
payload.get("neighbor_path_override_min_gap", defaults["neighbor_path_override_min_gap"]),
|
||||
defaults["neighbor_path_override_min_gap"],
|
||||
),
|
||||
"neighbor_path_override_max_confidence": _clamp_probability(
|
||||
payload.get(
|
||||
"neighbor_path_override_max_confidence",
|
||||
defaults["neighbor_path_override_max_confidence"],
|
||||
),
|
||||
defaults["neighbor_path_override_max_confidence"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_tasks(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
|
||||
"""Normalizes task settings map for OCR, summarization, and routing tasks."""
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
return {
|
||||
TASK_OCR_HANDWRITING: _normalize_ocr_task(payload.get(TASK_OCR_HANDWRITING, {}), provider_ids),
|
||||
TASK_SUMMARY_GENERATION: _normalize_summary_task(payload.get(TASK_SUMMARY_GENERATION, {}), provider_ids),
|
||||
TASK_ROUTING_CLASSIFICATION: _normalize_routing_task(payload.get(TASK_ROUTING_CLASSIFICATION, {}), provider_ids),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_upload_defaults(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalizes upload default destination path and tags."""
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
|
||||
default_path = str(defaults.get("logical_path", "Inbox")).strip() or "Inbox"
|
||||
raw_path = str(payload.get("logical_path", default_path)).strip()
|
||||
logical_path = raw_path or default_path
|
||||
|
||||
raw_tags = payload.get("tags", defaults.get("tags", []))
|
||||
tags: list[str] = []
|
||||
seen_lowered: set[str] = set()
|
||||
if isinstance(raw_tags, list):
|
||||
for raw_tag in raw_tags:
|
||||
normalized = str(raw_tag).strip()
|
||||
if not normalized:
|
||||
continue
|
||||
lowered = normalized.lower()
|
||||
if lowered in seen_lowered:
|
||||
continue
|
||||
seen_lowered.add(lowered)
|
||||
tags.append(normalized)
|
||||
if len(tags) >= 50:
|
||||
break
|
||||
|
||||
return {
|
||||
"logical_path": logical_path,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_display_settings(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalizes display settings used by the document dashboard UI."""
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
|
||||
default_cards_per_page = _safe_int(defaults.get("cards_per_page", 12), 12)
|
||||
cards_per_page = _clamp_cards_per_page(
|
||||
_safe_int(payload.get("cards_per_page", default_cards_per_page), default_cards_per_page)
|
||||
)
|
||||
return {
|
||||
"cards_per_page": cards_per_page,
|
||||
"log_typing_animation_enabled": bool(
|
||||
payload.get("log_typing_animation_enabled", defaults.get("log_typing_animation_enabled", True))
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_predefined_paths(
|
||||
payload: Any,
|
||||
existing_items: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Normalizes predefined path entries and enforces irreversible global-sharing flag."""
|
||||
|
||||
existing_map: dict[str, dict[str, Any]] = {}
|
||||
if isinstance(existing_items, list):
|
||||
for item in existing_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
value = str(item.get("value", "")).strip().strip("/")
|
||||
if not value:
|
||||
continue
|
||||
existing_map[value.lower()] = {
|
||||
"value": value,
|
||||
"global_shared": bool(item.get("global_shared", False)),
|
||||
}
|
||||
|
||||
if not isinstance(payload, list):
|
||||
return list(existing_map.values())
|
||||
|
||||
normalized: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
limit = _clamp_predefined_entries_limit(len(payload))
|
||||
for item in payload:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
value = str(item.get("value", "")).strip().strip("/")
|
||||
if not value:
|
||||
continue
|
||||
lowered = value.lower()
|
||||
if lowered in seen:
|
||||
continue
|
||||
seen.add(lowered)
|
||||
existing = existing_map.get(lowered)
|
||||
requested_global = bool(item.get("global_shared", False))
|
||||
global_shared = bool(existing.get("global_shared", False) if existing else False) or requested_global
|
||||
normalized.append(
|
||||
{
|
||||
"value": value,
|
||||
"global_shared": global_shared,
|
||||
}
|
||||
)
|
||||
if len(normalized) >= limit:
|
||||
break
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_predefined_tags(
|
||||
payload: Any,
|
||||
existing_items: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Normalizes predefined tag entries and enforces irreversible global-sharing flag."""
|
||||
|
||||
existing_map: dict[str, dict[str, Any]] = {}
|
||||
if isinstance(existing_items, list):
|
||||
for item in existing_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
value = str(item.get("value", "")).strip()
|
||||
if not value:
|
||||
continue
|
||||
existing_map[value.lower()] = {
|
||||
"value": value,
|
||||
"global_shared": bool(item.get("global_shared", False)),
|
||||
}
|
||||
|
||||
if not isinstance(payload, list):
|
||||
return list(existing_map.values())
|
||||
|
||||
normalized: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
limit = _clamp_predefined_entries_limit(len(payload))
|
||||
for item in payload:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
value = str(item.get("value", "")).strip()
|
||||
if not value:
|
||||
continue
|
||||
lowered = value.lower()
|
||||
if lowered in seen:
|
||||
continue
|
||||
seen.add(lowered)
|
||||
existing = existing_map.get(lowered)
|
||||
requested_global = bool(item.get("global_shared", False))
|
||||
global_shared = bool(existing.get("global_shared", False) if existing else False) or requested_global
|
||||
normalized.append(
|
||||
{
|
||||
"value": value,
|
||||
"global_shared": global_shared,
|
||||
}
|
||||
)
|
||||
if len(normalized) >= limit:
|
||||
break
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_handwriting_style_settings(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalizes handwriting-style clustering settings exposed in the settings UI."""
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
|
||||
default_enabled = bool(defaults.get("enabled", True))
|
||||
default_embed_model = str(defaults.get("embed_model", DEFAULT_HANDWRITING_STYLE_EMBED_MODEL)).strip()
|
||||
default_neighbor_limit = _safe_int(defaults.get("neighbor_limit", 8), 8)
|
||||
default_match_min = _clamp_probability(defaults.get("match_min_similarity", 0.86), 0.86)
|
||||
default_bootstrap_match_min = _clamp_probability(defaults.get("bootstrap_match_min_similarity", 0.89), 0.89)
|
||||
default_bootstrap_sample_size = _safe_int(defaults.get("bootstrap_sample_size", 3), 3)
|
||||
default_image_max_side = _safe_int(defaults.get("image_max_side", 1024), 1024)
|
||||
|
||||
return {
|
||||
"enabled": bool(payload.get("enabled", default_enabled)),
|
||||
"embed_model": str(payload.get("embed_model", default_embed_model)).strip() or default_embed_model,
|
||||
"neighbor_limit": _clamp_handwriting_style_neighbor_limit(
|
||||
_safe_int(payload.get("neighbor_limit", default_neighbor_limit), default_neighbor_limit)
|
||||
),
|
||||
"match_min_similarity": _clamp_probability(
|
||||
payload.get("match_min_similarity", default_match_min),
|
||||
default_match_min,
|
||||
),
|
||||
"bootstrap_match_min_similarity": _clamp_probability(
|
||||
payload.get("bootstrap_match_min_similarity", default_bootstrap_match_min),
|
||||
default_bootstrap_match_min,
|
||||
),
|
||||
"bootstrap_sample_size": _clamp_handwriting_style_sample_size(
|
||||
_safe_int(payload.get("bootstrap_sample_size", default_bootstrap_sample_size), default_bootstrap_sample_size)
|
||||
),
|
||||
"image_max_side": _clamp_handwriting_style_image_max_side(
|
||||
_safe_int(payload.get("image_max_side", default_image_max_side), default_image_max_side)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _sanitize_settings(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitizes all persisted settings into a stable normalized structure."""
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
payload = {}
|
||||
|
||||
defaults = _default_settings()
|
||||
|
||||
providers_payload = payload.get("providers")
|
||||
normalized_providers: list[dict[str, Any]] = []
|
||||
seen_provider_ids: set[str] = set()
|
||||
|
||||
if isinstance(providers_payload, list):
|
||||
for index, provider_payload in enumerate(providers_payload):
|
||||
if not isinstance(provider_payload, dict):
|
||||
continue
|
||||
fallback = defaults["providers"][0]
|
||||
candidate = _normalize_provider(provider_payload, fallback_id=f"provider-{index + 1}", fallback_values=fallback)
|
||||
if candidate["id"] in seen_provider_ids:
|
||||
continue
|
||||
seen_provider_ids.add(candidate["id"])
|
||||
normalized_providers.append(candidate)
|
||||
|
||||
if not normalized_providers:
|
||||
normalized_providers = [dict(defaults["providers"][0])]
|
||||
|
||||
provider_ids = [provider["id"] for provider in normalized_providers]
|
||||
tasks_payload = payload.get("tasks", {})
|
||||
normalized_tasks = _normalize_tasks(tasks_payload, provider_ids)
|
||||
upload_defaults = _normalize_upload_defaults(payload.get("upload_defaults", {}), defaults["upload_defaults"])
|
||||
display_settings = _normalize_display_settings(payload.get("display", {}), defaults["display"])
|
||||
predefined_paths = _normalize_predefined_paths(
|
||||
payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
|
||||
existing_items=payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
|
||||
)
|
||||
predefined_tags = _normalize_predefined_tags(
|
||||
payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
|
||||
existing_items=payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
|
||||
)
|
||||
handwriting_style_settings = _normalize_handwriting_style_settings(
|
||||
payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
|
||||
defaults[HANDWRITING_STYLE_SETTINGS_KEY],
|
||||
)
|
||||
|
||||
return {
|
||||
"upload_defaults": upload_defaults,
|
||||
"display": display_settings,
|
||||
PREDEFINED_PATHS_SETTINGS_KEY: predefined_paths,
|
||||
PREDEFINED_TAGS_SETTINGS_KEY: predefined_tags,
|
||||
HANDWRITING_STYLE_SETTINGS_KEY: handwriting_style_settings,
|
||||
"providers": normalized_providers,
|
||||
"tasks": normalized_tasks,
|
||||
}
|
||||
|
||||
|
||||
def ensure_app_settings() -> None:
|
||||
"""Creates a settings file with defaults when no persisted settings are present."""
|
||||
|
||||
path = _settings_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if path.exists():
|
||||
return
|
||||
|
||||
defaults = _sanitize_settings(_default_settings())
|
||||
path.write_text(json.dumps(defaults, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def _read_raw_settings() -> dict[str, Any]:
|
||||
"""Reads persisted settings from disk and returns normalized values."""
|
||||
|
||||
ensure_app_settings()
|
||||
path = _settings_path()
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
payload = {}
|
||||
return _sanitize_settings(payload)
|
||||
|
||||
|
||||
def _write_settings(payload: dict[str, Any]) -> None:
|
||||
"""Persists sanitized settings payload to host-mounted storage."""
|
||||
|
||||
path = _settings_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def read_app_settings() -> dict[str, Any]:
|
||||
"""Reads settings and returns a sanitized view safe for API responses."""
|
||||
|
||||
payload = _read_raw_settings()
|
||||
providers_response: list[dict[str, Any]] = []
|
||||
for provider in payload["providers"]:
|
||||
api_key = str(provider.get("api_key", ""))
|
||||
providers_response.append(
|
||||
{
|
||||
"id": provider["id"],
|
||||
"label": provider["label"],
|
||||
"provider_type": provider["provider_type"],
|
||||
"base_url": provider["base_url"],
|
||||
"timeout_seconds": int(provider["timeout_seconds"]),
|
||||
"api_key_set": bool(api_key),
|
||||
"api_key_masked": _mask_api_key(api_key),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"upload_defaults": payload.get("upload_defaults", {"logical_path": "Inbox", "tags": []}),
|
||||
"display": payload.get("display", {"cards_per_page": 12, "log_typing_animation_enabled": True}),
|
||||
PREDEFINED_PATHS_SETTINGS_KEY: payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
|
||||
PREDEFINED_TAGS_SETTINGS_KEY: payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
|
||||
HANDWRITING_STYLE_SETTINGS_KEY: payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
|
||||
"providers": providers_response,
|
||||
"tasks": payload["tasks"],
|
||||
}
|
||||
|
||||
|
||||
def reset_app_settings() -> dict[str, Any]:
|
||||
"""Resets persisted application settings to sanitized repository defaults."""
|
||||
|
||||
defaults = _sanitize_settings(_default_settings())
|
||||
_write_settings(defaults)
|
||||
return read_app_settings()
|
||||
|
||||
|
||||
def read_task_runtime_settings(task_name: str) -> dict[str, Any]:
|
||||
"""Returns runtime task settings and resolved provider including secret values."""
|
||||
|
||||
payload = _read_raw_settings()
|
||||
tasks = payload["tasks"]
|
||||
if task_name not in tasks:
|
||||
raise KeyError(f"Unknown task settings key: {task_name}")
|
||||
|
||||
task = dict(tasks[task_name])
|
||||
provider_map = {provider["id"]: provider for provider in payload["providers"]}
|
||||
provider = provider_map.get(task.get("provider_id"))
|
||||
if provider is None:
|
||||
provider = payload["providers"][0]
|
||||
task["provider_id"] = provider["id"]
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
"provider": dict(provider),
|
||||
}
|
||||
|
||||
|
||||
def update_app_settings(
|
||||
providers: list[dict[str, Any]] | None = None,
|
||||
tasks: dict[str, dict[str, Any]] | None = None,
|
||||
upload_defaults: dict[str, Any] | None = None,
|
||||
display: dict[str, Any] | None = None,
|
||||
handwriting_style: dict[str, Any] | None = None,
|
||||
predefined_paths: list[dict[str, Any]] | None = None,
|
||||
predefined_tags: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Updates app settings, persists them, and returns API-safe values."""
|
||||
|
||||
current_payload = _read_raw_settings()
|
||||
next_payload: dict[str, Any] = {
|
||||
"upload_defaults": dict(current_payload.get("upload_defaults", {"logical_path": "Inbox", "tags": []})),
|
||||
"display": dict(current_payload.get("display", {"cards_per_page": 12, "log_typing_animation_enabled": True})),
|
||||
PREDEFINED_PATHS_SETTINGS_KEY: list(current_payload.get(PREDEFINED_PATHS_SETTINGS_KEY, [])),
|
||||
PREDEFINED_TAGS_SETTINGS_KEY: list(current_payload.get(PREDEFINED_TAGS_SETTINGS_KEY, [])),
|
||||
HANDWRITING_STYLE_SETTINGS_KEY: dict(
|
||||
current_payload.get(HANDWRITING_STYLE_SETTINGS_KEY, _default_settings()[HANDWRITING_STYLE_SETTINGS_KEY])
|
||||
),
|
||||
"providers": list(current_payload["providers"]),
|
||||
"tasks": dict(current_payload["tasks"]),
|
||||
}
|
||||
|
||||
if providers is not None:
|
||||
existing_provider_map = {provider["id"]: provider for provider in current_payload["providers"]}
|
||||
next_providers: list[dict[str, Any]] = []
|
||||
for index, provider_payload in enumerate(providers):
|
||||
if not isinstance(provider_payload, dict):
|
||||
continue
|
||||
|
||||
provider_id = _normalize_provider_id(
|
||||
str(provider_payload.get("id", "")),
|
||||
fallback=f"provider-{index + 1}",
|
||||
)
|
||||
existing_provider = existing_provider_map.get(provider_id, {})
|
||||
merged_payload = dict(provider_payload)
|
||||
merged_payload["id"] = provider_id
|
||||
|
||||
if bool(provider_payload.get("clear_api_key", False)):
|
||||
merged_payload["api_key"] = ""
|
||||
elif "api_key" in provider_payload and provider_payload.get("api_key") is not None:
|
||||
merged_payload["api_key"] = str(provider_payload.get("api_key")).strip()
|
||||
else:
|
||||
merged_payload["api_key"] = str(existing_provider.get("api_key", ""))
|
||||
|
||||
normalized_provider = _normalize_provider(
|
||||
merged_payload,
|
||||
fallback_id=provider_id,
|
||||
fallback_values=existing_provider,
|
||||
)
|
||||
next_providers.append(normalized_provider)
|
||||
|
||||
if next_providers:
|
||||
next_payload["providers"] = next_providers
|
||||
|
||||
if tasks is not None:
|
||||
merged_tasks = dict(current_payload["tasks"])
|
||||
for task_name, task_update in tasks.items():
|
||||
if task_name not in merged_tasks or not isinstance(task_update, dict):
|
||||
continue
|
||||
existing_task = dict(merged_tasks[task_name])
|
||||
for key, value in task_update.items():
|
||||
if value is None:
|
||||
continue
|
||||
existing_task[key] = value
|
||||
merged_tasks[task_name] = existing_task
|
||||
next_payload["tasks"] = merged_tasks
|
||||
|
||||
if upload_defaults is not None and isinstance(upload_defaults, dict):
|
||||
next_upload_defaults = dict(next_payload.get("upload_defaults", {}))
|
||||
for key in ("logical_path", "tags"):
|
||||
if key in upload_defaults:
|
||||
next_upload_defaults[key] = upload_defaults[key]
|
||||
next_payload["upload_defaults"] = next_upload_defaults
|
||||
|
||||
if display is not None and isinstance(display, dict):
|
||||
next_display = dict(next_payload.get("display", {}))
|
||||
if "cards_per_page" in display:
|
||||
next_display["cards_per_page"] = display["cards_per_page"]
|
||||
if "log_typing_animation_enabled" in display:
|
||||
next_display["log_typing_animation_enabled"] = bool(display["log_typing_animation_enabled"])
|
||||
next_payload["display"] = next_display
|
||||
|
||||
if handwriting_style is not None and isinstance(handwriting_style, dict):
|
||||
next_handwriting_style = dict(next_payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}))
|
||||
for key in (
|
||||
"enabled",
|
||||
"embed_model",
|
||||
"neighbor_limit",
|
||||
"match_min_similarity",
|
||||
"bootstrap_match_min_similarity",
|
||||
"bootstrap_sample_size",
|
||||
"image_max_side",
|
||||
):
|
||||
if key in handwriting_style:
|
||||
next_handwriting_style[key] = handwriting_style[key]
|
||||
next_payload[HANDWRITING_STYLE_SETTINGS_KEY] = next_handwriting_style
|
||||
|
||||
if predefined_paths is not None:
|
||||
next_payload[PREDEFINED_PATHS_SETTINGS_KEY] = _normalize_predefined_paths(
|
||||
predefined_paths,
|
||||
existing_items=next_payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
|
||||
)
|
||||
|
||||
if predefined_tags is not None:
|
||||
next_payload[PREDEFINED_TAGS_SETTINGS_KEY] = _normalize_predefined_tags(
|
||||
predefined_tags,
|
||||
existing_items=next_payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
|
||||
)
|
||||
|
||||
sanitized = _sanitize_settings(next_payload)
|
||||
_write_settings(sanitized)
|
||||
return read_app_settings()
|
||||
|
||||
|
||||
def read_handwriting_provider_settings() -> dict[str, Any]:
|
||||
"""Returns OCR settings in legacy shape for the handwriting transcription service."""
|
||||
|
||||
runtime = read_task_runtime_settings(TASK_OCR_HANDWRITING)
|
||||
provider = runtime["provider"]
|
||||
task = runtime["task"]
|
||||
|
||||
return {
|
||||
"provider": provider["provider_type"],
|
||||
"enabled": bool(task.get("enabled", True)),
|
||||
"openai_base_url": str(provider.get("base_url", settings.default_openai_base_url)),
|
||||
"openai_model": str(task.get("model", settings.default_openai_model)),
|
||||
"openai_timeout_seconds": int(provider.get("timeout_seconds", settings.default_openai_timeout_seconds)),
|
||||
"openai_api_key": str(provider.get("api_key", "")),
|
||||
"prompt": str(task.get("prompt", DEFAULT_OCR_PROMPT)),
|
||||
"provider_id": str(provider.get("id", "openai-default")),
|
||||
}
|
||||
|
||||
|
||||
def read_handwriting_style_settings() -> dict[str, Any]:
|
||||
"""Returns handwriting-style clustering settings for Typesense style assignment logic."""
|
||||
|
||||
payload = _read_raw_settings()
|
||||
defaults = _default_settings()[HANDWRITING_STYLE_SETTINGS_KEY]
|
||||
return _normalize_handwriting_style_settings(
|
||||
payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
|
||||
defaults,
|
||||
)
|
||||
|
||||
|
||||
def read_predefined_paths_settings() -> list[dict[str, Any]]:
|
||||
"""Returns normalized predefined logical path catalog entries."""
|
||||
|
||||
payload = _read_raw_settings()
|
||||
return _normalize_predefined_paths(
|
||||
payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
|
||||
existing_items=payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
|
||||
)
|
||||
|
||||
|
||||
def read_predefined_tags_settings() -> list[dict[str, Any]]:
|
||||
"""Returns normalized predefined tag catalog entries."""
|
||||
|
||||
payload = _read_raw_settings()
|
||||
return _normalize_predefined_tags(
|
||||
payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
|
||||
existing_items=payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
|
||||
)
|
||||
|
||||
|
||||
def update_handwriting_settings(
|
||||
enabled: bool | None = None,
|
||||
openai_base_url: str | None = None,
|
||||
openai_model: str | None = None,
|
||||
openai_timeout_seconds: int | None = None,
|
||||
openai_api_key: str | None = None,
|
||||
clear_openai_api_key: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Updates OCR task and bound provider values using the legacy handwriting API contract."""
|
||||
|
||||
runtime = read_task_runtime_settings(TASK_OCR_HANDWRITING)
|
||||
provider = runtime["provider"]
|
||||
|
||||
provider_update: dict[str, Any] = {
|
||||
"id": provider["id"],
|
||||
"label": provider["label"],
|
||||
"provider_type": provider["provider_type"],
|
||||
"base_url": openai_base_url if openai_base_url is not None else provider["base_url"],
|
||||
"timeout_seconds": openai_timeout_seconds if openai_timeout_seconds is not None else provider["timeout_seconds"],
|
||||
}
|
||||
if clear_openai_api_key:
|
||||
provider_update["clear_api_key"] = True
|
||||
elif openai_api_key is not None:
|
||||
provider_update["api_key"] = openai_api_key
|
||||
|
||||
tasks_update: dict[str, dict[str, Any]] = {TASK_OCR_HANDWRITING: {}}
|
||||
if enabled is not None:
|
||||
tasks_update[TASK_OCR_HANDWRITING]["enabled"] = enabled
|
||||
if openai_model is not None:
|
||||
tasks_update[TASK_OCR_HANDWRITING]["model"] = openai_model
|
||||
|
||||
return update_app_settings(
|
||||
providers=[provider_update],
|
||||
tasks=tasks_update,
|
||||
)
|
||||
315
backend/app/services/extractor.py
Normal file
315
backend/app/services/extractor.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Document extraction service for text indexing, previews, and archive fan-out."""
|
||||
|
||||
import io
|
||||
import re
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import magic
|
||||
from docx import Document as DocxDocument
|
||||
from openpyxl import load_workbook
|
||||
from PIL import Image, ImageOps
|
||||
from pypdf import PdfReader
|
||||
import pymupdf
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.handwriting import (
|
||||
IMAGE_TEXT_TYPE_NO_TEXT,
|
||||
IMAGE_TEXT_TYPE_UNKNOWN,
|
||||
HandwritingTranscriptionError,
|
||||
HandwritingTranscriptionNotConfiguredError,
|
||||
HandwritingTranscriptionTimeoutError,
|
||||
classify_image_text_bytes,
|
||||
transcribe_handwriting_bytes,
|
||||
)
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
IMAGE_EXTENSIONS = {
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
".gif",
|
||||
".webp",
|
||||
".heic",
|
||||
}
|
||||
|
||||
SUPPORTED_TEXT_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".svg",
|
||||
".pdf",
|
||||
".docx",
|
||||
".xlsx",
|
||||
*IMAGE_EXTENSIONS,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""Represents output generated during extraction for a single file."""
|
||||
|
||||
text: str
|
||||
preview_bytes: bytes | None
|
||||
preview_suffix: str | None
|
||||
status: str
|
||||
metadata_json: dict[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchiveMember:
|
||||
"""Represents an extracted file entry from an archive."""
|
||||
|
||||
name: str
|
||||
data: bytes
|
||||
|
||||
|
||||
def sniff_mime(data: bytes) -> str:
|
||||
"""Detects MIME type using libmagic for robust format handling."""
|
||||
|
||||
return magic.from_buffer(data, mime=True) or "application/octet-stream"
|
||||
|
||||
|
||||
def is_supported_for_extraction(extension: str, mime_type: str) -> bool:
|
||||
"""Determines if a file should be text-processed for indexing and classification."""
|
||||
|
||||
return extension in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/")
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
"""Normalizes extracted text by removing repeated form separators and controls."""
|
||||
|
||||
cleaned = text.replace("\r", "\n").replace("\x00", "")
|
||||
lines: list[str] = []
|
||||
for line in cleaned.split("\n"):
|
||||
stripped = line.strip()
|
||||
if stripped and re.fullmatch(r"[.\-_*=~\s]{4,}", stripped):
|
||||
continue
|
||||
lines.append(line)
|
||||
|
||||
normalized = "\n".join(lines)
|
||||
normalized = re.sub(r"\n{3,}", "\n\n", normalized)
|
||||
return normalized.strip()
|
||||
|
||||
|
||||
def _extract_pdf_text(data: bytes) -> str:
|
||||
"""Extracts text from PDF bytes using pypdf page parsing."""
|
||||
|
||||
reader = PdfReader(io.BytesIO(data))
|
||||
pages: list[str] = []
|
||||
for page in reader.pages:
|
||||
pages.append(page.extract_text() or "")
|
||||
return _normalize_text("\n".join(pages))
|
||||
|
||||
|
||||
def _extract_pdf_preview(data: bytes) -> tuple[bytes | None, str | None]:
|
||||
"""Creates a JPEG thumbnail preview from the first PDF page."""
|
||||
|
||||
try:
|
||||
document = pymupdf.open(stream=data, filetype="pdf")
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
if document.page_count < 1:
|
||||
return None, None
|
||||
page = document.load_page(0)
|
||||
pixmap = page.get_pixmap(matrix=pymupdf.Matrix(1.5, 1.5), alpha=False)
|
||||
return pixmap.tobytes("jpeg"), ".jpg"
|
||||
except Exception:
|
||||
return None, None
|
||||
finally:
|
||||
document.close()
|
||||
|
||||
|
||||
def _extract_docx_text(data: bytes) -> str:
|
||||
"""Extracts paragraph text from DOCX content."""
|
||||
|
||||
document = DocxDocument(io.BytesIO(data))
|
||||
return _normalize_text("\n".join(paragraph.text for paragraph in document.paragraphs if paragraph.text))
|
||||
|
||||
|
||||
def _extract_xlsx_text(data: bytes) -> str:
|
||||
"""Extracts cell text from XLSX workbook sheets for indexing."""
|
||||
|
||||
workbook = load_workbook(io.BytesIO(data), data_only=True, read_only=True)
|
||||
chunks: list[str] = []
|
||||
for sheet in workbook.worksheets:
|
||||
chunks.append(sheet.title)
|
||||
row_count = 0
|
||||
for row in sheet.iter_rows(min_row=1, max_row=200):
|
||||
row_values = [str(cell.value) for cell in row if cell.value is not None]
|
||||
if row_values:
|
||||
chunks.append(" ".join(row_values))
|
||||
row_count += 1
|
||||
if row_count >= 200:
|
||||
break
|
||||
return _normalize_text("\n".join(chunks))
|
||||
|
||||
|
||||
def _build_image_preview(data: bytes) -> tuple[bytes | None, str | None]:
|
||||
"""Builds a JPEG preview thumbnail for image files."""
|
||||
|
||||
try:
|
||||
with Image.open(io.BytesIO(data)) as image:
|
||||
preview = ImageOps.exif_transpose(image).convert("RGB")
|
||||
preview.thumbnail((600, 600))
|
||||
output = io.BytesIO()
|
||||
preview.save(output, format="JPEG", optimize=True, quality=82)
|
||||
return output.getvalue(), ".jpg"
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
|
||||
def _extract_handwriting_text(data: bytes, mime_type: str) -> ExtractionResult:
|
||||
"""Extracts text from image bytes and records handwriting-vs-printed classification metadata."""
|
||||
|
||||
preview_bytes, preview_suffix = _build_image_preview(data)
|
||||
metadata_json: dict[str, object] = {}
|
||||
|
||||
try:
|
||||
text_type = classify_image_text_bytes(data, mime_type=mime_type)
|
||||
metadata_json = {
|
||||
"image_text_type": text_type.label,
|
||||
"image_text_type_confidence": text_type.confidence,
|
||||
"image_text_type_provider": text_type.provider,
|
||||
"image_text_type_model": text_type.model,
|
||||
}
|
||||
except HandwritingTranscriptionNotConfiguredError as error:
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="unsupported",
|
||||
metadata_json={"transcription_error": str(error), "image_text_type": IMAGE_TEXT_TYPE_UNKNOWN},
|
||||
)
|
||||
except HandwritingTranscriptionTimeoutError as error:
|
||||
metadata_json = {
|
||||
"image_text_type": IMAGE_TEXT_TYPE_UNKNOWN,
|
||||
"image_text_type_error": str(error),
|
||||
}
|
||||
except HandwritingTranscriptionError as error:
|
||||
metadata_json = {
|
||||
"image_text_type": IMAGE_TEXT_TYPE_UNKNOWN,
|
||||
"image_text_type_error": str(error),
|
||||
}
|
||||
|
||||
if metadata_json.get("image_text_type") == IMAGE_TEXT_TYPE_NO_TEXT:
|
||||
metadata_json["transcription_skipped"] = "no_text_detected"
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="processed",
|
||||
metadata_json=metadata_json,
|
||||
)
|
||||
|
||||
try:
|
||||
transcription = transcribe_handwriting_bytes(data, mime_type=mime_type)
|
||||
transcription_metadata: dict[str, object] = {
|
||||
"transcription_provider": transcription.provider,
|
||||
"transcription_model": transcription.model,
|
||||
"transcription_uncertainties": transcription.uncertainties,
|
||||
}
|
||||
return ExtractionResult(
|
||||
text=_normalize_text(transcription.text),
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="processed",
|
||||
metadata_json={**metadata_json, **transcription_metadata},
|
||||
)
|
||||
except HandwritingTranscriptionNotConfiguredError as error:
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="unsupported",
|
||||
metadata_json={**metadata_json, "transcription_error": str(error)},
|
||||
)
|
||||
except HandwritingTranscriptionTimeoutError as error:
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="error",
|
||||
metadata_json={**metadata_json, "transcription_error": str(error)},
|
||||
)
|
||||
except HandwritingTranscriptionError as error:
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="error",
|
||||
metadata_json={**metadata_json, "transcription_error": str(error)},
|
||||
)
|
||||
|
||||
|
||||
def extract_text_content(filename: str, data: bytes, mime_type: str) -> ExtractionResult:
|
||||
"""Extracts text and optional preview bytes for supported file types."""
|
||||
|
||||
extension = Path(filename).suffix.lower()
|
||||
text = ""
|
||||
preview_bytes: bytes | None = None
|
||||
preview_suffix: str | None = None
|
||||
|
||||
try:
|
||||
if extension == ".pdf":
|
||||
text = _extract_pdf_text(data)
|
||||
preview_bytes, preview_suffix = _extract_pdf_preview(data)
|
||||
elif extension in {".txt", ".md", ".csv", ".json", ".xml", ".svg"} or mime_type.startswith("text/"):
|
||||
text = _normalize_text(data.decode("utf-8", errors="ignore"))
|
||||
elif extension == ".docx":
|
||||
text = _extract_docx_text(data)
|
||||
elif extension == ".xlsx":
|
||||
text = _extract_xlsx_text(data)
|
||||
elif extension in IMAGE_EXTENSIONS:
|
||||
return _extract_handwriting_text(data=data, mime_type=mime_type)
|
||||
else:
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=None,
|
||||
preview_suffix=None,
|
||||
status="unsupported",
|
||||
metadata_json={"reason": "unsupported_format"},
|
||||
)
|
||||
except Exception as error:
|
||||
return ExtractionResult(
|
||||
text="",
|
||||
preview_bytes=None,
|
||||
preview_suffix=None,
|
||||
status="error",
|
||||
metadata_json={"reason": "extraction_exception", "error": str(error)},
|
||||
)
|
||||
|
||||
return ExtractionResult(
|
||||
text=text[: settings.max_text_length],
|
||||
preview_bytes=preview_bytes,
|
||||
preview_suffix=preview_suffix,
|
||||
status="processed",
|
||||
metadata_json={},
|
||||
)
|
||||
|
||||
|
||||
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
|
||||
"""Extracts processable members from zip archives with configurable depth limits."""
|
||||
|
||||
members: list[ArchiveMember] = []
|
||||
if depth > settings.max_zip_depth:
|
||||
return members
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(data)) as archive:
|
||||
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
|
||||
for info in infos:
|
||||
member_data = archive.read(info.filename)
|
||||
members.append(ArchiveMember(name=info.filename, data=member_data))
|
||||
|
||||
return members
|
||||
477
backend/app/services/handwriting.py
Normal file
477
backend/app/services/handwriting.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""Handwriting transcription service using OpenAI-compatible vision models."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from app.services.app_settings import DEFAULT_OCR_PROMPT, read_handwriting_provider_settings
|
||||
|
||||
MAX_IMAGE_SIDE = 2000
|
||||
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
|
||||
IMAGE_TEXT_TYPE_PRINTED = "printed_text"
|
||||
IMAGE_TEXT_TYPE_NO_TEXT = "no_text"
|
||||
IMAGE_TEXT_TYPE_UNKNOWN = "unknown"
|
||||
|
||||
IMAGE_TEXT_CLASSIFICATION_PROMPT = (
|
||||
"Classify the text content in this image.\n"
|
||||
"Choose exactly one label from: handwriting, printed_text, no_text.\n"
|
||||
"Definitions:\n"
|
||||
"- handwriting: text exists and most readable text is handwritten.\n"
|
||||
"- printed_text: text exists and most readable text is machine printed or typed.\n"
|
||||
"- no_text: no readable text is present.\n"
|
||||
"Return strict JSON only with shape:\n"
|
||||
"{\n"
|
||||
' "label": "handwriting|printed_text|no_text",\n'
|
||||
' "confidence": number\n'
|
||||
"}\n"
|
||||
"Confidence must be between 0 and 1."
|
||||
)
|
||||
|
||||
|
||||
class HandwritingTranscriptionError(Exception):
|
||||
"""Raised when handwriting transcription fails for a non-timeout reason."""
|
||||
|
||||
|
||||
class HandwritingTranscriptionTimeoutError(HandwritingTranscriptionError):
|
||||
"""Raised when handwriting transcription exceeds the configured timeout."""
|
||||
|
||||
|
||||
class HandwritingTranscriptionNotConfiguredError(HandwritingTranscriptionError):
|
||||
"""Raised when handwriting transcription is disabled or missing credentials."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandwritingTranscription:
|
||||
"""Represents transcription output and uncertainty markers."""
|
||||
|
||||
text: str
|
||||
uncertainties: list[str]
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageTextClassification:
|
||||
"""Represents model classification of image text modality for one image."""
|
||||
|
||||
label: str
|
||||
confidence: float
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
def _extract_uncertainties(text: str) -> list[str]:
|
||||
"""Extracts uncertainty markers from transcription output."""
|
||||
|
||||
matches = re.findall(r"\[\[\?(.*?)\?\]\]", text)
|
||||
return [match.strip() for match in matches if match.strip()]
|
||||
|
||||
|
||||
def _coerce_json_object(payload: str) -> dict[str, Any]:
|
||||
"""Parses and extracts a JSON object from raw model output text."""
|
||||
|
||||
text = payload.strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
|
||||
if fenced:
|
||||
try:
|
||||
parsed = json.loads(fenced.group(1))
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
first_brace = text.find("{")
|
||||
last_brace = text.rfind("}")
|
||||
if first_brace >= 0 and last_brace > first_brace:
|
||||
candidate = text[first_brace : last_brace + 1]
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def _clamp_probability(value: Any, fallback: float = 0.0) -> float:
|
||||
"""Clamps confidence-like values to the inclusive [0, 1] range."""
|
||||
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
return max(0.0, min(1.0, parsed))
|
||||
|
||||
|
||||
def _normalize_image_text_type(label: str) -> str:
|
||||
"""Normalizes classifier labels into one supported canonical image text type."""
|
||||
|
||||
normalized = label.strip().lower().replace("-", "_").replace(" ", "_")
|
||||
if normalized in {IMAGE_TEXT_TYPE_HANDWRITING, "handwritten", "handwritten_text"}:
|
||||
return IMAGE_TEXT_TYPE_HANDWRITING
|
||||
if normalized in {IMAGE_TEXT_TYPE_PRINTED, "printed", "typed", "machine_text"}:
|
||||
return IMAGE_TEXT_TYPE_PRINTED
|
||||
if normalized in {IMAGE_TEXT_TYPE_NO_TEXT, "no-text", "none", "no readable text"}:
|
||||
return IMAGE_TEXT_TYPE_NO_TEXT
|
||||
return IMAGE_TEXT_TYPE_UNKNOWN
|
||||
|
||||
|
||||
def _normalize_image_bytes(image_data: bytes) -> tuple[bytes, str]:
|
||||
"""Applies EXIF rotation and scales large images down for efficient transcription."""
|
||||
|
||||
with Image.open(io.BytesIO(image_data)) as image:
|
||||
rotated = ImageOps.exif_transpose(image)
|
||||
prepared = rotated.convert("RGB")
|
||||
long_side = max(prepared.width, prepared.height)
|
||||
if long_side > MAX_IMAGE_SIDE:
|
||||
scale = MAX_IMAGE_SIDE / long_side
|
||||
resized_width = max(1, int(prepared.width * scale))
|
||||
resized_height = max(1, int(prepared.height * scale))
|
||||
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output = io.BytesIO()
|
||||
prepared.save(output, format="JPEG", quality=90, optimize=True)
|
||||
return output.getvalue(), "image/jpeg"
|
||||
|
||||
|
||||
def _create_client(provider_settings: dict[str, Any]) -> OpenAI:
|
||||
"""Creates an OpenAI client configured for compatible endpoints and timeouts."""
|
||||
|
||||
api_key = str(provider_settings.get("openai_api_key", "")).strip() or "no-key-required"
|
||||
return OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=str(provider_settings["openai_base_url"]),
|
||||
timeout=int(provider_settings["openai_timeout_seconds"]),
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_response(response: Any) -> str:
|
||||
"""Extracts plain text from responses API output objects."""
|
||||
|
||||
output_text = getattr(response, "output_text", None)
|
||||
if isinstance(output_text, str) and output_text.strip():
|
||||
return output_text.strip()
|
||||
|
||||
output_items = getattr(response, "output", None)
|
||||
if not isinstance(output_items, list):
|
||||
return ""
|
||||
|
||||
texts: list[str] = []
|
||||
for item in output_items:
|
||||
item_data = item.model_dump() if hasattr(item, "model_dump") else item
|
||||
if not isinstance(item_data, dict):
|
||||
continue
|
||||
item_type = item_data.get("type")
|
||||
if item_type == "output_text":
|
||||
text = str(item_data.get("text", "")).strip()
|
||||
if text:
|
||||
texts.append(text)
|
||||
if item_type == "message":
|
||||
for content in item_data.get("content", []) or []:
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
if content.get("type") in {"output_text", "text"}:
|
||||
text = str(content.get("text", "")).strip()
|
||||
if text:
|
||||
texts.append(text)
|
||||
|
||||
return "\n".join(texts).strip()
|
||||
|
||||
|
||||
def _transcribe_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
||||
"""Transcribes handwriting using the responses API."""
|
||||
|
||||
response = client.responses.create(
|
||||
model=model,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": prompt,
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": image_data_url,
|
||||
"detail": "high",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
return _extract_text_from_response(response)
|
||||
|
||||
|
||||
def _transcribe_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
||||
"""Transcribes handwriting using chat completions for endpoint compatibility."""
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_data_url,
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
message_content = response.choices[0].message.content
|
||||
if isinstance(message_content, str):
|
||||
return message_content.strip()
|
||||
if isinstance(message_content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in message_content:
|
||||
if isinstance(part, dict):
|
||||
text = str(part.get("text", "")).strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
return "\n".join(text_parts).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _classify_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
||||
"""Classifies image text modality using the responses API."""
|
||||
|
||||
response = client.responses.create(
|
||||
model=model,
|
||||
input=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": prompt,
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": image_data_url,
|
||||
"detail": "high",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
return _extract_text_from_response(response)
|
||||
|
||||
|
||||
def _classify_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
||||
"""Classifies image text modality using chat completions for compatibility."""
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_data_url,
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
message_content = response.choices[0].message.content
|
||||
if isinstance(message_content, str):
|
||||
return message_content.strip()
|
||||
if isinstance(message_content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in message_content:
|
||||
if isinstance(part, dict):
|
||||
text = str(part.get("text", "")).strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
return "\n".join(text_parts).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _classify_image_text_data_url(image_data_url: str) -> ImageTextClassification:
|
||||
"""Classifies an image as handwriting, printed text, or no text."""
|
||||
|
||||
provider_settings = read_handwriting_provider_settings()
|
||||
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
|
||||
if provider_type != "openai_compatible":
|
||||
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
|
||||
|
||||
if not bool(provider_settings.get("enabled", True)):
|
||||
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
|
||||
|
||||
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
|
||||
client = _create_client(provider_settings)
|
||||
|
||||
try:
|
||||
output_text = _classify_with_responses(
|
||||
client=client,
|
||||
model=model,
|
||||
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
|
||||
image_data_url=image_data_url,
|
||||
)
|
||||
if not output_text:
|
||||
output_text = _classify_with_chat(
|
||||
client=client,
|
||||
model=model,
|
||||
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
|
||||
image_data_url=image_data_url,
|
||||
)
|
||||
except APITimeoutError as error:
|
||||
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
|
||||
except (APIConnectionError, APIError):
|
||||
try:
|
||||
output_text = _classify_with_chat(
|
||||
client=client,
|
||||
model=model,
|
||||
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
|
||||
image_data_url=image_data_url,
|
||||
)
|
||||
except APITimeoutError as timeout_error:
|
||||
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
|
||||
except Exception as fallback_error:
|
||||
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
|
||||
except Exception as error:
|
||||
raise HandwritingTranscriptionError(str(error)) from error
|
||||
|
||||
parsed = _coerce_json_object(output_text)
|
||||
if not parsed:
|
||||
raise HandwritingTranscriptionError("image_text_classification_parse_failed")
|
||||
|
||||
label = _normalize_image_text_type(str(parsed.get("label", "")))
|
||||
confidence = _clamp_probability(parsed.get("confidence", 0.0), fallback=0.0)
|
||||
return ImageTextClassification(
|
||||
label=label,
|
||||
confidence=confidence,
|
||||
provider="openai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def _transcribe_image_data_url(image_data_url: str) -> HandwritingTranscription:
|
||||
"""Transcribes a handwriting image data URL with configured OpenAI provider settings."""
|
||||
|
||||
provider_settings = read_handwriting_provider_settings()
|
||||
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
|
||||
if provider_type != "openai_compatible":
|
||||
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
|
||||
|
||||
if not bool(provider_settings.get("enabled", True)):
|
||||
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
|
||||
|
||||
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
|
||||
prompt = str(provider_settings.get("prompt", DEFAULT_OCR_PROMPT)).strip() or DEFAULT_OCR_PROMPT
|
||||
client = _create_client(provider_settings)
|
||||
|
||||
try:
|
||||
text = _transcribe_with_responses(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
|
||||
if not text:
|
||||
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
|
||||
except APITimeoutError as error:
|
||||
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
|
||||
except (APIConnectionError, APIError) as error:
|
||||
try:
|
||||
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
|
||||
except APITimeoutError as timeout_error:
|
||||
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
|
||||
except Exception as fallback_error:
|
||||
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
|
||||
except Exception as error:
|
||||
raise HandwritingTranscriptionError(str(error)) from error
|
||||
|
||||
final_text = text.strip()
|
||||
return HandwritingTranscription(
|
||||
text=final_text,
|
||||
uncertainties=_extract_uncertainties(final_text),
|
||||
provider="openai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def transcribe_handwriting_base64(image_base64: str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
|
||||
"""Transcribes handwriting from a base64 payload without data URL prefix."""
|
||||
|
||||
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
|
||||
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
|
||||
return _transcribe_image_data_url(image_data_url)
|
||||
|
||||
|
||||
def transcribe_handwriting_url(image_url: str) -> HandwritingTranscription:
|
||||
"""Transcribes handwriting from a direct image URL."""
|
||||
|
||||
return _transcribe_image_data_url(image_url)
|
||||
|
||||
|
||||
def transcribe_handwriting_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> HandwritingTranscription:
|
||||
"""Transcribes handwriting from raw image bytes after normalization."""
|
||||
|
||||
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
|
||||
encoded = base64.b64encode(normalized_bytes).decode("ascii")
|
||||
return transcribe_handwriting_base64(encoded, mime_type=normalized_mime)
|
||||
|
||||
|
||||
def classify_image_text_base64(image_base64: str, mime_type: str = "image/jpeg") -> ImageTextClassification:
|
||||
"""Classifies image text type from a base64 payload without data URL prefix."""
|
||||
|
||||
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
|
||||
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
|
||||
return _classify_image_text_data_url(image_data_url)
|
||||
|
||||
|
||||
def classify_image_text_url(image_url: str) -> ImageTextClassification:
|
||||
"""Classifies image text type from a direct image URL."""
|
||||
|
||||
return _classify_image_text_data_url(image_url)
|
||||
|
||||
|
||||
def classify_image_text_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> ImageTextClassification:
|
||||
"""Classifies image text type from raw image bytes after normalization."""
|
||||
|
||||
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
|
||||
encoded = base64.b64encode(normalized_bytes).decode("ascii")
|
||||
return classify_image_text_base64(encoded, mime_type=normalized_mime)
|
||||
|
||||
|
||||
def transcribe_handwriting(image: bytes | str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
|
||||
"""Transcribes handwriting from bytes, base64 text, or URL input."""
|
||||
|
||||
if isinstance(image, bytes):
|
||||
return transcribe_handwriting_bytes(image, mime_type=mime_type)
|
||||
|
||||
stripped = image.strip()
|
||||
if stripped.startswith("http://") or stripped.startswith("https://"):
|
||||
return transcribe_handwriting_url(stripped)
|
||||
return transcribe_handwriting_base64(stripped, mime_type=mime_type)
|
||||
435
backend/app/services/handwriting_style.py
Normal file
435
backend/app/services/handwriting_style.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""Handwriting-style clustering and style-scoped path composition for image documents."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.document import Document, DocumentStatus
|
||||
from app.services.app_settings import (
|
||||
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
|
||||
read_handwriting_style_settings,
|
||||
)
|
||||
from app.services.typesense_index import get_typesense_client
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
|
||||
HANDWRITING_STYLE_COLLECTION_SUFFIX = "_handwriting_styles"
|
||||
HANDWRITING_STYLE_EMBED_MODEL = DEFAULT_HANDWRITING_STYLE_EMBED_MODEL
|
||||
HANDWRITING_STYLE_MATCH_MIN_SIMILARITY = 0.86
|
||||
HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY = 0.89
|
||||
HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE = 3
|
||||
HANDWRITING_STYLE_NEIGHBOR_LIMIT = 8
|
||||
HANDWRITING_STYLE_IMAGE_MAX_SIDE = 1024
|
||||
HANDWRITING_STYLE_ID_PREFIX = "hw_style_"
|
||||
HANDWRITING_STYLE_ID_PATTERN = re.compile(r"^hw_style_(\d+)$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandwritingStyleNeighbor:
|
||||
"""Represents one nearest handwriting-style neighbor returned from Typesense."""
|
||||
|
||||
document_id: str
|
||||
style_cluster_id: str
|
||||
vector_distance: float
|
||||
similarity: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandwritingStyleAssignment:
|
||||
"""Represents the chosen handwriting-style cluster assignment for one document."""
|
||||
|
||||
style_cluster_id: str
|
||||
matched_existing: bool
|
||||
similarity: float
|
||||
vector_distance: float
|
||||
compared_neighbors: int
|
||||
match_min_similarity: float
|
||||
bootstrap_match_min_similarity: float
|
||||
|
||||
|
||||
def _style_collection_name() -> str:
|
||||
"""Builds the dedicated Typesense collection name used for handwriting-style vectors."""
|
||||
|
||||
return f"{settings.typesense_collection_name}{HANDWRITING_STYLE_COLLECTION_SUFFIX}"
|
||||
|
||||
|
||||
def _style_collection() -> Any:
|
||||
"""Returns the Typesense collection handle for handwriting-style indexing."""
|
||||
|
||||
client = get_typesense_client()
|
||||
return client.collections[_style_collection_name()]
|
||||
|
||||
|
||||
def _distance_to_similarity(vector_distance: float) -> float:
|
||||
"""Converts Typesense vector distance into conservative similarity in [0, 1]."""
|
||||
|
||||
return max(0.0, min(1.0, 1.0 - (vector_distance / 2.0)))
|
||||
|
||||
|
||||
def _encode_style_image_base64(image_data: bytes, image_max_side: int) -> str:
|
||||
"""Normalizes and downsizes image bytes and returns a base64-encoded JPEG payload."""
|
||||
|
||||
with Image.open(io.BytesIO(image_data)) as image:
|
||||
prepared = ImageOps.exif_transpose(image).convert("RGB")
|
||||
longest_side = max(prepared.width, prepared.height)
|
||||
if longest_side > image_max_side:
|
||||
scale = image_max_side / longest_side
|
||||
resized_width = max(1, int(prepared.width * scale))
|
||||
resized_height = max(1, int(prepared.height * scale))
|
||||
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output = io.BytesIO()
|
||||
prepared.save(output, format="JPEG", quality=86, optimize=True)
|
||||
return base64.b64encode(output.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def ensure_handwriting_style_collection() -> None:
|
||||
"""Creates the handwriting-style Typesense collection when it is not present."""
|
||||
|
||||
runtime_settings = read_handwriting_style_settings()
|
||||
embed_model = str(runtime_settings.get("embed_model", HANDWRITING_STYLE_EMBED_MODEL)).strip() or HANDWRITING_STYLE_EMBED_MODEL
|
||||
collection = _style_collection()
|
||||
should_recreate_collection = False
|
||||
try:
|
||||
existing_schema = collection.retrieve()
|
||||
if isinstance(existing_schema, dict):
|
||||
existing_fields = existing_schema.get("fields", [])
|
||||
if isinstance(existing_fields, list):
|
||||
for field in existing_fields:
|
||||
if not isinstance(field, dict):
|
||||
continue
|
||||
if str(field.get("name", "")).strip() != "embedding":
|
||||
continue
|
||||
embed_config = field.get("embed", {})
|
||||
model_config = embed_config.get("model_config", {}) if isinstance(embed_config, dict) else {}
|
||||
existing_model = str(model_config.get("model_name", "")).strip()
|
||||
if existing_model and existing_model != embed_model:
|
||||
should_recreate_collection = True
|
||||
break
|
||||
if not should_recreate_collection:
|
||||
return
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "404" not in message and "not found" not in message:
|
||||
raise
|
||||
|
||||
client = get_typesense_client()
|
||||
if should_recreate_collection:
|
||||
client.collections[_style_collection_name()].delete()
|
||||
|
||||
schema = {
|
||||
"name": _style_collection_name(),
|
||||
"fields": [
|
||||
{
|
||||
"name": "style_cluster_id",
|
||||
"type": "string",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "image_text_type",
|
||||
"type": "string",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"type": "int64",
|
||||
},
|
||||
{
|
||||
"name": "image",
|
||||
"type": "image",
|
||||
"store": False,
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": ["image"],
|
||||
"model_config": {
|
||||
"model_name": embed_model,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"default_sorting_field": "created_at",
|
||||
}
|
||||
client.collections.create(schema)
|
||||
|
||||
|
||||
def _search_style_neighbors(
|
||||
image_base64: str,
|
||||
limit: int,
|
||||
exclude_document_id: str | None = None,
|
||||
) -> list[HandwritingStyleNeighbor]:
|
||||
"""Returns nearest handwriting-style neighbors for one encoded image payload."""
|
||||
|
||||
ensure_handwriting_style_collection()
|
||||
client = get_typesense_client()
|
||||
|
||||
filter_clauses = [f"image_text_type:={IMAGE_TEXT_TYPE_HANDWRITING}"]
|
||||
if exclude_document_id:
|
||||
filter_clauses.append(f"id:!={exclude_document_id}")
|
||||
|
||||
search_payload = {
|
||||
"q": "*",
|
||||
"query_by": "embedding",
|
||||
"vector_query": f"embedding:([], image:{image_base64}, k:{max(1, limit)})",
|
||||
"exclude_fields": "embedding,image",
|
||||
"per_page": max(1, limit),
|
||||
"filter_by": " && ".join(filter_clauses),
|
||||
}
|
||||
response = client.multi_search.perform(
|
||||
{
|
||||
"searches": [
|
||||
{
|
||||
"collection": _style_collection_name(),
|
||||
**search_payload,
|
||||
}
|
||||
]
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
results = response.get("results", []) if isinstance(response, dict) else []
|
||||
first_result = results[0] if isinstance(results, list) and len(results) > 0 else {}
|
||||
hits = first_result.get("hits", []) if isinstance(first_result, dict) else []
|
||||
|
||||
neighbors: list[HandwritingStyleNeighbor] = []
|
||||
for hit in hits:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
document = hit.get("document")
|
||||
if not isinstance(document, dict):
|
||||
continue
|
||||
|
||||
document_id = str(document.get("id", "")).strip()
|
||||
style_cluster_id = str(document.get("style_cluster_id", "")).strip()
|
||||
if not document_id or not style_cluster_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
vector_distance = float(hit.get("vector_distance", 2.0))
|
||||
except (TypeError, ValueError):
|
||||
vector_distance = 2.0
|
||||
|
||||
neighbors.append(
|
||||
HandwritingStyleNeighbor(
|
||||
document_id=document_id,
|
||||
style_cluster_id=style_cluster_id,
|
||||
vector_distance=vector_distance,
|
||||
similarity=_distance_to_similarity(vector_distance),
|
||||
)
|
||||
)
|
||||
|
||||
if len(neighbors) >= limit:
|
||||
break
|
||||
|
||||
return neighbors
|
||||
|
||||
|
||||
def _next_style_cluster_id(session: Session) -> str:
|
||||
"""Allocates the next stable handwriting-style folder identifier."""
|
||||
|
||||
existing_ids = session.execute(
|
||||
select(Document.handwriting_style_id).where(Document.handwriting_style_id.is_not(None))
|
||||
).scalars().all()
|
||||
max_value = 0
|
||||
for existing_id in existing_ids:
|
||||
candidate = str(existing_id).strip()
|
||||
match = HANDWRITING_STYLE_ID_PATTERN.fullmatch(candidate)
|
||||
if not match:
|
||||
continue
|
||||
numeric_part = int(match.group(1))
|
||||
max_value = max(max_value, numeric_part)
|
||||
return f"{HANDWRITING_STYLE_ID_PREFIX}{max_value + 1}"
|
||||
|
||||
|
||||
def _style_cluster_sample_size(session: Session, style_cluster_id: str) -> int:
|
||||
"""Returns the number of indexed documents currently assigned to one style cluster."""
|
||||
|
||||
return int(
|
||||
session.execute(
|
||||
select(func.count())
|
||||
.select_from(Document)
|
||||
.where(Document.handwriting_style_id == style_cluster_id)
|
||||
.where(Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING)
|
||||
).scalar_one()
|
||||
)
|
||||
|
||||
|
||||
def assign_handwriting_style(
|
||||
session: Session,
|
||||
document: Document,
|
||||
image_data: bytes,
|
||||
) -> HandwritingStyleAssignment:
|
||||
"""Assigns a document to an existing handwriting-style cluster or creates a new one."""
|
||||
|
||||
runtime_settings = read_handwriting_style_settings()
|
||||
image_max_side = int(runtime_settings.get("image_max_side", HANDWRITING_STYLE_IMAGE_MAX_SIDE))
|
||||
neighbor_limit = int(runtime_settings.get("neighbor_limit", HANDWRITING_STYLE_NEIGHBOR_LIMIT))
|
||||
match_min_similarity = float(runtime_settings.get("match_min_similarity", HANDWRITING_STYLE_MATCH_MIN_SIMILARITY))
|
||||
bootstrap_match_min_similarity = float(
|
||||
runtime_settings.get("bootstrap_match_min_similarity", HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY)
|
||||
)
|
||||
bootstrap_sample_size = int(runtime_settings.get("bootstrap_sample_size", HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE))
|
||||
|
||||
image_base64 = _encode_style_image_base64(image_data, image_max_side=image_max_side)
|
||||
neighbors = _search_style_neighbors(
|
||||
image_base64=image_base64,
|
||||
limit=neighbor_limit,
|
||||
exclude_document_id=str(document.id),
|
||||
)
|
||||
|
||||
best_neighbor = neighbors[0] if neighbors else None
|
||||
similarity = best_neighbor.similarity if best_neighbor else 0.0
|
||||
vector_distance = best_neighbor.vector_distance if best_neighbor else 2.0
|
||||
cluster_sample_size = 0
|
||||
if best_neighbor:
|
||||
cluster_sample_size = _style_cluster_sample_size(
|
||||
session=session,
|
||||
style_cluster_id=best_neighbor.style_cluster_id,
|
||||
)
|
||||
required_similarity = (
|
||||
bootstrap_match_min_similarity
|
||||
if cluster_sample_size < bootstrap_sample_size
|
||||
else match_min_similarity
|
||||
)
|
||||
should_match_existing = (
|
||||
best_neighbor is not None and similarity >= required_similarity
|
||||
)
|
||||
|
||||
if should_match_existing and best_neighbor:
|
||||
style_cluster_id = best_neighbor.style_cluster_id
|
||||
matched_existing = True
|
||||
else:
|
||||
existing_style_cluster_id = (document.handwriting_style_id or "").strip()
|
||||
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(existing_style_cluster_id):
|
||||
style_cluster_id = existing_style_cluster_id
|
||||
else:
|
||||
style_cluster_id = _next_style_cluster_id(session=session)
|
||||
matched_existing = False
|
||||
|
||||
ensure_handwriting_style_collection()
|
||||
collection = _style_collection()
|
||||
payload = {
|
||||
"id": str(document.id),
|
||||
"style_cluster_id": style_cluster_id,
|
||||
"image_text_type": IMAGE_TEXT_TYPE_HANDWRITING,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"image": image_base64,
|
||||
}
|
||||
collection.documents.upsert(payload)
|
||||
|
||||
return HandwritingStyleAssignment(
|
||||
style_cluster_id=style_cluster_id,
|
||||
matched_existing=matched_existing,
|
||||
similarity=similarity,
|
||||
vector_distance=vector_distance,
|
||||
compared_neighbors=len(neighbors),
|
||||
match_min_similarity=match_min_similarity,
|
||||
bootstrap_match_min_similarity=bootstrap_match_min_similarity,
|
||||
)
|
||||
|
||||
|
||||
def delete_handwriting_style_document(document_id: str) -> None:
|
||||
"""Deletes one document id from the handwriting-style Typesense collection."""
|
||||
|
||||
collection = _style_collection()
|
||||
try:
|
||||
collection.documents[document_id].delete()
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "404" in message or "not found" in message:
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
def delete_many_handwriting_style_documents(document_ids: list[str]) -> None:
|
||||
"""Deletes many document ids from the handwriting-style Typesense collection."""
|
||||
|
||||
for document_id in document_ids:
|
||||
delete_handwriting_style_document(document_id)
|
||||
|
||||
|
||||
def apply_handwriting_style_path(style_cluster_id: str | None, path_value: str | None) -> str | None:
|
||||
"""Composes style-prefixed logical paths while preventing duplicate prefix nesting."""
|
||||
|
||||
if path_value is None:
|
||||
return None
|
||||
|
||||
normalized_path = path_value.strip().strip("/")
|
||||
if not normalized_path:
|
||||
return None
|
||||
|
||||
normalized_style = (style_cluster_id or "").strip().strip("/")
|
||||
if not normalized_style:
|
||||
return normalized_path
|
||||
|
||||
segments = [segment for segment in normalized_path.split("/") if segment]
|
||||
while segments and HANDWRITING_STYLE_ID_PATTERN.fullmatch(segments[0]):
|
||||
segments.pop(0)
|
||||
if segments and segments[0].strip().lower() == normalized_style.lower():
|
||||
segments.pop(0)
|
||||
|
||||
if len(segments) == 0:
|
||||
return normalized_style
|
||||
|
||||
sanitized_path = "/".join(segments)
|
||||
return f"{normalized_style}/{sanitized_path}"
|
||||
|
||||
|
||||
def resolve_handwriting_style_path_prefix(
|
||||
session: Session,
|
||||
style_cluster_id: str | None,
|
||||
*,
|
||||
exclude_document_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Resolves a stable path prefix for one style cluster, preferring known non-style root segments."""
|
||||
|
||||
normalized_style = (style_cluster_id or "").strip()
|
||||
if not normalized_style:
|
||||
return None
|
||||
|
||||
statement = select(Document.logical_path).where(
|
||||
Document.handwriting_style_id == normalized_style,
|
||||
Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING,
|
||||
Document.status != DocumentStatus.TRASHED,
|
||||
)
|
||||
if exclude_document_id:
|
||||
statement = statement.where(Document.id != exclude_document_id)
|
||||
rows = session.execute(statement).scalars().all()
|
||||
|
||||
segment_counts: dict[str, int] = {}
|
||||
segment_labels: dict[str, str] = {}
|
||||
for raw_path in rows:
|
||||
if not isinstance(raw_path, str):
|
||||
continue
|
||||
segments = [segment.strip() for segment in raw_path.split("/") if segment.strip()]
|
||||
if not segments:
|
||||
continue
|
||||
first_segment = segments[0]
|
||||
lowered = first_segment.lower()
|
||||
if lowered == "inbox":
|
||||
continue
|
||||
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(first_segment):
|
||||
continue
|
||||
segment_counts[lowered] = segment_counts.get(lowered, 0) + 1
|
||||
if lowered not in segment_labels:
|
||||
segment_labels[lowered] = first_segment
|
||||
|
||||
if not segment_counts:
|
||||
return normalized_style
|
||||
|
||||
winner = sorted(
|
||||
segment_counts.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)[0][0]
|
||||
return segment_labels.get(winner, normalized_style)
|
||||
227
backend/app/services/model_runtime.py
Normal file
227
backend/app/services/model_runtime.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Model runtime utilities for provider-bound LLM task execution."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
||||
|
||||
from app.services.app_settings import read_task_runtime_settings
|
||||
|
||||
|
||||
class ModelTaskError(Exception):
|
||||
"""Raised when a model task request fails."""
|
||||
|
||||
|
||||
class ModelTaskTimeoutError(ModelTaskError):
|
||||
"""Raised when a model task request times out."""
|
||||
|
||||
|
||||
class ModelTaskDisabledError(ModelTaskError):
|
||||
"""Raised when a model task is disabled in settings."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTaskRuntime:
|
||||
"""Resolved runtime configuration for one task and provider."""
|
||||
|
||||
task_name: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
base_url: str
|
||||
timeout_seconds: int
|
||||
api_key: str
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
|
||||
def _normalize_base_url(raw_value: str) -> str:
|
||||
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
|
||||
|
||||
trimmed = raw_value.strip().rstrip("/")
|
||||
if not trimmed:
|
||||
return "https://api.openai.com/v1"
|
||||
|
||||
parsed = urlparse(trimmed)
|
||||
path = parsed.path or ""
|
||||
if not path.endswith("/v1"):
|
||||
path = f"{path}/v1" if path else "/v1"
|
||||
|
||||
return urlunparse(parsed._replace(path=path))
|
||||
|
||||
|
||||
def _should_fallback_to_chat(error: Exception) -> bool:
|
||||
"""Determines whether a responses API failure should fallback to chat completions."""
|
||||
|
||||
status_code = getattr(error, "status_code", None)
|
||||
if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}:
|
||||
return True
|
||||
|
||||
message = str(error).lower()
|
||||
fallback_markers = (
|
||||
"404",
|
||||
"not found",
|
||||
"unknown endpoint",
|
||||
"unsupported",
|
||||
"invalid url",
|
||||
"responses",
|
||||
)
|
||||
return any(marker in message for marker in fallback_markers)
|
||||
|
||||
|
||||
def _extract_text_from_response(response: Any) -> str:
|
||||
"""Extracts plain text from Responses API outputs."""
|
||||
|
||||
output_text = getattr(response, "output_text", None)
|
||||
if isinstance(output_text, str) and output_text.strip():
|
||||
return output_text.strip()
|
||||
|
||||
output_items = getattr(response, "output", None)
|
||||
if not isinstance(output_items, list):
|
||||
return ""
|
||||
|
||||
chunks: list[str] = []
|
||||
for item in output_items:
|
||||
item_data = item.model_dump() if hasattr(item, "model_dump") else item
|
||||
if not isinstance(item_data, dict):
|
||||
continue
|
||||
|
||||
item_type = item_data.get("type")
|
||||
if item_type == "output_text":
|
||||
text = str(item_data.get("text", "")).strip()
|
||||
if text:
|
||||
chunks.append(text)
|
||||
|
||||
if item_type == "message":
|
||||
for content in item_data.get("content", []) or []:
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
if content.get("type") in {"output_text", "text"}:
|
||||
text = str(content.get("text", "")).strip()
|
||||
if text:
|
||||
chunks.append(text)
|
||||
|
||||
return "\n".join(chunks).strip()
|
||||
|
||||
|
||||
def _extract_text_from_chat_response(response: Any) -> str:
|
||||
"""Extracts text from Chat Completions API outputs."""
|
||||
|
||||
message_content = response.choices[0].message.content
|
||||
if isinstance(message_content, str):
|
||||
return message_content.strip()
|
||||
if not isinstance(message_content, list):
|
||||
return ""
|
||||
|
||||
chunks: list[str] = []
|
||||
for content in message_content:
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
text = str(content.get("text", "")).strip()
|
||||
if text:
|
||||
chunks.append(text)
|
||||
return "\n".join(chunks).strip()
|
||||
|
||||
|
||||
def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
|
||||
"""Resolves one task runtime including provider endpoint, model, and prompt."""
|
||||
|
||||
runtime_payload = read_task_runtime_settings(task_name)
|
||||
task_payload = runtime_payload["task"]
|
||||
provider_payload = runtime_payload["provider"]
|
||||
|
||||
if not bool(task_payload.get("enabled", True)):
|
||||
raise ModelTaskDisabledError(f"task_disabled:{task_name}")
|
||||
|
||||
provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip()
|
||||
if provider_type != "openai_compatible":
|
||||
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
|
||||
|
||||
return ModelTaskRuntime(
|
||||
task_name=task_name,
|
||||
provider_id=str(provider_payload.get("id", "")),
|
||||
provider_type=provider_type,
|
||||
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
|
||||
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
|
||||
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
|
||||
model=str(task_payload.get("model", "")).strip(),
|
||||
prompt=str(task_payload.get("prompt", "")).strip(),
|
||||
)
|
||||
|
||||
|
||||
def _create_client(runtime: ModelTaskRuntime) -> OpenAI:
|
||||
"""Builds an OpenAI SDK client for OpenAI-compatible provider endpoints."""
|
||||
|
||||
return OpenAI(
|
||||
api_key=runtime.api_key,
|
||||
base_url=runtime.base_url,
|
||||
timeout=runtime.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str:
|
||||
"""Runs a text-only task against the configured provider and returns plain output text."""
|
||||
|
||||
runtime = resolve_task_runtime(task_name)
|
||||
client = _create_client(runtime)
|
||||
prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt
|
||||
|
||||
try:
|
||||
response = client.responses.create(
|
||||
model=runtime.model,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": prompt,
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": user_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
text = _extract_text_from_response(response)
|
||||
if text:
|
||||
return text
|
||||
except APITimeoutError as error:
|
||||
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
|
||||
except APIConnectionError as error:
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
except APIError as error:
|
||||
if not _should_fallback_to_chat(error):
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
except Exception as error:
|
||||
if not _should_fallback_to_chat(error):
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
|
||||
try:
|
||||
fallback = client.chat.completions.create(
|
||||
model=runtime.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_text,
|
||||
},
|
||||
],
|
||||
)
|
||||
return _extract_text_from_chat_response(fallback)
|
||||
except APITimeoutError as error:
|
||||
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
|
||||
except (APIConnectionError, APIError) as error:
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
except Exception as error:
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
192
backend/app/services/processing_logs.py
Normal file
192
backend/app/services/processing_logs.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Persistence helpers for writing and querying processing pipeline log events."""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.document import Document
|
||||
from app.models.processing_log import ProcessingLogEntry
|
||||
|
||||
|
||||
MAX_STAGE_LENGTH = 64
|
||||
MAX_EVENT_LENGTH = 256
|
||||
MAX_LEVEL_LENGTH = 16
|
||||
MAX_PROVIDER_LENGTH = 128
|
||||
MAX_MODEL_LENGTH = 256
|
||||
MAX_DOCUMENT_FILENAME_LENGTH = 512
|
||||
MAX_PROMPT_LENGTH = 200000
|
||||
MAX_RESPONSE_LENGTH = 200000
|
||||
DEFAULT_KEEP_DOCUMENT_SESSIONS = 2
|
||||
DEFAULT_KEEP_UNBOUND_ENTRIES = 80
|
||||
PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY = "processing_log_autocommit"
|
||||
|
||||
|
||||
def _trim(value: str | None, max_length: int) -> str | None:
|
||||
"""Normalizes and truncates text values for safe log persistence."""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) <= max_length:
|
||||
return normalized
|
||||
return normalized[: max_length - 3] + "..."
|
||||
|
||||
|
||||
def _safe_payload(payload_json: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Ensures payload values are persisted as dictionaries."""
|
||||
|
||||
return payload_json if isinstance(payload_json, dict) else {}
|
||||
|
||||
|
||||
def set_processing_log_autocommit(session: Session, enabled: bool) -> None:
|
||||
"""Toggles per-session immediate commit behavior for processing log events."""
|
||||
|
||||
session.info[PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY] = bool(enabled)
|
||||
|
||||
|
||||
def is_processing_log_autocommit_enabled(session: Session) -> bool:
|
||||
"""Returns whether processing logs are committed immediately for the current session."""
|
||||
|
||||
return bool(session.info.get(PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY, False))
|
||||
|
||||
|
||||
def log_processing_event(
|
||||
session: Session,
|
||||
stage: str,
|
||||
event: str,
|
||||
*,
|
||||
level: str = "info",
|
||||
document: Document | None = None,
|
||||
document_id: UUID | None = None,
|
||||
document_filename: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
prompt_text: str | None = None,
|
||||
response_text: str | None = None,
|
||||
payload_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Persists one processing log entry linked to an optional document context."""
|
||||
|
||||
resolved_document_id = document.id if document is not None else document_id
|
||||
resolved_document_filename = document.original_filename if document is not None else document_filename
|
||||
|
||||
entry = ProcessingLogEntry(
|
||||
level=_trim(level, MAX_LEVEL_LENGTH) or "info",
|
||||
stage=_trim(stage, MAX_STAGE_LENGTH) or "pipeline",
|
||||
event=_trim(event, MAX_EVENT_LENGTH) or "event",
|
||||
document_id=resolved_document_id,
|
||||
document_filename=_trim(resolved_document_filename, MAX_DOCUMENT_FILENAME_LENGTH),
|
||||
provider_id=_trim(provider_id, MAX_PROVIDER_LENGTH),
|
||||
model_name=_trim(model_name, MAX_MODEL_LENGTH),
|
||||
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH),
|
||||
response_text=_trim(response_text, MAX_RESPONSE_LENGTH),
|
||||
payload_json=_safe_payload(payload_json),
|
||||
)
|
||||
session.add(entry)
|
||||
if is_processing_log_autocommit_enabled(session):
|
||||
session.commit()
|
||||
|
||||
|
||||
def count_processing_logs(session: Session, document_id: UUID | None = None) -> int:
|
||||
"""Counts persisted processing logs, optionally restricted to one document."""
|
||||
|
||||
statement = select(func.count()).select_from(ProcessingLogEntry)
|
||||
if document_id is not None:
|
||||
statement = statement.where(ProcessingLogEntry.document_id == document_id)
|
||||
return int(session.execute(statement).scalar_one())
|
||||
|
||||
|
||||
def list_processing_logs(
|
||||
session: Session,
|
||||
*,
|
||||
limit: int,
|
||||
offset: int,
|
||||
document_id: UUID | None = None,
|
||||
) -> list[ProcessingLogEntry]:
|
||||
"""Lists processing logs ordered by newest-first with optional document filter."""
|
||||
|
||||
statement = select(ProcessingLogEntry)
|
||||
if document_id is not None:
|
||||
statement = statement.where(ProcessingLogEntry.document_id == document_id)
|
||||
statement = statement.order_by(ProcessingLogEntry.created_at.desc(), ProcessingLogEntry.id.desc()).offset(offset).limit(limit)
|
||||
return session.execute(statement).scalars().all()
|
||||
|
||||
|
||||
def cleanup_processing_logs(
|
||||
session: Session,
|
||||
*,
|
||||
keep_document_sessions: int = DEFAULT_KEEP_DOCUMENT_SESSIONS,
|
||||
keep_unbound_entries: int = DEFAULT_KEEP_UNBOUND_ENTRIES,
|
||||
) -> dict[str, int]:
|
||||
"""Deletes old log entries while keeping recent document sessions and unbound events."""
|
||||
|
||||
normalized_keep_sessions = max(0, keep_document_sessions)
|
||||
normalized_keep_unbound = max(0, keep_unbound_entries)
|
||||
deleted_document_entries = 0
|
||||
deleted_unbound_entries = 0
|
||||
|
||||
recent_document_rows = session.execute(
|
||||
select(
|
||||
ProcessingLogEntry.document_id,
|
||||
func.max(ProcessingLogEntry.created_at).label("last_seen"),
|
||||
)
|
||||
.where(ProcessingLogEntry.document_id.is_not(None))
|
||||
.group_by(ProcessingLogEntry.document_id)
|
||||
.order_by(func.max(ProcessingLogEntry.created_at).desc())
|
||||
.limit(normalized_keep_sessions)
|
||||
).all()
|
||||
keep_document_ids = [row[0] for row in recent_document_rows if row[0] is not None]
|
||||
|
||||
if keep_document_ids:
|
||||
deleted_document_entries = int(
|
||||
session.execute(
|
||||
delete(ProcessingLogEntry).where(
|
||||
ProcessingLogEntry.document_id.is_not(None),
|
||||
ProcessingLogEntry.document_id.notin_(keep_document_ids),
|
||||
)
|
||||
).rowcount
|
||||
or 0
|
||||
)
|
||||
else:
|
||||
deleted_document_entries = int(
|
||||
session.execute(delete(ProcessingLogEntry).where(ProcessingLogEntry.document_id.is_not(None))).rowcount or 0
|
||||
)
|
||||
|
||||
keep_unbound_rows = session.execute(
|
||||
select(ProcessingLogEntry.id)
|
||||
.where(ProcessingLogEntry.document_id.is_(None))
|
||||
.order_by(ProcessingLogEntry.created_at.desc(), ProcessingLogEntry.id.desc())
|
||||
.limit(normalized_keep_unbound)
|
||||
).all()
|
||||
keep_unbound_ids = [row[0] for row in keep_unbound_rows]
|
||||
|
||||
if keep_unbound_ids:
|
||||
deleted_unbound_entries = int(
|
||||
session.execute(
|
||||
delete(ProcessingLogEntry).where(
|
||||
ProcessingLogEntry.document_id.is_(None),
|
||||
ProcessingLogEntry.id.notin_(keep_unbound_ids),
|
||||
)
|
||||
).rowcount
|
||||
or 0
|
||||
)
|
||||
else:
|
||||
deleted_unbound_entries = int(
|
||||
session.execute(delete(ProcessingLogEntry).where(ProcessingLogEntry.document_id.is_(None))).rowcount or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"deleted_document_entries": deleted_document_entries,
|
||||
"deleted_unbound_entries": deleted_unbound_entries,
|
||||
}
|
||||
|
||||
|
||||
def clear_processing_logs(session: Session) -> dict[str, int]:
|
||||
"""Deletes all persisted processing log entries and returns deletion count."""
|
||||
|
||||
deleted_entries = int(session.execute(delete(ProcessingLogEntry)).rowcount or 0)
|
||||
return {"deleted_entries": deleted_entries}
|
||||
1129
backend/app/services/routing_pipeline.py
Normal file
1129
backend/app/services/routing_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
59
backend/app/services/storage.py
Normal file
59
backend/app/services/storage.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""File storage utilities for persistence, retrieval, and checksum calculation."""
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def ensure_storage() -> None:
|
||||
"""Ensures required storage directories exist at service startup."""
|
||||
|
||||
for relative in ["originals", "derived/previews", "tmp"]:
|
||||
(settings.storage_root / relative).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def compute_sha256(data: bytes) -> str:
|
||||
"""Computes a SHA-256 hex digest for raw file bytes."""
|
||||
|
||||
return hashlib.sha256(data).hexdigest()
|
||||
|
||||
|
||||
def store_bytes(filename: str, data: bytes) -> str:
|
||||
"""Stores file content under a unique path and returns its storage-relative location."""
|
||||
|
||||
stamp = datetime.now(UTC).strftime("%Y/%m/%d")
|
||||
safe_ext = Path(filename).suffix.lower()
|
||||
target_dir = settings.storage_root / "originals" / stamp
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_name = f"{uuid.uuid4()}{safe_ext}"
|
||||
target_path = target_dir / target_name
|
||||
target_path.write_bytes(data)
|
||||
return str(target_path.relative_to(settings.storage_root))
|
||||
|
||||
|
||||
def read_bytes(relative_path: str) -> bytes:
|
||||
"""Reads and returns bytes from a storage-relative path."""
|
||||
|
||||
return (settings.storage_root / relative_path).read_bytes()
|
||||
|
||||
|
||||
def absolute_path(relative_path: str) -> Path:
|
||||
"""Returns the absolute filesystem path for a storage-relative location."""
|
||||
|
||||
return settings.storage_root / relative_path
|
||||
|
||||
|
||||
def write_preview(document_id: str, data: bytes, suffix: str = ".jpg") -> str:
|
||||
"""Writes preview bytes and returns the preview path relative to storage root."""
|
||||
|
||||
target_dir = settings.storage_root / "derived" / "previews"
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
target_path = target_dir / f"{document_id}{suffix}"
|
||||
target_path.write_bytes(data)
|
||||
return str(target_path.relative_to(settings.storage_root))
|
||||
257
backend/app/services/typesense_index.py
Normal file
257
backend/app/services/typesense_index.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Typesense indexing and semantic-neighbor retrieval for document routing."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import typesense
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.document import Document, DocumentStatus
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
MAX_TYPESENSE_QUERY_CHARS = 600
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimilarDocument:
|
||||
"""Represents one nearest-neighbor document returned by Typesense semantic search."""
|
||||
|
||||
document_id: str
|
||||
document_name: str
|
||||
summary_text: str
|
||||
logical_path: str
|
||||
tags: list[str]
|
||||
vector_distance: float
|
||||
|
||||
|
||||
def _build_client() -> typesense.Client:
|
||||
"""Builds a Typesense API client using configured host and credentials."""
|
||||
|
||||
return typesense.Client(
|
||||
{
|
||||
"nodes": [
|
||||
{
|
||||
"host": settings.typesense_host,
|
||||
"port": str(settings.typesense_port),
|
||||
"protocol": settings.typesense_protocol,
|
||||
}
|
||||
],
|
||||
"api_key": settings.typesense_api_key,
|
||||
"connection_timeout_seconds": settings.typesense_timeout_seconds,
|
||||
"num_retries": settings.typesense_num_retries,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
_client: typesense.Client | None = None
|
||||
|
||||
|
||||
def get_typesense_client() -> typesense.Client:
|
||||
"""Returns a cached Typesense client for repeated indexing and search operations."""
|
||||
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = _build_client()
|
||||
return _client
|
||||
|
||||
|
||||
def _collection() -> Any:
|
||||
"""Returns the configured Typesense collection handle."""
|
||||
|
||||
client = get_typesense_client()
|
||||
return client.collections[settings.typesense_collection_name]
|
||||
|
||||
|
||||
def ensure_typesense_collection() -> None:
|
||||
"""Creates the document semantic collection when it does not already exist."""
|
||||
|
||||
collection = _collection()
|
||||
try:
|
||||
collection.retrieve()
|
||||
return
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "404" not in message and "not found" not in message:
|
||||
raise
|
||||
|
||||
schema = {
|
||||
"name": settings.typesense_collection_name,
|
||||
"fields": [
|
||||
{
|
||||
"name": "document_name",
|
||||
"type": "string",
|
||||
},
|
||||
{
|
||||
"name": "summary_text",
|
||||
"type": "string",
|
||||
},
|
||||
{
|
||||
"name": "logical_path",
|
||||
"type": "string",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "tags",
|
||||
"type": "string[]",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "status",
|
||||
"type": "string",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "mime_type",
|
||||
"type": "string",
|
||||
"optional": True,
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "extension",
|
||||
"type": "string",
|
||||
"optional": True,
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"type": "int64",
|
||||
},
|
||||
{
|
||||
"name": "has_labels",
|
||||
"type": "bool",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": [
|
||||
"document_name",
|
||||
"summary_text",
|
||||
],
|
||||
"model_config": {
|
||||
"model_name": "ts/e5-small-v2",
|
||||
"indexing_prefix": "passage:",
|
||||
"query_prefix": "query:",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"default_sorting_field": "created_at",
|
||||
}
|
||||
client = get_typesense_client()
|
||||
client.collections.create(schema)
|
||||
|
||||
|
||||
def _has_labels(document: Document) -> bool:
|
||||
"""Determines whether a document has usable human-assigned routing metadata."""
|
||||
|
||||
if document.logical_path.strip() and document.logical_path.strip().lower() != "inbox":
|
||||
return True
|
||||
return len([tag for tag in document.tags if tag.strip()]) > 0
|
||||
|
||||
|
||||
def upsert_document_index(document: Document, summary_text: str) -> None:
|
||||
"""Upserts one document into Typesense for semantic retrieval and routing examples."""
|
||||
|
||||
ensure_typesense_collection()
|
||||
collection = _collection()
|
||||
payload = {
|
||||
"id": str(document.id),
|
||||
"document_name": document.original_filename,
|
||||
"summary_text": summary_text[:50000],
|
||||
"logical_path": document.logical_path,
|
||||
"tags": [tag for tag in document.tags if tag.strip()][:50],
|
||||
"status": document.status.value,
|
||||
"mime_type": document.mime_type,
|
||||
"extension": document.extension,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"has_labels": _has_labels(document) and document.status != DocumentStatus.TRASHED,
|
||||
}
|
||||
collection.documents.upsert(payload)
|
||||
|
||||
|
||||
def delete_document_index(document_id: str) -> None:
|
||||
"""Deletes one document from Typesense by identifier."""
|
||||
|
||||
collection = _collection()
|
||||
try:
|
||||
collection.documents[document_id].delete()
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "404" in message or "not found" in message:
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
def delete_many_documents_index(document_ids: list[str]) -> None:
|
||||
"""Deletes many documents from Typesense by identifiers."""
|
||||
|
||||
for document_id in document_ids:
|
||||
delete_document_index(document_id)
|
||||
|
||||
|
||||
def query_similar_documents(summary_text: str, limit: int, exclude_document_id: str | None = None) -> list[SimilarDocument]:
|
||||
"""Returns semantic nearest neighbors among labeled non-trashed indexed documents."""
|
||||
|
||||
ensure_typesense_collection()
|
||||
collection = _collection()
|
||||
normalized_query = " ".join(summary_text.strip().split())
|
||||
query_text = normalized_query[:MAX_TYPESENSE_QUERY_CHARS] if normalized_query else "document"
|
||||
search_payload = {
|
||||
"q": query_text,
|
||||
"query_by": "embedding",
|
||||
"vector_query": f"embedding:([], k:{max(1, limit)})",
|
||||
"exclude_fields": "embedding",
|
||||
"per_page": max(1, limit),
|
||||
"filter_by": "has_labels:=true && status:!=trashed",
|
||||
}
|
||||
|
||||
try:
|
||||
response = collection.documents.search(search_payload)
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "query string exceeds max allowed length" not in message:
|
||||
raise
|
||||
fallback_payload = dict(search_payload)
|
||||
fallback_payload["q"] = "document"
|
||||
response = collection.documents.search(fallback_payload)
|
||||
hits = response.get("hits", []) if isinstance(response, dict) else []
|
||||
|
||||
neighbors: list[SimilarDocument] = []
|
||||
for hit in hits:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
document = hit.get("document", {})
|
||||
if not isinstance(document, dict):
|
||||
continue
|
||||
|
||||
document_id = str(document.get("id", "")).strip()
|
||||
if not document_id:
|
||||
continue
|
||||
if exclude_document_id and document_id == exclude_document_id:
|
||||
continue
|
||||
|
||||
raw_tags = document.get("tags", [])
|
||||
tags = [str(tag).strip() for tag in raw_tags if str(tag).strip()] if isinstance(raw_tags, list) else []
|
||||
try:
|
||||
distance = float(hit.get("vector_distance", 2.0))
|
||||
except (TypeError, ValueError):
|
||||
distance = 2.0
|
||||
|
||||
neighbors.append(
|
||||
SimilarDocument(
|
||||
document_id=document_id,
|
||||
document_name=str(document.get("document_name", "")).strip(),
|
||||
summary_text=str(document.get("summary_text", "")).strip(),
|
||||
logical_path=str(document.get("logical_path", "")).strip(),
|
||||
tags=tags,
|
||||
vector_distance=distance,
|
||||
)
|
||||
)
|
||||
|
||||
if len(neighbors) >= limit:
|
||||
break
|
||||
|
||||
return neighbors
|
||||
Reference in New Issue
Block a user