Initial commit

This commit is contained in:
2026-02-21 09:44:18 -03:00
commit 5dfc2cbd85
65 changed files with 11989 additions and 0 deletions

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Backend application package for the DMS service."""

View File

@@ -0,0 +1 @@
"""API package containing route modules and router registration."""

17
backend/app/api/router.py Normal file
View File

@@ -0,0 +1,17 @@
"""API router registration for all HTTP route modules."""
from fastapi import APIRouter
from app.api.routes_documents import router as documents_router
from app.api.routes_health import router as health_router
from app.api.routes_processing_logs import router as processing_logs_router
from app.api.routes_search import router as search_router
from app.api.routes_settings import router as settings_router
api_router = APIRouter()
api_router.include_router(health_router)
api_router.include_router(documents_router, prefix="/documents", tags=["documents"])
api_router.include_router(processing_logs_router, prefix="/processing/logs", tags=["processing-logs"])
api_router.include_router(search_router, prefix="/search", tags=["search"])
api_router.include_router(settings_router, prefix="/settings", tags=["settings"])

View File

@@ -0,0 +1,725 @@
"""Document CRUD, lifecycle, metadata, file access, and content export endpoints."""
import io
import re
import unicodedata
import zipfile
from datetime import datetime, time
from pathlib import Path
from typing import Annotated, Literal
from uuid import UUID
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse, Response, StreamingResponse
from sqlalchemy import or_, func, select
from sqlalchemy.orm import Session
from app.services.app_settings import read_predefined_paths_settings, read_predefined_tags_settings
from app.db.base import get_session
from app.models.document import Document, DocumentStatus
from app.schemas.documents import (
ContentExportRequest,
DocumentDetailResponse,
DocumentResponse,
DocumentsListResponse,
DocumentUpdateRequest,
UploadConflict,
UploadResponse,
)
from app.services.extractor import sniff_mime
from app.services.handwriting_style import delete_many_handwriting_style_documents
from app.services.processing_logs import log_processing_event, set_processing_log_autocommit
from app.services.storage import absolute_path, compute_sha256, store_bytes
from app.services.typesense_index import delete_many_documents_index, upsert_document_index
from app.worker.queue import get_processing_queue
router = APIRouter()
def _parse_csv(value: str | None) -> list[str]:
"""Parses comma-separated query values into a normalized non-empty list."""
if not value:
return []
return [part.strip() for part in value.split(",") if part.strip()]
def _parse_date(value: str | None) -> datetime | None:
"""Parses ISO date strings into UTC-naive midnight datetimes."""
if not value:
return None
try:
parsed = datetime.fromisoformat(value)
return parsed
except ValueError:
pass
try:
date_value = datetime.strptime(value, "%Y-%m-%d").date()
return datetime.combine(date_value, time.min)
except ValueError:
return None
def _apply_discovery_filters(
statement,
*,
path_filter: str | None,
tag_filter: str | None,
type_filter: str | None,
processed_from: str | None,
processed_to: str | None,
):
"""Applies optional path/tag/type/date filters to list and search statements."""
if path_filter and path_filter.strip():
statement = statement.where(Document.logical_path.ilike(f"{path_filter.strip()}%"))
tags = _parse_csv(tag_filter)
if tags:
statement = statement.where(Document.tags.overlap(tags))
types = _parse_csv(type_filter)
if types:
type_clauses = []
for value in types:
lowered = value.lower()
type_clauses.append(Document.extension.ilike(lowered))
type_clauses.append(Document.mime_type.ilike(lowered))
type_clauses.append(Document.image_text_type.ilike(lowered))
statement = statement.where(or_(*type_clauses))
processed_from_dt = _parse_date(processed_from)
if processed_from_dt is not None:
statement = statement.where(Document.processed_at.is_not(None), Document.processed_at >= processed_from_dt)
processed_to_dt = _parse_date(processed_to)
if processed_to_dt is not None:
statement = statement.where(Document.processed_at.is_not(None), Document.processed_at <= processed_to_dt)
return statement
def _summary_for_index(document: Document) -> str:
"""Resolves best-available summary text for semantic index updates outside worker pipeline."""
candidate = document.metadata_json.get("summary_text")
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
extracted = document.extracted_text.strip()
if extracted:
return extracted[:12000]
return f"{document.original_filename}\n{document.mime_type}\n{document.logical_path}"
def _normalize_tags(raw_tags: str | None) -> list[str]:
"""Parses comma-separated tags into a cleaned unique list."""
if not raw_tags:
return []
tags = [tag.strip() for tag in raw_tags.split(",") if tag.strip()]
return list(dict.fromkeys(tags))[:50]
def _sanitize_filename(filename: str) -> str:
"""Normalizes user-supplied filenames while preserving readability and extensions."""
base = filename.strip().replace("\\", " ").replace("/", " ")
base = re.sub(r"\s+", " ", base)
return base[:512] or "document"
def _slugify_segment(value: str) -> str:
"""Creates a filesystem-safe slug for path segments and markdown file names."""
normalized = unicodedata.normalize("NFKD", value)
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
cleaned = re.sub(r"[^a-zA-Z0-9._ -]+", "", ascii_text).strip()
compact = re.sub(r"\s+", "-", cleaned)
compact = compact.strip(".-_")
return compact[:120] or "document"
def _markdown_for_document(document: Document) -> str:
"""Builds a markdown representation of extracted document content and metadata."""
lines = [
f"# {document.original_filename}",
"",
f"- Document ID: `{document.id}`",
f"- Logical Path: `{document.logical_path}`",
f"- Source Path: `{document.source_relative_path}`",
f"- Tags: {', '.join(document.tags) if document.tags else '(none)' }",
"",
"## Extracted Content",
"",
]
if document.extracted_text.strip():
lines.append(document.extracted_text)
else:
lines.append("_No extracted text available for this document._")
return "\n".join(lines).strip() + "\n"
def _markdown_filename(document: Document) -> str:
"""Builds a deterministic markdown filename for a single document export."""
stem = Path(document.original_filename).stem or document.original_filename
slug = _slugify_segment(stem)
return f"{slug}-{str(document.id)[:8]}.md"
def _zip_entry_name(document: Document, used_names: set[str]) -> str:
"""Builds a unique zip entry path for a document markdown export."""
path_segments = [segment for segment in document.logical_path.split("/") if segment]
sanitized_segments = [_slugify_segment(segment) for segment in path_segments]
filename = _markdown_filename(document)
base_entry = "/".join([*sanitized_segments, filename]) if sanitized_segments else filename
entry = base_entry
suffix = 1
while entry in used_names:
stem = Path(filename).stem
ext = Path(filename).suffix
candidate = f"{stem}-{suffix}{ext}"
entry = "/".join([*sanitized_segments, candidate]) if sanitized_segments else candidate
suffix += 1
used_names.add(entry)
return entry
def _resolve_previous_status(metadata_json: dict, fallback_status: DocumentStatus) -> DocumentStatus:
"""Resolves the status to restore from trash using recorded metadata."""
raw_status = metadata_json.get("status_before_trash")
if isinstance(raw_status, str):
try:
parsed = DocumentStatus(raw_status)
if parsed != DocumentStatus.TRASHED:
return parsed
except ValueError:
pass
return fallback_status
def _build_document_list_statement(
only_trashed: bool,
include_trashed: bool,
path_prefix: str | None,
):
"""Builds a base SQLAlchemy select statement with lifecycle and path filters."""
statement = select(Document)
if only_trashed:
statement = statement.where(Document.status == DocumentStatus.TRASHED)
elif not include_trashed:
statement = statement.where(Document.status != DocumentStatus.TRASHED)
if path_prefix:
trimmed_prefix = path_prefix.strip()
if trimmed_prefix:
statement = statement.where(Document.logical_path.ilike(f"{trimmed_prefix}%"))
return statement
def _collect_document_tree(session: Session, root_document_id: UUID) -> list[tuple[int, Document]]:
"""Collects a document and all descendants for recursive permanent deletion."""
queue: list[tuple[UUID, int]] = [(root_document_id, 0)]
visited: set[UUID] = set()
collected: list[tuple[int, Document]] = []
while queue:
current_id, depth = queue.pop(0)
if current_id in visited:
continue
visited.add(current_id)
document = session.execute(select(Document).where(Document.id == current_id)).scalar_one_or_none()
if document is None:
continue
collected.append((depth, document))
child_ids = session.execute(
select(Document.id).where(Document.parent_document_id == current_id)
).scalars().all()
for child_id in child_ids:
queue.append((child_id, depth + 1))
collected.sort(key=lambda item: item[0], reverse=True)
return collected
@router.get("", response_model=DocumentsListResponse)
def list_documents(
offset: int = Query(default=0, ge=0),
limit: int = Query(default=50, ge=1, le=200),
include_trashed: bool = Query(default=False),
only_trashed: bool = Query(default=False),
path_prefix: str | None = Query(default=None),
path_filter: str | None = Query(default=None),
tag_filter: str | None = Query(default=None),
type_filter: str | None = Query(default=None),
processed_from: str | None = Query(default=None),
processed_to: str | None = Query(default=None),
session: Session = Depends(get_session),
) -> DocumentsListResponse:
"""Returns paginated documents ordered by newest upload timestamp."""
base_statement = _build_document_list_statement(
only_trashed=only_trashed,
include_trashed=include_trashed,
path_prefix=path_prefix,
)
base_statement = _apply_discovery_filters(
base_statement,
path_filter=path_filter,
tag_filter=tag_filter,
type_filter=type_filter,
processed_from=processed_from,
processed_to=processed_to,
)
statement = base_statement.order_by(Document.created_at.desc()).offset(offset).limit(limit)
items = session.execute(statement).scalars().all()
count_statement = select(func.count()).select_from(base_statement.subquery())
total = session.execute(count_statement).scalar_one()
return DocumentsListResponse(total=total, items=[DocumentResponse.model_validate(item) for item in items])
@router.get("/tags")
def list_tags(
include_trashed: bool = Query(default=False),
session: Session = Depends(get_session),
) -> dict[str, list[str]]:
"""Returns distinct tags currently assigned across all matching documents."""
statement = select(Document.tags)
if not include_trashed:
statement = statement.where(Document.status != DocumentStatus.TRASHED)
rows = session.execute(statement).scalars().all()
tags = {tag for row in rows for tag in row if tag}
tags.update(
str(item.get("value", "")).strip()
for item in read_predefined_tags_settings()
if str(item.get("value", "")).strip()
)
tags = sorted(tags)
return {"tags": tags}
@router.get("/paths")
def list_paths(
include_trashed: bool = Query(default=False),
session: Session = Depends(get_session),
) -> dict[str, list[str]]:
"""Returns distinct logical paths currently assigned across all matching documents."""
statement = select(Document.logical_path)
if not include_trashed:
statement = statement.where(Document.status != DocumentStatus.TRASHED)
rows = session.execute(statement).scalars().all()
paths = {row for row in rows if row}
paths.update(
str(item.get("value", "")).strip()
for item in read_predefined_paths_settings()
if str(item.get("value", "")).strip()
)
paths = sorted(paths)
return {"paths": paths}
@router.get("/types")
def list_types(
include_trashed: bool = Query(default=False),
session: Session = Depends(get_session),
) -> dict[str, list[str]]:
"""Returns distinct document type values from extension, MIME, and image text type."""
statement = select(Document.extension, Document.mime_type, Document.image_text_type)
if not include_trashed:
statement = statement.where(Document.status != DocumentStatus.TRASHED)
rows = session.execute(statement).all()
values: set[str] = set()
for extension, mime_type, image_text_type in rows:
for candidate in (extension, mime_type, image_text_type):
normalized = str(candidate).strip().lower() if isinstance(candidate, str) else ""
if normalized:
values.add(normalized)
return {"types": sorted(values)}
@router.post("/content-md/export")
def export_contents_markdown(
payload: ContentExportRequest,
session: Session = Depends(get_session),
) -> StreamingResponse:
"""Exports extracted contents for selected documents as individual markdown files in a ZIP archive."""
has_document_ids = len(payload.document_ids) > 0
has_path_prefix = bool(payload.path_prefix and payload.path_prefix.strip())
if not has_document_ids and not has_path_prefix:
raise HTTPException(status_code=400, detail="Provide document_ids or path_prefix for export")
statement = select(Document)
if has_document_ids:
statement = statement.where(Document.id.in_(payload.document_ids))
if has_path_prefix:
statement = statement.where(Document.logical_path.ilike(f"{payload.path_prefix.strip()}%"))
if payload.only_trashed:
statement = statement.where(Document.status == DocumentStatus.TRASHED)
elif not payload.include_trashed:
statement = statement.where(Document.status != DocumentStatus.TRASHED)
documents = session.execute(statement.order_by(Document.logical_path.asc(), Document.created_at.asc())).scalars().all()
if not documents:
raise HTTPException(status_code=404, detail="No matching documents found for export")
archive_buffer = io.BytesIO()
used_entries: set[str] = set()
with zipfile.ZipFile(archive_buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
for document in documents:
entry_name = _zip_entry_name(document, used_entries)
archive.writestr(entry_name, _markdown_for_document(document))
archive_buffer.seek(0)
headers = {"Content-Disposition": 'attachment; filename="document-contents-md.zip"'}
return StreamingResponse(archive_buffer, media_type="application/zip", headers=headers)
@router.get("/{document_id}", response_model=DocumentDetailResponse)
def get_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentDetailResponse:
"""Returns one document by unique identifier."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentDetailResponse.model_validate(document)
@router.get("/{document_id}/download")
def download_document(document_id: UUID, session: Session = Depends(get_session)) -> FileResponse:
"""Downloads original document bytes for the requested document identifier."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
file_path = absolute_path(document.stored_relative_path)
return FileResponse(path=file_path, filename=document.original_filename, media_type=document.mime_type)
@router.get("/{document_id}/preview")
def preview_document(document_id: UUID, session: Session = Depends(get_session)) -> FileResponse:
"""Streams the original document inline when browser rendering is supported."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
original_path = absolute_path(document.stored_relative_path)
return FileResponse(path=original_path, media_type=document.mime_type)
@router.get("/{document_id}/thumbnail")
def thumbnail_document(document_id: UUID, session: Session = Depends(get_session)) -> FileResponse:
"""Returns a generated thumbnail image for dashboard card previews."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
preview_relative_path = document.metadata_json.get("preview_relative_path")
if not preview_relative_path:
raise HTTPException(status_code=404, detail="Thumbnail not available")
preview_path = absolute_path(preview_relative_path)
if not preview_path.exists():
raise HTTPException(status_code=404, detail="Thumbnail file not found")
return FileResponse(path=preview_path)
@router.get("/{document_id}/content-md")
def download_document_content_markdown(document_id: UUID, session: Session = Depends(get_session)) -> Response:
"""Downloads extracted content for one document as a markdown file."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
markdown_content = _markdown_for_document(document)
filename = _markdown_filename(document)
headers = {"Content-Disposition": f'attachment; filename="{filename}"'}
return Response(content=markdown_content, media_type="text/markdown; charset=utf-8", headers=headers)
@router.post("/upload", response_model=UploadResponse)
async def upload_documents(
files: Annotated[list[UploadFile], File(description="Files to upload")],
relative_paths: Annotated[list[str] | None, Form()] = None,
logical_path: Annotated[str, Form()] = "Inbox",
tags: Annotated[str | None, Form()] = None,
conflict_mode: Annotated[Literal["ask", "replace", "duplicate"], Form()] = "ask",
session: Session = Depends(get_session),
) -> UploadResponse:
"""Uploads files, records metadata, and enqueues asynchronous extraction tasks."""
set_processing_log_autocommit(session, True)
normalized_tags = _normalize_tags(tags)
queue = get_processing_queue()
uploaded: list[DocumentResponse] = []
conflicts: list[UploadConflict] = []
indexed_relative_paths = relative_paths or []
prepared_uploads: list[dict[str, object]] = []
for idx, file in enumerate(files):
filename = file.filename or f"uploaded_{idx}"
data = await file.read()
sha256 = compute_sha256(data)
source_relative_path = indexed_relative_paths[idx] if idx < len(indexed_relative_paths) else filename
extension = Path(filename).suffix.lower()
detected_mime = sniff_mime(data)
log_processing_event(
session=session,
stage="upload",
event="Upload request received",
level="info",
document_filename=filename,
payload_json={
"source_relative_path": source_relative_path,
"logical_path": logical_path,
"tags": normalized_tags,
"mime_type": detected_mime,
"size_bytes": len(data),
"conflict_mode": conflict_mode,
},
)
prepared_uploads.append(
{
"filename": filename,
"data": data,
"sha256": sha256,
"source_relative_path": source_relative_path,
"extension": extension,
"mime_type": detected_mime,
}
)
existing = session.execute(select(Document).where(Document.sha256 == sha256)).scalar_one_or_none()
if existing and conflict_mode == "ask":
log_processing_event(
session=session,
stage="upload",
event="Upload conflict detected",
level="warning",
document_id=existing.id,
document_filename=filename,
payload_json={
"sha256": sha256,
"existing_document_id": str(existing.id),
},
)
conflicts.append(
UploadConflict(
original_filename=filename,
sha256=sha256,
existing_document_id=existing.id,
)
)
if conflicts and conflict_mode == "ask":
session.commit()
return UploadResponse(uploaded=[], conflicts=conflicts)
for prepared in prepared_uploads:
existing = session.execute(
select(Document).where(Document.sha256 == str(prepared["sha256"]))
).scalar_one_or_none()
replaces_document_id = existing.id if existing and conflict_mode == "replace" else None
stored_relative_path = store_bytes(str(prepared["filename"]), bytes(prepared["data"]))
document = Document(
original_filename=str(prepared["filename"]),
source_relative_path=str(prepared["source_relative_path"]),
stored_relative_path=stored_relative_path,
mime_type=str(prepared["mime_type"]),
extension=str(prepared["extension"]),
sha256=str(prepared["sha256"]),
size_bytes=len(bytes(prepared["data"])),
logical_path=logical_path,
tags=list(normalized_tags),
replaces_document_id=replaces_document_id,
metadata_json={"upload": "web"},
)
session.add(document)
session.flush()
queue.enqueue("app.worker.tasks.process_document_task", str(document.id))
log_processing_event(
session=session,
stage="upload",
event="Document record created and queued",
level="info",
document=document,
payload_json={
"source_relative_path": document.source_relative_path,
"stored_relative_path": document.stored_relative_path,
"logical_path": document.logical_path,
"tags": list(document.tags),
"replaces_document_id": str(replaces_document_id) if replaces_document_id is not None else None,
},
)
uploaded.append(DocumentResponse.model_validate(document))
session.commit()
return UploadResponse(uploaded=uploaded, conflicts=conflicts)
@router.patch("/{document_id}", response_model=DocumentResponse)
def update_document(
document_id: UUID,
payload: DocumentUpdateRequest,
session: Session = Depends(get_session),
) -> DocumentResponse:
"""Updates document metadata and refreshes semantic index representation."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
if payload.original_filename is not None:
document.original_filename = _sanitize_filename(payload.original_filename)
if payload.logical_path is not None:
document.logical_path = payload.logical_path.strip() or "Inbox"
if payload.tags is not None:
document.tags = list(dict.fromkeys([tag.strip() for tag in payload.tags if tag.strip()]))[:50]
try:
upsert_document_index(document=document, summary_text=_summary_for_index(document))
except Exception:
pass
session.commit()
session.refresh(document)
return DocumentResponse.model_validate(document)
@router.post("/{document_id}/trash", response_model=DocumentResponse)
def trash_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentResponse:
"""Marks a document as trashed without deleting files from storage."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
if document.status != DocumentStatus.TRASHED:
document.metadata_json = {
**document.metadata_json,
"status_before_trash": document.status.value,
}
document.status = DocumentStatus.TRASHED
try:
upsert_document_index(document=document, summary_text=_summary_for_index(document))
except Exception:
pass
session.commit()
session.refresh(document)
return DocumentResponse.model_validate(document)
@router.post("/{document_id}/restore", response_model=DocumentResponse)
def restore_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentResponse:
"""Restores a trashed document to its previous lifecycle status."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
if document.status == DocumentStatus.TRASHED:
fallback = DocumentStatus.PROCESSED if document.processed_at else DocumentStatus.QUEUED
restored_status = _resolve_previous_status(document.metadata_json, fallback)
document.status = restored_status
metadata_json = dict(document.metadata_json)
metadata_json.pop("status_before_trash", None)
document.metadata_json = metadata_json
try:
upsert_document_index(document=document, summary_text=_summary_for_index(document))
except Exception:
pass
session.commit()
session.refresh(document)
return DocumentResponse.model_validate(document)
@router.delete("/{document_id}")
def delete_document(document_id: UUID, session: Session = Depends(get_session)) -> dict[str, int]:
"""Permanently deletes a document and all descendant archive members including stored files."""
root = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if root is None:
raise HTTPException(status_code=404, detail="Document not found")
if root.status != DocumentStatus.TRASHED:
raise HTTPException(status_code=400, detail="Move document to trash before permanent deletion")
document_tree = _collect_document_tree(session=session, root_document_id=document_id)
document_ids = [document.id for _, document in document_tree]
try:
delete_many_documents_index([str(current_id) for current_id in document_ids])
except Exception:
pass
try:
delete_many_handwriting_style_documents([str(current_id) for current_id in document_ids])
except Exception:
pass
deleted_files = 0
for _, document in document_tree:
source_path = absolute_path(document.stored_relative_path)
if source_path.exists() and source_path.is_file():
source_path.unlink(missing_ok=True)
deleted_files += 1
preview_relative_path = document.metadata_json.get("preview_relative_path")
if isinstance(preview_relative_path, str):
preview_path = absolute_path(preview_relative_path)
if preview_path.exists() and preview_path.is_file():
preview_path.unlink(missing_ok=True)
session.delete(document)
session.commit()
return {"deleted_documents": len(document_tree), "deleted_files": deleted_files}
@router.post("/{document_id}/reprocess", response_model=DocumentResponse)
def reprocess_document(document_id: UUID, session: Session = Depends(get_session)) -> DocumentResponse:
"""Re-enqueues a document for extraction and suggestion processing."""
document = session.execute(select(Document).where(Document.id == document_id)).scalar_one_or_none()
if document is None:
raise HTTPException(status_code=404, detail="Document not found")
if document.status == DocumentStatus.TRASHED:
raise HTTPException(status_code=400, detail="Restore document before reprocessing")
queue = get_processing_queue()
document.status = DocumentStatus.QUEUED
try:
upsert_document_index(document=document, summary_text=_summary_for_index(document))
except Exception:
pass
session.commit()
queue.enqueue("app.worker.tasks.process_document_task", str(document.id))
session.refresh(document)
return DocumentResponse.model_validate(document)

View File

@@ -0,0 +1,13 @@
"""Health and readiness endpoints for orchestration and uptime checks."""
from fastapi import APIRouter
router = APIRouter(prefix="/health", tags=["health"])
@router.get("")
def health() -> dict[str, str]:
"""Returns service liveness status."""
return {"status": "ok"}

View File

@@ -0,0 +1,66 @@
"""Read-only API endpoints for processing pipeline event logs."""
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.db.base import get_session
from app.schemas.processing_logs import ProcessingLogEntryResponse, ProcessingLogListResponse
from app.services.processing_logs import (
cleanup_processing_logs,
clear_processing_logs,
count_processing_logs,
list_processing_logs,
)
router = APIRouter()
@router.get("", response_model=ProcessingLogListResponse)
def get_processing_logs(
offset: int = Query(default=0, ge=0),
limit: int = Query(default=120, ge=1, le=400),
document_id: UUID | None = Query(default=None),
session: Session = Depends(get_session),
) -> ProcessingLogListResponse:
"""Returns paginated processing logs ordered from newest to oldest."""
items = list_processing_logs(
session=session,
limit=limit,
offset=offset,
document_id=document_id,
)
total = count_processing_logs(session=session, document_id=document_id)
return ProcessingLogListResponse(
total=total,
items=[ProcessingLogEntryResponse.model_validate(item) for item in items],
)
@router.post("/trim")
def trim_processing_logs(
keep_document_sessions: int = Query(default=2, ge=0, le=20),
keep_unbound_entries: int = Query(default=80, ge=0, le=400),
session: Session = Depends(get_session),
) -> dict[str, int]:
"""Deletes old processing logs while keeping recent document sessions and unbound events."""
result = cleanup_processing_logs(
session=session,
keep_document_sessions=keep_document_sessions,
keep_unbound_entries=keep_unbound_entries,
)
session.commit()
return result
@router.post("/clear")
def clear_all_processing_logs(session: Session = Depends(get_session)) -> dict[str, int]:
"""Deletes all processing logs to reset the diagnostics timeline."""
result = clear_processing_logs(session=session)
session.commit()
return result

View File

@@ -0,0 +1,84 @@
"""Search endpoints for full-text and metadata document discovery."""
from fastapi import APIRouter, Depends, Query
from sqlalchemy import Text, cast, func, select
from sqlalchemy.orm import Session
from app.api.routes_documents import _apply_discovery_filters
from app.db.base import get_session
from app.models.document import Document, DocumentStatus
from app.schemas.documents import DocumentResponse, SearchResponse
router = APIRouter()
@router.get("", response_model=SearchResponse)
def search_documents(
query: str = Query(min_length=2),
offset: int = Query(default=0, ge=0),
limit: int = Query(default=50, ge=1, le=200),
include_trashed: bool = Query(default=False),
only_trashed: bool = Query(default=False),
path_filter: str | None = Query(default=None),
tag_filter: str | None = Query(default=None),
type_filter: str | None = Query(default=None),
processed_from: str | None = Query(default=None),
processed_to: str | None = Query(default=None),
session: Session = Depends(get_session),
) -> SearchResponse:
"""Searches documents using PostgreSQL full-text ranking plus metadata matching."""
vector = func.to_tsvector(
"simple",
func.coalesce(Document.original_filename, "")
+ " "
+ func.coalesce(Document.logical_path, "")
+ " "
+ func.coalesce(Document.extracted_text, "")
+ " "
+ func.coalesce(cast(Document.tags, Text), ""),
)
ts_query = func.plainto_tsquery("simple", query)
rank = func.ts_rank_cd(vector, ts_query)
search_filter = (
vector.op("@@")(ts_query)
| Document.original_filename.ilike(f"%{query}%")
| Document.logical_path.ilike(f"%{query}%")
| cast(Document.tags, Text).ilike(f"%{query}%")
)
statement = select(Document).where(search_filter)
if only_trashed:
statement = statement.where(Document.status == DocumentStatus.TRASHED)
elif not include_trashed:
statement = statement.where(Document.status != DocumentStatus.TRASHED)
statement = _apply_discovery_filters(
statement,
path_filter=path_filter,
tag_filter=tag_filter,
type_filter=type_filter,
processed_from=processed_from,
processed_to=processed_to,
)
statement = statement.order_by(rank.desc(), Document.created_at.desc()).offset(offset).limit(limit)
items = session.execute(statement).scalars().all()
count_statement = select(func.count(Document.id)).where(search_filter)
if only_trashed:
count_statement = count_statement.where(Document.status == DocumentStatus.TRASHED)
elif not include_trashed:
count_statement = count_statement.where(Document.status != DocumentStatus.TRASHED)
count_statement = _apply_discovery_filters(
count_statement,
path_filter=path_filter,
tag_filter=tag_filter,
type_filter=type_filter,
processed_from=processed_from,
processed_to=processed_to,
)
total = session.execute(count_statement).scalar_one()
return SearchResponse(total=total, items=[DocumentResponse.model_validate(item) for item in items])

View File

@@ -0,0 +1,232 @@
"""API routes for managing persistent single-user application settings."""
from fastapi import APIRouter
from app.schemas.settings import (
AppSettingsUpdateRequest,
AppSettingsResponse,
DisplaySettingsResponse,
HandwritingSettingsResponse,
HandwritingStyleSettingsResponse,
HandwritingSettingsUpdateRequest,
OcrTaskSettingsResponse,
ProviderSettingsResponse,
RoutingTaskSettingsResponse,
SummaryTaskSettingsResponse,
TaskSettingsResponse,
UploadDefaultsResponse,
)
from app.services.app_settings import (
TASK_OCR_HANDWRITING,
TASK_ROUTING_CLASSIFICATION,
TASK_SUMMARY_GENERATION,
read_app_settings,
reset_app_settings,
update_app_settings,
update_handwriting_settings,
)
router = APIRouter()
def _build_response(payload: dict) -> AppSettingsResponse:
"""Converts internal settings dictionaries into API response models."""
upload_defaults_payload = payload.get("upload_defaults", {})
display_payload = payload.get("display", {})
providers_payload = payload.get("providers", [])
tasks_payload = payload.get("tasks", {})
handwriting_style_payload = payload.get("handwriting_style_clustering", {})
ocr_payload = tasks_payload.get(TASK_OCR_HANDWRITING, {})
summary_payload = tasks_payload.get(TASK_SUMMARY_GENERATION, {})
routing_payload = tasks_payload.get(TASK_ROUTING_CLASSIFICATION, {})
return AppSettingsResponse(
upload_defaults=UploadDefaultsResponse(
logical_path=str(upload_defaults_payload.get("logical_path", "Inbox")),
tags=[
str(tag).strip()
for tag in upload_defaults_payload.get("tags", [])
if isinstance(tag, str) and tag.strip()
],
),
display=DisplaySettingsResponse(
cards_per_page=int(display_payload.get("cards_per_page", 12)),
log_typing_animation_enabled=bool(display_payload.get("log_typing_animation_enabled", True)),
),
handwriting_style_clustering=HandwritingStyleSettingsResponse(
enabled=bool(handwriting_style_payload.get("enabled", True)),
embed_model=str(handwriting_style_payload.get("embed_model", "ts/clip-vit-b-p32")),
neighbor_limit=int(handwriting_style_payload.get("neighbor_limit", 8)),
match_min_similarity=float(handwriting_style_payload.get("match_min_similarity", 0.86)),
bootstrap_match_min_similarity=float(
handwriting_style_payload.get("bootstrap_match_min_similarity", 0.89)
),
bootstrap_sample_size=int(handwriting_style_payload.get("bootstrap_sample_size", 3)),
image_max_side=int(handwriting_style_payload.get("image_max_side", 1024)),
),
predefined_paths=[
{
"value": str(item.get("value", "")).strip(),
"global_shared": bool(item.get("global_shared", False)),
}
for item in payload.get("predefined_paths", [])
if isinstance(item, dict) and str(item.get("value", "")).strip()
],
predefined_tags=[
{
"value": str(item.get("value", "")).strip(),
"global_shared": bool(item.get("global_shared", False)),
}
for item in payload.get("predefined_tags", [])
if isinstance(item, dict) and str(item.get("value", "")).strip()
],
providers=[
ProviderSettingsResponse(
id=str(provider.get("id", "")),
label=str(provider.get("label", "")),
provider_type=str(provider.get("provider_type", "openai_compatible")),
base_url=str(provider.get("base_url", "https://api.openai.com/v1")),
timeout_seconds=int(provider.get("timeout_seconds", 45)),
api_key_set=bool(provider.get("api_key_set", False)),
api_key_masked=str(provider.get("api_key_masked", "")),
)
for provider in providers_payload
],
tasks=TaskSettingsResponse(
ocr_handwriting=OcrTaskSettingsResponse(
enabled=bool(ocr_payload.get("enabled", True)),
provider_id=str(ocr_payload.get("provider_id", "openai-default")),
model=str(ocr_payload.get("model", "gpt-4.1-mini")),
prompt=str(ocr_payload.get("prompt", "")),
),
summary_generation=SummaryTaskSettingsResponse(
enabled=bool(summary_payload.get("enabled", True)),
provider_id=str(summary_payload.get("provider_id", "openai-default")),
model=str(summary_payload.get("model", "gpt-4.1-mini")),
prompt=str(summary_payload.get("prompt", "")),
max_input_tokens=int(summary_payload.get("max_input_tokens", 8000)),
),
routing_classification=RoutingTaskSettingsResponse(
enabled=bool(routing_payload.get("enabled", True)),
provider_id=str(routing_payload.get("provider_id", "openai-default")),
model=str(routing_payload.get("model", "gpt-4.1-mini")),
prompt=str(routing_payload.get("prompt", "")),
neighbor_count=int(routing_payload.get("neighbor_count", 8)),
neighbor_min_similarity=float(routing_payload.get("neighbor_min_similarity", 0.84)),
auto_apply_confidence_threshold=float(routing_payload.get("auto_apply_confidence_threshold", 0.78)),
auto_apply_neighbor_similarity_threshold=float(
routing_payload.get("auto_apply_neighbor_similarity_threshold", 0.55)
),
neighbor_path_override_enabled=bool(routing_payload.get("neighbor_path_override_enabled", True)),
neighbor_path_override_min_similarity=float(
routing_payload.get("neighbor_path_override_min_similarity", 0.86)
),
neighbor_path_override_min_gap=float(routing_payload.get("neighbor_path_override_min_gap", 0.04)),
neighbor_path_override_max_confidence=float(
routing_payload.get("neighbor_path_override_max_confidence", 0.9)
),
),
),
)
@router.get("", response_model=AppSettingsResponse)
def get_app_settings() -> AppSettingsResponse:
"""Returns persisted provider and per-task settings configuration."""
return _build_response(read_app_settings())
@router.patch("", response_model=AppSettingsResponse)
def set_app_settings(payload: AppSettingsUpdateRequest) -> AppSettingsResponse:
"""Updates providers and task settings and returns resulting persisted configuration."""
providers_payload = None
if payload.providers is not None:
providers_payload = [provider.model_dump() for provider in payload.providers]
tasks_payload = None
if payload.tasks is not None:
tasks_payload = payload.tasks.model_dump(exclude_none=True)
upload_defaults_payload = None
if payload.upload_defaults is not None:
upload_defaults_payload = payload.upload_defaults.model_dump(exclude_none=True)
display_payload = None
if payload.display is not None:
display_payload = payload.display.model_dump(exclude_none=True)
handwriting_style_payload = None
if payload.handwriting_style_clustering is not None:
handwriting_style_payload = payload.handwriting_style_clustering.model_dump(exclude_none=True)
predefined_paths_payload = None
if payload.predefined_paths is not None:
predefined_paths_payload = [item.model_dump(exclude_none=True) for item in payload.predefined_paths]
predefined_tags_payload = None
if payload.predefined_tags is not None:
predefined_tags_payload = [item.model_dump(exclude_none=True) for item in payload.predefined_tags]
updated = update_app_settings(
providers=providers_payload,
tasks=tasks_payload,
upload_defaults=upload_defaults_payload,
display=display_payload,
handwriting_style=handwriting_style_payload,
predefined_paths=predefined_paths_payload,
predefined_tags=predefined_tags_payload,
)
return _build_response(updated)
@router.post("/reset", response_model=AppSettingsResponse)
def reset_settings_to_defaults() -> AppSettingsResponse:
"""Resets all persisted settings to default providers and task bindings."""
return _build_response(reset_app_settings())
@router.patch("/handwriting", response_model=AppSettingsResponse)
def set_handwriting_settings(payload: HandwritingSettingsUpdateRequest) -> AppSettingsResponse:
"""Updates handwriting transcription settings and returns the resulting configuration."""
updated = update_handwriting_settings(
enabled=payload.enabled,
openai_base_url=payload.openai_base_url,
openai_model=payload.openai_model,
openai_timeout_seconds=payload.openai_timeout_seconds,
openai_api_key=payload.openai_api_key,
clear_openai_api_key=payload.clear_openai_api_key,
)
return _build_response(updated)
@router.get("/handwriting", response_model=HandwritingSettingsResponse)
def get_handwriting_settings() -> HandwritingSettingsResponse:
"""Returns legacy handwriting response shape for compatibility with older clients."""
payload = _build_response(read_app_settings())
fallback_provider = ProviderSettingsResponse(
id="openai-default",
label="OpenAI Default",
provider_type="openai_compatible",
base_url="https://api.openai.com/v1",
timeout_seconds=45,
api_key_set=False,
api_key_masked="",
)
ocr = payload.tasks.ocr_handwriting
provider = next((item for item in payload.providers if item.id == ocr.provider_id), None)
if provider is None:
provider = payload.providers[0] if payload.providers else fallback_provider
return HandwritingSettingsResponse(
provider=provider.provider_type,
enabled=ocr.enabled,
openai_base_url=provider.base_url,
openai_model=ocr.model,
openai_timeout_seconds=provider.timeout_seconds,
openai_api_key_set=provider.api_key_set,
openai_api_key_masked=provider.api_key_masked,
)

View File

@@ -0,0 +1 @@
"""Core settings and shared configuration package."""

View File

@@ -0,0 +1,46 @@
"""Application settings and environment configuration."""
from functools import lru_cache
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Defines runtime configuration values loaded from environment variables."""
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
app_name: str = "dcm-dms"
app_env: str = "development"
database_url: str = "postgresql+psycopg://dcm:dcm@db:5432/dcm"
redis_url: str = "redis://redis:6379/0"
storage_root: Path = Path("/data/storage")
upload_chunk_size: int = 4 * 1024 * 1024
max_zip_members: int = 250
max_zip_depth: int = 2
max_text_length: int = 500_000
default_openai_base_url: str = "https://api.openai.com/v1"
default_openai_model: str = "gpt-4.1-mini"
default_openai_timeout_seconds: int = 45
default_openai_handwriting_enabled: bool = True
default_openai_api_key: str = ""
default_summary_model: str = "gpt-4.1-mini"
default_routing_model: str = "gpt-4.1-mini"
typesense_protocol: str = "http"
typesense_host: str = "typesense"
typesense_port: int = 8108
typesense_api_key: str = "dcm-typesense-key"
typesense_collection_name: str = "documents"
typesense_timeout_seconds: int = 120
typesense_num_retries: int = 0
public_base_url: str = "http://localhost:8000"
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:5173", "http://localhost:3000"])
@lru_cache(maxsize=1)
def get_settings() -> Settings:
"""Returns a cached settings object for dependency injection and service access."""
return Settings()

