"""Handwriting transcription service using OpenAI-compatible vision models.""" import base64 import io import json import re from dataclasses import dataclass from typing import Any from openai import APIConnectionError, APIError, APITimeoutError, OpenAI from PIL import Image, ImageOps from app.services.app_settings import DEFAULT_OCR_PROMPT, read_handwriting_provider_settings MAX_IMAGE_SIDE = 2000 IMAGE_TEXT_TYPE_HANDWRITING = "handwriting" IMAGE_TEXT_TYPE_PRINTED = "printed_text" IMAGE_TEXT_TYPE_NO_TEXT = "no_text" IMAGE_TEXT_TYPE_UNKNOWN = "unknown" IMAGE_TEXT_CLASSIFICATION_PROMPT = ( "Classify the text content in this image.\n" "Choose exactly one label from: handwriting, printed_text, no_text.\n" "Definitions:\n" "- handwriting: text exists and most readable text is handwritten.\n" "- printed_text: text exists and most readable text is machine printed or typed.\n" "- no_text: no readable text is present.\n" "Return strict JSON only with shape:\n" "{\n" ' "label": "handwriting|printed_text|no_text",\n' ' "confidence": number\n' "}\n" "Confidence must be between 0 and 1." ) class HandwritingTranscriptionError(Exception): """Raised when handwriting transcription fails for a non-timeout reason.""" class HandwritingTranscriptionTimeoutError(HandwritingTranscriptionError): """Raised when handwriting transcription exceeds the configured timeout.""" class HandwritingTranscriptionNotConfiguredError(HandwritingTranscriptionError): """Raised when handwriting transcription is disabled or missing credentials.""" @dataclass class HandwritingTranscription: """Represents transcription output and uncertainty markers.""" text: str uncertainties: list[str] provider: str model: str @dataclass class ImageTextClassification: """Represents model classification of image text modality for one image.""" label: str confidence: float provider: str model: str def _extract_uncertainties(text: str) -> list[str]: """Extracts uncertainty markers from transcription output.""" matches = re.findall(r"\[\[\?(.*?)\?\]\]", text) return [match.strip() for match in matches if match.strip()] def _coerce_json_object(payload: str) -> dict[str, Any]: """Parses and extracts a JSON object from raw model output text.""" text = payload.strip() if not text: return {} try: parsed = json.loads(text) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass fenced = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL | re.IGNORECASE) if fenced: try: parsed = json.loads(fenced.group(1)) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass first_brace = text.find("{") last_brace = text.rfind("}") if first_brace >= 0 and last_brace > first_brace: candidate = text[first_brace : last_brace + 1] try: parsed = json.loads(candidate) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: return {} return {} def _clamp_probability(value: Any, fallback: float = 0.0) -> float: """Clamps confidence-like values to the inclusive [0, 1] range.""" try: parsed = float(value) except (TypeError, ValueError): return fallback return max(0.0, min(1.0, parsed)) def _normalize_image_text_type(label: str) -> str: """Normalizes classifier labels into one supported canonical image text type.""" normalized = label.strip().lower().replace("-", "_").replace(" ", "_") if normalized in {IMAGE_TEXT_TYPE_HANDWRITING, "handwritten", "handwritten_text"}: return IMAGE_TEXT_TYPE_HANDWRITING if normalized in {IMAGE_TEXT_TYPE_PRINTED, "printed", "typed", "machine_text"}: return IMAGE_TEXT_TYPE_PRINTED if normalized in {IMAGE_TEXT_TYPE_NO_TEXT, "no-text", "none", "no readable text"}: return IMAGE_TEXT_TYPE_NO_TEXT return IMAGE_TEXT_TYPE_UNKNOWN def _normalize_image_bytes(image_data: bytes) -> tuple[bytes, str]: """Applies EXIF rotation and scales large images down for efficient transcription.""" with Image.open(io.BytesIO(image_data)) as image: rotated = ImageOps.exif_transpose(image) prepared = rotated.convert("RGB") long_side = max(prepared.width, prepared.height) if long_side > MAX_IMAGE_SIDE: scale = MAX_IMAGE_SIDE / long_side resized_width = max(1, int(prepared.width * scale)) resized_height = max(1, int(prepared.height * scale)) prepared = prepared.resize((resized_width, resized_height), Image.Resampling.LANCZOS) output = io.BytesIO() prepared.save(output, format="JPEG", quality=90, optimize=True) return output.getvalue(), "image/jpeg" def _create_client(provider_settings: dict[str, Any]) -> OpenAI: """Creates an OpenAI client configured for compatible endpoints and timeouts.""" api_key = str(provider_settings.get("openai_api_key", "")).strip() or "no-key-required" return OpenAI( api_key=api_key, base_url=str(provider_settings["openai_base_url"]), timeout=int(provider_settings["openai_timeout_seconds"]), ) def _extract_text_from_response(response: Any) -> str: """Extracts plain text from responses API output objects.""" output_text = getattr(response, "output_text", None) if isinstance(output_text, str) and output_text.strip(): return output_text.strip() output_items = getattr(response, "output", None) if not isinstance(output_items, list): return "" texts: list[str] = [] for item in output_items: item_data = item.model_dump() if hasattr(item, "model_dump") else item if not isinstance(item_data, dict): continue item_type = item_data.get("type") if item_type == "output_text": text = str(item_data.get("text", "")).strip() if text: texts.append(text) if item_type == "message": for content in item_data.get("content", []) or []: if not isinstance(content, dict): continue if content.get("type") in {"output_text", "text"}: text = str(content.get("text", "")).strip() if text: texts.append(text) return "\n".join(texts).strip() def _transcribe_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str: """Transcribes handwriting using the responses API.""" response = client.responses.create( model=model, input=[ { "role": "user", "content": [ { "type": "input_text", "text": prompt, }, { "type": "input_image", "image_url": image_data_url, "detail": "high", }, ], } ], ) return _extract_text_from_response(response) def _transcribe_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str: """Transcribes handwriting using chat completions for endpoint compatibility.""" response = client.chat.completions.create( model=model, messages=[ { "role": "user", "content": [ { "type": "text", "text": prompt, }, { "type": "image_url", "image_url": { "url": image_data_url, "detail": "high", }, }, ], } ], ) message_content = response.choices[0].message.content if isinstance(message_content, str): return message_content.strip() if isinstance(message_content, list): text_parts: list[str] = [] for part in message_content: if isinstance(part, dict): text = str(part.get("text", "")).strip() if text: text_parts.append(text) return "\n".join(text_parts).strip() return "" def _classify_with_responses(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str: """Classifies image text modality using the responses API.""" response = client.responses.create( model=model, input=[ { "role": "user", "content": [ { "type": "input_text", "text": prompt, }, { "type": "input_image", "image_url": image_data_url, "detail": "high", }, ], } ], ) return _extract_text_from_response(response) def _classify_with_chat(client: OpenAI, model: str, prompt: str, image_data_url: str) -> str: """Classifies image text modality using chat completions for compatibility.""" response = client.chat.completions.create( model=model, messages=[ { "role": "user", "content": [ { "type": "text", "text": prompt, }, { "type": "image_url", "image_url": { "url": image_data_url, "detail": "high", }, }, ], } ], ) message_content = response.choices[0].message.content if isinstance(message_content, str): return message_content.strip() if isinstance(message_content, list): text_parts: list[str] = [] for part in message_content: if isinstance(part, dict): text = str(part.get("text", "")).strip() if text: text_parts.append(text) return "\n".join(text_parts).strip() return "" def _classify_image_text_data_url(image_data_url: str) -> ImageTextClassification: """Classifies an image as handwriting, printed text, or no text.""" provider_settings = read_handwriting_provider_settings() provider_type = str(provider_settings.get("provider", "openai_compatible")).strip() if provider_type != "openai_compatible": raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}") if not bool(provider_settings.get("enabled", True)): raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled") model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini" client = _create_client(provider_settings) try: output_text = _classify_with_responses( client=client, model=model, prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT, image_data_url=image_data_url, ) if not output_text: output_text = _classify_with_chat( client=client, model=model, prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT, image_data_url=image_data_url, ) except APITimeoutError as error: raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error except (APIConnectionError, APIError): try: output_text = _classify_with_chat( client=client, model=model, prompt=IMAGE_TEXT_CLASSIFICATION_PROMPT, image_data_url=image_data_url, ) except APITimeoutError as timeout_error: raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error except Exception as fallback_error: raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error except Exception as error: raise HandwritingTranscriptionError(str(error)) from error parsed = _coerce_json_object(output_text) if not parsed: raise HandwritingTranscriptionError("image_text_classification_parse_failed") label = _normalize_image_text_type(str(parsed.get("label", ""))) confidence = _clamp_probability(parsed.get("confidence", 0.0), fallback=0.0) return ImageTextClassification( label=label, confidence=confidence, provider="openai", model=model, ) def _transcribe_image_data_url(image_data_url: str) -> HandwritingTranscription: """Transcribes a handwriting image data URL with configured OpenAI provider settings.""" provider_settings = read_handwriting_provider_settings() provider_type = str(provider_settings.get("provider", "openai_compatible")).strip() if provider_type != "openai_compatible": raise HandwritingTranscriptionError(f"unsupported_provider_type:{provider_type}") if not bool(provider_settings.get("enabled", True)): raise HandwritingTranscriptionNotConfiguredError("handwriting_transcription_disabled") model = str(provider_settings.get("openai_model", "gpt-4.1-mini")).strip() or "gpt-4.1-mini" prompt = str(provider_settings.get("prompt", DEFAULT_OCR_PROMPT)).strip() or DEFAULT_OCR_PROMPT client = _create_client(provider_settings) try: text = _transcribe_with_responses(client=client, model=model, prompt=prompt, image_data_url=image_data_url) if not text: text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url) except APITimeoutError as error: raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from error except (APIConnectionError, APIError) as error: try: text = _transcribe_with_chat(client=client, model=model, prompt=prompt, image_data_url=image_data_url) except APITimeoutError as timeout_error: raise HandwritingTranscriptionTimeoutError("openai_request_timeout") from timeout_error except Exception as fallback_error: raise HandwritingTranscriptionError(str(fallback_error)) from fallback_error except Exception as error: raise HandwritingTranscriptionError(str(error)) from error final_text = text.strip() return HandwritingTranscription( text=final_text, uncertainties=_extract_uncertainties(final_text), provider="openai", model=model, ) def transcribe_handwriting_base64(image_base64: str, mime_type: str = "image/jpeg") -> HandwritingTranscription: """Transcribes handwriting from a base64 payload without data URL prefix.""" normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg" image_data_url = f"data:{normalized_mime};base64,{image_base64}" return _transcribe_image_data_url(image_data_url) def transcribe_handwriting_url(image_url: str) -> HandwritingTranscription: """Transcribes handwriting from a direct image URL.""" return _transcribe_image_data_url(image_url) def transcribe_handwriting_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> HandwritingTranscription: """Transcribes handwriting from raw image bytes after normalization.""" normalized_bytes, normalized_mime = _normalize_image_bytes(image_data) encoded = base64.b64encode(normalized_bytes).decode("ascii") return transcribe_handwriting_base64(encoded, mime_type=normalized_mime) def classify_image_text_base64(image_base64: str, mime_type: str = "image/jpeg") -> ImageTextClassification: """Classifies image text type from a base64 payload without data URL prefix.""" normalized_mime = mime_type.strip().lower() if mime_type.strip() else "image/jpeg" image_data_url = f"data:{normalized_mime};base64,{image_base64}" return _classify_image_text_data_url(image_data_url) def classify_image_text_url(image_url: str) -> ImageTextClassification: """Classifies image text type from a direct image URL.""" return _classify_image_text_data_url(image_url) def classify_image_text_bytes(image_data: bytes, mime_type: str = "image/jpeg") -> ImageTextClassification: """Classifies image text type from raw image bytes after normalization.""" normalized_bytes, normalized_mime = _normalize_image_bytes(image_data) encoded = base64.b64encode(normalized_bytes).decode("ascii") return classify_image_text_base64(encoded, mime_type=normalized_mime) def transcribe_handwriting(image: bytes | str, mime_type: str = "image/jpeg") -> HandwritingTranscription: """Transcribes handwriting from bytes, base64 text, or URL input.""" if isinstance(image, bytes): return transcribe_handwriting_bytes(image, mime_type=mime_type) stripped = image.strip() if stripped.startswith("http://") or stripped.startswith("https://"): return transcribe_handwriting_url(stripped) return transcribe_handwriting_base64(stripped, mime_type=mime_type)