"""Model runtime utilities for provider-bound LLM task execution.""" from dataclasses import dataclass from typing import Any from openai import APIConnectionError, APIError, APITimeoutError, OpenAI from app.core.config import normalize_and_validate_provider_base_url from app.services.app_settings import read_task_runtime_settings class ModelTaskError(Exception): """Raised when a model task request fails.""" class ModelTaskTimeoutError(ModelTaskError): """Raised when a model task request times out.""" class ModelTaskDisabledError(ModelTaskError): """Raised when a model task is disabled in settings.""" @dataclass class ModelTaskRuntime: """Resolved runtime configuration for one task and provider.""" task_name: str provider_id: str provider_type: str base_url: str timeout_seconds: int api_key: str model: str prompt: str def _normalize_base_url(raw_value: str) -> str: """Normalizes provider base URL and enforces SSRF protections before outbound calls.""" return normalize_and_validate_provider_base_url(raw_value, resolve_dns=True) def _should_fallback_to_chat(error: Exception) -> bool: """Determines whether a responses API failure should fallback to chat completions.""" status_code = getattr(error, "status_code", None) if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}: return True message = str(error).lower() fallback_markers = ( "404", "not found", "unknown endpoint", "unsupported", "invalid url", "responses", ) return any(marker in message for marker in fallback_markers) def _extract_text_from_response(response: Any) -> str: """Extracts plain text from Responses API outputs.""" output_text = getattr(response, "output_text", None) if isinstance(output_text, str) and output_text.strip(): return output_text.strip() output_items = getattr(response, "output", None) if not isinstance(output_items, list): return "" chunks: list[str] = [] for item in output_items: item_data = item.model_dump() if hasattr(item, "model_dump") else item if not isinstance(item_data, dict): continue item_type = item_data.get("type") if item_type == "output_text": text = str(item_data.get("text", "")).strip() if text: chunks.append(text) if item_type == "message": for content in item_data.get("content", []) or []: if not isinstance(content, dict): continue if content.get("type") in {"output_text", "text"}: text = str(content.get("text", "")).strip() if text: chunks.append(text) return "\n".join(chunks).strip() def _extract_text_from_chat_response(response: Any) -> str: """Extracts text from Chat Completions API outputs.""" message_content = response.choices[0].message.content if isinstance(message_content, str): return message_content.strip() if not isinstance(message_content, list): return "" chunks: list[str] = [] for content in message_content: if not isinstance(content, dict): continue text = str(content.get("text", "")).strip() if text: chunks.append(text) return "\n".join(chunks).strip() def resolve_task_runtime(task_name: str) -> ModelTaskRuntime: """Resolves one task runtime including provider endpoint, model, and prompt.""" runtime_payload = read_task_runtime_settings(task_name) task_payload = runtime_payload["task"] provider_payload = runtime_payload["provider"] if not bool(task_payload.get("enabled", True)): raise ModelTaskDisabledError(f"task_disabled:{task_name}") provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip() if provider_type != "openai_compatible": raise ModelTaskError(f"unsupported_provider_type:{provider_type}") try: normalized_base_url = _normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))) except ValueError as error: raise ModelTaskError(f"invalid_provider_base_url:{error}") from error return ModelTaskRuntime( task_name=task_name, provider_id=str(provider_payload.get("id", "")), provider_type=provider_type, base_url=normalized_base_url, timeout_seconds=int(provider_payload.get("timeout_seconds", 45)), api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required", model=str(task_payload.get("model", "")).strip(), prompt=str(task_payload.get("prompt", "")).strip(), ) def _create_client(runtime: ModelTaskRuntime) -> OpenAI: """Builds an OpenAI SDK client for OpenAI-compatible provider endpoints.""" return OpenAI( api_key=runtime.api_key, base_url=runtime.base_url, timeout=runtime.timeout_seconds, ) def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str: """Runs a text-only task against the configured provider and returns plain output text.""" runtime = resolve_task_runtime(task_name) client = _create_client(runtime) prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt try: response = client.responses.create( model=runtime.model, input=[ { "role": "system", "content": [ { "type": "input_text", "text": prompt, } ], }, { "role": "user", "content": [ { "type": "input_text", "text": user_text, } ], }, ], ) text = _extract_text_from_response(response) if text: return text except APITimeoutError as error: raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error except APIConnectionError as error: raise ModelTaskError(f"task_error:{task_name}:{error}") from error except APIError as error: if not _should_fallback_to_chat(error): raise ModelTaskError(f"task_error:{task_name}:{error}") from error except Exception as error: if not _should_fallback_to_chat(error): raise ModelTaskError(f"task_error:{task_name}:{error}") from error try: fallback = client.chat.completions.create( model=runtime.model, messages=[ { "role": "system", "content": prompt, }, { "role": "user", "content": user_text, }, ], ) return _extract_text_from_chat_response(fallback) except APITimeoutError as error: raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error except (APIConnectionError, APIError) as error: raise ModelTaskError(f"task_error:{task_name}:{error}") from error except Exception as error: raise ModelTaskError(f"task_error:{task_name}:{error}") from error