228 lines
7.4 KiB
Python
228 lines
7.4 KiB
Python
"""Model runtime utilities for provider-bound LLM task execution."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
from urllib.parse import urlparse, urlunparse
|
|
|
|
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
|
|
|
from app.services.app_settings import read_task_runtime_settings
|
|
|
|
|
|
class ModelTaskError(Exception):
|
|
"""Raised when a model task request fails."""
|
|
|
|
|
|
class ModelTaskTimeoutError(ModelTaskError):
|
|
"""Raised when a model task request times out."""
|
|
|
|
|
|
class ModelTaskDisabledError(ModelTaskError):
|
|
"""Raised when a model task is disabled in settings."""
|
|
|
|
|
|
@dataclass
|
|
class ModelTaskRuntime:
|
|
"""Resolved runtime configuration for one task and provider."""
|
|
|
|
task_name: str
|
|
provider_id: str
|
|
provider_type: str
|
|
base_url: str
|
|
timeout_seconds: int
|
|
api_key: str
|
|
model: str
|
|
prompt: str
|
|
|
|
|
|
def _normalize_base_url(raw_value: str) -> str:
|
|
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
|
|
|
|
trimmed = raw_value.strip().rstrip("/")
|
|
if not trimmed:
|
|
return "https://api.openai.com/v1"
|
|
|
|
parsed = urlparse(trimmed)
|
|
path = parsed.path or ""
|
|
if not path.endswith("/v1"):
|
|
path = f"{path}/v1" if path else "/v1"
|
|
|
|
return urlunparse(parsed._replace(path=path))
|
|
|
|
|
|
def _should_fallback_to_chat(error: Exception) -> bool:
|
|
"""Determines whether a responses API failure should fallback to chat completions."""
|
|
|
|
status_code = getattr(error, "status_code", None)
|
|
if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}:
|
|
return True
|
|
|
|
message = str(error).lower()
|
|
fallback_markers = (
|
|
"404",
|
|
"not found",
|
|
"unknown endpoint",
|
|
"unsupported",
|
|
"invalid url",
|
|
"responses",
|
|
)
|
|
return any(marker in message for marker in fallback_markers)
|
|
|
|
|
|
def _extract_text_from_response(response: Any) -> str:
|
|
"""Extracts plain text from Responses API outputs."""
|
|
|
|
output_text = getattr(response, "output_text", None)
|
|
if isinstance(output_text, str) and output_text.strip():
|
|
return output_text.strip()
|
|
|
|
output_items = getattr(response, "output", None)
|
|
if not isinstance(output_items, list):
|
|
return ""
|
|
|
|
chunks: list[str] = []
|
|
for item in output_items:
|
|
item_data = item.model_dump() if hasattr(item, "model_dump") else item
|
|
if not isinstance(item_data, dict):
|
|
continue
|
|
|
|
item_type = item_data.get("type")
|
|
if item_type == "output_text":
|
|
text = str(item_data.get("text", "")).strip()
|
|
if text:
|
|
chunks.append(text)
|
|
|
|
if item_type == "message":
|
|
for content in item_data.get("content", []) or []:
|
|
if not isinstance(content, dict):
|
|
continue
|
|
if content.get("type") in {"output_text", "text"}:
|
|
text = str(content.get("text", "")).strip()
|
|
if text:
|
|
chunks.append(text)
|
|
|
|
return "\n".join(chunks).strip()
|
|
|
|
|
|
def _extract_text_from_chat_response(response: Any) -> str:
|
|
"""Extracts text from Chat Completions API outputs."""
|
|
|
|
message_content = response.choices[0].message.content
|
|
if isinstance(message_content, str):
|
|
return message_content.strip()
|
|
if not isinstance(message_content, list):
|
|
return ""
|
|
|
|
chunks: list[str] = []
|
|
for content in message_content:
|
|
if not isinstance(content, dict):
|
|
continue
|
|
text = str(content.get("text", "")).strip()
|
|
if text:
|
|
chunks.append(text)
|
|
return "\n".join(chunks).strip()
|
|
|
|
|
|
def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
|
|
"""Resolves one task runtime including provider endpoint, model, and prompt."""
|
|
|
|
runtime_payload = read_task_runtime_settings(task_name)
|
|
task_payload = runtime_payload["task"]
|
|
provider_payload = runtime_payload["provider"]
|
|
|
|
if not bool(task_payload.get("enabled", True)):
|
|
raise ModelTaskDisabledError(f"task_disabled:{task_name}")
|
|
|
|
provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip()
|
|
if provider_type != "openai_compatible":
|
|
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
|
|
|
|
return ModelTaskRuntime(
|
|
task_name=task_name,
|
|
provider_id=str(provider_payload.get("id", "")),
|
|
provider_type=provider_type,
|
|
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
|
|
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
|
|
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
|
|
model=str(task_payload.get("model", "")).strip(),
|
|
prompt=str(task_payload.get("prompt", "")).strip(),
|
|
)
|
|
|
|
|
|
def _create_client(runtime: ModelTaskRuntime) -> OpenAI:
|
|
"""Builds an OpenAI SDK client for OpenAI-compatible provider endpoints."""
|
|
|
|
return OpenAI(
|
|
api_key=runtime.api_key,
|
|
base_url=runtime.base_url,
|
|
timeout=runtime.timeout_seconds,
|
|
)
|
|
|
|
|
|
def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str:
|
|
"""Runs a text-only task against the configured provider and returns plain output text."""
|
|
|
|
runtime = resolve_task_runtime(task_name)
|
|
client = _create_client(runtime)
|
|
prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt
|
|
|
|
try:
|
|
response = client.responses.create(
|
|
model=runtime.model,
|
|
input=[
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "input_text",
|
|
"text": prompt,
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "input_text",
|
|
"text": user_text,
|
|
}
|
|
],
|
|
},
|
|
],
|
|
)
|
|
text = _extract_text_from_response(response)
|
|
if text:
|
|
return text
|
|
except APITimeoutError as error:
|
|
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
|
|
except APIConnectionError as error:
|
|
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
|
except APIError as error:
|
|
if not _should_fallback_to_chat(error):
|
|
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
|
except Exception as error:
|
|
if not _should_fallback_to_chat(error):
|
|
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
|
|
|
try:
|
|
fallback = client.chat.completions.create(
|
|
model=runtime.model,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": prompt,
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": user_text,
|
|
},
|
|
],
|
|
)
|
|
return _extract_text_from_chat_response(fallback)
|
|
except APITimeoutError as error:
|
|
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
|
|
except (APIConnectionError, APIError) as error:
|
|
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
|
except Exception as error:
|
|
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|