View File

@@ -0,0 +1 @@
"""Database package exposing engine and session utilities."""

53
backend/app/db/base.py Normal file
View File

@@ -0,0 +1,53 @@
"""Database engine and session utilities for persistence operations."""
from collections.abc import Generator
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session, declarative_base, sessionmaker
from app.core.config import get_settings
Base = declarative_base()
settings = get_settings()
engine = create_engine(settings.database_url, pool_pre_ping=True)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False)
def get_session() -> Generator[Session, None, None]:
"""Provides a transactional database session for FastAPI request handling."""
session = SessionLocal()
try:
yield session
finally:
session.close()
def init_db() -> None:
"""Initializes all ORM tables and search-related database extensions/indexes."""
from app import models # noqa: F401
Base.metadata.create_all(bind=engine)
with engine.begin() as connection:
connection.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
connection.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_documents_text_search
ON documents
USING GIN (
to_tsvector(
'simple',
coalesce(original_filename, '') || ' ' ||
coalesce(logical_path, '') || ' ' ||
coalesce(extracted_text, '')
)
)
"""
)
)
connection.execute(text("CREATE INDEX IF NOT EXISTS idx_documents_sha256 ON documents (sha256)"))

50
backend/app/main.py Normal file
View File

@@ -0,0 +1,50 @@
"""FastAPI entrypoint for the DMS backend service."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.router import api_router
from app.core.config import get_settings
from app.db.base import init_db
from app.services.app_settings import ensure_app_settings
from app.services.handwriting_style import ensure_handwriting_style_collection
from app.services.storage import ensure_storage
from app.services.typesense_index import ensure_typesense_collection
settings = get_settings()
def create_app() -> FastAPI:
"""Builds and configures the FastAPI application instance."""
app = FastAPI(title="DCM DMS API", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(api_router, prefix="/api/v1")
@app.on_event("startup")
def startup_event() -> None:
"""Initializes storage directories and database schema on service startup."""
ensure_storage()
ensure_app_settings()
init_db()
try:
ensure_typesense_collection()
except Exception:
pass
try:
ensure_handwriting_style_collection()
except Exception:
pass
return app
app = create_app()

View File

@@ -0,0 +1,6 @@
"""Model exports for ORM metadata discovery."""
from app.models.document import Document, DocumentStatus
from app.models.processing_log import ProcessingLogEntry
__all__ = ["Document", "DocumentStatus", "ProcessingLogEntry"]

View File

@@ -0,0 +1,65 @@
"""Data model representing a stored and processed document."""
import uuid
from datetime import UTC, datetime
from enum import Enum
from sqlalchemy import Boolean, DateTime, Enum as SqlEnum, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.base import Base
class DocumentStatus(str, Enum):
"""Enumerates processing states for uploaded documents."""
QUEUED = "queued"
PROCESSED = "processed"
UNSUPPORTED = "unsupported"
ERROR = "error"
TRASHED = "trashed"
class Document(Base):
"""Stores file identity, storage paths, extracted content, and classification metadata."""
__tablename__ = "documents"
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
original_filename: Mapped[str] = mapped_column(String(512), nullable=False)
source_relative_path: Mapped[str] = mapped_column(String(1024), nullable=False, default="")
stored_relative_path: Mapped[str] = mapped_column(String(1024), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=False, default="application/octet-stream")
extension: Mapped[str] = mapped_column(String(32), nullable=False, default="")
sha256: Mapped[str] = mapped_column(String(128), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
logical_path: Mapped[str] = mapped_column(String(1024), nullable=False, default="Inbox")
suggested_path: Mapped[str | None] = mapped_column(String(1024), nullable=True)
tags: Mapped[list[str]] = mapped_column(ARRAY(String), nullable=False, default=list)
suggested_tags: Mapped[list[str]] = mapped_column(ARRAY(String), nullable=False, default=list)
metadata_json: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict)
extracted_text: Mapped[str] = mapped_column(Text, nullable=False, default="")
image_text_type: Mapped[str | None] = mapped_column(String(64), nullable=True)
handwriting_style_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
status: Mapped[DocumentStatus] = mapped_column(SqlEnum(DocumentStatus), nullable=False, default=DocumentStatus.QUEUED)
preview_available: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
archived_member_path: Mapped[str | None] = mapped_column(String(1024), nullable=True)
is_archive_member: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
parent_document_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("documents.id"), nullable=True)
replaces_document_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), ForeignKey("documents.id"), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC))
processed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
parent_document: Mapped["Document | None"] = relationship(
"Document",
remote_side="Document.id",
foreign_keys=[parent_document_id],
post_update=True,
)

