436 lines
15 KiB
Python
436 lines
15 KiB
Python
"""Handwriting-style clustering and style-scoped path composition for image documents."""
|
|
|
|
import base64
|
|
import io
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from PIL import Image, ImageOps
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import get_settings
|
|
from app.models.document import Document, DocumentStatus
|
|
from app.services.app_settings import (
|
|
DEFAULT_HANDWRITING_STYLE_EMBED_MODEL,
|
|
read_handwriting_style_settings,
|
|
)
|
|
from app.services.typesense_index import get_typesense_client
|
|
|
|
|
|
settings = get_settings()
|
|
|
|
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
|
|
HANDWRITING_STYLE_COLLECTION_SUFFIX = "_handwriting_styles"
|
|
HANDWRITING_STYLE_EMBED_MODEL = DEFAULT_HANDWRITING_STYLE_EMBED_MODEL
|
|
HANDWRITING_STYLE_MATCH_MIN_SIMILARITY = 0.86
|
|
HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY = 0.89
|
|
HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE = 3
|
|
HANDWRITING_STYLE_NEIGHBOR_LIMIT = 8
|
|
HANDWRITING_STYLE_IMAGE_MAX_SIDE = 1024
|
|
HANDWRITING_STYLE_ID_PREFIX = "hw_style_"
|
|
HANDWRITING_STYLE_ID_PATTERN = re.compile(r"^hw_style_(\d+)$")
|
|
|
|
|
|
@dataclass
|
|
class HandwritingStyleNeighbor:
|
|
"""Represents one nearest handwriting-style neighbor returned from Typesense."""
|
|
|
|
document_id: str
|
|
style_cluster_id: str
|
|
vector_distance: float
|
|
similarity: float
|
|
|
|
|
|
@dataclass
|
|
class HandwritingStyleAssignment:
|
|
"""Represents the chosen handwriting-style cluster assignment for one document."""
|
|
|
|
style_cluster_id: str
|
|
matched_existing: bool
|
|
similarity: float
|
|
vector_distance: float
|
|
compared_neighbors: int
|
|
match_min_similarity: float
|
|
bootstrap_match_min_similarity: float
|
|
|
|
|
|
def _style_collection_name() -> str:
|
|
"""Builds the dedicated Typesense collection name used for handwriting-style vectors."""
|
|
|
|
return f"{settings.typesense_collection_name}{HANDWRITING_STYLE_COLLECTION_SUFFIX}"
|
|
|
|
|
|
def _style_collection() -> Any:
|
|
"""Returns the Typesense collection handle for handwriting-style indexing."""
|
|
|
|
client = get_typesense_client()
|
|
return client.collections[_style_collection_name()]
|
|
|
|
|
|
def _distance_to_similarity(vector_distance: float) -> float:
|
|
"""Converts Typesense vector distance into conservative similarity in [0, 1]."""
|
|
|
|
return max(0.0, min(1.0, 1.0 - (vector_distance / 2.0)))
|
|
|
|
|
|
def _encode_style_image_base64(image_data: bytes, image_max_side: int) -> str:
|
|
"""Normalizes and downsizes image bytes and returns a base64-encoded JPEG payload."""
|
|
|
|
with Image.open(io.BytesIO(image_data)) as image:
|
|
prepared = ImageOps.exif_transpose(image).convert("RGB")
|
|
longest_side = max(prepared.width, prepared.height)
|
|
if longest_side > image_max_side:
|
|
scale = image_max_side / longest_side
|
|
resized_width = max(1, int(prepared.width * scale))
|
|
resized_height = max(1, int(prepared.height * scale))
|
|
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
|
|
|
|
output = io.BytesIO()
|
|
prepared.save(output, format="JPEG", quality=86, optimize=True)
|
|
return base64.b64encode(output.getvalue()).decode("ascii")
|
|
|
|
|
|
def ensure_handwriting_style_collection() -> None:
|
|
"""Creates the handwriting-style Typesense collection when it is not present."""
|
|
|
|
runtime_settings = read_handwriting_style_settings()
|
|
embed_model = str(runtime_settings.get("embed_model", HANDWRITING_STYLE_EMBED_MODEL)).strip() or HANDWRITING_STYLE_EMBED_MODEL
|
|
collection = _style_collection()
|
|
should_recreate_collection = False
|
|
try:
|
|
existing_schema = collection.retrieve()
|
|
if isinstance(existing_schema, dict):
|
|
existing_fields = existing_schema.get("fields", [])
|
|
if isinstance(existing_fields, list):
|
|
for field in existing_fields:
|
|
if not isinstance(field, dict):
|
|
continue
|
|
if str(field.get("name", "")).strip() != "embedding":
|
|
continue
|
|
embed_config = field.get("embed", {})
|
|
model_config = embed_config.get("model_config", {}) if isinstance(embed_config, dict) else {}
|
|
existing_model = str(model_config.get("model_name", "")).strip()
|
|
if existing_model and existing_model != embed_model:
|
|
should_recreate_collection = True
|
|
break
|
|
if not should_recreate_collection:
|
|
return
|
|
except Exception as error:
|
|
message = str(error).lower()
|
|
if "404" not in message and "not found" not in message:
|
|
raise
|
|
|
|
client = get_typesense_client()
|
|
if should_recreate_collection:
|
|
client.collections[_style_collection_name()].delete()
|
|
|
|
schema = {
|
|
"name": _style_collection_name(),
|
|
"fields": [
|
|
{
|
|
"name": "style_cluster_id",
|
|
"type": "string",
|
|
"facet": True,
|
|
},
|
|
{
|
|
"name": "image_text_type",
|
|
"type": "string",
|
|
"facet": True,
|
|
},
|
|
{
|
|
"name": "created_at",
|
|
"type": "int64",
|
|
},
|
|
{
|
|
"name": "image",
|
|
"type": "image",
|
|
"store": False,
|
|
},
|
|
{
|
|
"name": "embedding",
|
|
"type": "float[]",
|
|
"embed": {
|
|
"from": ["image"],
|
|
"model_config": {
|
|
"model_name": embed_model,
|
|
},
|
|
},
|
|
},
|
|
],
|
|
"default_sorting_field": "created_at",
|
|
}
|
|
client.collections.create(schema)
|
|
|
|
|
|
def _search_style_neighbors(
|
|
image_base64: str,
|
|
limit: int,
|
|
exclude_document_id: str | None = None,
|
|
) -> list[HandwritingStyleNeighbor]:
|
|
"""Returns nearest handwriting-style neighbors for one encoded image payload."""
|
|
|
|
ensure_handwriting_style_collection()
|
|
client = get_typesense_client()
|
|
|
|
filter_clauses = [f"image_text_type:={IMAGE_TEXT_TYPE_HANDWRITING}"]
|
|
if exclude_document_id:
|
|
filter_clauses.append(f"id:!={exclude_document_id}")
|
|
|
|
search_payload = {
|
|
"q": "*",
|
|
"query_by": "embedding",
|
|
"vector_query": f"embedding:([], image:{image_base64}, k:{max(1, limit)})",
|
|
"exclude_fields": "embedding,image",
|
|
"per_page": max(1, limit),
|
|
"filter_by": " && ".join(filter_clauses),
|
|
}
|
|
response = client.multi_search.perform(
|
|
{
|
|
"searches": [
|
|
{
|
|
"collection": _style_collection_name(),
|
|
**search_payload,
|
|
}
|
|
]
|
|
},
|
|
{},
|
|
)
|
|
|
|
results = response.get("results", []) if isinstance(response, dict) else []
|
|
first_result = results[0] if isinstance(results, list) and len(results) > 0 else {}
|
|
hits = first_result.get("hits", []) if isinstance(first_result, dict) else []
|
|
|
|
neighbors: list[HandwritingStyleNeighbor] = []
|
|
for hit in hits:
|
|
if not isinstance(hit, dict):
|
|
continue
|
|
document = hit.get("document")
|
|
if not isinstance(document, dict):
|
|
continue
|
|
|
|
document_id = str(document.get("id", "")).strip()
|
|
style_cluster_id = str(document.get("style_cluster_id", "")).strip()
|
|
if not document_id or not style_cluster_id:
|
|
continue
|
|
|
|
try:
|
|
vector_distance = float(hit.get("vector_distance", 2.0))
|
|
except (TypeError, ValueError):
|
|
vector_distance = 2.0
|
|
|
|
neighbors.append(
|
|
HandwritingStyleNeighbor(
|
|
document_id=document_id,
|
|
style_cluster_id=style_cluster_id,
|
|
vector_distance=vector_distance,
|
|
similarity=_distance_to_similarity(vector_distance),
|
|
)
|
|
)
|
|
|
|
if len(neighbors) >= limit:
|
|
break
|
|
|
|
return neighbors
|
|
|
|
|
|
def _next_style_cluster_id(session: Session) -> str:
|
|
"""Allocates the next stable handwriting-style folder identifier."""
|
|
|
|
existing_ids = session.execute(
|
|
select(Document.handwriting_style_id).where(Document.handwriting_style_id.is_not(None))
|
|
).scalars().all()
|
|
max_value = 0
|
|
for existing_id in existing_ids:
|
|
candidate = str(existing_id).strip()
|
|
match = HANDWRITING_STYLE_ID_PATTERN.fullmatch(candidate)
|
|
if not match:
|
|
continue
|
|
numeric_part = int(match.group(1))
|
|
max_value = max(max_value, numeric_part)
|
|
return f"{HANDWRITING_STYLE_ID_PREFIX}{max_value + 1}"
|
|
|
|
|
|
def _style_cluster_sample_size(session: Session, style_cluster_id: str) -> int:
|
|
"""Returns the number of indexed documents currently assigned to one style cluster."""
|
|
|
|
return int(
|
|
session.execute(
|
|
select(func.count())
|
|
.select_from(Document)
|
|
.where(Document.handwriting_style_id == style_cluster_id)
|
|
.where(Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING)
|
|
).scalar_one()
|
|
)
|
|
|
|
|
|
def assign_handwriting_style(
|
|
session: Session,
|
|
document: Document,
|
|
image_data: bytes,
|
|
) -> HandwritingStyleAssignment:
|
|
"""Assigns a document to an existing handwriting-style cluster or creates a new one."""
|
|
|
|
runtime_settings = read_handwriting_style_settings()
|
|
image_max_side = int(runtime_settings.get("image_max_side", HANDWRITING_STYLE_IMAGE_MAX_SIDE))
|
|
neighbor_limit = int(runtime_settings.get("neighbor_limit", HANDWRITING_STYLE_NEIGHBOR_LIMIT))
|
|
match_min_similarity = float(runtime_settings.get("match_min_similarity", HANDWRITING_STYLE_MATCH_MIN_SIMILARITY))
|
|
bootstrap_match_min_similarity = float(
|
|
runtime_settings.get("bootstrap_match_min_similarity", HANDWRITING_STYLE_BOOTSTRAP_MIN_SIMILARITY)
|
|
)
|
|
bootstrap_sample_size = int(runtime_settings.get("bootstrap_sample_size", HANDWRITING_STYLE_BOOTSTRAP_SAMPLE_SIZE))
|
|
|
|
image_base64 = _encode_style_image_base64(image_data, image_max_side=image_max_side)
|
|
neighbors = _search_style_neighbors(
|
|
image_base64=image_base64,
|
|
limit=neighbor_limit,
|
|
exclude_document_id=str(document.id),
|
|
)
|
|
|
|
best_neighbor = neighbors[0] if neighbors else None
|
|
similarity = best_neighbor.similarity if best_neighbor else 0.0
|
|
vector_distance = best_neighbor.vector_distance if best_neighbor else 2.0
|
|
cluster_sample_size = 0
|
|
if best_neighbor:
|
|
cluster_sample_size = _style_cluster_sample_size(
|
|
session=session,
|
|
style_cluster_id=best_neighbor.style_cluster_id,
|
|
)
|
|
required_similarity = (
|
|
bootstrap_match_min_similarity
|
|
if cluster_sample_size < bootstrap_sample_size
|
|
else match_min_similarity
|
|
)
|
|
should_match_existing = (
|
|
best_neighbor is not None and similarity >= required_similarity
|
|
)
|
|
|
|
if should_match_existing and best_neighbor:
|
|
style_cluster_id = best_neighbor.style_cluster_id
|
|
matched_existing = True
|
|
else:
|
|
existing_style_cluster_id = (document.handwriting_style_id or "").strip()
|
|
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(existing_style_cluster_id):
|
|
style_cluster_id = existing_style_cluster_id
|
|
else:
|
|
style_cluster_id = _next_style_cluster_id(session=session)
|
|
matched_existing = False
|
|
|
|
ensure_handwriting_style_collection()
|
|
collection = _style_collection()
|
|
payload = {
|
|
"id": str(document.id),
|
|
"style_cluster_id": style_cluster_id,
|
|
"image_text_type": IMAGE_TEXT_TYPE_HANDWRITING,
|
|
"created_at": int(document.created_at.timestamp()),
|
|
"image": image_base64,
|
|
}
|
|
collection.documents.upsert(payload)
|
|
|
|
return HandwritingStyleAssignment(
|
|
style_cluster_id=style_cluster_id,
|
|
matched_existing=matched_existing,
|
|
similarity=similarity,
|
|
vector_distance=vector_distance,
|
|
compared_neighbors=len(neighbors),
|
|
match_min_similarity=match_min_similarity,
|
|
bootstrap_match_min_similarity=bootstrap_match_min_similarity,
|
|
)
|
|
|
|
|
|
def delete_handwriting_style_document(document_id: str) -> None:
|
|
"""Deletes one document id from the handwriting-style Typesense collection."""
|
|
|
|
collection = _style_collection()
|
|
try:
|
|
collection.documents[document_id].delete()
|
|
except Exception as error:
|
|
message = str(error).lower()
|
|
if "404" in message or "not found" in message:
|
|
return
|
|
raise
|
|
|
|
|
|
def delete_many_handwriting_style_documents(document_ids: list[str]) -> None:
|
|
"""Deletes many document ids from the handwriting-style Typesense collection."""
|
|
|
|
for document_id in document_ids:
|
|
delete_handwriting_style_document(document_id)
|
|
|
|
|
|
def apply_handwriting_style_path(style_cluster_id: str | None, path_value: str | None) -> str | None:
|
|
"""Composes style-prefixed logical paths while preventing duplicate prefix nesting."""
|
|
|
|
if path_value is None:
|
|
return None
|
|
|
|
normalized_path = path_value.strip().strip("/")
|
|
if not normalized_path:
|
|
return None
|
|
|
|
normalized_style = (style_cluster_id or "").strip().strip("/")
|
|
if not normalized_style:
|
|
return normalized_path
|
|
|
|
segments = [segment for segment in normalized_path.split("/") if segment]
|
|
while segments and HANDWRITING_STYLE_ID_PATTERN.fullmatch(segments[0]):
|
|
segments.pop(0)
|
|
if segments and segments[0].strip().lower() == normalized_style.lower():
|
|
segments.pop(0)
|
|
|
|
if len(segments) == 0:
|
|
return normalized_style
|
|
|
|
sanitized_path = "/".join(segments)
|
|
return f"{normalized_style}/{sanitized_path}"
|
|
|
|
|
|
def resolve_handwriting_style_path_prefix(
|
|
session: Session,
|
|
style_cluster_id: str | None,
|
|
*,
|
|
exclude_document_id: str | None = None,
|
|
) -> str | None:
|
|
"""Resolves a stable path prefix for one style cluster, preferring known non-style root segments."""
|
|
|
|
normalized_style = (style_cluster_id or "").strip()
|
|
if not normalized_style:
|
|
return None
|
|
|
|
statement = select(Document.logical_path).where(
|
|
Document.handwriting_style_id == normalized_style,
|
|
Document.image_text_type == IMAGE_TEXT_TYPE_HANDWRITING,
|
|
Document.status != DocumentStatus.TRASHED,
|
|
)
|
|
if exclude_document_id:
|
|
statement = statement.where(Document.id != exclude_document_id)
|
|
rows = session.execute(statement).scalars().all()
|
|
|
|
segment_counts: dict[str, int] = {}
|
|
segment_labels: dict[str, str] = {}
|
|
for raw_path in rows:
|
|
if not isinstance(raw_path, str):
|
|
continue
|
|
segments = [segment.strip() for segment in raw_path.split("/") if segment.strip()]
|
|
if not segments:
|
|
continue
|
|
first_segment = segments[0]
|
|
lowered = first_segment.lower()
|
|
if lowered == "inbox":
|
|
continue
|
|
if HANDWRITING_STYLE_ID_PATTERN.fullmatch(first_segment):
|
|
continue
|
|
segment_counts[lowered] = segment_counts.get(lowered, 0) + 1
|
|
if lowered not in segment_labels:
|
|
segment_labels[lowered] = first_segment
|
|
|
|
if not segment_counts:
|
|
return normalized_style
|
|
|
|
winner = sorted(
|
|
segment_counts.items(),
|
|
key=lambda item: (-item[1], item[0]),
|
|
)[0][0]
|
|
return segment_labels.get(winner, normalized_style)
|