254 lines
12 KiB
Python
254 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from openai import OpenAI
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from app.config import Settings
|
|
from app.models import Alert
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SYSTEM_PROMPT = (
|
|
"You are an expert email authentication and DMARC operations analyst. You explain deterministic DMARC telemetry "
|
|
"to a business owner/admin. You must not invent facts. You must distinguish confirmed facts from likely "
|
|
"interpretations. You must never claim an account is compromised solely from DMARC aggregate failures. You must "
|
|
"provide practical next steps. Output only valid JSON matching the requested schema."
|
|
)
|
|
|
|
ALERT_PROMPT = (
|
|
"Explain this DMARC alert to a business owner/admin. Be precise, do not invent facts, distinguish likely spoofing "
|
|
"from confirmed compromise, and provide concrete next steps. DMARC aggregate source IPs are observed transmitting "
|
|
"IPs from the reporter's point of view and may be final-hop relays, forwarders, mailing lists, or gateways. If SPF "
|
|
"fails but DKIM aligns and DMARC passes, do not frame the IP as a threat or as something to add to SPF; explain that "
|
|
"forwarding commonly breaks SPF while DKIM can still prove authorization. If a source is not legitimate, say not to "
|
|
"add it to known senders, keep it unauthorized, preserve or tighten DMARC enforcement after legitimate senders are "
|
|
"aligned, and investigate whether any internal system is leaking mail through that source. Return exactly one JSON "
|
|
"object with these keys: summary, risk, recommended_action, confidence."
|
|
)
|
|
|
|
POSTURE_DIGEST_PROMPT = (
|
|
"Write a current DMARC posture report for the admin using all supplied deterministic telemetry and all open alerts. "
|
|
"Base the report on unresolved/open risk, not only one report day. Mention exact counts/rates, important failing or "
|
|
"unknown sources, relevant reporters, and concrete remediation. DMARC aggregate source IPs are observed transmitting "
|
|
"IPs from the reporter's point of view and may be final-hop relays, forwarders, mailing lists, or gateways. For "
|
|
"SPF-fail, DKIM-pass, DMARC-pass observations, explain that this commonly indicates forwarding or an intermediary "
|
|
"relay and do not recommend adding those observed relay IPs to SPF solely because they appear in aggregate reports. "
|
|
"For unknown failing sources, explain both branches: if legitimate, authorize/fix SPF/DKIM/alignment and classify; "
|
|
"if not legitimate, do not authorize it, leave it unknown, and use DMARC enforcement such as quarantine/reject once "
|
|
"legitimate senders are aligned. Do not claim mailbox compromise from aggregate data alone. Return only JSON "
|
|
"matching required_json_schema."
|
|
)
|
|
|
|
WEEKLY_PROMPT = (
|
|
"Include high-level posture, trend changes, new senders, persistent failures, whether DMARC policy posture looks "
|
|
"safe, and recommended operational actions. Only say consider stricter policy if the metrics support it."
|
|
)
|
|
|
|
|
|
class AlertExplanation(BaseModel):
|
|
summary: str
|
|
risk: str
|
|
recommended_action: str
|
|
confidence: str = "medium"
|
|
|
|
|
|
class SummaryOutput(BaseModel):
|
|
headline: str
|
|
summary: str
|
|
action_items: list[str] = []
|
|
business_risk: str
|
|
|
|
|
|
def _stringify_action(value: Any) -> str:
|
|
if isinstance(value, list):
|
|
return "; ".join(str(item) for item in value if item)
|
|
if value is None:
|
|
return ""
|
|
return str(value)
|
|
|
|
|
|
def normalize_alert_explanation(output: dict[str, Any], alert: Alert | Any) -> AlertExplanation:
|
|
if {"summary", "risk", "recommended_action"}.issubset(output):
|
|
return AlertExplanation.model_validate(output)
|
|
|
|
explanation = output.get("explanation")
|
|
if isinstance(explanation, dict):
|
|
source = {**explanation, **{key: value for key, value in output.items() if key != "explanation"}}
|
|
else:
|
|
source = dict(output)
|
|
if explanation and "summary" not in source:
|
|
source["summary"] = str(explanation)
|
|
|
|
summary = source.get("summary") or source.get("headline") or getattr(alert, "summary", "")
|
|
risk = source.get("risk") or source.get("business_risk") or source.get("impact")
|
|
action = (
|
|
source.get("recommended_action")
|
|
or source.get("recommendation")
|
|
or source.get("next_step")
|
|
or source.get("next_steps")
|
|
or source.get("action_items")
|
|
)
|
|
|
|
if not risk:
|
|
risk = "Review the deterministic facts. DMARC aggregate data alone does not prove mailbox compromise."
|
|
if not action:
|
|
action = "Review the deterministic facts before changing DNS or sender classification; do not add relay or forwarding IPs to SPF solely because they appear in aggregate reports."
|
|
|
|
return AlertExplanation(
|
|
summary=str(summary),
|
|
risk=str(risk),
|
|
recommended_action=_stringify_action(action),
|
|
confidence=str(source.get("confidence") or "medium"),
|
|
)
|
|
|
|
|
|
def normalize_summary_output(output: dict[str, Any], payload: dict[str, Any]) -> SummaryOutput:
|
|
metrics = payload.get("metrics") or {}
|
|
headline = output.get("headline") or output.get("title")
|
|
summary = output.get("summary") or output.get("explanation") or output.get("analysis")
|
|
risk = output.get("business_risk") or output.get("risk") or output.get("impact")
|
|
actions = output.get("action_items") or output.get("recommended_actions") or output.get("recommendations") or output.get("next_steps") or []
|
|
if isinstance(actions, str):
|
|
actions = [item.strip() for item in actions.split(";") if item.strip()]
|
|
if not isinstance(actions, list):
|
|
actions = []
|
|
|
|
if not headline:
|
|
headline = f"DMARC posture for {payload.get('domain', 'domain')} on {payload.get('period', 'the selected period')}"
|
|
if not summary:
|
|
total = metrics.get("total_messages", 0)
|
|
pass_rate = metrics.get("dmarc_pass_rate", 0)
|
|
failed = metrics.get("dmarc_failed", 0)
|
|
unknown = metrics.get("unknown_sources", 0)
|
|
summary = (
|
|
f"{payload.get('domain', 'The domain')} processed {total} DMARC-observed messages with a {pass_rate}% "
|
|
f"DMARC pass rate. {failed} messages failed DMARC and {unknown} unknown sources were observed."
|
|
)
|
|
if not risk:
|
|
risk = "Review failures and unknown senders before changing policy. DMARC aggregate data alone does not prove mailbox compromise."
|
|
if not actions:
|
|
top_sources = payload.get("top_sources") or []
|
|
source = top_sources[0]["source_ip"] if top_sources and isinstance(top_sources[0], dict) and top_sources[0].get("source_ip") else "the top unknown or failing sources"
|
|
actions = [
|
|
f"Review {source}; if legitimate, fix SPF/DKIM alignment and classify it as approved, and if not legitimate, do not authorize it and rely on DMARC enforcement after legitimate senders are aligned."
|
|
]
|
|
|
|
return SummaryOutput(
|
|
headline=str(headline),
|
|
summary=str(summary),
|
|
action_items=[str(item) for item in actions if str(item).strip()],
|
|
business_risk=str(risk),
|
|
)
|
|
|
|
|
|
def fallback_alert_explanation(alert: Alert | Any) -> AlertExplanation:
|
|
return AlertExplanation(
|
|
summary=getattr(alert, "summary", "DMARC Sentinel created a deterministic alert."),
|
|
risk="Review the deterministic facts. DMARC aggregate data alone does not prove mailbox compromise.",
|
|
recommended_action="Review the deterministic facts before changing DNS or sender classification; do not add relay or forwarding IPs to SPF solely because they appear in aggregate reports.",
|
|
confidence="fallback",
|
|
)
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
self.client = None
|
|
if settings.llm.provider == "openai" and os.getenv(settings.llm.api_key_env):
|
|
self.client = OpenAI(api_key=os.getenv(settings.llm.api_key_env), timeout=settings.llm.timeout_seconds)
|
|
|
|
def _prompt(self, path: str, fallback: str) -> str:
|
|
try:
|
|
prompt_path = Path(path)
|
|
if prompt_path.exists():
|
|
return prompt_path.read_text(encoding="utf-8").strip()
|
|
except OSError as exc:
|
|
logger.warning("Could not read prompt file %s: %s", path, exc)
|
|
return fallback
|
|
|
|
def _json_call(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
if self.client is None:
|
|
raise RuntimeError("OpenAI client is not configured")
|
|
last_error: Exception | None = None
|
|
for attempt in range(self.settings.llm.max_retries + 1):
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.settings.llm.model,
|
|
temperature=self.settings.llm.temperature,
|
|
response_format={"type": "json_object"},
|
|
messages=[
|
|
{"role": "system", "content": self._prompt(self.settings.llm.system_prompt_path, SYSTEM_PROMPT)},
|
|
{"role": "user", "content": json.dumps(payload, sort_keys=True)},
|
|
],
|
|
)
|
|
text = response.choices[0].message.content or "{}"
|
|
return json.loads(text)
|
|
except Exception as exc:
|
|
last_error = exc
|
|
logger.warning("LLM call failed on attempt %s: %s", attempt + 1, exc)
|
|
raise RuntimeError(f"LLM call failed: {last_error}")
|
|
|
|
def explain_alert(self, alert: Alert) -> AlertExplanation:
|
|
payload = {
|
|
"task": "explain_dmarc_alert",
|
|
"domain": alert.domain,
|
|
"severity": alert.severity,
|
|
"alert_type": alert.type,
|
|
"facts": json.loads(alert.details_json or "{}"),
|
|
"required_json_schema": {
|
|
"summary": "string, one concise sentence based only on the supplied facts",
|
|
"risk": "string, business/operational risk without claiming compromise from aggregate data alone",
|
|
"recommended_action": "string, concrete next step for the admin",
|
|
"confidence": "low|medium|high",
|
|
},
|
|
"instruction": self._prompt(self.settings.llm.alert_prompt_path, ALERT_PROMPT),
|
|
}
|
|
last_error: Exception | None = None
|
|
for _ in range(2):
|
|
try:
|
|
output = self._json_call(payload)
|
|
return normalize_alert_explanation(output, alert)
|
|
except (RuntimeError, ValidationError, json.JSONDecodeError) as exc:
|
|
last_error = exc
|
|
logger.warning("LLM alert explanation validation failed for %s: %s", alert.fingerprint, exc)
|
|
logger.warning("Using fallback LLM alert explanation for %s: %s", alert.fingerprint, last_error)
|
|
return fallback_alert_explanation(alert)
|
|
|
|
def daily_summary(self, payload: dict[str, Any]) -> SummaryOutput:
|
|
try:
|
|
enriched = {
|
|
**payload,
|
|
"required_json_schema": payload.get("required_json_schema")
|
|
or {
|
|
"headline": "string",
|
|
"summary": "string",
|
|
"action_items": "array of strings",
|
|
"business_risk": "string",
|
|
},
|
|
"instruction": payload.get("instruction") or self._prompt(self.settings.llm.digest_prompt_path, POSTURE_DIGEST_PROMPT),
|
|
}
|
|
output = self._json_call(enriched)
|
|
return normalize_summary_output(output, enriched)
|
|
except Exception as exc:
|
|
logger.warning("Using fallback daily summary: %s", exc)
|
|
return normalize_summary_output({}, payload)
|
|
|
|
def weekly_summary(self, payload: dict[str, Any]) -> SummaryOutput:
|
|
try:
|
|
output = self._json_call({**payload, "instruction": payload.get("instruction") or self._prompt(self.settings.llm.weekly_prompt_path, WEEKLY_PROMPT)})
|
|
return SummaryOutput.model_validate(output)
|
|
except Exception as exc:
|
|
logger.warning("Using fallback weekly summary: %s", exc)
|
|
return SummaryOutput(
|
|
headline="Weekly DMARC posture summary generated from deterministic telemetry.",
|
|
summary="Review trend changes, new senders, and persistent failures before changing DMARC policy.",
|
|
action_items=["Classify legitimate new senders.", "Investigate persistent failures."],
|
|
business_risk="Unknown",
|
|
)
|