Files
2026-05-16 12:05:36 -03:00

254 lines
12 KiB
Python

from __future__ import annotations
import json
import logging
import os
from pathlib import Path
from typing import Any
from openai import OpenAI
from pydantic import BaseModel, ValidationError
from app.config import Settings
from app.models import Alert
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"You are an expert email authentication and DMARC operations analyst. You explain deterministic DMARC telemetry "
"to a business owner/admin. You must not invent facts. You must distinguish confirmed facts from likely "
"interpretations. You must never claim an account is compromised solely from DMARC aggregate failures. You must "
"provide practical next steps. Output only valid JSON matching the requested schema."
)
ALERT_PROMPT = (
"Explain this DMARC alert to a business owner/admin. Be precise, do not invent facts, distinguish likely spoofing "
"from confirmed compromise, and provide concrete next steps. DMARC aggregate source IPs are observed transmitting "
"IPs from the reporter's point of view and may be final-hop relays, forwarders, mailing lists, or gateways. If SPF "
"fails but DKIM aligns and DMARC passes, do not frame the IP as a threat or as something to add to SPF; explain that "
"forwarding commonly breaks SPF while DKIM can still prove authorization. If a source is not legitimate, say not to "
"add it to known senders, keep it unauthorized, preserve or tighten DMARC enforcement after legitimate senders are "
"aligned, and investigate whether any internal system is leaking mail through that source. Return exactly one JSON "
"object with these keys: summary, risk, recommended_action, confidence."
)
POSTURE_DIGEST_PROMPT = (
"Write a current DMARC posture report for the admin using all supplied deterministic telemetry and all open alerts. "
"Base the report on unresolved/open risk, not only one report day. Mention exact counts/rates, important failing or "
"unknown sources, relevant reporters, and concrete remediation. DMARC aggregate source IPs are observed transmitting "
"IPs from the reporter's point of view and may be final-hop relays, forwarders, mailing lists, or gateways. For "
"SPF-fail, DKIM-pass, DMARC-pass observations, explain that this commonly indicates forwarding or an intermediary "
"relay and do not recommend adding those observed relay IPs to SPF solely because they appear in aggregate reports. "
"For unknown failing sources, explain both branches: if legitimate, authorize/fix SPF/DKIM/alignment and classify; "
"if not legitimate, do not authorize it, leave it unknown, and use DMARC enforcement such as quarantine/reject once "
"legitimate senders are aligned. Do not claim mailbox compromise from aggregate data alone. Return only JSON "
"matching required_json_schema."
)
WEEKLY_PROMPT = (
"Include high-level posture, trend changes, new senders, persistent failures, whether DMARC policy posture looks "
"safe, and recommended operational actions. Only say consider stricter policy if the metrics support it."
)
class AlertExplanation(BaseModel):
summary: str
risk: str
recommended_action: str
confidence: str = "medium"
class SummaryOutput(BaseModel):
headline: str
summary: str
action_items: list[str] = []
business_risk: str
def _stringify_action(value: Any) -> str:
if isinstance(value, list):
return "; ".join(str(item) for item in value if item)
if value is None:
return ""
return str(value)
def normalize_alert_explanation(output: dict[str, Any], alert: Alert | Any) -> AlertExplanation:
if {"summary", "risk", "recommended_action"}.issubset(output):
return AlertExplanation.model_validate(output)
explanation = output.get("explanation")
if isinstance(explanation, dict):
source = {**explanation, **{key: value for key, value in output.items() if key != "explanation"}}
else:
source = dict(output)
if explanation and "summary" not in source:
source["summary"] = str(explanation)
summary = source.get("summary") or source.get("headline") or getattr(alert, "summary", "")
risk = source.get("risk") or source.get("business_risk") or source.get("impact")
action = (
source.get("recommended_action")
or source.get("recommendation")
or source.get("next_step")
or source.get("next_steps")
or source.get("action_items")
)
if not risk:
risk = "Review the deterministic facts. DMARC aggregate data alone does not prove mailbox compromise."
if not action:
action = "Review the deterministic facts before changing DNS or sender classification; do not add relay or forwarding IPs to SPF solely because they appear in aggregate reports."
return AlertExplanation(
summary=str(summary),
risk=str(risk),
recommended_action=_stringify_action(action),
confidence=str(source.get("confidence") or "medium"),
)
def normalize_summary_output(output: dict[str, Any], payload: dict[str, Any]) -> SummaryOutput:
metrics = payload.get("metrics") or {}
headline = output.get("headline") or output.get("title")
summary = output.get("summary") or output.get("explanation") or output.get("analysis")
risk = output.get("business_risk") or output.get("risk") or output.get("impact")
actions = output.get("action_items") or output.get("recommended_actions") or output.get("recommendations") or output.get("next_steps") or []
if isinstance(actions, str):
actions = [item.strip() for item in actions.split(";") if item.strip()]
if not isinstance(actions, list):
actions = []
if not headline:
headline = f"DMARC posture for {payload.get('domain', 'domain')} on {payload.get('period', 'the selected period')}"
if not summary:
total = metrics.get("total_messages", 0)
pass_rate = metrics.get("dmarc_pass_rate", 0)
failed = metrics.get("dmarc_failed", 0)
unknown = metrics.get("unknown_sources", 0)
summary = (
f"{payload.get('domain', 'The domain')} processed {total} DMARC-observed messages with a {pass_rate}% "
f"DMARC pass rate. {failed} messages failed DMARC and {unknown} unknown sources were observed."
)
if not risk:
risk = "Review failures and unknown senders before changing policy. DMARC aggregate data alone does not prove mailbox compromise."
if not actions:
top_sources = payload.get("top_sources") or []
source = top_sources[0]["source_ip"] if top_sources and isinstance(top_sources[0], dict) and top_sources[0].get("source_ip") else "the top unknown or failing sources"
actions = [
f"Review {source}; if legitimate, fix SPF/DKIM alignment and classify it as approved, and if not legitimate, do not authorize it and rely on DMARC enforcement after legitimate senders are aligned."
]
return SummaryOutput(
headline=str(headline),
summary=str(summary),
action_items=[str(item) for item in actions if str(item).strip()],
business_risk=str(risk),
)
def fallback_alert_explanation(alert: Alert | Any) -> AlertExplanation:
return AlertExplanation(
summary=getattr(alert, "summary", "DMARC Sentinel created a deterministic alert."),
risk="Review the deterministic facts. DMARC aggregate data alone does not prove mailbox compromise.",
recommended_action="Review the deterministic facts before changing DNS or sender classification; do not add relay or forwarding IPs to SPF solely because they appear in aggregate reports.",
confidence="fallback",
)
class LLMClient:
def __init__(self, settings: Settings):
self.settings = settings
self.client = None
if settings.llm.provider == "openai" and os.getenv(settings.llm.api_key_env):
self.client = OpenAI(api_key=os.getenv(settings.llm.api_key_env), timeout=settings.llm.timeout_seconds)
def _prompt(self, path: str, fallback: str) -> str:
try:
prompt_path = Path(path)
if prompt_path.exists():
return prompt_path.read_text(encoding="utf-8").strip()
except OSError as exc:
logger.warning("Could not read prompt file %s: %s", path, exc)
return fallback
def _json_call(self, payload: dict[str, Any]) -> dict[str, Any]:
if self.client is None:
raise RuntimeError("OpenAI client is not configured")
last_error: Exception | None = None
for attempt in range(self.settings.llm.max_retries + 1):
try:
response = self.client.chat.completions.create(
model=self.settings.llm.model,
temperature=self.settings.llm.temperature,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": self._prompt(self.settings.llm.system_prompt_path, SYSTEM_PROMPT)},
{"role": "user", "content": json.dumps(payload, sort_keys=True)},
],
)
text = response.choices[0].message.content or "{}"
return json.loads(text)
except Exception as exc:
last_error = exc
logger.warning("LLM call failed on attempt %s: %s", attempt + 1, exc)
raise RuntimeError(f"LLM call failed: {last_error}")
def explain_alert(self, alert: Alert) -> AlertExplanation:
payload = {
"task": "explain_dmarc_alert",
"domain": alert.domain,
"severity": alert.severity,
"alert_type": alert.type,
"facts": json.loads(alert.details_json or "{}"),
"required_json_schema": {
"summary": "string, one concise sentence based only on the supplied facts",
"risk": "string, business/operational risk without claiming compromise from aggregate data alone",
"recommended_action": "string, concrete next step for the admin",
"confidence": "low|medium|high",
},
"instruction": self._prompt(self.settings.llm.alert_prompt_path, ALERT_PROMPT),
}
last_error: Exception | None = None
for _ in range(2):
try:
output = self._json_call(payload)
return normalize_alert_explanation(output, alert)
except (RuntimeError, ValidationError, json.JSONDecodeError) as exc:
last_error = exc
logger.warning("LLM alert explanation validation failed for %s: %s", alert.fingerprint, exc)
logger.warning("Using fallback LLM alert explanation for %s: %s", alert.fingerprint, last_error)
return fallback_alert_explanation(alert)
def daily_summary(self, payload: dict[str, Any]) -> SummaryOutput:
try:
enriched = {
**payload,
"required_json_schema": payload.get("required_json_schema")
or {
"headline": "string",
"summary": "string",
"action_items": "array of strings",
"business_risk": "string",
},
"instruction": payload.get("instruction") or self._prompt(self.settings.llm.digest_prompt_path, POSTURE_DIGEST_PROMPT),
}
output = self._json_call(enriched)
return normalize_summary_output(output, enriched)
except Exception as exc:
logger.warning("Using fallback daily summary: %s", exc)
return normalize_summary_output({}, payload)
def weekly_summary(self, payload: dict[str, Any]) -> SummaryOutput:
try:
output = self._json_call({**payload, "instruction": payload.get("instruction") or self._prompt(self.settings.llm.weekly_prompt_path, WEEKLY_PROMPT)})
return SummaryOutput.model_validate(output)
except Exception as exc:
logger.warning("Using fallback weekly summary: %s", exc)
return SummaryOutput(
headline="Weekly DMARC posture summary generated from deterministic telemetry.",
summary="Review trend changes, new senders, and persistent failures before changing DMARC policy.",
action_items=["Classify legitimate new senders.", "Investigate persistent failures."],
business_risk="Unknown",
)