View File

@@ -0,0 +1,33 @@
"""Data model representing one persisted processing pipeline log entry."""
import uuid
from datetime import UTC, datetime
from sqlalchemy import BigInteger, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from app.db.base import Base
class ProcessingLogEntry(Base):
"""Stores a timestamped processing event with optional model prompt and response text."""
__tablename__ = "processing_logs"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(UTC))
level: Mapped[str] = mapped_column(String(16), nullable=False, default="info")
stage: Mapped[str] = mapped_column(String(64), nullable=False)
event: Mapped[str] = mapped_column(String(256), nullable=False)
document_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("documents.id", ondelete="SET NULL"),
nullable=True,
)
document_filename: Mapped[str | None] = mapped_column(String(512), nullable=True)
provider_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
model_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
prompt_text: Mapped[str | None] = mapped_column(Text, nullable=True)
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
payload_json: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict)

View File

@@ -0,0 +1 @@
"""Pydantic schema package for API request and response models."""

View File

@@ -0,0 +1,92 @@
"""Pydantic schema definitions for document API payloads."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
from app.models.document import DocumentStatus
class DocumentResponse(BaseModel):
"""Represents a document record returned by API endpoints."""
id: UUID
original_filename: str
source_relative_path: str
mime_type: str
extension: str
size_bytes: int
sha256: str
logical_path: str
suggested_path: str | None
image_text_type: str | None
handwriting_style_id: str | None
tags: list[str] = Field(default_factory=list)
suggested_tags: list[str] = Field(default_factory=list)
status: DocumentStatus
preview_available: bool
is_archive_member: bool
archived_member_path: str | None
parent_document_id: UUID | None
replaces_document_id: UUID | None
created_at: datetime
processed_at: datetime | None
class Config:
"""Enables ORM object parsing for SQLAlchemy model instances."""
from_attributes = True
class DocumentDetailResponse(DocumentResponse):
"""Represents a full document payload including extracted text content."""
extracted_text: str
metadata_json: dict
class DocumentsListResponse(BaseModel):
"""Represents a paginated document list response payload."""
total: int
items: list[DocumentResponse]
class UploadConflict(BaseModel):
"""Describes an upload conflict where a matching checksum already exists."""
original_filename: str
sha256: str
existing_document_id: UUID
class UploadResponse(BaseModel):
"""Represents the result of a batch upload request."""
uploaded: list[DocumentResponse] = Field(default_factory=list)
conflicts: list[UploadConflict] = Field(default_factory=list)
class DocumentUpdateRequest(BaseModel):
"""Captures document metadata changes."""
original_filename: str | None = None
logical_path: str | None = None
tags: list[str] | None = None
class SearchResponse(BaseModel):
"""Represents the result of a search query."""
total: int
items: list[DocumentResponse]
class ContentExportRequest(BaseModel):
"""Describes filters used to export extracted document contents as Markdown files."""
document_ids: list[UUID] = Field(default_factory=list)
path_prefix: str | None = None
include_trashed: bool = False
only_trashed: bool = False

View File

@@ -0,0 +1,35 @@
"""Pydantic schemas for processing pipeline log API payloads."""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field
class ProcessingLogEntryResponse(BaseModel):
"""Represents one persisted processing log event returned by API endpoints."""
id: int
created_at: datetime
level: str
stage: str
event: str
document_id: UUID | None
document_filename: str | None
provider_id: str | None
model_name: str | None
prompt_text: str | None
response_text: str | None
payload_json: dict
class Config:
"""Enables ORM object parsing for SQLAlchemy model instances."""
from_attributes = True
class ProcessingLogListResponse(BaseModel):
"""Represents a paginated collection of processing log records."""
total: int
items: list[ProcessingLogEntryResponse] = Field(default_factory=list)

View File

@@ -0,0 +1,242 @@
"""Pydantic schemas for application-level runtime settings."""
from pydantic import BaseModel, Field
class ProviderSettingsResponse(BaseModel):
"""Represents a persisted model provider with non-secret connection metadata."""
id: str
label: str
provider_type: str = "openai_compatible"
base_url: str
timeout_seconds: int
api_key_set: bool
api_key_masked: str = ""
class ProviderSettingsUpdateRequest(BaseModel):
"""Represents a model provider create-or-update request."""
id: str
label: str
provider_type: str = "openai_compatible"
base_url: str
timeout_seconds: int = Field(default=45, ge=5, le=180)
api_key: str | None = None
clear_api_key: bool = False
class OcrTaskSettingsResponse(BaseModel):
"""Represents OCR task runtime settings and prompt configuration."""
enabled: bool
provider_id: str
model: str
prompt: str
class OcrTaskSettingsUpdateRequest(BaseModel):
"""Represents OCR task settings updates."""
enabled: bool | None = None
provider_id: str | None = None
model: str | None = None
prompt: str | None = None
class SummaryTaskSettingsResponse(BaseModel):
"""Represents summarization task runtime settings."""
enabled: bool
provider_id: str
model: str
prompt: str
max_input_tokens: int
class SummaryTaskSettingsUpdateRequest(BaseModel):
"""Represents summarization task settings updates."""
enabled: bool | None = None
provider_id: str | None = None
model: str | None = None
prompt: str | None = None
max_input_tokens: int | None = Field(default=None, ge=512, le=64000)
class RoutingTaskSettingsResponse(BaseModel):
"""Represents routing task runtime settings for path and tag classification."""
enabled: bool
provider_id: str
model: str
prompt: str
neighbor_count: int
neighbor_min_similarity: float
auto_apply_confidence_threshold: float
auto_apply_neighbor_similarity_threshold: float
neighbor_path_override_enabled: bool
neighbor_path_override_min_similarity: float
neighbor_path_override_min_gap: float
neighbor_path_override_max_confidence: float
class RoutingTaskSettingsUpdateRequest(BaseModel):
"""Represents routing task settings updates."""
enabled: bool | None = None
provider_id: str | None = None
model: str | None = None
prompt: str | None = None
neighbor_count: int | None = Field(default=None, ge=1, le=40)
neighbor_min_similarity: float | None = Field(default=None, ge=0.0, le=1.0)
auto_apply_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
auto_apply_neighbor_similarity_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
neighbor_path_override_enabled: bool | None = None
neighbor_path_override_min_similarity: float | None = Field(default=None, ge=0.0, le=1.0)
neighbor_path_override_min_gap: float | None = Field(default=None, ge=0.0, le=1.0)
neighbor_path_override_max_confidence: float | None = Field(default=None, ge=0.0, le=1.0)
class UploadDefaultsResponse(BaseModel):
"""Represents default upload destination and default tags."""
logical_path: str
tags: list[str] = Field(default_factory=list)
class UploadDefaultsUpdateRequest(BaseModel):
"""Represents updates for default upload destination and default tags."""
logical_path: str | None = None
tags: list[str] | None = None
class DisplaySettingsResponse(BaseModel):
"""Represents document-list display preferences."""
cards_per_page: int = Field(default=12, ge=1, le=200)
log_typing_animation_enabled: bool = True
class DisplaySettingsUpdateRequest(BaseModel):
"""Represents updates for document-list display preferences."""
cards_per_page: int | None = Field(default=None, ge=1, le=200)
log_typing_animation_enabled: bool | None = None
class PredefinedPathEntryResponse(BaseModel):
"""Represents one predefined logical path with global discoverability scope."""
value: str
global_shared: bool
class PredefinedPathEntryUpdateRequest(BaseModel):
"""Represents one predefined logical path create-or-update request."""
value: str
global_shared: bool = False
class PredefinedTagEntryResponse(BaseModel):
"""Represents one predefined tag with global discoverability scope."""
value: str
global_shared: bool
class PredefinedTagEntryUpdateRequest(BaseModel):
"""Represents one predefined tag create-or-update request."""
value: str
global_shared: bool = False
class HandwritingStyleSettingsResponse(BaseModel):
"""Represents handwriting-style clustering settings used by Typesense image embeddings."""
enabled: bool
embed_model: str
neighbor_limit: int
match_min_similarity: float
bootstrap_match_min_similarity: float
bootstrap_sample_size: int
image_max_side: int
class HandwritingStyleSettingsUpdateRequest(BaseModel):
"""Represents updates for handwriting-style clustering and match thresholds."""
enabled: bool | None = None
embed_model: str | None = None
neighbor_limit: int | None = Field(default=None, ge=1, le=32)
match_min_similarity: float | None = Field(default=None, ge=0.0, le=1.0)
bootstrap_match_min_similarity: float | None = Field(default=None, ge=0.0, le=1.0)
bootstrap_sample_size: int | None = Field(default=None, ge=1, le=30)
image_max_side: int | None = Field(default=None, ge=256, le=4096)
class TaskSettingsResponse(BaseModel):
"""Represents all task-level model bindings and prompt settings."""
ocr_handwriting: OcrTaskSettingsResponse
summary_generation: SummaryTaskSettingsResponse
routing_classification: RoutingTaskSettingsResponse
class TaskSettingsUpdateRequest(BaseModel):
"""Represents partial updates for task-level settings."""
ocr_handwriting: OcrTaskSettingsUpdateRequest | None = None
summary_generation: SummaryTaskSettingsUpdateRequest | None = None
routing_classification: RoutingTaskSettingsUpdateRequest | None = None
class AppSettingsResponse(BaseModel):
"""Represents all application settings exposed by the API."""
upload_defaults: UploadDefaultsResponse
display: DisplaySettingsResponse
handwriting_style_clustering: HandwritingStyleSettingsResponse
predefined_paths: list[PredefinedPathEntryResponse] = Field(default_factory=list)
predefined_tags: list[PredefinedTagEntryResponse] = Field(default_factory=list)
providers: list[ProviderSettingsResponse]
tasks: TaskSettingsResponse
class AppSettingsUpdateRequest(BaseModel):
"""Represents full settings update input for providers and task bindings."""
upload_defaults: UploadDefaultsUpdateRequest | None = None
display: DisplaySettingsUpdateRequest | None = None
handwriting_style_clustering: HandwritingStyleSettingsUpdateRequest | None = None
predefined_paths: list[PredefinedPathEntryUpdateRequest] | None = None
predefined_tags: list[PredefinedTagEntryUpdateRequest] | None = None
providers: list[ProviderSettingsUpdateRequest] | None = None
tasks: TaskSettingsUpdateRequest | None = None
class HandwritingSettingsResponse(BaseModel):
"""Represents legacy handwriting response shape kept for backward compatibility."""
provider: str = "openai_compatible"
enabled: bool
openai_base_url: str
openai_model: str
openai_timeout_seconds: int
openai_api_key_set: bool
openai_api_key_masked: str = ""
class HandwritingSettingsUpdateRequest(BaseModel):
"""Represents legacy handwriting update shape kept for backward compatibility."""
enabled: bool | None = None
openai_base_url: str | None = None
openai_model: str | None = None
openai_timeout_seconds: int | None = Field(default=None, ge=5, le=180)
openai_api_key: str | None = None
clear_openai_api_key: bool = False

View File

@@ -0,0 +1 @@
"""Domain services package for storage, extraction, and classification logic."""

View File

@@ -0,0 +1,885 @@
"""Persistent single-user application settings service backed by host-mounted storage."""
import json
import re
from pathlib import Path
from typing import Any
from app.core.config import get_settings
settings = get_settings()
TASK_OCR_HANDWRITING = "ocr_handwriting"
TASK_SUMMARY_GENERATION = "summary_generation"
TASK_ROUTING_CLASSIFICATION = "routing_classification"
HANDWRITING_STYLE_SETTINGS_KEY = "handwriting_style_clustering"
PREDEFINED_PATHS_SETTINGS_KEY = "predefined_paths"
PREDEFINED_TAGS_SETTINGS_KEY = "predefined_tags"
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL = "ts/clip-vit-b-p32"
DEFAULT_OCR_PROMPT = (
"You are an expert at reading messy handwritten notes, including hard-to-read writing.\n"
"Task: transcribe the handwriting as exactly as possible.\n\n"
"Rules:\n"
"- Output ONLY the transcription in German, no commentary.\n"
"- Preserve original line breaks where they clearly exist.\n"
"- Do NOT translate or correct grammar or spelling.\n"
"- If a word or character is unclear, wrap your best guess in [[? ... ?]].\n"
"- If something is unreadable, write [[?unleserlich?]] in its place."
)
DEFAULT_SUMMARY_PROMPT = (
"You summarize documents for indexing and routing.\n"
"Return concise markdown with key entities, purpose, and document category hints.\n"
"Do not invent facts and do not include any explanation outside the summary."
)
DEFAULT_ROUTING_PROMPT = (
"You classify one document into an existing logical path and tags.\n"
"Prefer existing paths and tags when possible.\n"
"If the evidence is weak, keep chosen_path as null and use suggestions instead.\n"
"Return JSON only with this exact shape:\n"
"{\n"
" \"chosen_path\": string | null,\n"
" \"chosen_tags\": string[],\n"
" \"suggested_new_paths\": string[],\n"
" \"suggested_new_tags\": string[],\n"
" \"confidence\": number\n"
"}\n"
"Confidence must be between 0 and 1."
)
def _default_settings() -> dict[str, Any]:
"""Builds default settings including providers and model task bindings."""
return {
"upload_defaults": {
"logical_path": "Inbox",
"tags": [],
},
"display": {
"cards_per_page": 12,
"log_typing_animation_enabled": True,
},
PREDEFINED_PATHS_SETTINGS_KEY: [],
PREDEFINED_TAGS_SETTINGS_KEY: [],
HANDWRITING_STYLE_SETTINGS_KEY: {
"enabled": True,
"embed_model": DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
"neighbor_limit": 8,
"match_min_similarity": 0.86,
"bootstrap_match_min_similarity": 0.89,
"bootstrap_sample_size": 3,
"image_max_side": 1024,
},
"providers": [
{
"id": "openai-default",
"label": "OpenAI Default",
"provider_type": "openai_compatible",
"base_url": settings.default_openai_base_url,
"timeout_seconds": settings.default_openai_timeout_seconds,
"api_key": settings.default_openai_api_key,
}
],
"tasks": {
TASK_OCR_HANDWRITING: {
"enabled": settings.default_openai_handwriting_enabled,
"provider_id": "openai-default",
"model": settings.default_openai_model,
"prompt": DEFAULT_OCR_PROMPT,
},
TASK_SUMMARY_GENERATION: {
"enabled": True,
"provider_id": "openai-default",
"model": settings.default_summary_model,
"prompt": DEFAULT_SUMMARY_PROMPT,
"max_input_tokens": 8000,
},
TASK_ROUTING_CLASSIFICATION: {
"enabled": True,
"provider_id": "openai-default",
"model": settings.default_routing_model,
"prompt": DEFAULT_ROUTING_PROMPT,
"neighbor_count": 8,
"neighbor_min_similarity": 0.84,
"auto_apply_confidence_threshold": 0.78,
"auto_apply_neighbor_similarity_threshold": 0.55,
"neighbor_path_override_enabled": True,
"neighbor_path_override_min_similarity": 0.86,
"neighbor_path_override_min_gap": 0.04,
"neighbor_path_override_max_confidence": 0.9,
},
},
}
def _settings_path() -> Path:
"""Returns the absolute path of the persisted settings file."""
return settings.storage_root / "settings.json"
def _clamp_timeout(value: int) -> int:
"""Clamps timeout values to a safe and practical range."""
return max(5, min(180, value))
def _clamp_input_tokens(value: int) -> int:
"""Clamps per-request summary input token budget values to practical bounds."""
return max(512, min(64000, value))
def _clamp_neighbor_count(value: int) -> int:
"""Clamps nearest-neighbor lookup count for routing classification."""
return max(1, min(40, value))
def _clamp_cards_per_page(value: int) -> int:
"""Clamps dashboard cards-per-page display setting to practical bounds."""
return max(1, min(200, value))
def _clamp_predefined_entries_limit(value: int) -> int:
"""Clamps maximum count for predefined tag/path catalog entries."""
return max(1, min(2000, value))
def _clamp_handwriting_style_neighbor_limit(value: int) -> int:
"""Clamps handwriting-style nearest-neighbor count used for style matching."""
return max(1, min(32, value))
def _clamp_handwriting_style_sample_size(value: int) -> int:
"""Clamps handwriting-style bootstrap sample size used for stricter matching."""
return max(1, min(30, value))
def _clamp_handwriting_style_image_max_side(value: int) -> int:
"""Clamps handwriting-style image normalization max-side pixel size."""
return max(256, min(4096, value))
def _clamp_probability(value: float, fallback: float) -> float:
"""Clamps probability-like numbers to the range [0, 1]."""
try:
parsed = float(value)
except (TypeError, ValueError):
return fallback
return max(0.0, min(1.0, parsed))
def _safe_int(value: Any, fallback: int) -> int:
"""Safely converts arbitrary values to integers with fallback handling."""
try:
return int(value)
except (TypeError, ValueError):
return fallback
def _normalize_provider_id(value: str | None, fallback: str) -> str:
"""Normalizes provider identifiers into stable lowercase slug values."""
candidate = (value or "").strip().lower()
candidate = re.sub(r"[^a-z0-9_-]+", "-", candidate).strip("-")
return candidate or fallback
def _mask_api_key(value: str) -> str:
"""Masks a secret API key while retaining enough characters for identification."""
if not value:
return ""
if len(value) <= 6:
return "*" * len(value)
return f"{value[:4]}...{value[-2:]}"
def _normalize_provider(
payload: dict[str, Any],
fallback_id: str,
fallback_values: dict[str, Any],
) -> dict[str, Any]:
"""Normalizes one provider payload to a stable shape with bounds and defaults."""
defaults = _default_settings()["providers"][0]
provider_id = _normalize_provider_id(str(payload.get("id", fallback_id)), fallback_id)
provider_type = str(payload.get("provider_type", fallback_values.get("provider_type", defaults["provider_type"]))).strip()
if provider_type != "openai_compatible":
provider_type = "openai_compatible"
api_key_value = payload.get("api_key", fallback_values.get("api_key", defaults["api_key"]))
api_key = str(api_key_value).strip() if api_key_value is not None else ""
return {
"id": provider_id,
"label": str(payload.get("label", fallback_values.get("label", provider_id))).strip() or provider_id,
"provider_type": provider_type,
"base_url": str(payload.get("base_url", fallback_values.get("base_url", defaults["base_url"]))).strip()
or defaults["base_url"],
"timeout_seconds": _clamp_timeout(
_safe_int(
payload.get("timeout_seconds", fallback_values.get("timeout_seconds", defaults["timeout_seconds"])),
defaults["timeout_seconds"],
)
),
"api_key": api_key,
}
def _normalize_ocr_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes OCR task settings while enforcing valid provider references."""
defaults = _default_settings()["tasks"][TASK_OCR_HANDWRITING]
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
if provider_id not in provider_ids:
provider_id = provider_ids[0]
return {
"enabled": bool(payload.get("enabled", defaults["enabled"])),
"provider_id": provider_id,
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
}
def _normalize_summary_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes summary task settings while enforcing valid provider references."""
defaults = _default_settings()["tasks"][TASK_SUMMARY_GENERATION]
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
if provider_id not in provider_ids:
provider_id = provider_ids[0]
raw_max_tokens = payload.get("max_input_tokens")
if raw_max_tokens is None:
legacy_chars = _safe_int(payload.get("max_source_chars", 0), 0)
if legacy_chars > 0:
raw_max_tokens = max(512, legacy_chars // 4)
else:
raw_max_tokens = defaults["max_input_tokens"]
return {
"enabled": bool(payload.get("enabled", defaults["enabled"])),
"provider_id": provider_id,
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
"max_input_tokens": _clamp_input_tokens(
_safe_int(raw_max_tokens, defaults["max_input_tokens"])
),
}
def _normalize_routing_task(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes routing task settings while enforcing valid provider references."""
defaults = _default_settings()["tasks"][TASK_ROUTING_CLASSIFICATION]
provider_id = str(payload.get("provider_id", defaults["provider_id"])).strip()
if provider_id not in provider_ids:
provider_id = provider_ids[0]
return {
"enabled": bool(payload.get("enabled", defaults["enabled"])),
"provider_id": provider_id,
"model": str(payload.get("model", defaults["model"])).strip() or defaults["model"],
"prompt": str(payload.get("prompt", defaults["prompt"])).strip() or defaults["prompt"],
"neighbor_count": _clamp_neighbor_count(
_safe_int(payload.get("neighbor_count", defaults["neighbor_count"]), defaults["neighbor_count"])
),
"neighbor_min_similarity": _clamp_probability(
payload.get("neighbor_min_similarity", defaults["neighbor_min_similarity"]),
defaults["neighbor_min_similarity"],
),
"auto_apply_confidence_threshold": _clamp_probability(
payload.get("auto_apply_confidence_threshold", defaults["auto_apply_confidence_threshold"]),
defaults["auto_apply_confidence_threshold"],
),
"auto_apply_neighbor_similarity_threshold": _clamp_probability(
payload.get(
"auto_apply_neighbor_similarity_threshold",
defaults["auto_apply_neighbor_similarity_threshold"],
),
defaults["auto_apply_neighbor_similarity_threshold"],
),
"neighbor_path_override_enabled": bool(
payload.get("neighbor_path_override_enabled", defaults["neighbor_path_override_enabled"])
),
"neighbor_path_override_min_similarity": _clamp_probability(
payload.get(
"neighbor_path_override_min_similarity",
defaults["neighbor_path_override_min_similarity"],
),
defaults["neighbor_path_override_min_similarity"],
),
"neighbor_path_override_min_gap": _clamp_probability(
payload.get("neighbor_path_override_min_gap", defaults["neighbor_path_override_min_gap"]),
defaults["neighbor_path_override_min_gap"],
),
"neighbor_path_override_max_confidence": _clamp_probability(
payload.get(
"neighbor_path_override_max_confidence",
defaults["neighbor_path_override_max_confidence"],
),
defaults["neighbor_path_override_max_confidence"],
),
}
def _normalize_tasks(payload: dict[str, Any], provider_ids: list[str]) -> dict[str, Any]:
"""Normalizes task settings map for OCR, summarization, and routing tasks."""
if not isinstance(payload, dict):
payload = {}
return {
TASK_OCR_HANDWRITING: _normalize_ocr_task(payload.get(TASK_OCR_HANDWRITING, {}), provider_ids),
TASK_SUMMARY_GENERATION: _normalize_summary_task(payload.get(TASK_SUMMARY_GENERATION, {}), provider_ids),
TASK_ROUTING_CLASSIFICATION: _normalize_routing_task(payload.get(TASK_ROUTING_CLASSIFICATION, {}), provider_ids),
}
def _normalize_upload_defaults(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
"""Normalizes upload default destination path and tags."""
if not isinstance(payload, dict):
payload = {}
default_path = str(defaults.get("logical_path", "Inbox")).strip() or "Inbox"
raw_path = str(payload.get("logical_path", default_path)).strip()
logical_path = raw_path or default_path
raw_tags = payload.get("tags", defaults.get("tags", []))
tags: list[str] = []
seen_lowered: set[str] = set()
if isinstance(raw_tags, list):
for raw_tag in raw_tags:
normalized = str(raw_tag).strip()
if not normalized:
continue
lowered = normalized.lower()
if lowered in seen_lowered:
continue
seen_lowered.add(lowered)
tags.append(normalized)
if len(tags) >= 50:
break
return {
"logical_path": logical_path,
"tags": tags,
}
def _normalize_display_settings(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
"""Normalizes display settings used by the document dashboard UI."""
if not isinstance(payload, dict):
payload = {}
default_cards_per_page = _safe_int(defaults.get("cards_per_page", 12), 12)
cards_per_page = _clamp_cards_per_page(
_safe_int(payload.get("cards_per_page", default_cards_per_page), default_cards_per_page)
)
return {
"cards_per_page": cards_per_page,
"log_typing_animation_enabled": bool(
payload.get("log_typing_animation_enabled", defaults.get("log_typing_animation_enabled", True))
),
}
def _normalize_predefined_paths(
payload: Any,
existing_items: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Normalizes predefined path entries and enforces irreversible global-sharing flag."""
existing_map: dict[str, dict[str, Any]] = {}
if isinstance(existing_items, list):
for item in existing_items:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip().strip("/")
if not value:
continue
existing_map[value.lower()] = {
"value": value,
"global_shared": bool(item.get("global_shared", False)),
}
if not isinstance(payload, list):
return list(existing_map.values())
normalized: list[dict[str, Any]] = []
seen: set[str] = set()
limit = _clamp_predefined_entries_limit(len(payload))
for item in payload:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip().strip("/")
if not value:
continue
lowered = value.lower()
if lowered in seen:
continue
seen.add(lowered)
existing = existing_map.get(lowered)
requested_global = bool(item.get("global_shared", False))
global_shared = bool(existing.get("global_shared", False) if existing else False) or requested_global
normalized.append(
{
"value": value,
"global_shared": global_shared,
}
)
if len(normalized) >= limit:
break
return normalized
def _normalize_predefined_tags(
payload: Any,
existing_items: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
"""Normalizes predefined tag entries and enforces irreversible global-sharing flag."""
existing_map: dict[str, dict[str, Any]] = {}
if isinstance(existing_items, list):
for item in existing_items:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip()
if not value:
continue
existing_map[value.lower()] = {
"value": value,
"global_shared": bool(item.get("global_shared", False)),
}
if not isinstance(payload, list):
return list(existing_map.values())
normalized: list[dict[str, Any]] = []
seen: set[str] = set()
limit = _clamp_predefined_entries_limit(len(payload))
for item in payload:
if not isinstance(item, dict):
continue
value = str(item.get("value", "")).strip()
if not value:
continue
lowered = value.lower()
if lowered in seen:
continue
seen.add(lowered)
existing = existing_map.get(lowered)
requested_global = bool(item.get("global_shared", False))
global_shared = bool(existing.get("global_shared", False) if existing else False) or requested_global
normalized.append(
{
"value": value,
"global_shared": global_shared,
}
)
if len(normalized) >= limit:
break
return normalized
def _normalize_handwriting_style_settings(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
"""Normalizes handwriting-style clustering settings exposed in the settings UI."""
if not isinstance(payload, dict):
payload = {}
default_enabled = bool(defaults.get("enabled", True))
default_embed_model = str(defaults.get("embed_model", DEFAULT_HANDWRITING_STYLE_EMBED_MODEL)).strip()
default_neighbor_limit = _safe_int(defaults.get("neighbor_limit", 8), 8)
default_match_min = _clamp_probability(defaults.get("match_min_similarity", 0.86), 0.86)
default_bootstrap_match_min = _clamp_probability(defaults.get("bootstrap_match_min_similarity", 0.89), 0.89)
default_bootstrap_sample_size = _safe_int(defaults.get("bootstrap_sample_size", 3), 3)
default_image_max_side = _safe_int(defaults.get("image_max_side", 1024), 1024)
return {
"enabled": bool(payload.get("enabled", default_enabled)),
"embed_model": str(payload.get("embed_model", default_embed_model)).strip() or default_embed_model,
"neighbor_limit": _clamp_handwriting_style_neighbor_limit(
_safe_int(payload.get("neighbor_limit", default_neighbor_limit), default_neighbor_limit)
),
"match_min_similarity": _clamp_probability(
payload.get("match_min_similarity", default_match_min),
default_match_min,
),
"bootstrap_match_min_similarity": _clamp_probability(
payload.get("bootstrap_match_min_similarity", default_bootstrap_match_min),
default_bootstrap_match_min,
),
"bootstrap_sample_size": _clamp_handwriting_style_sample_size(
_safe_int(payload.get("bootstrap_sample_size", default_bootstrap_sample_size), default_bootstrap_sample_size)
),
"image_max_side": _clamp_handwriting_style_image_max_side(
_safe_int(payload.get("image_max_side", default_image_max_side), default_image_max_side)
),
}
def _sanitize_settings(payload: dict[str, Any]) -> dict[str, Any]:
"""Sanitizes all persisted settings into a stable normalized structure."""
if not isinstance(payload, dict):
payload = {}
defaults = _default_settings()
providers_payload = payload.get("providers")
normalized_providers: list[dict[str, Any]] = []
seen_provider_ids: set[str] = set()
if isinstance(providers_payload, list):
for index, provider_payload in enumerate(providers_payload):
if not isinstance(provider_payload, dict):
continue
fallback = defaults["providers"][0]
candidate = _normalize_provider(provider_payload, fallback_id=f"provider-{index + 1}", fallback_values=fallback)
if candidate["id"] in seen_provider_ids:
continue
seen_provider_ids.add(candidate["id"])
normalized_providers.append(candidate)
if not normalized_providers:
normalized_providers = [dict(defaults["providers"][0])]
provider_ids = [provider["id"] for provider in normalized_providers]
tasks_payload = payload.get("tasks", {})
normalized_tasks = _normalize_tasks(tasks_payload, provider_ids)
upload_defaults = _normalize_upload_defaults(payload.get("upload_defaults", {}), defaults["upload_defaults"])
display_settings = _normalize_display_settings(payload.get("display", {}), defaults["display"])
predefined_paths = _normalize_predefined_paths(
payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
)
predefined_tags = _normalize_predefined_tags(
payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
)
handwriting_style_settings = _normalize_handwriting_style_settings(
payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
defaults[HANDWRITING_STYLE_SETTINGS_KEY],
)
return {
"upload_defaults": upload_defaults,
"display": display_settings,
PREDEFINED_PATHS_SETTINGS_KEY: predefined_paths,
PREDEFINED_TAGS_SETTINGS_KEY: predefined_tags,
HANDWRITING_STYLE_SETTINGS_KEY: handwriting_style_settings,
"providers": normalized_providers,
"tasks": normalized_tasks,
}
def ensure_app_settings() -> None:
"""Creates a settings file with defaults when no persisted settings are present."""
path = _settings_path()
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
return
defaults = _sanitize_settings(_default_settings())
path.write_text(json.dumps(defaults, indent=2), encoding="utf-8")
def _read_raw_settings() -> dict[str, Any]:
"""Reads persisted settings from disk and returns normalized values."""
ensure_app_settings()
path = _settings_path()
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
payload = {}
return _sanitize_settings(payload)
def _write_settings(payload: dict[str, Any]) -> None:
"""Persists sanitized settings payload to host-mounted storage."""
path = _settings_path()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def read_app_settings() -> dict[str, Any]:
"""Reads settings and returns a sanitized view safe for API responses."""
payload = _read_raw_settings()
providers_response: list[dict[str, Any]] = []
for provider in payload["providers"]:
api_key = str(provider.get("api_key", ""))
providers_response.append(
{
"id": provider["id"],
"label": provider["label"],
"provider_type": provider["provider_type"],
"base_url": provider["base_url"],
"timeout_seconds": int(provider["timeout_seconds"]),
"api_key_set": bool(api_key),
"api_key_masked": _mask_api_key(api_key),
}
)
return {
"upload_defaults": payload.get("upload_defaults", {"logical_path": "Inbox", "tags": []}),
"display": payload.get("display", {"cards_per_page": 12, "log_typing_animation_enabled": True}),
PREDEFINED_PATHS_SETTINGS_KEY: payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
PREDEFINED_TAGS_SETTINGS_KEY: payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
HANDWRITING_STYLE_SETTINGS_KEY: payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
"providers": providers_response,
"tasks": payload["tasks"],
}
def reset_app_settings() -> dict[str, Any]:
"""Resets persisted application settings to sanitized repository defaults."""
defaults = _sanitize_settings(_default_settings())
_write_settings(defaults)
return read_app_settings()
def read_task_runtime_settings(task_name: str) -> dict[str, Any]:
"""Returns runtime task settings and resolved provider including secret values."""
payload = _read_raw_settings()
tasks = payload["tasks"]
if task_name not in tasks:
raise KeyError(f"Unknown task settings key: {task_name}")
task = dict(tasks[task_name])
provider_map = {provider["id"]: provider for provider in payload["providers"]}
provider = provider_map.get(task.get("provider_id"))
if provider is None:
provider = payload["providers"][0]
task["provider_id"] = provider["id"]
return {
"task": task,
"provider": dict(provider),
}
def update_app_settings(
providers: list[dict[str, Any]] | None = None,
tasks: dict[str, dict[str, Any]] | None = None,
upload_defaults: dict[str, Any] | None = None,
display: dict[str, Any] | None = None,
handwriting_style: dict[str, Any] | None = None,
predefined_paths: list[dict[str, Any]] | None = None,
predefined_tags: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Updates app settings, persists them, and returns API-safe values."""
current_payload = _read_raw_settings()
next_payload: dict[str, Any] = {
"upload_defaults": dict(current_payload.get("upload_defaults", {"logical_path": "Inbox", "tags": []})),
"display": dict(current_payload.get("display", {"cards_per_page": 12, "log_typing_animation_enabled": True})),
PREDEFINED_PATHS_SETTINGS_KEY: list(current_payload.get(PREDEFINED_PATHS_SETTINGS_KEY, [])),
PREDEFINED_TAGS_SETTINGS_KEY: list(current_payload.get(PREDEFINED_TAGS_SETTINGS_KEY, [])),
HANDWRITING_STYLE_SETTINGS_KEY: dict(
current_payload.get(HANDWRITING_STYLE_SETTINGS_KEY, _default_settings()[HANDWRITING_STYLE_SETTINGS_KEY])
),
"providers": list(current_payload["providers"]),
"tasks": dict(current_payload["tasks"]),
}
if providers is not None:
existing_provider_map = {provider["id"]: provider for provider in current_payload["providers"]}
next_providers: list[dict[str, Any]] = []
for index, provider_payload in enumerate(providers):
if not isinstance(provider_payload, dict):
continue
provider_id = _normalize_provider_id(
str(provider_payload.get("id", "")),
fallback=f"provider-{index + 1}",
)
existing_provider = existing_provider_map.get(provider_id, {})
merged_payload = dict(provider_payload)
merged_payload["id"] = provider_id
if bool(provider_payload.get("clear_api_key", False)):
merged_payload["api_key"] = ""
elif "api_key" in provider_payload and provider_payload.get("api_key") is not None:
merged_payload["api_key"] = str(provider_payload.get("api_key")).strip()
else:
merged_payload["api_key"] = str(existing_provider.get("api_key", ""))
normalized_provider = _normalize_provider(
merged_payload,
fallback_id=provider_id,
fallback_values=existing_provider,
)
next_providers.append(normalized_provider)
if next_providers:
next_payload["providers"] = next_providers
if tasks is not None:
merged_tasks = dict(current_payload["tasks"])
for task_name, task_update in tasks.items():
if task_name not in merged_tasks or not isinstance(task_update, dict):
continue
existing_task = dict(merged_tasks[task_name])
for key, value in task_update.items():
if value is None:
continue
existing_task[key] = value
merged_tasks[task_name] = existing_task
next_payload["tasks"] = merged_tasks
if upload_defaults is not None and isinstance(upload_defaults, dict):
next_upload_defaults = dict(next_payload.get("upload_defaults", {}))
for key in ("logical_path", "tags"):
if key in upload_defaults:
next_upload_defaults[key] = upload_defaults[key]
next_payload["upload_defaults"] = next_upload_defaults
if display is not None and isinstance(display, dict):
next_display = dict(next_payload.get("display", {}))
if "cards_per_page" in display:
next_display["cards_per_page"] = display["cards_per_page"]
if "log_typing_animation_enabled" in display:
next_display["log_typing_animation_enabled"] = bool(display["log_typing_animation_enabled"])
next_payload["display"] = next_display
if handwriting_style is not None and isinstance(handwriting_style, dict):
next_handwriting_style = dict(next_payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}))
for key in (
"enabled",
"embed_model",
"neighbor_limit",
"match_min_similarity",
"bootstrap_match_min_similarity",
"bootstrap_sample_size",
"image_max_side",
):
if key in handwriting_style:
next_handwriting_style[key] = handwriting_style[key]
next_payload[HANDWRITING_STYLE_SETTINGS_KEY] = next_handwriting_style
if predefined_paths is not None:
next_payload[PREDEFINED_PATHS_SETTINGS_KEY] = _normalize_predefined_paths(
predefined_paths,
existing_items=next_payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
)
if predefined_tags is not None:
next_payload[PREDEFINED_TAGS_SETTINGS_KEY] = _normalize_predefined_tags(
predefined_tags,
existing_items=next_payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
)
sanitized = _sanitize_settings(next_payload)
_write_settings(sanitized)
return read_app_settings()
def read_handwriting_provider_settings() -> dict[str, Any]:
"""Returns OCR settings in legacy shape for the handwriting transcription service."""
runtime = read_task_runtime_settings(TASK_OCR_HANDWRITING)
provider = runtime["provider"]
task = runtime["task"]
return {
"provider": provider["provider_type"],
"enabled": bool(task.get("enabled", True)),
"openai_base_url": str(provider.get("base_url", settings.default_openai_base_url)),
"openai_model": str(task.get("model", settings.default_openai_model)),
"openai_timeout_seconds": int(provider.get("timeout_seconds", settings.default_openai_timeout_seconds)),
"openai_api_key": str(provider.get("api_key", "")),
"prompt": str(task.get("prompt", DEFAULT_OCR_PROMPT)),
"provider_id": str(provider.get("id", "openai-default")),
}
def read_handwriting_style_settings() -> dict[str, Any]:
"""Returns handwriting-style clustering settings for Typesense style assignment logic."""
payload = _read_raw_settings()
defaults = _default_settings()[HANDWRITING_STYLE_SETTINGS_KEY]
return _normalize_handwriting_style_settings(
payload.get(HANDWRITING_STYLE_SETTINGS_KEY, {}),
defaults,
)
def read_predefined_paths_settings() -> list[dict[str, Any]]:
"""Returns normalized predefined logical path catalog entries."""
payload = _read_raw_settings()
return _normalize_predefined_paths(
payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_PATHS_SETTINGS_KEY, []),
)
def read_predefined_tags_settings() -> list[dict[str, Any]]:
"""Returns normalized predefined tag catalog entries."""
payload = _read_raw_settings()
return _normalize_predefined_tags(
payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
existing_items=payload.get(PREDEFINED_TAGS_SETTINGS_KEY, []),
)
def update_handwriting_settings(
enabled: bool | None = None,
openai_base_url: str | None = None,
openai_model: str | None = None,
openai_timeout_seconds: int | None = None,
openai_api_key: str | None = None,
clear_openai_api_key: bool = False,
) -> dict[str, Any]:
"""Updates OCR task and bound provider values using the legacy handwriting API contract."""
runtime = read_task_runtime_settings(TASK_OCR_HANDWRITING)
provider = runtime["provider"]
provider_update: dict[str, Any] = {
"id": provider["id"],
"label": provider["label"],
"provider_type": provider["provider_type"],
"base_url": openai_base_url if openai_base_url is not None else provider["base_url"],
"timeout_seconds": openai_timeout_seconds if openai_timeout_seconds is not None else provider["timeout_seconds"],
}
if clear_openai_api_key:
provider_update["clear_api_key"] = True
elif openai_api_key is not None:
provider_update["api_key"] = openai_api_key
tasks_update: dict[str, dict[str, Any]] = {TASK_OCR_HANDWRITING: {}}
if enabled is not None:
tasks_update[TASK_OCR_HANDWRITING]["enabled"] = enabled
if openai_model is not None:
tasks_update[TASK_OCR_HANDWRITING]["model"] = openai_model
return update_app_settings(
providers=[provider_update],
tasks=tasks_update,
)

