Files
ledgerdock/backend/app/services/model_runtime.py

224 lines
7.4 KiB
Python

"""Model runtime utilities for provider-bound LLM task execution."""
from dataclasses import dataclass
from typing import Any
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
from app.core.config import normalize_and_validate_provider_base_url
from app.services.app_settings import read_task_runtime_settings
class ModelTaskError(Exception):
"""Raised when a model task request fails."""
class ModelTaskTimeoutError(ModelTaskError):
"""Raised when a model task request times out."""
class ModelTaskDisabledError(ModelTaskError):
"""Raised when a model task is disabled in settings."""
@dataclass
class ModelTaskRuntime:
"""Resolved runtime configuration for one task and provider."""
task_name: str
provider_id: str
provider_type: str
base_url: str
timeout_seconds: int
api_key: str
model: str
prompt: str
def _normalize_base_url(raw_value: str) -> str:
"""Normalizes provider base URL and enforces SSRF protections before outbound calls."""
return normalize_and_validate_provider_base_url(raw_value, resolve_dns=True)
def _should_fallback_to_chat(error: Exception) -> bool:
"""Determines whether a responses API failure should fallback to chat completions."""
status_code = getattr(error, "status_code", None)
if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}:
return True
message = str(error).lower()
fallback_markers = (
"404",
"not found",
"unknown endpoint",
"unsupported",
"invalid url",
"responses",
)
return any(marker in message for marker in fallback_markers)
def _extract_text_from_response(response: Any) -> str:
"""Extracts plain text from Responses API outputs."""
output_text = getattr(response, "output_text", None)
if isinstance(output_text, str) and output_text.strip():
return output_text.strip()
output_items = getattr(response, "output", None)
if not isinstance(output_items, list):
return ""
chunks: list[str] = []
for item in output_items:
item_data = item.model_dump() if hasattr(item, "model_dump") else item
if not isinstance(item_data, dict):
continue
item_type = item_data.get("type")
if item_type == "output_text":
text = str(item_data.get("text", "")).strip()
if text:
chunks.append(text)
if item_type == "message":
for content in item_data.get("content", []) or []:
if not isinstance(content, dict):
continue
if content.get("type") in {"output_text", "text"}:
text = str(content.get("text", "")).strip()
if text:
chunks.append(text)
return "\n".join(chunks).strip()
def _extract_text_from_chat_response(response: Any) -> str:
"""Extracts text from Chat Completions API outputs."""
message_content = response.choices[0].message.content
if isinstance(message_content, str):
return message_content.strip()
if not isinstance(message_content, list):
return ""
chunks: list[str] = []
for content in message_content:
if not isinstance(content, dict):
continue
text = str(content.get("text", "")).strip()
if text:
chunks.append(text)
return "\n".join(chunks).strip()
def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
"""Resolves one task runtime including provider endpoint, model, and prompt."""
runtime_payload = read_task_runtime_settings(task_name)
task_payload = runtime_payload["task"]
provider_payload = runtime_payload["provider"]
if not bool(task_payload.get("enabled", True)):
raise ModelTaskDisabledError(f"task_disabled:{task_name}")
provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip()
if provider_type != "openai_compatible":
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
try:
normalized_base_url = _normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1")))
except ValueError as error:
raise ModelTaskError(f"invalid_provider_base_url:{error}") from error
return ModelTaskRuntime(
task_name=task_name,
provider_id=str(provider_payload.get("id", "")),
provider_type=provider_type,
base_url=normalized_base_url,
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
model=str(task_payload.get("model", "")).strip(),
prompt=str(task_payload.get("prompt", "")).strip(),
)
def _create_client(runtime: ModelTaskRuntime) -> OpenAI:
"""Builds an OpenAI SDK client for OpenAI-compatible provider endpoints."""
return OpenAI(
api_key=runtime.api_key,
base_url=runtime.base_url,
timeout=runtime.timeout_seconds,
)
def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str:
"""Runs a text-only task against the configured provider and returns plain output text."""
runtime = resolve_task_runtime(task_name)
client = _create_client(runtime)
prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt
try:
response = client.responses.create(
model=runtime.model,
input=[
{
"role": "system",
"content": [
{
"type": "input_text",
"text": prompt,
}
],
},
{
"role": "user",
"content": [
{
"type": "input_text",
"text": user_text,
}
],
},
],
)
text = _extract_text_from_response(response)
if text:
return text
except APITimeoutError as error:
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
except APIConnectionError as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except APIError as error:
if not _should_fallback_to_chat(error):
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except Exception as error:
if not _should_fallback_to_chat(error):
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
try:
fallback = client.chat.completions.create(
model=runtime.model,
messages=[
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": user_text,
},
],
)
return _extract_text_from_chat_response(fallback)
except APITimeoutError as error:
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
except (APIConnectionError, APIError) as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
except Exception as error:
raise ModelTaskError(f"task_error:{task_name}:{error}") from error