Initial commit
This commit is contained in:
227
backend/app/services/model_runtime.py
Normal file
227
backend/app/services/model_runtime.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Model runtime utilities for provider-bound LLM task execution."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from openai import APIConnectionError, APIError, APITimeoutError, OpenAI
|
||||
|
||||
from app.services.app_settings import read_task_runtime_settings
|
||||
|
||||
|
||||
class ModelTaskError(Exception):
|
||||
"""Raised when a model task request fails."""
|
||||
|
||||
|
||||
class ModelTaskTimeoutError(ModelTaskError):
|
||||
"""Raised when a model task request times out."""
|
||||
|
||||
|
||||
class ModelTaskDisabledError(ModelTaskError):
|
||||
"""Raised when a model task is disabled in settings."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTaskRuntime:
|
||||
"""Resolved runtime configuration for one task and provider."""
|
||||
|
||||
task_name: str
|
||||
provider_id: str
|
||||
provider_type: str
|
||||
base_url: str
|
||||
timeout_seconds: int
|
||||
api_key: str
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
|
||||
def _normalize_base_url(raw_value: str) -> str:
|
||||
"""Normalizes provider base URL and appends /v1 for OpenAI-compatible servers."""
|
||||
|
||||
trimmed = raw_value.strip().rstrip("/")
|
||||
if not trimmed:
|
||||
return "https://api.openai.com/v1"
|
||||
|
||||
parsed = urlparse(trimmed)
|
||||
path = parsed.path or ""
|
||||
if not path.endswith("/v1"):
|
||||
path = f"{path}/v1" if path else "/v1"
|
||||
|
||||
return urlunparse(parsed._replace(path=path))
|
||||
|
||||
|
||||
def _should_fallback_to_chat(error: Exception) -> bool:
|
||||
"""Determines whether a responses API failure should fallback to chat completions."""
|
||||
|
||||
status_code = getattr(error, "status_code", None)
|
||||
if isinstance(status_code, int) and status_code in {400, 404, 405, 415, 422, 501}:
|
||||
return True
|
||||
|
||||
message = str(error).lower()
|
||||
fallback_markers = (
|
||||
"404",
|
||||
"not found",
|
||||
"unknown endpoint",
|
||||
"unsupported",
|
||||
"invalid url",
|
||||
"responses",
|
||||
)
|
||||
return any(marker in message for marker in fallback_markers)
|
||||
|
||||
|
||||
def _extract_text_from_response(response: Any) -> str:
|
||||
"""Extracts plain text from Responses API outputs."""
|
||||
|
||||
output_text = getattr(response, "output_text", None)
|
||||
if isinstance(output_text, str) and output_text.strip():
|
||||
return output_text.strip()
|
||||
|
||||
output_items = getattr(response, "output", None)
|
||||
if not isinstance(output_items, list):
|
||||
return ""
|
||||
|
||||
chunks: list[str] = []
|
||||
for item in output_items:
|
||||
item_data = item.model_dump() if hasattr(item, "model_dump") else item
|
||||
if not isinstance(item_data, dict):
|
||||
continue
|
||||
|
||||
item_type = item_data.get("type")
|
||||
if item_type == "output_text":
|
||||
text = str(item_data.get("text", "")).strip()
|
||||
if text:
|
||||
chunks.append(text)
|
||||
|
||||
if item_type == "message":
|
||||
for content in item_data.get("content", []) or []:
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
if content.get("type") in {"output_text", "text"}:
|
||||
text = str(content.get("text", "")).strip()
|
||||
if text:
|
||||
chunks.append(text)
|
||||
|
||||
return "\n".join(chunks).strip()
|
||||
|
||||
|
||||
def _extract_text_from_chat_response(response: Any) -> str:
|
||||
"""Extracts text from Chat Completions API outputs."""
|
||||
|
||||
message_content = response.choices[0].message.content
|
||||
if isinstance(message_content, str):
|
||||
return message_content.strip()
|
||||
if not isinstance(message_content, list):
|
||||
return ""
|
||||
|
||||
chunks: list[str] = []
|
||||
for content in message_content:
|
||||
if not isinstance(content, dict):
|
||||
continue
|
||||
text = str(content.get("text", "")).strip()
|
||||
if text:
|
||||
chunks.append(text)
|
||||
return "\n".join(chunks).strip()
|
||||
|
||||
|
||||
def resolve_task_runtime(task_name: str) -> ModelTaskRuntime:
|
||||
"""Resolves one task runtime including provider endpoint, model, and prompt."""
|
||||
|
||||
runtime_payload = read_task_runtime_settings(task_name)
|
||||
task_payload = runtime_payload["task"]
|
||||
provider_payload = runtime_payload["provider"]
|
||||
|
||||
if not bool(task_payload.get("enabled", True)):
|
||||
raise ModelTaskDisabledError(f"task_disabled:{task_name}")
|
||||
|
||||
provider_type = str(provider_payload.get("provider_type", "openai_compatible")).strip()
|
||||
if provider_type != "openai_compatible":
|
||||
raise ModelTaskError(f"unsupported_provider_type:{provider_type}")
|
||||
|
||||
return ModelTaskRuntime(
|
||||
task_name=task_name,
|
||||
provider_id=str(provider_payload.get("id", "")),
|
||||
provider_type=provider_type,
|
||||
base_url=_normalize_base_url(str(provider_payload.get("base_url", "https://api.openai.com/v1"))),
|
||||
timeout_seconds=int(provider_payload.get("timeout_seconds", 45)),
|
||||
api_key=str(provider_payload.get("api_key", "")).strip() or "no-key-required",
|
||||
model=str(task_payload.get("model", "")).strip(),
|
||||
prompt=str(task_payload.get("prompt", "")).strip(),
|
||||
)
|
||||
|
||||
|
||||
def _create_client(runtime: ModelTaskRuntime) -> OpenAI:
|
||||
"""Builds an OpenAI SDK client for OpenAI-compatible provider endpoints."""
|
||||
|
||||
return OpenAI(
|
||||
api_key=runtime.api_key,
|
||||
base_url=runtime.base_url,
|
||||
timeout=runtime.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
def complete_text_task(task_name: str, user_text: str, prompt_override: str | None = None) -> str:
|
||||
"""Runs a text-only task against the configured provider and returns plain output text."""
|
||||
|
||||
runtime = resolve_task_runtime(task_name)
|
||||
client = _create_client(runtime)
|
||||
prompt = (prompt_override or runtime.prompt).strip() or runtime.prompt
|
||||
|
||||
try:
|
||||
response = client.responses.create(
|
||||
model=runtime.model,
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": prompt,
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": user_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
text = _extract_text_from_response(response)
|
||||
if text:
|
||||
return text
|
||||
except APITimeoutError as error:
|
||||
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
|
||||
except APIConnectionError as error:
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
except APIError as error:
|
||||
if not _should_fallback_to_chat(error):
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
except Exception as error:
|
||||
if not _should_fallback_to_chat(error):
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
|
||||
try:
|
||||
fallback = client.chat.completions.create(
|
||||
model=runtime.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_text,
|
||||
},
|
||||
],
|
||||
)
|
||||
return _extract_text_from_chat_response(fallback)
|
||||
except APITimeoutError as error:
|
||||
raise ModelTaskTimeoutError(f"task_timeout:{task_name}") from error
|
||||
except (APIConnectionError, APIError) as error:
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
except Exception as error:
|
||||
raise ModelTaskError(f"task_error:{task_name}:{error}") from error
|
||||
Reference in New Issue
Block a user