Initial commit
This commit is contained in:
435
backend/app/services/handwriting_style.py
Normal file
435
backend/app/services/handwriting_style.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""Handwriting-style clustering and style-scoped path composition for image documents."""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.document import Document, DocumentStatus
|
||||
from app.services.app_settings import (
|
||||
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
|
||||
read_handwriting_style_settings,
|
||||
)
|
||||
from app.services.typesense_index import get_typesense_client
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
|
||||
HANDWRITING_STYLE_COLLECTION_SUFFIX = "_handwriting_styles"
|
||||
HANDWRITING_STYLE_EMBED_MODEL = DEFAULT_HANDWRITING_STYLE_EMBED_MODEL
|
||||
HANDWRITING_STYLE_MATCH_MIN_SIMILARITY = 0.86
|
||||
HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY = 0.89
|
||||
HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE = 3
|
||||
HANDWRITING_STYLE_NEIGHBOR_LIMIT = 8
|
||||
HANDWRITING_STYLE_IMAGE_MAX_SIDE = 1024
|
||||
HANDWRITING_STYLE_ID_PREFIX = "hw_style_"
|
||||
HANDWRITING_STYLE_ID_PATTERN = re.compile(r"^hw_style_(\d+)$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandwritingStyleNeighbor:
|
||||
"""Represents one nearest handwriting-style neighbor returned from Typesense."""
|
||||
|
||||
document_id: str
|
||||
style_cluster_id: str
|
||||
vector_distance: float
|
||||
similarity: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class HandwritingStyleAssignment:
|
||||
"""Represents the chosen handwriting-style cluster assignment for one document."""
|
||||
|
||||
style_cluster_id: str
|
||||
matched_existing: bool
|
||||
similarity: float
|
||||
vector_distance: float
|
||||
compared_neighbors: int
|
||||
match_min_similarity: float
|
||||
bootstrap_match_min_similarity: float
|
||||
|
||||
|
||||
def _style_collection_name() -> str:
|
||||
"""Builds the dedicated Typesense collection name used for handwriting-style vectors."""
|
||||
|
||||
return f"{settings.typesense_collection_name}{HANDWRITING_STYLE_COLLECTION_SUFFIX}"
|
||||
|
||||
|
||||
def _style_collection() -> Any:
|
||||
"""Returns the Typesense collection handle for handwriting-style indexing."""
|
||||
|
||||
client = get_typesense_client()
|
||||
return client.collections[_style_collection_name()]
|
||||
|
||||
|
||||
def _distance_to_similarity(vector_distance: float) -> float:
|
||||
"""Converts Typesense vector distance into conservative similarity in [0, 1]."""
|
||||
|
||||
return max(0.0, min(1.0, 1.0 - (vector_distance / 2.0)))
|
||||
|
||||
|
||||
def _encode_style_image_base64(image_data: bytes, image_max_side: int) -> str:
|
||||
"""Normalizes and downsizes image bytes and returns a base64-encoded JPEG payload."""
|
||||
|
||||
with Image.open(io.BytesIO(image_data)) as image:
|
||||
prepared = ImageOps.exif_transpose(image).convert("RGB")
|
||||
longest_side = max(prepared.width, prepared.height)
|
||||
if longest_side > image_max_side:
|
||||
scale = image_max_side / longest_side
|
||||
resized_width = max(1, int(prepared.width * scale))
|
||||
resized_height = max(1, int(prepared.height * scale))
|
||||
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output = io.BytesIO()
|
||||
prepared.save(output, format="JPEG", quality=86, optimize=True)
|
||||
return base64.b64encode(output.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def ensure_handwriting_style_collection() -> None:
|
||||
"""Creates the handwriting-style Typesense collection when it is not present."""
|
||||
|
||||
runtime_settings = read_handwriting_style_settings()
|
||||
embed_model = str(runtime_settings.get("embed_model", HANDWRITING_STYLE_EMBED_MODEL)).strip() or HANDWRITING_STYLE_EMBED_MODEL
|
||||
collection = _style_collection()
|
||||
should_recreate_collection = False
|
||||
try:
|
||||
existing_schema = collection.retrieve()
|
||||
if isinstance(existing_schema, dict):
|
||||
existing_fields = existing_schema.get("fields", [])
|
||||
if isinstance(existing_fields, list):
|
||||
for field in existing_fields:
|
||||
if not isinstance(field, dict):
|
||||
continue
|
||||
if str(field.get("name", "")).strip() != "embedding":
|
||||
continue
|
||||
embed_config = field.get("embed", {})
|
||||
model_config = embed_config.get("model_config", {}) if isinstance(embed_config, dict) else {}
|
||||
existing_model = str(model_config.get("model_name", "")).strip()
|
||||
if existing_model and existing_model != embed_model:
|
||||
should_recreate_collection = True
|
||||
break
|
||||
if not should_recreate_collection:
|
||||
return
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "404" not in message and "not found" not in message:
|
||||
raise
|
||||
|
||||
client = get_typesense_client()
|
||||
if should_recreate_collection:
|
||||
client.collections[_style_collection_name()].delete()
|
||||
|
||||
schema = {
|
||||
"name": _style_collection_name(),
|
||||
"fields": [
|
||||
{
|
||||
"name": "style_cluster_id",
|
||||
"type": "string",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "image_text_type",
|
||||
"type": "string",
|
||||
"facet": True,
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"type": "int64",
|
||||
},
|
||||
{
|
||||
"name": "image",
|
||||
"type": "image",
|
||||
"store": False,
|
||||
},
|
||||
{
|
||||
"name": "embedding",
|
||||
"type": "float[]",
|
||||
"embed": {
|
||||
"from": ["image"],
|
||||
"model_config": {
|
||||
"model_name": embed_model,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"default_sorting_field": "created_at",
|
||||
}
|
||||
client.collections.create(schema)
|
||||
|
||||
|
||||
def _search_style_neighbors(
|
||||
image_base64: str,
|
||||
limit: int,
|
||||
exclude_document_id: str | None = None,
|
||||
) -> list[HandwritingStyleNeighbor]:
|
||||
"""Returns nearest handwriting-style neighbors for one encoded image payload."""
|
||||
|
||||
ensure_handwriting_style_collection()
|
||||
client = get_typesense_client()
|
||||
|
||||
filter_clauses = [f"image_text_type:={IMAGE_TEXT_TYPE_HANDWRITING}"]
|
||||
if exclude_document_id:
|
||||
filter_clauses.append(f"id:!={exclude_document_id}")
|
||||
|
||||
search_payload = {
|
||||
"q": "*",
|
||||
"query_by": "embedding",
|
||||
"vector_query": f"embedding:([], image:{image_base64}, k:{max(1, limit)})",
|
||||
"exclude_fields": "embedding,image",
|
||||
"per_page": max(1, limit),
|
||||
"filter_by": " && ".join(filter_clauses),
|
||||
}
|
||||
response = client.multi_search.perform(
|
||||
{
|
||||
"searches": [
|
||||
{
|
||||
"collection": _style_collection_name(),
|
||||
**search_payload,
|
||||
}
|
||||
]
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
results = response.get("results", []) if isinstance(response, dict) else []
|
||||
first_result = results[0] if isinstance(results, list) and len(results) > 0 else {}
|
||||
hits = first_result.get("hits", []) if isinstance(first_result, dict) else []
|
||||
|
||||
neighbors: list[HandwritingStyleNeighbor] = []
|
||||
for hit in hits:
|
||||
if not isinstance(hit, dict):
|
||||
continue
|
||||
document = hit.get("document")
|
||||
if not isinstance(document, dict):
|
||||
continue
|
||||
|
||||
document_id = str(document.get("id", "")).strip()
|
||||
style_cluster_id = str(document.get("style_cluster_id", "")).strip()
|
||||
if not document_id or not style_cluster_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
vector_distance = float(hit.get("vector_distance", 2.0))
|
||||
except (TypeError, ValueError):
|
||||
vector_distance = 2.0
|
||||
|
||||
neighbors.append(
|
||||
HandwritingStyleNeighbor(
|
||||
document_id=document_id,
|
||||
style_cluster_id=style_cluster_id,
|
||||
vector_distance=vector_distance,
|
||||
similarity=_distance_to_similarity(vector_distance),
|
||||
)
|
||||
)
|
||||
|
||||
if len(neighbors) >= limit:
|
||||
break
|
||||
|
||||
return neighbors
|
||||
|
||||
|
||||
def _next_style_cluster_id(session: Session) -> str:
|
||||
"""Allocates the next stable handwriting-style folder identifier."""
|
||||
|
||||
existing_ids = session.execute(
|
||||
select(Document.handwriting_style_id).where(Document.handwriting_style_id.is_not(None))
|
||||
).scalars().all()
|
||||
max_value = 0
|
||||
for existing_id in existing_ids:
|
||||
candidate = str(existing_id).strip()
|
||||
match = HANDWRITING_STYLE_ID_PATTERN.fullmatch(candidate)
|
||||
if not match:
|
||||
continue
|
||||
numeric_part = int(match.group(1))
|
||||
max_value = max(max_value, numeric_part)
|
||||
return f"{HANDWRITING_STYLE_ID_PREFIX}{max_value + 1}"
|
||||
|
||||
|
||||
def _style_cluster_sample_size(session: Session, style_cluster_id: str) -> int:
|
||||
"""Returns the number of indexed documents currently assigned to one style cluster."""
|
||||
|
||||
return int(
|
||||
session.execute(
|
||||
select(func.count())
|
||||
.select_from(Document)
|
||||
.where(Document.handwriting_style_id == style_cluster_id)
|
||||
.where(Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING)
|
||||
).scalar_one()
|
||||
)
|
||||
|
||||
|
||||
def assign_handwriting_style(
|
||||
session: Session,
|
||||
document: Document,
|
||||
image_data: bytes,
|
||||
) -> HandwritingStyleAssignment:
|
||||
"""Assigns a document to an existing handwriting-style cluster or creates a new one."""
|
||||
|
||||
runtime_settings = read_handwriting_style_settings()
|
||||
image_max_side = int(runtime_settings.get("image_max_side", HANDWRITING_STYLE_IMAGE_MAX_SIDE))
|
||||
neighbor_limit = int(runtime_settings.get("neighbor_limit", HANDWRITING_STYLE_NEIGHBOR_LIMIT))
|
||||
match_min_similarity = float(runtime_settings.get("match_min_similarity", HANDWRITING_STYLE_MATCH_MIN_SIMILARITY))
|
||||
bootstrap_match_min_similarity = float(
|
||||
runtime_settings.get("bootstrap_match_min_similarity", HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY)
|
||||
)
|
||||
bootstrap_sample_size = int(runtime_settings.get("bootstrap_sample_size", HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE))
|
||||
|
||||
image_base64 = _encode_style_image_base64(image_data, image_max_side=image_max_side)
|
||||
neighbors = _search_style_neighbors(
|
||||
image_base64=image_base64,
|
||||
limit=neighbor_limit,
|
||||
exclude_document_id=str(document.id),
|
||||
)
|
||||
|
||||
best_neighbor = neighbors[0] if neighbors else None
|
||||
similarity = best_neighbor.similarity if best_neighbor else 0.0
|
||||
vector_distance = best_neighbor.vector_distance if best_neighbor else 2.0
|
||||
cluster_sample_size = 0
|
||||
if best_neighbor:
|
||||
cluster_sample_size = _style_cluster_sample_size(
|
||||
session=session,
|
||||
style_cluster_id=best_neighbor.style_cluster_id,
|
||||
)
|
||||
required_similarity = (
|
||||
bootstrap_match_min_similarity
|
||||
if cluster_sample_size < bootstrap_sample_size
|
||||
else match_min_similarity
|
||||
)
|
||||
should_match_existing = (
|
||||
best_neighbor is not None and similarity >= required_similarity
|
||||
)
|
||||
|
||||
if should_match_existing and best_neighbor:
|
||||
style_cluster_id = best_neighbor.style_cluster_id
|
||||
matched_existing = True
|
||||
else:
|
||||
existing_style_cluster_id = (document.handwriting_style_id or "").strip()
|
||||
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(existing_style_cluster_id):
|
||||
style_cluster_id = existing_style_cluster_id
|
||||
else:
|
||||
style_cluster_id = _next_style_cluster_id(session=session)
|
||||
matched_existing = False
|
||||
|
||||
ensure_handwriting_style_collection()
|
||||
collection = _style_collection()
|
||||
payload = {
|
||||
"id": str(document.id),
|
||||
"style_cluster_id": style_cluster_id,
|
||||
"image_text_type": IMAGE_TEXT_TYPE_HANDWRITING,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"image": image_base64,
|
||||
}
|
||||
collection.documents.upsert(payload)
|
||||
|
||||
return HandwritingStyleAssignment(
|
||||
style_cluster_id=style_cluster_id,
|
||||
matched_existing=matched_existing,
|
||||
similarity=similarity,
|
||||
vector_distance=vector_distance,
|
||||
compared_neighbors=len(neighbors),
|
||||
match_min_similarity=match_min_similarity,
|
||||
bootstrap_match_min_similarity=bootstrap_match_min_similarity,
|
||||
)
|
||||
|
||||
|
||||
def delete_handwriting_style_document(document_id: str) -> None:
|
||||
"""Deletes one document id from the handwriting-style Typesense collection."""
|
||||
|
||||
collection = _style_collection()
|
||||
try:
|
||||
collection.documents[document_id].delete()
|
||||
except Exception as error:
|
||||
message = str(error).lower()
|
||||
if "404" in message or "not found" in message:
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
def delete_many_handwriting_style_documents(document_ids: list[str]) -> None:
|
||||
"""Deletes many document ids from the handwriting-style Typesense collection."""
|
||||
|
||||
for document_id in document_ids:
|
||||
delete_handwriting_style_document(document_id)
|
||||
|
||||
|
||||
def apply_handwriting_style_path(style_cluster_id: str | None, path_value: str | None) -> str | None:
|
||||
"""Composes style-prefixed logical paths while preventing duplicate prefix nesting."""
|
||||
|
||||
if path_value is None:
|
||||
return None
|
||||
|
||||
normalized_path = path_value.strip().strip("/")
|
||||
if not normalized_path:
|
||||
return None
|
||||
|
||||
normalized_style = (style_cluster_id or "").strip().strip("/")
|
||||
if not normalized_style:
|
||||
return normalized_path
|
||||
|
||||
segments = [segment for segment in normalized_path.split("/") if segment]
|
||||
while segments and HANDWRITING_STYLE_ID_PATTERN.fullmatch(segments[0]):
|
||||
segments.pop(0)
|
||||
if segments and segments[0].strip().lower() == normalized_style.lower():
|
||||
segments.pop(0)
|
||||
|
||||
if len(segments) == 0:
|
||||
return normalized_style
|
||||
|
||||
sanitized_path = "/".join(segments)
|
||||
return f"{normalized_style}/{sanitized_path}"
|
||||
|
||||
|
||||
def resolve_handwriting_style_path_prefix(
|
||||
session: Session,
|
||||
style_cluster_id: str | None,
|
||||
*,
|
||||
exclude_document_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Resolves a stable path prefix for one style cluster, preferring known non-style root segments."""
|
||||
|
||||
normalized_style = (style_cluster_id or "").strip()
|
||||
if not normalized_style:
|
||||
return None
|
||||
|
||||
statement = select(Document.logical_path).where(
|
||||
Document.handwriting_style_id == normalized_style,
|
||||
Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING,
|
||||
Document.status != DocumentStatus.TRASHED,
|
||||
)
|
||||
if exclude_document_id:
|
||||
statement = statement.where(Document.id != exclude_document_id)
|
||||
rows = session.execute(statement).scalars().all()
|
||||
|
||||
segment_counts: dict[str, int] = {}
|
||||
segment_labels: dict[str, str] = {}
|
||||
for raw_path in rows:
|
||||
if not isinstance(raw_path, str):
|
||||
continue
|
||||
segments = [segment.strip() for segment in raw_path.split("/") if segment.strip()]
|
||||
if not segments:
|
||||
continue
|
||||
first_segment = segments[0]
|
||||
lowered = first_segment.lower()
|
||||
if lowered == "inbox":
|
||||
continue
|
||||
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(first_segment):
|
||||
continue
|
||||
segment_counts[lowered] = segment_counts.get(lowered, 0) + 1
|
||||
if lowered not in segment_labels:
|
||||
segment_labels[lowered] = first_segment
|
||||
|
||||
if not segment_counts:
|
||||
return normalized_style
|
||||
|
||||
winner = sorted(
|
||||
segment_counts.items(),
|
||||
key=lambda item: (-item[1], item[0]),
|
||||
)[0][0]
|
||||
return segment_labels.get(winner, normalized_style)
|
||||
Reference in New Issue
Block a user