View File

@@ -0,0 +1,315 @@
"""Document extraction service for text indexing, previews, and archive fan-out."""
import io
import re
import zipfile
from dataclasses import dataclass, field
from pathlib import Path
import magic
from docx import Document as DocxDocument
from openpyxl import load_workbook
from PIL import Image, ImageOps
from pypdf import PdfReader
import pymupdf
from app.core.config import get_settings
from app.services.handwriting import (
IMAGE_TEXT_TYPE_NO_TEXT,
IMAGE_TEXT_TYPE_UNKNOWN,
HandwritingTranscriptionError,
HandwritingTranscriptionNotConfiguredError,
HandwritingTranscriptionTimeoutError,
classify_image_text_bytes,
transcribe_handwriting_bytes,
)
settings = get_settings()
IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
".gif",
".webp",
".heic",
}
SUPPORTED_TEXT_EXTENSIONS = {
".txt",
".md",
".csv",
".json",
".xml",
".svg",
".pdf",
".docx",
".xlsx",
*IMAGE_EXTENSIONS,
}
@dataclass
class ExtractionResult:
"""Represents output generated during extraction for a single file."""
text: str
preview_bytes: bytes | None
preview_suffix: str | None
status: str
metadata_json: dict[str, object] = field(default_factory=dict)
@dataclass
class ArchiveMember:
"""Represents an extracted file entry from an archive."""
name: str
data: bytes
def sniff_mime(data: bytes) -> str:
"""Detects MIME type using libmagic for robust format handling."""
return magic.from_buffer(data, mime=True) or "application/octet-stream"
def is_supported_for_extraction(extension: str, mime_type: str) -> bool:
"""Determines if a file should be text-processed for indexing and classification."""
return extension in SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith("text/")
def _normalize_text(text: str) -> str:
"""Normalizes extracted text by removing repeated form separators and controls."""
cleaned = text.replace("\r", "\n").replace("\x00", "")
lines: list[str] = []
for line in cleaned.split("\n"):
stripped = line.strip()
if stripped and re.fullmatch(r"[.\-_*=~\s]{4,}", stripped):
continue
lines.append(line)
normalized = "\n".join(lines)
normalized = re.sub(r"\n{3,}", "\n\n", normalized)
return normalized.strip()
def _extract_pdf_text(data: bytes) -> str:
"""Extracts text from PDF bytes using pypdf page parsing."""
reader = PdfReader(io.BytesIO(data))
pages: list[str] = []
for page in reader.pages:
pages.append(page.extract_text() or "")
return _normalize_text("\n".join(pages))
def _extract_pdf_preview(data: bytes) -> tuple[bytes | None, str | None]:
"""Creates a JPEG thumbnail preview from the first PDF page."""
try:
document = pymupdf.open(stream=data, filetype="pdf")
except Exception:
return None, None
try:
if document.page_count < 1:
return None, None
page = document.load_page(0)
pixmap = page.get_pixmap(matrix=pymupdf.Matrix(1.5, 1.5), alpha=False)
return pixmap.tobytes("jpeg"), ".jpg"
except Exception:
return None, None
finally:
document.close()
def _extract_docx_text(data: bytes) -> str:
"""Extracts paragraph text from DOCX content."""
document = DocxDocument(io.BytesIO(data))
return _normalize_text("\n".join(paragraph.text for paragraph in document.paragraphs if paragraph.text))
def _extract_xlsx_text(data: bytes) -> str:
"""Extracts cell text from XLSX workbook sheets for indexing."""
workbook = load_workbook(io.BytesIO(data), data_only=True, read_only=True)
chunks: list[str] = []
for sheet in workbook.worksheets:
chunks.append(sheet.title)
row_count = 0
for row in sheet.iter_rows(min_row=1, max_row=200):
row_values = [str(cell.value) for cell in row if cell.value is not None]
if row_values:
chunks.append(" ".join(row_values))
row_count += 1
if row_count >= 200:
break
return _normalize_text("\n".join(chunks))
def _build_image_preview(data: bytes) -> tuple[bytes | None, str | None]:
"""Builds a JPEG preview thumbnail for image files."""
try:
with Image.open(io.BytesIO(data)) as image:
preview = ImageOps.exif_transpose(image).convert("RGB")
preview.thumbnail((600, 600))
output = io.BytesIO()
preview.save(output, format="JPEG", optimize=True, quality=82)
return output.getvalue(), ".jpg"
except Exception:
return None, None
def _extract_handwriting_text(data: bytes, mime_type: str) -> ExtractionResult:
"""Extracts text from image bytes and records handwriting-vs-printed classification metadata."""
preview_bytes, preview_suffix = _build_image_preview(data)
metadata_json: dict[str, object] = {}
try:
text_type = classify_image_text_bytes(data, mime_type=mime_type)
metadata_json = {
"image_text_type": text_type.label,
"image_text_type_confidence": text_type.confidence,
"image_text_type_provider": text_type.provider,
"image_text_type_model": text_type.model,
}
except HandwritingTranscriptionNotConfiguredError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="unsupported",
metadata_json={"transcription_error": str(error), "image_text_type": IMAGE_TEXT_TYPE_UNKNOWN},
)
except HandwritingTranscriptionTimeoutError as error:
metadata_json = {
"image_text_type": IMAGE_TEXT_TYPE_UNKNOWN,
"image_text_type_error": str(error),
}
except HandwritingTranscriptionError as error:
metadata_json = {
"image_text_type": IMAGE_TEXT_TYPE_UNKNOWN,
"image_text_type_error": str(error),
}
if metadata_json.get("image_text_type") == IMAGE_TEXT_TYPE_NO_TEXT:
metadata_json["transcription_skipped"] = "no_text_detected"
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="processed",
metadata_json=metadata_json,
)
try:
transcription = transcribe_handwriting_bytes(data, mime_type=mime_type)
transcription_metadata: dict[str, object] = {
"transcription_provider": transcription.provider,
"transcription_model": transcription.model,
"transcription_uncertainties": transcription.uncertainties,
}
return ExtractionResult(
text=_normalize_text(transcription.text),
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="processed",
metadata_json={**metadata_json, **transcription_metadata},
)
except HandwritingTranscriptionNotConfiguredError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="unsupported",
metadata_json={**metadata_json, "transcription_error": str(error)},
)
except HandwritingTranscriptionTimeoutError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="error",
metadata_json={**metadata_json, "transcription_error": str(error)},
)
except HandwritingTranscriptionError as error:
return ExtractionResult(
text="",
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="error",
metadata_json={**metadata_json, "transcription_error": str(error)},
)
def extract_text_content(filename: str, data: bytes, mime_type: str) -> ExtractionResult:
"""Extracts text and optional preview bytes for supported file types."""
extension = Path(filename).suffix.lower()
text = ""
preview_bytes: bytes | None = None
preview_suffix: str | None = None
try:
if extension == ".pdf":
text = _extract_pdf_text(data)
preview_bytes, preview_suffix = _extract_pdf_preview(data)
elif extension in {".txt", ".md", ".csv", ".json", ".xml", ".svg"} or mime_type.startswith("text/"):
text = _normalize_text(data.decode("utf-8", errors="ignore"))
elif extension == ".docx":
text = _extract_docx_text(data)
elif extension == ".xlsx":
text = _extract_xlsx_text(data)
elif extension in IMAGE_EXTENSIONS:
return _extract_handwriting_text(data=data, mime_type=mime_type)
else:
return ExtractionResult(
text="",
preview_bytes=None,
preview_suffix=None,
status="unsupported",
metadata_json={"reason": "unsupported_format"},
)
except Exception as error:
return ExtractionResult(
text="",
preview_bytes=None,
preview_suffix=None,
status="error",
metadata_json={"reason": "extraction_exception", "error": str(error)},
)
return ExtractionResult(
text=text[: settings.max_text_length],
preview_bytes=preview_bytes,
preview_suffix=preview_suffix,
status="processed",
metadata_json={},
)
def extract_archive_members(data: bytes, depth: int = 0) -> list[ArchiveMember]:
"""Extracts processable members from zip archives with configurable depth limits."""
members: list[ArchiveMember] = []
if depth > settings.max_zip_depth:
return members
with zipfile.ZipFile(io.BytesIO(data)) as archive:
infos = [info for info in archive.infolist() if not info.is_dir()][: settings.max_zip_members]
for info in infos:
member_data = archive.read(info.filename)
members.append(ArchiveMember(name=info.filename, data=member_data))
return members

