484 lines
18 KiB
Python
484 lines
18 KiB
Python
"""Handwriting transcription service using OpenAI-compatible vision models."""
|
|
|
|
import base64
|
|
import io
|
|
import json
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
|
from PIL import Image, ImageOps
|
|
|
|
from app.core.config import normalize_and_validate_provider_base_url
|
|
from app.services.app_settings import DEFAULT_OCR_PROMPT, read_handwriting_provider_settings
|
|
|
|
MAX_IMAGE_SIDE = 2000
|
|
IMAGE_TEXT_TYPE_HANDWRITING = "handwriting"
|
|
IMAGE_TEXT_TYPE_PRINTED = "printed_text"
|
|
IMAGE_TEXT_TYPE_NO_TEXT = "no_text"
|
|
IMAGE_TEXT_TYPE_UNKNOWN = "unknown"
|
|
|
|
IMAGE_TEXT_CLASSIFICATION_PROMPT = (
|
|
"Classify the text content in this image.\n"
|
|
"Choose exactly one label from: handwriting, printed_text, no_text.\n"
|
|
"Definitions:\n"
|
|
"- handwriting: text exists and most readable text is handwritten.\n"
|
|
"- printed_text: text exists and most readable text is machine printed or typed.\n"
|
|
"- no_text: no readable text is present.\n"
|
|
"Return strict JSON only with shape:\n"
|
|
"{\n"
|
|
' "label": "handwriting|printed_text|no_text",\n'
|
|
' "confidence": number\n'
|
|
"}\n"
|
|
"Confidence must be between 0 and 1."
|
|
)
|
|
|
|
|
|
class HandwritingTranscriptionError(Exception):
|
|
"""Raised when handwriting transcription fails for a non-timeout reason."""
|
|
|
|
|
|
class HandwritingTranscriptionTimeoutError(HandwritingTranscriptionError):
|
|
"""Raised when handwriting transcription exceeds the configured timeout."""
|
|
|
|
|
|
class HandwritingTranscriptionNotConfiguredError(HandwritingTranscriptionError):
|
|
"""Raised when handwriting transcription is disabled or missing credentials."""
|
|
|
|
|
|
@dataclass
|
|
class HandwritingTranscription:
|
|
"""Represents transcription output and uncertainty markers."""
|
|
|
|
text: str
|
|
uncertainties: list[str]
|
|
provider: str
|
|
model: str
|
|
|
|
|
|
@dataclass
|
|
class ImageTextClassification:
|
|
"""Represents model classification of image text modality for one image."""
|
|
|
|
label: str
|
|
confidence: float
|
|
provider: str
|
|
model: str
|
|
|
|
|
|
def _extract_uncertainties(text: str) -> list[str]:
|
|
"""Extracts uncertainty markers from transcription output."""
|
|
|
|
matches = re.findall(r"\[\[\?(.*?)\?\]\]", text)
|
|
return [match.strip() for match in matches if match.strip()]
|
|
|
|
|
|
def _coerce_json_object(payload: str) -> dict[str, Any]:
|
|
"""Parses and extracts a JSON object from raw model output text."""
|
|
|
|
text = payload.strip()
|
|
if not text:
|
|
return {}
|
|
|
|
try:
|
|
parsed = json.loads(text)
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE)
|
|
if fenced:
|
|
try:
|
|
parsed = json.loads(fenced.group(1))
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
first_brace = text.find("{")
|
|
last_brace = text.rfind("}")
|
|
if first_brace >= 0 and last_brace > first_brace:
|
|
candidate = text[first_brace : last_brace + 1]
|
|
try:
|
|
parsed = json.loads(candidate)
|
|
if isinstance(parsed, dict):
|
|
return parsed
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
return {}
|
|
|
|
|
|
def _clamp_probability(value: Any, fallback: float = 0.0) -> float:
|
|
"""Clamps confidence-like values to the inclusive [0, 1] range."""
|
|
|
|
try:
|
|
parsed = float(value)
|
|
except (TypeError, ValueError):
|
|
return fallback
|
|
return max(0.0, min(1.0, parsed))
|
|
|
|
|
|
def _normalize_image_text_type(label: str) -> str:
|
|
"""Normalizes classifier labels into one supported canonical image text type."""
|
|
|
|
normalized = label.strip().lower().replace("-", "_").replace(" ", "_")
|
|
if normalized in {IMAGE_TEXT_TYPE_HANDWRITING, "handwritten", "handwritten_text"}:
|
|
return IMAGE_TEXT_TYPE_HANDWRITING
|
|
if normalized in {IMAGE_TEXT_TYPE_PRINTED, "printed", "typed", "machine_text"}:
|
|
return IMAGE_TEXT_TYPE_PRINTED
|
|
if normalized in {IMAGE_TEXT_TYPE_NO_TEXT, "no-text", "none", "no readable text"}:
|
|
return IMAGE_TEXT_TYPE_NO_TEXT
|
|
return IMAGE_TEXT_TYPE_UNKNOWN
|
|
|
|
|
|
def _normalize_image_bytes(image_data: bytes) -> tuple[bytes, str]:
|
|
"""Applies EXIF rotation and scales large images down for efficient transcription."""
|
|
|
|
with Image.open(io.BytesIO(image_data)) as image:
|
|
rotated = ImageOps.exif_transpose(image)
|
|
prepared = rotated.convert("RGB")
|
|
long_side = max(prepared.width, prepared.height)
|
|
if long_side > MAX_IMAGE_SIDE:
|
|
scale = MAX_IMAGE_SIDE / long_side
|
|
resized_width = max(1, int(prepared.width * scale))
|
|
resized_height = max(1, int(prepared.height * scale))
|
|
prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS)
|
|
|
|
output = io.BytesIO()
|
|
prepared.save(output, format="JPEG", quality=90, optimize=True)
|
|
return output.getvalue(), "image/jpeg"
|
|
|
|
|
|
def _create_client(provider_settings: dict[str, Any]) -> OpenAI:
|
|
"""Creates an OpenAI client configured with DNS-revalidated endpoint and request timeout controls."""
|
|
|
|
api_key = str(provider_settings.get("openai_api_key", "")).strip() or "no-key-required"
|
|
raw_base_url = str(provider_settings.get("openai_base_url", "")).strip()
|
|
try:
|
|
normalized_base_url = normalize_and_validate_provider_base_url(raw_base_url, resolve_dns=True)
|
|
except ValueError as error:
|
|
raise HandwritingTranscriptionError(f"invalid_provider_base_url:{error}") from error
|
|
return OpenAI(
|
|
api_key=api_key,
|
|
base_url=normalized_base_url,
|
|
timeout=int(provider_settings["openai_timeout_seconds"]),
|
|
)
|
|
|
|
|
|
def _extract_text_from_response(response: Any) -> str:
|
|
"""Extracts plain text from responses API output objects."""
|
|
|
|
output_text = getattr(response, "output_text", None)
|
|
if isinstance(output_text, str) and output_text.strip():
|
|
return output_text.strip()
|
|
|
|
output_items = getattr(response, "output", None)
|
|
if not isinstance(output_items, list):
|
|
return ""
|
|
|
|
texts: list[str] = []
|
|
for item in output_items:
|
|
item_data = item.model_dump() if hasattr(item, "model_dump") else item
|
|
if not isinstance(item_data, dict):
|
|
continue
|
|
item_type = item_data.get("type")
|
|
if item_type == "output_text":
|
|
text = str(item_data.get("text", "")).strip()
|
|
if text:
|
|
texts.append(text)
|
|
if item_type == "message":
|
|
for content in item_data.get("content", []) or []:
|
|
if not isinstance(content, dict):
|
|
continue
|
|
if content.get("type") in {"output_text", "text"}:
|
|
text = str(content.get("text", "")).strip()
|
|
if text:
|
|
texts.append(text)
|
|
|
|
return "\n".join(texts).strip()
|
|
|
|
|
|
def _transcribe_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
|
"""Transcribes handwriting using the responses API."""
|
|
|
|
response = client.responses.create(
|
|
model=model,
|
|
input=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_text",
|
|
"text": prompt,
|
|
},
|
|
{
|
|
"type": "input_image",
|
|
"image_url": image_data_url,
|
|
"detail": "high",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
)
|
|
return _extract_text_from_response(response)
|
|
|
|
|
|
def _transcribe_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
|
"""Transcribes handwriting using chat completions for endpoint compatibility."""
|
|
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": prompt,
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": image_data_url,
|
|
"detail": "high",
|
|
},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
)
|
|
|
|
message_content = response.choices[0].message.content
|
|
if isinstance(message_content, str):
|
|
return message_content.strip()
|
|
if isinstance(message_content, list):
|
|
text_parts: list[str] = []
|
|
for part in message_content:
|
|
if isinstance(part, dict):
|
|
text = str(part.get("text", "")).strip()
|
|
if text:
|
|
text_parts.append(text)
|
|
return "\n".join(text_parts).strip()
|
|
return ""
|
|
|
|
|
|
def _classify_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
|
"""Classifies image text modality using the responses API."""
|
|
|
|
response = client.responses.create(
|
|
model=model,
|
|
input=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_text",
|
|
"text": prompt,
|
|
},
|
|
{
|
|
"type": "input_image",
|
|
"image_url": image_data_url,
|
|
"detail": "high",
|
|
},
|
|
],
|
|
}
|
|
],
|
|
)
|
|
return _extract_text_from_response(response)
|
|
|
|
|
|
def _classify_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str:
|
|
"""Classifies image text modality using chat completions for compatibility."""
|
|
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": prompt,
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": image_data_url,
|
|
"detail": "high",
|
|
},
|
|
},
|
|
],
|
|
}
|
|
],
|
|
)
|
|
|
|
message_content = response.choices[0].message.content
|
|
if isinstance(message_content, str):
|
|
return message_content.strip()
|
|
if isinstance(message_content, list):
|
|
text_parts: list[str] = []
|
|
for part in message_content:
|
|
if isinstance(part, dict):
|
|
text = str(part.get("text", "")).strip()
|
|
if text:
|
|
text_parts.append(text)
|
|
return "\n".join(text_parts).strip()
|
|
return ""
|
|
|
|
|
|
def _classify_image_text_data_url(image_data_url: str) -> ImageTextClassification:
|
|
"""Classifies an image as handwriting, printed text, or no text."""
|
|
|
|
provider_settings = read_handwriting_provider_settings()
|
|
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
|
|
if provider_type != "openai_compatible":
|
|
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
|
|
|
|
if not bool(provider_settings.get("enabled", True)):
|
|
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
|
|
|
|
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
|
|
client = _create_client(provider_settings)
|
|
|
|
try:
|
|
output_text = _classify_with_responses(
|
|
client=client,
|
|
model=model,
|
|
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
|
|
image_data_url=image_data_url,
|
|
)
|
|
if not output_text:
|
|
output_text = _classify_with_chat(
|
|
client=client,
|
|
model=model,
|
|
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
|
|
image_data_url=image_data_url,
|
|
)
|
|
except APITimeoutError as error:
|
|
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
|
|
except (APIConnectionError, APIError):
|
|
try:
|
|
output_text = _classify_with_chat(
|
|
client=client,
|
|
model=model,
|
|
prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT,
|
|
image_data_url=image_data_url,
|
|
)
|
|
except APITimeoutError as timeout_error:
|
|
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
|
|
except Exception as fallback_error:
|
|
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
|
|
except Exception as error:
|
|
raise HandwritingTranscriptionError(str(error)) from error
|
|
|
|
parsed = _coerce_json_object(output_text)
|
|
if not parsed:
|
|
raise HandwritingTranscriptionError("image_text_classification_parse_failed")
|
|
|
|
label = _normalize_image_text_type(str(parsed.get("label", "")))
|
|
confidence = _clamp_probability(parsed.get("confidence", 0.0), fallback=0.0)
|
|
return ImageTextClassification(
|
|
label=label,
|
|
confidence=confidence,
|
|
provider="openai",
|
|
model=model,
|
|
)
|
|
|
|
|
|
def _transcribe_image_data_url(image_data_url: str) -> HandwritingTranscription:
|
|
"""Transcribes a handwriting image data URL with configured OpenAI provider settings."""
|
|
|
|
provider_settings = read_handwriting_provider_settings()
|
|
provider_type = str(provider_settings.get("provider", "openai_compatible")).strip()
|
|
if provider_type != "openai_compatible":
|
|
raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}")
|
|
|
|
if not bool(provider_settings.get("enabled", True)):
|
|
raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled")
|
|
|
|
model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini"
|
|
prompt = str(provider_settings.get("prompt", DEFAULT_OCR_PROMPT)).strip() or DEFAULT_OCR_PROMPT
|
|
client = _create_client(provider_settings)
|
|
|
|
try:
|
|
text = _transcribe_with_responses(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
|
|
if not text:
|
|
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
|
|
except APITimeoutError as error:
|
|
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error
|
|
except (APIConnectionError, APIError) as error:
|
|
try:
|
|
text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url)
|
|
except APITimeoutError as timeout_error:
|
|
raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error
|
|
except Exception as fallback_error:
|
|
raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error
|
|
except Exception as error:
|
|
raise HandwritingTranscriptionError(str(error)) from error
|
|
|
|
final_text = text.strip()
|
|
return HandwritingTranscription(
|
|
text=final_text,
|
|
uncertainties=_extract_uncertainties(final_text),
|
|
provider="openai",
|
|
model=model,
|
|
)
|
|
|
|
|
|
def transcribe_handwriting_base64(image_base64: str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
|
|
"""Transcribes handwriting from a base64 payload without data URL prefix."""
|
|
|
|
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
|
|
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
|
|
return _transcribe_image_data_url(image_data_url)
|
|
|
|
|
|
def transcribe_handwriting_url(image_url: str) -> HandwritingTranscription:
|
|
"""Transcribes handwriting from a direct image URL."""
|
|
|
|
return _transcribe_image_data_url(image_url)
|
|
|
|
|
|
def transcribe_handwriting_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> HandwritingTranscription:
|
|
"""Transcribes handwriting from raw image bytes after normalization."""
|
|
|
|
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
|
|
encoded = base64.b64encode(normalized_bytes).decode("ascii")
|
|
return transcribe_handwriting_base64(encoded, mime_type=normalized_mime)
|
|
|
|
|
|
def classify_image_text_base64(image_base64: str, mime_type: str = "image/jpeg") -> ImageTextClassification:
|
|
"""Classifies image text type from a base64 payload without data URL prefix."""
|
|
|
|
normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg"
|
|
image_data_url = f"data:{normalized_mime};base64,{image_base64}"
|
|
return _classify_image_text_data_url(image_data_url)
|
|
|
|
|
|
def classify_image_text_url(image_url: str) -> ImageTextClassification:
|
|
"""Classifies image text type from a direct image URL."""
|
|
|
|
return _classify_image_text_data_url(image_url)
|
|
|
|
|
|
def classify_image_text_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> ImageTextClassification:
|
|
"""Classifies image text type from raw image bytes after normalization."""
|
|
|
|
normalized_bytes, normalized_mime = _normalize_image_bytes(image_data)
|
|
encoded = base64.b64encode(normalized_bytes).decode("ascii")
|
|
return classify_image_text_base64(encoded, mime_type=normalized_mime)
|
|
|
|
|
|
def transcribe_handwriting(image: bytes | str, mime_type: str = "image/jpeg") -> HandwritingTranscription:
|
|
"""Transcribes handwriting from bytes, base64 text, or URL input."""
|
|
|
|
if isinstance(image, bytes):
|
|
return transcribe_handwriting_bytes(image, mime_type=mime_type)
|
|
|
|
stripped = image.strip()
|
|
if stripped.startswith("http://") or stripped.startswith("https://"):
|
|
return transcribe_handwriting_url(stripped)
|
|
return transcribe_handwriting_base64(stripped, mime_type=mime_type)
|