View File

@@ -0,0 +1,477 @@
"""Handwriting transcription service using OpenAI-compatible vision models."""
import base64
import io
import json
import re
from dataclasses import dataclass
from typing import Any
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from PIL import Image, ImageOps
from app.services.app_settings import DEFAULT_OCR_PROMPT, read_handwriting_provider_settings
MAX_IMAGE_SIDE = 2000
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
IMAGE_TEXT_TYPE_PRINTED = "printed_text"
IMAGE_TEXT_TYPE_NO_TEXT = "no_text"
IMAGE_TEXT_TYPE_UNKNOWN = "unknown"
IMAGE_TEXT_CLASSIFICATION_PROMPT = (
"Classify the text content in this image.\n"
"Choose exactly one label from: handwriting, printed_text, no_text.\n"
"Definitions:\n"
"- handwriting: text exists and most readable text is handwritten.\n"
"- printed_text: text exists and most readable text is machine printed or typed.\n"
"- no_text: no readable text is present.\n"
"Return strict JSON only with shape:\n"
"{\n"
' "label": "handwriting|printed_text|no_text",\n'
' "confidence": number\n'
"}\n"
"Confidence must be between 0 and 1."
)
class HandwritingTranscriptionError(Exception):
"""Raised when handwriting transcription fails for a non-timeout reason."""
class HandwritingTranscriptionTimeoutError(HandwritingTranscriptionError):
"""Raised when handwriting transcription exceeds the configured timeout."""
class HandwritingTranscriptionNotConfiguredError(HandwritingTranscriptionError):
"""Raised when handwriting transcription is disabled or missing credentials."""
@dataclass
class HandwritingTranscription:
"""Represents transcription output and uncertainty markers."""
text: str
uncertainties: list[str]
provider: str
model: str
@dataclass
class ImageTextClassification:
"""Represents model classification of image text modality for one image."""
label: str
confidence: float
provider: str
model: str
def _extract_uncertainties(text: str) -> list[str]:
"""Extracts uncertainty markers from transcription output."""
matches = re.findall(r"\[\[\?(.*?)\?\]\]", text)
return [match.strip() for match in matches if match.strip()]
def _coerce_json_object(payload: str) -> dict[str, Any]:
"""Parses and extracts a JSON object from raw model output text."""
text = payload.strip()
if not text:
return {}
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
if fenced:
try:
parsed = json.loads(fenced.group(1))
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
first_brace = text.find("{")
last_brace = text.rfind("}")
if first_brace >= 0 and last_brace > first_brace:
candidate = text[first_brace : last_brace + 1]
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {}
return {}
def _clamp_probability(value: Any, fallback: float = 0.0) -> float:
"""Clamps confidence-like values to the inclusive [0, 1] range."""
try:
parsed = float(value)
except (TypeError, ValueError):
return fallback
return max(0.0, min(1.0, parsed))
def _normalize_image_text_type(label: str) -> str:
"""Normalizes classifier labels into one supported canonical image text type."""
normalized = label.strip().lower().replace("-", "_").replace(" ", "_")
if normalized in {IMAGE_TEXT_TYPE_HANDWRITING, "handwritten", "handwritten_text"}:
return IMAGE_TEXT_TYPE_HANDWRITING
if normalized in {IMAGE_TEXT_TYPE_PRINTED, "printed", "typed", "machine_text"}:
return IMAGE_TEXT_TYPE_PRINTED
if normalized in {IMAGE_TEXT_TYPE_NO_TEXT, "no-text", "none", "no readable text"}:
return IMAGE_TEXT_TYPE_NO_TEXT
return IMAGE_TEXT_TYPE_UNKNOWN
def _normalize_image_bytes(image_data: bytes) -> tuple[bytes, str]:
"""Applies EXIF rotation and scales large images down for efficient transcription."""
with Image.open(io.BytesIO(image_data)) as image:
rotated = ImageOps.exif_transpose(image)
prepared = rotated.convert("RGB")
long_side = max(prepared.width, prepared.height)
if long_side > MAX_IMAGE_SIDE:
scale = MAX_IMAGE_SIDE / long_side
resized_width = max(1, int(prepared.width * scale))
resized_height = max(1, int(prepared.height * scale))
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
output = io.BytesIO()
prepared.save(output, format="JPEG", quality=90, optimize=True)
return output.getvalue(), "image/jpeg"
def _create_client(provider_settings: dict[str, Any]) -> OpenAI:
"""Creates an OpenAI client configured for compatible endpoints and timeouts."""
api_key = str(provider_settings.get("openai_api_key", "")).strip() or "no-key-required"
return OpenAI(
api_key=api_key,
base_url=str(provider_settings["openai_base_url"]),
timeout=int(provider_settings["openai_timeout_seconds"]),
)
def _extract_text_from_response(response: Any) -> str:
"""Extracts plain text from responses API output objects."""
output_text = getattr(response, "output_text", None)
if isinstance(output_text, str) and output_text.strip():
return output_text.strip()
output_items = getattr(response, "output", None)
if not isinstance(output_items, list):
return ""
texts: list[str] = []
for item in output_items:
item_data = item.model_dump() if hasattr(item, "model_dump") else item
if not isinstance(item_data, dict):
continue
item_type = item_data.get("type")
if item_type == "output_text":
text = str(item_data.get("text", "")).strip()
if text:
texts.append(text)
if item_type == "message":
for content in item_data.get("content", []) or []:
if not isinstance(content, dict):
continue
if content.get("type") in {"output_text", "text"}:
text = str(content.get("text", "")).strip()
if text:
texts.append(text)
return "\n".join(texts).strip()
def _transcribe_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Transcribes handwriting using the responses API."""
response = client.responses.create(
model=model,
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": prompt,
},
{
"type": "input_image",
"image_url": image_data_url,
"detail": "high",
},
],
}
],
)
return _extract_text_from_response(response)
def _transcribe_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Transcribes handwriting using chat completions for endpoint compatibility."""
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": image_data_url,
"detail": "high",
},
},
],
}
],
)
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if isinstance(message_content, list):
text_parts: list[str] = []
for part in message_content:
if isinstance(part, dict):
text = str(part.get("text", "")).strip()
if text:
text_parts.append(text)
return "\n".join(text_parts).strip()
return ""
def _classify_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Classifies image text modality using the responses API."""
response = client.responses.create(
model=model,
input=[
{
"role": "user",
"content": [
{
"type": "input_text",
"text": prompt,
},
{
"type": "input_image",
"image_url": image_data_url,
"detail": "high",
},
],
}
],
)
return _extract_text_from_response(response)
def _classify_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
"""Classifies image text modality using chat completions for compatibility."""
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{
"type": "image_url",
"image_url": {
"url": image_data_url,
"detail": "high",
},
},
],
}
],
)
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if isinstance(message_content, list):
text_parts: list[str] = []
for part in message_content:
if isinstance(part, dict):
text = str(part.get("text", "")).strip()
if text:
text_parts.append(text)
return "\n".join(text_parts).strip()
return ""
def _classify_image_text_data_url(image_data_url: str) -> ImageTextClassification:
"""Classifies an image as handwriting, printed text, or no text."""
provider_settings = read_handwriting_provider_settings()
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
if not bool(provider_settings.get("enabled", True)):
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
client = _create_client(provider_settings)
try:
output_text = _classify_with_responses(
client=client,
model=model,
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
image_data_url=image_data_url,
)
if not output_text:
output_text = _classify_with_chat(
client=client,
model=model,
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
image_data_url=image_data_url,
)
except APITimeoutError as error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
except (APIConnectionError, APIError):
try:
output_text = _classify_with_chat(
client=client,
model=model,
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
image_data_url=image_data_url,
)
except APITimeoutError as timeout_error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
except Exception as fallback_error:
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
except Exception as error:
raise HandwritingTranscriptionError(str(error)) from error
parsed = _coerce_json_object(output_text)
if not parsed:
raise HandwritingTranscriptionError("image_text_classification_parse_failed")
label = _normalize_image_text_type(str(parsed.get("label", "")))
confidence = _clamp_probability(parsed.get("confidence", 0.0), fallback=0.0)
return ImageTextClassification(
label=label,
confidence=confidence,
provider="openai",
model=model,
)
def _transcribe_image_data_url(image_data_url: str) -> HandwritingTranscription:
"""Transcribes a handwriting image data URL with configured OpenAI provider settings."""
provider_settings = read_handwriting_provider_settings()
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
if not bool(provider_settings.get("enabled", True)):
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
prompt = str(provider_settings.get("prompt", DEFAULT_OCR_PROMPT)).strip() or DEFAULT_OCR_PROMPT
client = _create_client(provider_settings)
try:
text = _transcribe_with_responses(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
if not text:
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
except APITimeoutError as error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
except (APIConnectionError, APIError) as error:
try:
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
except APITimeoutError as timeout_error:
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
except Exception as fallback_error:
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
except Exception as error:
raise HandwritingTranscriptionError(str(error)) from error
final_text = text.strip()
return HandwritingTranscription(
text=final_text,
uncertainties=_extract_uncertainties(final_text),
provider="openai",
model=model,
)
def transcribe_handwriting_base64(image_base64: str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
"""Transcribes handwriting from a base64 payload without data URL prefix."""
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
return _transcribe_image_data_url(image_data_url)
def transcribe_handwriting_url(image_url: str) -> HandwritingTranscription:
"""Transcribes handwriting from a direct image URL."""
return _transcribe_image_data_url(image_url)
def transcribe_handwriting_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> HandwritingTranscription:
"""Transcribes handwriting from raw image bytes after normalization."""
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
encoded = base64.b64encode(normalized_bytes).decode("ascii")
return transcribe_handwriting_base64(encoded, mime_type=normalized_mime)
def classify_image_text_base64(image_base64: str, mime_type: str = "image/jpeg") -> ImageTextClassification:
"""Classifies image text type from a base64 payload without data URL prefix."""
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
return _classify_image_text_data_url(image_data_url)
def classify_image_text_url(image_url: str) -> ImageTextClassification:
"""Classifies image text type from a direct image URL."""
return _classify_image_text_data_url(image_url)
def classify_image_text_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> ImageTextClassification:
"""Classifies image text type from raw image bytes after normalization."""
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
encoded = base64.b64encode(normalized_bytes).decode("ascii")
return classify_image_text_base64(encoded, mime_type=normalized_mime)
def transcribe_handwriting(image: bytes | str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
"""Transcribes handwriting from bytes, base64 text, or URL input."""
if isinstance(image, bytes):
return transcribe_handwriting_bytes(image, mime_type=mime_type)
stripped = image.strip()
if stripped.startswith("http://") or stripped.startswith("https://"):
return transcribe_handwriting_url(stripped)
return transcribe_handwriting_base64(stripped, mime_type=mime_type)

View File

@@ -0,0 +1,435 @@
"""Handwriting-style clustering and style-scoped path composition for image documents."""
import base64
import io
import re
from dataclasses import dataclass
from typing import Any
from PIL import Image, ImageOps
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.models.document import Document, DocumentStatus
from app.services.app_settings import (
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
read_handwriting_style_settings,
)
from app.services.typesense_index import get_typesense_client
settings = get_settings()
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
HANDWRITING_STYLE_COLLECTION_SUFFIX = "_handwriting_styles"
HANDWRITING_STYLE_EMBED_MODEL = DEFAULT_HANDWRITING_STYLE_EMBED_MODEL
HANDWRITING_STYLE_MATCH_MIN_SIMILARITY = 0.86
HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY = 0.89
HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE = 3
HANDWRITING_STYLE_NEIGHBOR_LIMIT = 8
HANDWRITING_STYLE_IMAGE_MAX_SIDE = 1024
HANDWRITING_STYLE_ID_PREFIX = "hw_style_"
HANDWRITING_STYLE_ID_PATTERN = re.compile(r"^hw_style_(\d+)$")
@dataclass
class HandwritingStyleNeighbor:
"""Represents one nearest handwriting-style neighbor returned from Typesense."""
document_id: str
style_cluster_id: str
vector_distance: float
similarity: float
@dataclass
class HandwritingStyleAssignment:
"""Represents the chosen handwriting-style cluster assignment for one document."""
style_cluster_id: str
matched_existing: bool
similarity: float
vector_distance: float
compared_neighbors: int
match_min_similarity: float
bootstrap_match_min_similarity: float
def _style_collection_name() -> str:
"""Builds the dedicated Typesense collection name used for handwriting-style vectors."""
return f"{settings.typesense_collection_name}{HANDWRITING_STYLE_COLLECTION_SUFFIX}"
def _style_collection() -> Any:
"""Returns the Typesense collection handle for handwriting-style indexing."""
client = get_typesense_client()
return client.collections[_style_collection_name()]
def _distance_to_similarity(vector_distance: float) -> float:
"""Converts Typesense vector distance into conservative similarity in [0, 1]."""
return max(0.0, min(1.0, 1.0 - (vector_distance / 2.0)))
def _encode_style_image_base64(image_data: bytes, image_max_side: int) -> str:
"""Normalizes and downsizes image bytes and returns a base64-encoded JPEG payload."""
with Image.open(io.BytesIO(image_data)) as image:
prepared = ImageOps.exif_transpose(image).convert("RGB")
longest_side = max(prepared.width, prepared.height)
if longest_side > image_max_side:
scale = image_max_side / longest_side
resized_width = max(1, int(prepared.width * scale))
resized_height = max(1, int(prepared.height * scale))
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
output = io.BytesIO()
prepared.save(output, format="JPEG", quality=86, optimize=True)
return base64.b64encode(output.getvalue()).decode("ascii")
def ensure_handwriting_style_collection() -> None:
"""Creates the handwriting-style Typesense collection when it is not present."""
runtime_settings = read_handwriting_style_settings()
embed_model = str(runtime_settings.get("embed_model", HANDWRITING_STYLE_EMBED_MODEL)).strip() or HANDWRITING_STYLE_EMBED_MODEL
collection = _style_collection()
should_recreate_collection = False
try:
existing_schema = collection.retrieve()
if isinstance(existing_schema, dict):
existing_fields = existing_schema.get("fields", [])
if isinstance(existing_fields, list):
for field in existing_fields:
if not isinstance(field, dict):
continue
if str(field.get("name", "")).strip() != "embedding":
continue
embed_config = field.get("embed", {})
model_config = embed_config.get("model_config", {}) if isinstance(embed_config, dict) else {}
existing_model = str(model_config.get("model_name", "")).strip()
if existing_model and existing_model != embed_model:
should_recreate_collection = True
break
if not should_recreate_collection:
return
except Exception as error:
message = str(error).lower()
if "404" not in message and "not found" not in message:
raise
client = get_typesense_client()
if should_recreate_collection:
client.collections[_style_collection_name()].delete()
schema = {
"name": _style_collection_name(),
"fields": [
{
"name": "style_cluster_id",
"type": "string",
"facet": True,
},
{
"name": "image_text_type",
"type": "string",
"facet": True,
},
{
"name": "created_at",
"type": "int64",
},
{
"name": "image",
"type": "image",
"store": False,
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": ["image"],
"model_config": {
"model_name": embed_model,
},
},
},
],
"default_sorting_field": "created_at",
}
client.collections.create(schema)
def _search_style_neighbors(
image_base64: str,
limit: int,
exclude_document_id: str | None = None,
) -> list[HandwritingStyleNeighbor]:
"""Returns nearest handwriting-style neighbors for one encoded image payload."""
ensure_handwriting_style_collection()
client = get_typesense_client()
filter_clauses = [f"image_text_type:={IMAGE_TEXT_TYPE_HANDWRITING}"]
if exclude_document_id:
filter_clauses.append(f"id:!={exclude_document_id}")
search_payload = {
"q": "*",
"query_by": "embedding",
"vector_query": f"embedding:([], image:{image_base64}, k:{max(1, limit)})",
"exclude_fields": "embedding,image",
"per_page": max(1, limit),
"filter_by": " && ".join(filter_clauses),
}
response = client.multi_search.perform(
{
"searches": [
{
"collection": _style_collection_name(),
**search_payload,
}
]
},
{},
)
results = response.get("results", []) if isinstance(response, dict) else []
first_result = results[0] if isinstance(results, list) and len(results) > 0 else {}
hits = first_result.get("hits", []) if isinstance(first_result, dict) else []
neighbors: list[HandwritingStyleNeighbor] = []
for hit in hits:
if not isinstance(hit, dict):
continue
document = hit.get("document")
if not isinstance(document, dict):
continue
document_id = str(document.get("id", "")).strip()
style_cluster_id = str(document.get("style_cluster_id", "")).strip()
if not document_id or not style_cluster_id:
continue
try:
vector_distance = float(hit.get("vector_distance", 2.0))
except (TypeError, ValueError):
vector_distance = 2.0
neighbors.append(
HandwritingStyleNeighbor(
document_id=document_id,
style_cluster_id=style_cluster_id,
vector_distance=vector_distance,
similarity=_distance_to_similarity(vector_distance),
)
)
if len(neighbors) >= limit:
break
return neighbors
def _next_style_cluster_id(session: Session) -> str:
"""Allocates the next stable handwriting-style folder identifier."""
existing_ids = session.execute(
select(Document.handwriting_style_id).where(Document.handwriting_style_id.is_not(None))
).scalars().all()
max_value = 0
for existing_id in existing_ids:
candidate = str(existing_id).strip()
match = HANDWRITING_STYLE_ID_PATTERN.fullmatch(candidate)
if not match:
continue
numeric_part = int(match.group(1))
max_value = max(max_value, numeric_part)
return f"{HANDWRITING_STYLE_ID_PREFIX}{max_value + 1}"
def _style_cluster_sample_size(session: Session, style_cluster_id: str) -> int:
"""Returns the number of indexed documents currently assigned to one style cluster."""
return int(
session.execute(
select(func.count())
.select_from(Document)
.where(Document.handwriting_style_id == style_cluster_id)
.where(Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING)
).scalar_one()
)
def assign_handwriting_style(
session: Session,
document: Document,
image_data: bytes,
) -> HandwritingStyleAssignment:
"""Assigns a document to an existing handwriting-style cluster or creates a new one."""
runtime_settings = read_handwriting_style_settings()
image_max_side = int(runtime_settings.get("image_max_side", HANDWRITING_STYLE_IMAGE_MAX_SIDE))
neighbor_limit = int(runtime_settings.get("neighbor_limit", HANDWRITING_STYLE_NEIGHBOR_LIMIT))
match_min_similarity = float(runtime_settings.get("match_min_similarity", HANDWRITING_STYLE_MATCH_MIN_SIMILARITY))
bootstrap_match_min_similarity = float(
runtime_settings.get("bootstrap_match_min_similarity", HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY)
)
bootstrap_sample_size = int(runtime_settings.get("bootstrap_sample_size", HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE))
image_base64 = _encode_style_image_base64(image_data, image_max_side=image_max_side)
neighbors = _search_style_neighbors(
image_base64=image_base64,
limit=neighbor_limit,
exclude_document_id=str(document.id),
)
best_neighbor = neighbors[0] if neighbors else None
similarity = best_neighbor.similarity if best_neighbor else 0.0
vector_distance = best_neighbor.vector_distance if best_neighbor else 2.0
cluster_sample_size = 0
if best_neighbor:
cluster_sample_size = _style_cluster_sample_size(
session=session,
style_cluster_id=best_neighbor.style_cluster_id,
)
required_similarity = (
bootstrap_match_min_similarity
if cluster_sample_size < bootstrap_sample_size
else match_min_similarity
)
should_match_existing = (
best_neighbor is not None and similarity >= required_similarity
)
if should_match_existing and best_neighbor:
style_cluster_id = best_neighbor.style_cluster_id
matched_existing = True
else:
existing_style_cluster_id = (document.handwriting_style_id or "").strip()
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(existing_style_cluster_id):
style_cluster_id = existing_style_cluster_id
else:
style_cluster_id = _next_style_cluster_id(session=session)
matched_existing = False
ensure_handwriting_style_collection()
collection = _style_collection()
payload = {
"id": str(document.id),
"style_cluster_id": style_cluster_id,
"image_text_type": IMAGE_TEXT_TYPE_HANDWRITING,
"created_at": int(document.created_at.timestamp()),
"image": image_base64,
}
collection.documents.upsert(payload)
return HandwritingStyleAssignment(
style_cluster_id=style_cluster_id,
matched_existing=matched_existing,
similarity=similarity,
vector_distance=vector_distance,
compared_neighbors=len(neighbors),
match_min_similarity=match_min_similarity,
bootstrap_match_min_similarity=bootstrap_match_min_similarity,
)
def delete_handwriting_style_document(document_id: str) -> None:
"""Deletes one document id from the handwriting-style Typesense collection."""
collection = _style_collection()
try:
collection.documents[document_id].delete()
except Exception as error:
message = str(error).lower()
if "404" in message or "not found" in message:
return
raise
def delete_many_handwriting_style_documents(document_ids: list[str]) -> None:
"""Deletes many document ids from the handwriting-style Typesense collection."""
for document_id in document_ids:
delete_handwriting_style_document(document_id)
def apply_handwriting_style_path(style_cluster_id: str | None, path_value: str | None) -> str | None:
"""Composes style-prefixed logical paths while preventing duplicate prefix nesting."""
if path_value is None:
return None
normalized_path = path_value.strip().strip("/")
if not normalized_path:
return None
normalized_style = (style_cluster_id or "").strip().strip("/")
if not normalized_style:
return normalized_path
segments = [segment for segment in normalized_path.split("/") if segment]
while segments and HANDWRITING_STYLE_ID_PATTERN.fullmatch(segments[0]):
segments.pop(0)
if segments and segments[0].strip().lower() == normalized_style.lower():
segments.pop(0)
if len(segments) == 0:
return normalized_style
sanitized_path = "/".join(segments)
return f"{normalized_style}/{sanitized_path}"
def resolve_handwriting_style_path_prefix(
session: Session,
style_cluster_id: str | None,
*,
exclude_document_id: str | None = None,
) -> str | None:
"""Resolves a stable path prefix for one style cluster, preferring known non-style root segments."""
normalized_style = (style_cluster_id or "").strip()
if not normalized_style:
return None
statement = select(Document.logical_path).where(
Document.handwriting_style_id == normalized_style,
Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING,
Document.status != DocumentStatus.TRASHED,
)
if exclude_document_id:
statement = statement.where(Document.id != exclude_document_id)
rows = session.execute(statement).scalars().all()
segment_counts: dict[str, int] = {}
segment_labels: dict[str, str] = {}
for raw_path in rows:
if not isinstance(raw_path, str):
continue
segments = [segment.strip() for segment in raw_path.split("/") if segment.strip()]
if not segments:
continue
first_segment = segments[0]
lowered = first_segment.lower()
if lowered == "inbox":
continue
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(first_segment):
continue
segment_counts[lowered] = segment_counts.get(lowered, 0) + 1
if lowered not in segment_labels:
segment_labels[lowered] = first_segment
if not segment_counts:
return normalized_style
winner = sorted(
segment_counts.items(),
key=lambda item: (-item[1], item[0]),
)[0][0]
return segment_labels.get(winner, normalized_style)

View File

@@ -0,0 +1,227 @@
"""Model runtime utilities for provider-bound LLM task execution."""
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse, urlunparse
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from app.services.app_settings import read_task_runtime_settings
class ModelTaskError(Exception):
"""Raised when a model task request fails."""
class ModelTaskTimeoutError(ModelTaskError):
"""Raised when a model task request times out."""
class ModelTaskDisabledError(ModelTaskError):
"""Raised when a model task is disabled in settings."""
@dataclass
class ModelTaskRuntime:
"""Resolved runtime configuration for one task and provider."""
task_name: str
provider_id: str
provider_type: str
base_url: str
timeout_seconds: int
api_key: str
model: str
prompt: str
def _normalize_base_url(raw_value: str) -> str:
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
trimmed = raw_value.strip().rstrip("/")
if not trimmed:
return "https://api.openai.com/v1"
parsed = urlparse(trimmed)
path = parsed.path or ""
if not path.endswith("/v1"):
path = f"{path}/v1" if path else "/v1"
return urlunparse(parsed._replace(path=path))
def _should_fallback_to_chat(error: Exception) -> bool:
"""Determines whether a responses API failure should fallback to chat completions."""
status_code = getattr(error, "status_code", None)
if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}:
return True
message = str(error).lower()
fallback_markers = (
"404",
"not found",
"unknown endpoint",
"unsupported",
"invalid url",
"responses",
)
return any(marker in message for marker in fallback_markers)
def _extract_text_from_response(response: Any) -> str:
"""Extracts plain text from Responses API outputs."""
output_text = getattr(response, "output_text", None)
if isinstance(output_text, str) and output_text.strip():
return output_text.strip()
output_items = getattr(response, "output", None)
if not isinstance(output_items, list):
return ""
chunks: list[str] = []
for item in output_items:
item_data = item.model_dump() if hasattr(item, "model_dump") else item
if not isinstance(item_data, dict):
continue
item_type = item_data.get("type")
if item_type == "output_text":
text = str(item_data.get("text", "")).strip()
if text:
chunks.append(text)
if item_type == "message":
for content in item_data.get("content", []) or []:
if not isinstance(content, dict):
continue
if content.get("type") in {"output_text", "text"}:
text = str(content.get("text", "")).strip()
if text:
chunks.append(text)
return "\n".join(chunks).strip()
def _extract_text_from_chat_response(response: Any) -> str:
"""Extracts text from Chat Completions API outputs."""
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if not isinstance(message_content, list):
return ""
chunks: list[str] = []
for content in message_content:
if not isinstance(content, dict):
continue
text = str(content.get("text", "")).strip()
if text:
chunks.append(text)
return "\n".join(chunks).strip()
def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
"""Resolves one task runtime including provider endpoint, model, and prompt."""
runtime_payload = read_task_runtime_settings(task_name)
task_payload = runtime_payload["task"]
provider_payload = runtime_payload["provider"]
if not bool(task_payload.get("enabled", True)):
raise ModelTaskDisabledError(f"task_disabled:{task_name}")
provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
return ModelTaskRuntime(
task_name=task_name,
provider_id=str(provider_payload.get("id", "")),
provider_type=provider_type,
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
model=str(task_payload.get("model", "")).strip(),
prompt=str(task_payload.get("prompt", "")).strip(),
)
def _create_client(runtime: ModelTaskRuntime) -> OpenAI:
"""Builds an OpenAI SDK client for OpenAI-compatible provider endpoints."""
return OpenAI(
api_key=runtime.api_key,
base_url=runtime.base_url,
timeout=runtime.timeout_seconds,
)
def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str:
"""Runs a text-only task against the configured provider and returns plain output text."""
runtime = resolve_task_runtime(task_name)
client = _create_client(runtime)
prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt
try:
response = client.responses.create(
model=runtime.model,
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": prompt,
}
],
},
{
"role": "user",
"content": [
{
"type": "input_text",
"text": user_text,
}
],
},
],
)
text = _extract_text_from_response(response)
if text:
return text
except APITimeoutError as error:
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
except APIConnectionError as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except APIError as error:
if not _should_fallback_to_chat(error):
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except Exception as error:
if not _should_fallback_to_chat(error):
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
try:
fallback = client.chat.completions.create(
model=runtime.model,
messages=[
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": user_text,
},
],
)
return _extract_text_from_chat_response(fallback)
except APITimeoutError as error:
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
except (APIConnectionError, APIError) as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except Exception as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error

View File

@@ -0,0 +1,192 @@
"""Persistence helpers for writing and querying processing pipeline log events."""
from typing import Any
from uuid import UUID
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from app.models.document import Document
from app.models.processing_log import ProcessingLogEntry
MAX_STAGE_LENGTH = 64
MAX_EVENT_LENGTH = 256
MAX_LEVEL_LENGTH = 16
MAX_PROVIDER_LENGTH = 128
MAX_MODEL_LENGTH = 256
MAX_DOCUMENT_FILENAME_LENGTH = 512
MAX_PROMPT_LENGTH = 200000
MAX_RESPONSE_LENGTH = 200000
DEFAULT_KEEP_DOCUMENT_SESSIONS = 2
DEFAULT_KEEP_UNBOUND_ENTRIES = 80
PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY = "processing_log_autocommit"
def _trim(value: str | None, max_length: int) -> str | None:
"""Normalizes and truncates text values for safe log persistence."""
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) <= max_length:
return normalized
return normalized[: max_length - 3] + "..."
def _safe_payload(payload_json: dict[str, Any] | None) -> dict[str, Any]:
"""Ensures payload values are persisted as dictionaries."""
return payload_json if isinstance(payload_json, dict) else {}
def set_processing_log_autocommit(session: Session, enabled: bool) -> None:
"""Toggles per-session immediate commit behavior for processing log events."""
session.info[PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY] = bool(enabled)
def is_processing_log_autocommit_enabled(session: Session) -> bool:
"""Returns whether processing logs are committed immediately for the current session."""
return bool(session.info.get(PROCESSING_LOG_AUTOCOMMIT_SESSION_KEY, False))
def log_processing_event(
session: Session,
stage: str,
event: str,
*,
level: str = "info",
document: Document | None = None,
document_id: UUID | None = None,
document_filename: str | None = None,
provider_id: str | None = None,
model_name: str | None = None,
prompt_text: str | None = None,
response_text: str | None = None,
payload_json: dict[str, Any] | None = None,
) -> None:
"""Persists one processing log entry linked to an optional document context."""
resolved_document_id = document.id if document is not None else document_id
resolved_document_filename = document.original_filename if document is not None else document_filename
entry = ProcessingLogEntry(
level=_trim(level, MAX_LEVEL_LENGTH) or "info",
stage=_trim(stage, MAX_STAGE_LENGTH) or "pipeline",
event=_trim(event, MAX_EVENT_LENGTH) or "event",
document_id=resolved_document_id,
document_filename=_trim(resolved_document_filename, MAX_DOCUMENT_FILENAME_LENGTH),
provider_id=_trim(provider_id, MAX_PROVIDER_LENGTH),
model_name=_trim(model_name, MAX_MODEL_LENGTH),
prompt_text=_trim(prompt_text, MAX_PROMPT_LENGTH),
response_text=_trim(response_text, MAX_RESPONSE_LENGTH),
payload_json=_safe_payload(payload_json),
)
session.add(entry)
if is_processing_log_autocommit_enabled(session):
session.commit()
def count_processing_logs(session: Session, document_id: UUID | None = None) -> int:
"""Counts persisted processing logs, optionally restricted to one document."""
statement = select(func.count()).select_from(ProcessingLogEntry)
if document_id is not None:
statement = statement.where(ProcessingLogEntry.document_id == document_id)
return int(session.execute(statement).scalar_one())
def list_processing_logs(
session: Session,
*,
limit: int,
offset: int,
document_id: UUID | None = None,
) -> list[ProcessingLogEntry]:
"""Lists processing logs ordered by newest-first with optional document filter."""
statement = select(ProcessingLogEntry)
if document_id is not None:
statement = statement.where(ProcessingLogEntry.document_id == document_id)
statement = statement.order_by(ProcessingLogEntry.created_at.desc(), ProcessingLogEntry.id.desc()).offset(offset).limit(limit)
return session.execute(statement).scalars().all()
def cleanup_processing_logs(
session: Session,
*,
keep_document_sessions: int = DEFAULT_KEEP_DOCUMENT_SESSIONS,
keep_unbound_entries: int = DEFAULT_KEEP_UNBOUND_ENTRIES,
) -> dict[str, int]:
"""Deletes old log entries while keeping recent document sessions and unbound events."""
normalized_keep_sessions = max(0, keep_document_sessions)
normalized_keep_unbound = max(0, keep_unbound_entries)
deleted_document_entries = 0
deleted_unbound_entries = 0
recent_document_rows = session.execute(
select(
ProcessingLogEntry.document_id,
func.max(ProcessingLogEntry.created_at).label("last_seen"),
)
.where(ProcessingLogEntry.document_id.is_not(None))
.group_by(ProcessingLogEntry.document_id)
.order_by(func.max(ProcessingLogEntry.created_at).desc())
.limit(normalized_keep_sessions)
).all()
keep_document_ids = [row[0] for row in recent_document_rows if row[0] is not None]
if keep_document_ids:
deleted_document_entries = int(
session.execute(
delete(ProcessingLogEntry).where(
ProcessingLogEntry.document_id.is_not(None),
ProcessingLogEntry.document_id.notin_(keep_document_ids),
)
).rowcount
or 0
)
else:
deleted_document_entries = int(
session.execute(delete(ProcessingLogEntry).where(ProcessingLogEntry.document_id.is_not(None))).rowcount or 0
)
keep_unbound_rows = session.execute(
select(ProcessingLogEntry.id)
.where(ProcessingLogEntry.document_id.is_(None))
.order_by(ProcessingLogEntry.created_at.desc(), ProcessingLogEntry.id.desc())
.limit(normalized_keep_unbound)
).all()
keep_unbound_ids = [row[0] for row in keep_unbound_rows]
if keep_unbound_ids:
deleted_unbound_entries = int(
session.execute(
delete(ProcessingLogEntry).where(
ProcessingLogEntry.document_id.is_(None),
ProcessingLogEntry.id.notin_(keep_unbound_ids),
)
).rowcount
or 0
)
else:
deleted_unbound_entries = int(
session.execute(delete(ProcessingLogEntry).where(ProcessingLogEntry.document_id.is_(None))).rowcount or 0
)
return {
"deleted_document_entries": deleted_document_entries,
"deleted_unbound_entries": deleted_unbound_entries,
}
def clear_processing_logs(session: Session) -> dict[str, int]:
"""Deletes all persisted processing log entries and returns deletion count."""
deleted_entries = int(session.execute(delete(ProcessingLogEntry)).rowcount or 0)
return {"deleted_entries": deleted_entries}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,59 @@
"""File storage utilities for persistence, retrieval, and checksum calculation."""
import hashlib
import uuid
from datetime import UTC, datetime
from pathlib import Path
from app.core.config import get_settings
settings = get_settings()
def ensure_storage() -> None:
"""Ensures required storage directories exist at service startup."""
for relative in ["originals", "derived/previews", "tmp"]:
(settings.storage_root / relative).mkdir(parents=True, exist_ok=True)
def compute_sha256(data: bytes) -> str:
"""Computes a SHA-256 hex digest for raw file bytes."""
return hashlib.sha256(data).hexdigest()
def store_bytes(filename: str, data: bytes) -> str:
"""Stores file content under a unique path and returns its storage-relative location."""
stamp = datetime.now(UTC).strftime("%Y/%m/%d")
safe_ext = Path(filename).suffix.lower()
target_dir = settings.storage_root / "originals" / stamp
target_dir.mkdir(parents=True, exist_ok=True)
target_name = f"{uuid.uuid4()}{safe_ext}"
target_path = target_dir / target_name
target_path.write_bytes(data)
return str(target_path.relative_to(settings.storage_root))
def read_bytes(relative_path: str) -> bytes:
"""Reads and returns bytes from a storage-relative path."""
return (settings.storage_root / relative_path).read_bytes()
def absolute_path(relative_path: str) -> Path:
"""Returns the absolute filesystem path for a storage-relative location."""
return settings.storage_root / relative_path
def write_preview(document_id: str, data: bytes, suffix: str = ".jpg") -> str:
"""Writes preview bytes and returns the preview path relative to storage root."""
target_dir = settings.storage_root / "derived" / "previews"
target_dir.mkdir(parents=True, exist_ok=True)
target_path = target_dir / f"{document_id}{suffix}"
target_path.write_bytes(data)
return str(target_path.relative_to(settings.storage_root))

View File

@@ -0,0 +1,257 @@
"""Typesense indexing and semantic-neighbor retrieval for document routing."""
from dataclasses import dataclass
from typing import Any
import typesense
from app.core.config import get_settings
from app.models.document import Document, DocumentStatus
settings = get_settings()
MAX_TYPESENSE_QUERY_CHARS = 600
@dataclass
class SimilarDocument:
"""Represents one nearest-neighbor document returned by Typesense semantic search."""
document_id: str
document_name: str
summary_text: str
logical_path: str
tags: list[str]
vector_distance: float
def _build_client() -> typesense.Client:
"""Builds a Typesense API client using configured host and credentials."""
return typesense.Client(
{
"nodes": [
{
"host": settings.typesense_host,
"port": str(settings.typesense_port),
"protocol": settings.typesense_protocol,
}
],
"api_key": settings.typesense_api_key,
"connection_timeout_seconds": settings.typesense_timeout_seconds,
"num_retries": settings.typesense_num_retries,
}
)
_client: typesense.Client | None = None
def get_typesense_client() -> typesense.Client:
"""Returns a cached Typesense client for repeated indexing and search operations."""
global _client
if _client is None:
_client = _build_client()
return _client
def _collection() -> Any:
"""Returns the configured Typesense collection handle."""
client = get_typesense_client()
return client.collections[settings.typesense_collection_name]
def ensure_typesense_collection() -> None:
"""Creates the document semantic collection when it does not already exist."""
collection = _collection()
try:
collection.retrieve()
return
except Exception as error:
message = str(error).lower()
if "404" not in message and "not found" not in message:
raise
schema = {
"name": settings.typesense_collection_name,
"fields": [
{
"name": "document_name",
"type": "string",
},
{
"name": "summary_text",
"type": "string",
},
{
"name": "logical_path",
"type": "string",
"facet": True,
},
{
"name": "tags",
"type": "string[]",
"facet": True,
},
{
"name": "status",
"type": "string",
"facet": True,
},
{
"name": "mime_type",
"type": "string",
"optional": True,
"facet": True,
},
{
"name": "extension",
"type": "string",
"optional": True,
"facet": True,
},
{
"name": "created_at",
"type": "int64",
},
{
"name": "has_labels",
"type": "bool",
"facet": True,
},
{
"name": "embedding",
"type": "float[]",
"embed": {
"from": [
"document_name",
"summary_text",
],
"model_config": {
"model_name": "ts/e5-small-v2",
"indexing_prefix": "passage:",
"query_prefix": "query:",
},
},
},
],
"default_sorting_field": "created_at",
}
client = get_typesense_client()
client.collections.create(schema)
def _has_labels(document: Document) -> bool:
"""Determines whether a document has usable human-assigned routing metadata."""
if document.logical_path.strip() and document.logical_path.strip().lower() != "inbox":
return True
return len([tag for tag in document.tags if tag.strip()]) > 0
def upsert_document_index(document: Document, summary_text: str) -> None:
"""Upserts one document into Typesense for semantic retrieval and routing examples."""
ensure_typesense_collection()
collection = _collection()
payload = {
"id": str(document.id),
"document_name": document.original_filename,
"summary_text": summary_text[:50000],
"logical_path": document.logical_path,
"tags": [tag for tag in document.tags if tag.strip()][:50],
"status": document.status.value,
"mime_type": document.mime_type,
"extension": document.extension,
"created_at": int(document.created_at.timestamp()),
"has_labels": _has_labels(document) and document.status != DocumentStatus.TRASHED,
}
collection.documents.upsert(payload)
def delete_document_index(document_id: str) -> None:
"""Deletes one document from Typesense by identifier."""
collection = _collection()
try:
collection.documents[document_id].delete()
except Exception as error:
message = str(error).lower()
if "404" in message or "not found" in message:
return
raise
def delete_many_documents_index(document_ids: list[str]) -> None:
"""Deletes many documents from Typesense by identifiers."""
for document_id in document_ids:
delete_document_index(document_id)
def query_similar_documents(summary_text: str, limit: int, exclude_document_id: str | None = None) -> list[SimilarDocument]:
"""Returns semantic nearest neighbors among labeled non-trashed indexed documents."""
ensure_typesense_collection()
collection = _collection()
normalized_query = " ".join(summary_text.strip().split())
query_text = normalized_query[:MAX_TYPESENSE_QUERY_CHARS] if normalized_query else "document"
search_payload = {
"q": query_text,
"query_by": "embedding",
"vector_query": f"embedding:([], k:{max(1, limit)})",
"exclude_fields": "embedding",
"per_page": max(1, limit),
"filter_by": "has_labels:=true && status:!=trashed",
}
try:
response = collection.documents.search(search_payload)
except Exception as error:
message = str(error).lower()
if "query string exceeds max allowed length" not in message:
raise
fallback_payload = dict(search_payload)
fallback_payload["q"] = "document"
response = collection.documents.search(fallback_payload)
hits = response.get("hits", []) if isinstance(response, dict) else []
neighbors: list[SimilarDocument] = []
for hit in hits:
if not isinstance(hit, dict):
continue
document = hit.get("document", {})
if not isinstance(document, dict):
continue
document_id = str(document.get("id", "")).strip()
if not document_id:
continue
if exclude_document_id and document_id == exclude_document_id:
continue
raw_tags = document.get("tags", [])
tags = [str(tag).strip() for tag in raw_tags if str(tag).strip()] if isinstance(raw_tags, list) else []
try:
distance = float(hit.get("vector_distance", 2.0))
except (TypeError, ValueError):
distance = 2.0
neighbors.append(
SimilarDocument(
document_id=document_id,
document_name=str(document.get("document_name", "")).strip(),
summary_text=str(document.get("summary_text", "")).strip(),
logical_path=str(document.get("logical_path", "")).strip(),
tags=tags,
vector_distance=distance,
)
)
if len(neighbors) >= limit:
break
return neighbors

View File

@@ -0,0 +1 @@
"""Background worker package for queueing and document processing tasks."""

View File

@@ -0,0 +1,21 @@
"""Queue connection helpers used by API and worker processes."""
from redis import Redis
from rq import Queue
from app.core.config import get_settings
settings = get_settings()
def get_redis() -> Redis:
"""Creates a Redis connection from configured URL."""
return Redis.from_url(settings.redis_url)
def get_processing_queue() -> Queue:
"""Returns the named queue for document processing jobs."""
return Queue("dcm", connection=get_redis())

544
backend/app/worker/tasks.py Normal file
View File

@@ -0,0 +1,544 @@
"""Background worker tasks for extraction, indexing, and archive fan-out."""
import uuid
from datetime import UTC, datetime
from pathlib import Path
from sqlalchemy import select
from app.db.base import SessionLocal
from app.models.document import Document, DocumentStatus
from app.services.app_settings import read_handwriting_provider_settings, read_handwriting_style_settings
from app.services.extractor import (
IMAGE_EXTENSIONS,
extract_archive_members,
extract_text_content,
is_supported_for_extraction,
sniff_mime,
)
from app.services.handwriting import IMAGE_TEXT_TYPE_HANDWRITING
from app.services.handwriting_style import (
assign_handwriting_style,
delete_handwriting_style_document,
)
from app.services.processing_logs import cleanup_processing_logs, log_processing_event, set_processing_log_autocommit
from app.services.routing_pipeline import (
apply_routing_decision,
classify_document_routing,
summarize_document,
upsert_semantic_index,
)
from app.services.storage import absolute_path, compute_sha256, store_bytes, write_preview
from app.worker.queue import get_processing_queue
def _create_archive_member_document(
parent: Document,
member_name: str,
member_data: bytes,
mime_type: str,
) -> Document:
"""Creates a child document entity for a file extracted from an uploaded archive."""
extension = Path(member_name).suffix.lower()
stored_relative_path = store_bytes(member_name, member_data)
return Document(
original_filename=Path(member_name).name,
source_relative_path=f"{parent.source_relative_path}/{member_name}".strip("/"),
stored_relative_path=stored_relative_path,
mime_type=mime_type,
extension=extension,
sha256=compute_sha256(member_data),
size_bytes=len(member_data),
logical_path=parent.logical_path,
tags=list(parent.tags),
metadata_json={"origin": "archive", "parent": str(parent.id)},
is_archive_member=True,
archived_member_path=member_name,
parent_document_id=parent.id,
)
def process_document_task(document_id: str) -> None:
"""Processes one queued document and updates extraction and suggestion fields."""
with SessionLocal() as session:
set_processing_log_autocommit(session, True)
queue = get_processing_queue()
document = session.execute(
select(Document).where(Document.id == uuid.UUID(document_id))
).scalar_one_or_none()
if document is None:
return
log_processing_event(
session=session,
stage="worker",
event="Document processing started",
level="info",
document=document,
payload_json={"status": document.status.value},
)
if document.status == DocumentStatus.TRASHED:
log_processing_event(
session=session,
stage="worker",
event="Document skipped because it is trashed",
level="warning",
document=document,
)
session.commit()
return
source_path = absolute_path(document.stored_relative_path)
data = source_path.read_bytes()
if document.extension == ".zip":
child_ids: list[str] = []
log_processing_event(
session=session,
stage="archive",
event="Archive extraction started",
level="info",
document=document,
payload_json={"size_bytes": len(data)},
)
try:
members = extract_archive_members(data)
for member in members:
mime_type = sniff_mime(member.data)
child = _create_archive_member_document(
parent=document,
member_name=member.name,
member_data=member.data,
mime_type=mime_type,
)
session.add(child)
session.flush()
child_ids.append(str(child.id))
log_processing_event(
session=session,
stage="archive",
event="Archive member extracted and queued",
level="info",
document=child,
payload_json={
"parent_document_id": str(document.id),
"member_name": member.name,
"member_size_bytes": len(member.data),
"mime_type": mime_type,
},
)
document.status = DocumentStatus.PROCESSED
document.extracted_text = f"archive with {len(members)} files"
log_processing_event(
session=session,
stage="archive",
event="Archive extraction completed",
level="info",
document=document,
payload_json={"member_count": len(members)},
)
except Exception as exc:
document.status = DocumentStatus.ERROR
document.metadata_json = {**document.metadata_json, "error": str(exc)}
log_processing_event(
session=session,
stage="archive",
event="Archive extraction failed",
level="error",
document=document,
response_text=str(exc),
)
if document.status == DocumentStatus.PROCESSED:
try:
summary_text = summarize_document(session=session, document=document)
metadata_json = dict(document.metadata_json)
metadata_json["summary_text"] = summary_text[:20000]
document.metadata_json = metadata_json
routing_decision = classify_document_routing(session=session, document=document, summary_text=summary_text)
apply_routing_decision(document=document, decision=routing_decision, session=session)
routing_metadata = document.metadata_json.get("routing", {})
log_processing_event(
session=session,
stage="routing",
event="Routing decision applied",
level="info",
document=document,
payload_json=routing_metadata if isinstance(routing_metadata, dict) else {},
)
log_processing_event(
session=session,
stage="indexing",
event="Typesense upsert started",
level="info",
document=document,
)
upsert_semantic_index(document=document, summary_text=summary_text)
log_processing_event(
session=session,
stage="indexing",
event="Typesense upsert completed",
level="info",
document=document,
)
except Exception as exc:
document.metadata_json = {
**document.metadata_json,
"routing_error": str(exc),
}
log_processing_event(
session=session,
stage="routing",
event="Routing or indexing failed for archive document",
level="error",
document=document,
response_text=str(exc),
)
document.processed_at = datetime.now(UTC)
log_processing_event(
session=session,
stage="worker",
event="Document processing completed",
level="info",
document=document,
payload_json={"status": document.status.value},
)
cleanup_processing_logs(session=session, keep_document_sessions=2, keep_unbound_entries=80)
session.commit()
for child_id in child_ids:
queue.enqueue("app.worker.tasks.process_document_task", child_id)
for child_id in child_ids:
log_processing_event(
session=session,
stage="archive",
event="Archive child job enqueued",
level="info",
document_id=uuid.UUID(child_id),
payload_json={"parent_document_id": str(document.id)},
)
session.commit()
return
if not is_supported_for_extraction(document.extension, document.mime_type):
document.status = DocumentStatus.UNSUPPORTED
document.processed_at = datetime.now(UTC)
log_processing_event(
session=session,
stage="extraction",
event="Document type unsupported for extraction",
level="warning",
document=document,
payload_json={"extension": document.extension, "mime_type": document.mime_type},
)
log_processing_event(
session=session,
stage="worker",
event="Document processing completed",
level="info",
document=document,
payload_json={"status": document.status.value},
)
cleanup_processing_logs(session=session, keep_document_sessions=2, keep_unbound_entries=80)
session.commit()
return
if document.extension in IMAGE_EXTENSIONS:
ocr_settings = read_handwriting_provider_settings()
log_processing_event(
session=session,
stage="ocr",
event="OCR request started",
level="info",
document=document,
provider_id=str(ocr_settings.get("provider_id", "")),
model_name=str(ocr_settings.get("openai_model", "")),
prompt_text=str(ocr_settings.get("prompt", "")),
payload_json={"mime_type": document.mime_type},
)
else:
log_processing_event(
session=session,
stage="extraction",
event="Text extraction started",
level="info",
document=document,
payload_json={"extension": document.extension, "mime_type": document.mime_type},
)
extraction = extract_text_content(document.original_filename, data, document.mime_type)
if extraction.preview_bytes and extraction.preview_suffix:
preview_relative_path = write_preview(str(document.id), extraction.preview_bytes, extraction.preview_suffix)
document.metadata_json = {**document.metadata_json, "preview_relative_path": preview_relative_path}
document.preview_available = True
log_processing_event(
session=session,
stage="extraction",
event="Preview generated",
level="info",
document=document,
payload_json={"preview_relative_path": preview_relative_path},
)
if extraction.metadata_json:
document.metadata_json = {**document.metadata_json, **extraction.metadata_json}
if document.extension in IMAGE_EXTENSIONS:
image_text_type = extraction.metadata_json.get("image_text_type")
if isinstance(image_text_type, str) and image_text_type.strip():
document.image_text_type = image_text_type.strip()
else:
document.image_text_type = None
else:
document.image_text_type = None
document.handwriting_style_id = None
if extraction.status == "error":
document.status = DocumentStatus.ERROR
document.metadata_json = {**document.metadata_json, "error": "extraction_failed"}
if document.extension in IMAGE_EXTENSIONS:
document.handwriting_style_id = None
metadata_json = dict(document.metadata_json)
metadata_json.pop("handwriting_style", None)
document.metadata_json = metadata_json
try:
delete_handwriting_style_document(str(document.id))
except Exception:
pass
document.processed_at = datetime.now(UTC)
log_processing_event(
session=session,
stage="extraction",
event="Extraction failed",
level="error",
document=document,
response_text=str(extraction.metadata_json.get("error", "extraction_failed")),
payload_json=extraction.metadata_json,
)
if "transcription_error" in extraction.metadata_json:
log_processing_event(
session=session,
stage="ocr",
event="OCR request failed",
level="error",
document=document,
response_text=str(extraction.metadata_json.get("transcription_error", "")),
)
log_processing_event(
session=session,
stage="worker",
event="Document processing completed",
level="info",
document=document,
payload_json={"status": document.status.value},
)
cleanup_processing_logs(session=session, keep_document_sessions=2, keep_unbound_entries=80)
session.commit()
return
if extraction.status == "unsupported":
document.status = DocumentStatus.UNSUPPORTED
if document.extension in IMAGE_EXTENSIONS:
document.handwriting_style_id = None
metadata_json = dict(document.metadata_json)
metadata_json.pop("handwriting_style", None)
document.metadata_json = metadata_json
try:
delete_handwriting_style_document(str(document.id))
except Exception:
pass
document.processed_at = datetime.now(UTC)
log_processing_event(
session=session,
stage="extraction",
event="Extraction returned unsupported",
level="warning",
document=document,
payload_json=extraction.metadata_json,
)
log_processing_event(
session=session,
stage="worker",
event="Document processing completed",
level="info",
document=document,
payload_json={"status": document.status.value},
)
cleanup_processing_logs(session=session, keep_document_sessions=2, keep_unbound_entries=80)
session.commit()
return
if document.extension in IMAGE_EXTENSIONS:
image_text_type = document.image_text_type or ""
if image_text_type == IMAGE_TEXT_TYPE_HANDWRITING:
style_settings = read_handwriting_style_settings()
if not bool(style_settings.get("enabled", True)):
document.handwriting_style_id = None
metadata_json = dict(document.metadata_json)
metadata_json.pop("handwriting_style", None)
metadata_json["handwriting_style_disabled"] = True
document.metadata_json = metadata_json
log_processing_event(
session=session,
stage="style",
event="Handwriting style clustering disabled",
level="warning",
document=document,
payload_json={
"enabled": False,
"embed_model": style_settings.get("embed_model"),
},
)
else:
try:
assignment = assign_handwriting_style(
session=session,
document=document,
image_data=data,
)
document.handwriting_style_id = assignment.style_cluster_id
metadata_json = dict(document.metadata_json)
metadata_json["handwriting_style"] = {
"style_cluster_id": assignment.style_cluster_id,
"matched_existing": assignment.matched_existing,
"similarity": assignment.similarity,
"vector_distance": assignment.vector_distance,
"compared_neighbors": assignment.compared_neighbors,
"match_min_similarity": assignment.match_min_similarity,
"bootstrap_match_min_similarity": assignment.bootstrap_match_min_similarity,
}
metadata_json.pop("handwriting_style_disabled", None)
document.metadata_json = metadata_json
log_processing_event(
session=session,
stage="style",
event="Handwriting style assigned",
level="info",
document=document,
payload_json=metadata_json["handwriting_style"],
)
except Exception as style_error:
document.handwriting_style_id = None
metadata_json = dict(document.metadata_json)
metadata_json["handwriting_style_error"] = str(style_error)
metadata_json.pop("handwriting_style", None)
metadata_json.pop("handwriting_style_disabled", None)
document.metadata_json = metadata_json
log_processing_event(
session=session,
stage="style",
event="Handwriting style assignment failed",
level="error",
document=document,
response_text=str(style_error),
)
else:
document.handwriting_style_id = None
metadata_json = dict(document.metadata_json)
metadata_json.pop("handwriting_style", None)
metadata_json.pop("handwriting_style_disabled", None)
document.metadata_json = metadata_json
try:
delete_handwriting_style_document(str(document.id))
except Exception:
pass
if document.extension in IMAGE_EXTENSIONS:
log_processing_event(
session=session,
stage="ocr",
event="OCR response received",
level="info",
document=document,
provider_id=str(
extraction.metadata_json.get(
"transcription_provider",
extraction.metadata_json.get("image_text_type_provider", ""),
)
),
model_name=str(
extraction.metadata_json.get(
"transcription_model",
extraction.metadata_json.get("image_text_type_model", ""),
)
),
response_text=extraction.text,
payload_json={
"image_text_type": document.image_text_type,
"image_text_type_confidence": extraction.metadata_json.get("image_text_type_confidence"),
"transcription_skipped": extraction.metadata_json.get("transcription_skipped"),
"uncertainty_count": len(
extraction.metadata_json.get("transcription_uncertainties", [])
if isinstance(extraction.metadata_json.get("transcription_uncertainties", []), list)
else []
)
},
)
else:
log_processing_event(
session=session,
stage="extraction",
event="Text extraction completed",
level="info",
document=document,
response_text=extraction.text,
payload_json={"text_length": len(extraction.text)},
)
document.extracted_text = extraction.text
try:
summary_text = summarize_document(session=session, document=document)
routing_decision = classify_document_routing(session=session, document=document, summary_text=summary_text)
apply_routing_decision(document=document, decision=routing_decision, session=session)
routing_metadata = document.metadata_json.get("routing", {})
log_processing_event(
session=session,
stage="routing",
event="Routing decision applied",
level="info",
document=document,
payload_json=routing_metadata if isinstance(routing_metadata, dict) else {},
)
log_processing_event(
session=session,
stage="indexing",
event="Typesense upsert started",
level="info",
document=document,
)
upsert_semantic_index(document=document, summary_text=summary_text)
log_processing_event(
session=session,
stage="indexing",
event="Typesense upsert completed",
level="info",
document=document,
)
metadata_json = dict(document.metadata_json)
metadata_json["summary_text"] = summary_text[:20000]
document.metadata_json = metadata_json
except Exception as exc:
document.metadata_json = {
**document.metadata_json,
"routing_error": str(exc),
}
log_processing_event(
session=session,
stage="routing",
event="Routing or indexing failed",
level="error",
document=document,
response_text=str(exc),
)
document.status = DocumentStatus.PROCESSED
document.processed_at = datetime.now(UTC)
log_processing_event(
session=session,
stage="worker",
event="Document processing completed",
level="info",
document=document,
payload_json={"status": document.status.value},
)
cleanup_processing_logs(session=session, keep_document_sessions=2, keep_unbound_entries=80)
session.commit()