from __future__ import annotations import json import logging from datetime import date, datetime, time, timedelta, timezone from zoneinfo import ZoneInfo from apscheduler.schedulers.background import BackgroundScheduler from sqlalchemy import desc, func, select from sqlalchemy.orm import Session from app.alerts import send_digest_email from app.config import Settings from app.db import session_scope from app.inbox_locks import inbox_run_locks from app.llm import LLMClient from app.message_processor import process_inbox from app.models import Alert, DailyStat, LLMReport, Record, Report, utcnow logger = logging.getLogger(__name__) scheduler: BackgroundScheduler | None = None def _as_utc(value: datetime | str | None) -> datetime | None: if value is None: return None if isinstance(value, str): try: value = datetime.fromisoformat(value.replace("Z", "+00:00")) except ValueError: return None if value.tzinfo is None: return value.replace(tzinfo=timezone.utc) return value def poll_all(settings: Settings) -> None: logger.info("Poll start") with session_scope() as session: for inbox in settings.enabled_inboxes(): lease = inbox_run_locks.acquire(inbox.id, blocking=False) if not lease: logger.info("Skipping inbox %s because another import is already running", inbox.id) continue try: with lease: process_inbox(session, settings, inbox, mode="new") except Exception as exc: logger.warning("Polling inbox %s failed: %s", inbox.id, exc) logger.info("Poll end") def _domain_records(session: Session, domain: str, start: datetime, end: datetime) -> list[Record]: return session.execute( select(Record) .join(Report) .where( Report.domain == domain, func.coalesce(Report.date_end, Report.date_begin, Report.created_at) >= start, func.coalesce(Report.date_end, Report.date_begin, Report.created_at) < end, ) ).scalars().all() def aggregate_daily_stats(session: Session, domain: str, day: date) -> DailyStat: start = datetime.combine(day, time.min, tzinfo=timezone.utc) end = start + timedelta(days=1) records = _domain_records(session, domain, start, end) total = sum(row.count for row in records) reporters = session.execute( select(Report.org_name, func.count(Report.id)) .where( Report.domain == domain, func.coalesce(Report.date_end, Report.date_begin, Report.created_at) >= start, func.coalesce(Report.date_end, Report.date_begin, Report.created_at) < end, ) .group_by(Report.org_name) ).all() sources = sorted(((row.source_ip, row.count) for row in records), key=lambda item: item[1], reverse=True)[:10] stat = session.scalar(select(DailyStat).where(DailyStat.domain == domain, DailyStat.date == day)) if not stat: stat = DailyStat(domain=domain, date=day) session.add(stat) stat.total_messages = total stat.dmarc_pass_count = sum(row.count for row in records if row.dmarc_pass) stat.dmarc_fail_count = total - stat.dmarc_pass_count stat.spf_aligned_count = sum(row.count for row in records if row.spf_aligned) stat.spf_failed_count = total - stat.spf_aligned_count stat.dkim_aligned_count = sum(row.count for row in records if row.dkim_aligned) stat.dkim_failed_count = total - stat.dkim_aligned_count stat.unknown_source_count = len({row.source_ip for row in records if not row.is_known_sender}) stat.known_source_count = len({row.source_ip for row in records if row.is_known_sender}) stat.quarantine_count = sum(row.count for row in records if row.disposition == "quarantine") stat.reject_count = sum(row.count for row in records if row.disposition == "reject") stat.top_reporters_json = json.dumps([{"org": org, "reports": count} for org, count in reporters if org]) stat.top_sources_json = json.dumps([{"source_ip": ip, "count": count} for ip, count in sources]) return stat def _summary_payload(session: Session, domain: str, day: date, stat: DailyStat) -> dict: period_start = datetime.combine(day, time.min, tzinfo=timezone.utc) period_end = datetime.combine(day + timedelta(days=1), time.min, tzinfo=timezone.utc) critical = session.scalar(select(func.count(Alert.id)).where(Alert.domain == domain, Alert.status == "open", Alert.severity == "critical")) or 0 warnings = session.scalar(select(func.count(Alert.id)).where(Alert.domain == domain, Alert.status == "open", Alert.severity == "warning")) or 0 alerts = session.execute( select(Alert) .where(Alert.domain == domain, Alert.status == "open") .order_by(Alert.severity.desc(), Alert.updated_at.desc()) .limit(10) ).scalars().all() reports = session.execute( select(Report.org_name, func.count(Report.id)) .where( Report.domain == domain, func.coalesce(Report.date_end, Report.date_begin, Report.created_at) >= period_start, func.coalesce(Report.date_end, Report.date_begin, Report.created_at) < period_end, ) .group_by(Report.org_name) .order_by(desc(func.count(Report.id))) .limit(10) ).all() total = stat.total_messages return { "task": "daily_dmarc_summary", "domain": domain, "period": day.isoformat(), "required_json_schema": { "headline": "string, specific concise headline for the report period", "summary": "string, 2-4 sentences using the supplied metrics, sources, reporters and alerts", "action_items": "array of specific action strings based on the telemetry", "business_risk": "string, concise risk statement; do not claim compromise from DMARC aggregate data alone", }, "metrics": { "total_messages": total, "dmarc_passed": stat.dmarc_pass_count, "dmarc_failed": stat.dmarc_fail_count, "dmarc_pass_rate": round(stat.dmarc_pass_count / total * 100, 2) if total else 0, "spf_alignment_rate": round(stat.spf_aligned_count / total * 100, 2) if total else 0, "dkim_alignment_rate": round(stat.dkim_aligned_count / total * 100, 2) if total else 0, "unknown_sources": stat.unknown_source_count, "critical_alerts": critical, "warnings": warnings, }, "top_sources": json.loads(stat.top_sources_json or "[]"), "reporters": [{"org": org or "unknown", "reports": count} for org, count in reports], "alerts": [ { "severity": alert.severity, "type": alert.type, "title": alert.title, "summary": alert.summary, "details": json.loads(alert.details_json or "{}"), } for alert in alerts ], "instruction": ( "Write an actual operational DMARC daily summary for the admin. Mention exact pass/fail counts and rates, " "important unknown or failing sources, relevant reporters, and concrete next actions. Do not provide generic " "advice if the telemetry supports a specific recommendation. Return only JSON matching required_json_schema." ), } def _posture_payload(session: Session, domain: str) -> tuple[dict, datetime, datetime]: bounds = session.execute( select( func.min(func.coalesce(Report.date_end, Report.date_begin, Report.created_at)), func.max(func.coalesce(Report.date_end, Report.date_begin, Report.created_at)), ).where(Report.domain == domain) ).one() period_start = _as_utc(bounds[0]) or datetime.now(timezone.utc) period_end = _as_utc(bounds[1]) or period_start records = session.execute(select(Record).join(Report).where(Report.domain == domain)).scalars().all() reports = session.execute( select(Report.org_name, func.count(Report.id)) .where(Report.domain == domain) .group_by(Report.org_name) .order_by(desc(func.count(Report.id))) .limit(10) ).all() alerts = session.execute( select(Alert) .where(Alert.domain == domain, Alert.status == "open") .order_by(Alert.severity.desc(), Alert.updated_at.desc()) ).scalars().all() total = sum(row.count for row in records) dmarc_pass = sum(row.count for row in records if row.dmarc_pass) spf_aligned = sum(row.count for row in records if row.spf_aligned) dkim_aligned = sum(row.count for row in records if row.dkim_aligned) unknown_records = [row for row in records if not row.is_known_sender] failing_unknown = [row for row in unknown_records if not row.dmarc_pass] top_sources = sorted(records, key=lambda row: row.count, reverse=True)[:12] return ( { "task": "current_dmarc_open_posture_summary", "domain": domain, "period": {"start": period_start.isoformat(), "end": period_end.isoformat()}, "required_json_schema": { "headline": "string, specific concise headline for the current posture", "summary": "string, 2-5 sentences based on all imported telemetry and open alerts", "action_items": "array of specific action strings with if-legitimate and if-not-legitimate remediation where relevant", "business_risk": "string, concise risk statement; do not claim compromise from DMARC aggregate data alone", }, "metrics": { "total_reports": session.scalar(select(func.count(Report.id)).where(Report.domain == domain)) or 0, "total_messages": total, "dmarc_passed": dmarc_pass, "dmarc_failed": total - dmarc_pass, "dmarc_pass_rate": round(dmarc_pass / total * 100, 2) if total else 0, "spf_alignment_rate": round(spf_aligned / total * 100, 2) if total else 0, "dkim_alignment_rate": round(dkim_aligned / total * 100, 2) if total else 0, "unknown_sources": len({row.source_ip for row in unknown_records}), "unknown_failing_sources": len({row.source_ip for row in failing_unknown}), "open_critical_alerts": len([alert for alert in alerts if alert.severity == "critical"]), "open_warnings": len([alert for alert in alerts if alert.severity == "warning"]), }, "top_sources": [ { "source_ip": row.source_ip, "count": row.count, "dmarc_pass": row.dmarc_pass, "known_sender": row.known_sender_name, "spf_aligned": row.spf_aligned, "dkim_aligned": row.dkim_aligned, } for row in top_sources ], "reporters": [{"org": org or "unknown", "reports": count} for org, count in reports], "open_alerts": [ { "severity": alert.severity, "type": alert.type, "title": alert.title, "summary": alert.summary, "details": json.loads(alert.details_json or "{}"), } for alert in alerts[:25] ], "instruction": ( "Write a current DMARC posture report from all imported telemetry and all open alerts. Do not focus only " "on the latest report day. For failing unknown sources, state what to do if they are legitimate and what " "to do if they are not legitimate. Make the DMARC enforcement relationship explicit: quarantine/reject " "helps receivers handle unauthorized spoofing only after legitimate senders are aligned. Return only JSON." ), }, period_start, period_end, ) def _portfolio_posture_payload(session: Session) -> tuple[dict, datetime, datetime] | None: domains = session.execute(select(Report.domain).distinct().order_by(Report.domain)).scalars().all() if not domains: return None bounds = session.execute( select( func.min(func.coalesce(Report.date_end, Report.date_begin, Report.created_at)), func.max(func.coalesce(Report.date_end, Report.date_begin, Report.created_at)), ) ).one() period_start = _as_utc(bounds[0]) or datetime.now(timezone.utc) period_end = _as_utc(bounds[1]) or period_start records = session.execute(select(Record).join(Report)).scalars().all() alerts = session.execute( select(Alert) .where(Alert.status == "open") .order_by(Alert.severity.desc(), Alert.updated_at.desc()) ).scalars().all() total = sum(row.count for row in records) dmarc_pass = sum(row.count for row in records if row.dmarc_pass) unknown_records = [row for row in records if not row.is_known_sender] failing_unknown = [row for row in unknown_records if not row.dmarc_pass] domain_rows = [] for domain in domains: domain_records = session.execute(select(Record).join(Report).where(Report.domain == domain)).scalars().all() domain_total = sum(row.count for row in domain_records) domain_pass = sum(row.count for row in domain_records if row.dmarc_pass) domain_alerts = [alert for alert in alerts if alert.domain == domain] domain_rows.append( { "domain": domain, "reports": session.scalar(select(func.count(Report.id)).where(Report.domain == domain)) or 0, "messages": domain_total, "dmarc_pass_rate": round(domain_pass / domain_total * 100, 2) if domain_total else 0, "unknown_sources": len({row.source_ip for row in domain_records if not row.is_known_sender}), "open_critical_alerts": len([alert for alert in domain_alerts if alert.severity == "critical"]), "open_warnings": len([alert for alert in domain_alerts if alert.severity == "warning"]), } ) top_alerts = [ { "domain": alert.domain, "severity": alert.severity, "type": alert.type, "title": alert.title, "summary": alert.summary, "details": json.loads(alert.details_json or "{}"), } for alert in alerts[:12] ] return ( { "task": "current_dmarc_portfolio_posture_summary", "scope": "all_domains", "domains": domains, "period": {"start": period_start.isoformat(), "end": period_end.isoformat()}, "required_json_schema": { "headline": "string, concise portfolio headline across all domains", "summary": "string, 2-3 sentences covering all domains without per-record verbosity", "action_items": "array of 1-4 specific cross-domain or domain-named action strings", "business_risk": "string, concise portfolio-level risk statement", }, "metrics": { "domains": len(domains), "total_reports": session.scalar(select(func.count(Report.id))) or 0, "total_messages": total, "dmarc_passed": dmarc_pass, "dmarc_failed": total - dmarc_pass, "dmarc_pass_rate": round(dmarc_pass / total * 100, 2) if total else 0, "unknown_sources": len({row.source_ip for row in unknown_records}), "unknown_failing_sources": len({row.source_ip for row in failing_unknown}), "open_critical_alerts": len([alert for alert in alerts if alert.severity == "critical"]), "open_warnings": len([alert for alert in alerts if alert.severity == "warning"]), }, "domain_posture": domain_rows, "top_open_alerts": top_alerts, "instruction": ( "Write a compact all-domain DMARC portfolio posture for the overview page. Compare domains only where " "there is a meaningful difference. Keep it shorter than a single-domain detail report. Mention exact " "domain names only for domains that need attention. Return only JSON." ), }, period_start, period_end, ) def generate_open_posture_summaries(settings: Settings, *, force: bool = True) -> list[LLMReport]: if not settings.llm.generate_daily_summary: logger.info("Open posture summaries skipped because daily LLM summaries are disabled") return [] llm = LLMClient(settings) generated: list[LLMReport] = [] with session_scope() as session: portfolio = _portfolio_posture_payload(session) if portfolio: payload, period_start, period_end = portfolio existing = session.scalar( select(LLMReport).where( LLMReport.domain == "__all__", LLMReport.report_type == "posture", LLMReport.period_start == period_start, LLMReport.period_end == period_end, ) ) if existing and not force: generated.append(existing) else: report = existing or LLMReport( domain="__all__", period_start=period_start, period_end=period_end, report_type="posture", input_json="{}", output_json="{}", plain_text="", ) if settings.llm.store_llm_outputs and not existing: session.add(report) output = llm.daily_summary(payload) plain = f"{output.headline}\n\n{output.summary}\n\nActions: " + "; ".join(output.action_items) if settings.llm.store_llm_outputs: report.input_json = json.dumps(payload, sort_keys=True, default=str) report.output_json = output.model_dump_json() report.plain_text = plain generated.append(report) send_digest_email(settings, "DMARC Sentinel portfolio posture summary", plain) domains = session.execute(select(Report.domain).distinct()).scalars().all() for domain in domains: payload, period_start, period_end = _posture_payload(session, domain) existing = session.scalar( select(LLMReport).where( LLMReport.domain == domain, LLMReport.report_type == "posture", LLMReport.period_start == period_start, LLMReport.period_end == period_end, ) ) if existing and not force: generated.append(existing) continue report = existing or LLMReport( domain=domain, period_start=period_start, period_end=period_end, report_type="posture", input_json="{}", output_json="{}", plain_text="", ) if settings.llm.store_llm_outputs and not existing: session.add(report) output = llm.daily_summary(payload) plain = f"{output.headline}\n\n{output.summary}\n\nActions: " + "; ".join(output.action_items) if settings.llm.store_llm_outputs: report.input_json = json.dumps(payload, sort_keys=True, default=str) report.output_json = output.model_dump_json() report.plain_text = plain generated.append(report) send_digest_email(settings, f"DMARC Sentinel posture summary for {domain}", plain) logger.info("Open posture summaries generated") return generated def generate_daily_summaries(settings: Settings, target_day: date | None = None, *, force: bool = False) -> list[LLMReport]: if not settings.llm.generate_daily_summary: logger.info("Daily summaries skipped because daily LLM summaries are disabled") return [] target_day = target_day or (date.today() - timedelta(days=1)) llm = LLMClient(settings) generated: list[LLMReport] = [] with session_scope() as session: domains = session.execute(select(Report.domain).distinct()).scalars().all() for domain in domains: stat = aggregate_daily_stats(session, domain, target_day) payload = _summary_payload(session, domain, target_day, stat) period_start = datetime.combine(target_day, time.min, tzinfo=timezone.utc) period_end = datetime.combine(target_day + timedelta(days=1), time.min, tzinfo=timezone.utc) existing = session.scalar( select(LLMReport).where( LLMReport.domain == domain, LLMReport.report_type == "daily", LLMReport.period_start == period_start, LLMReport.period_end == period_end, ) ) if existing: if not force: generated.append(existing) continue report = existing else: report = LLMReport( domain=domain, period_start=period_start, period_end=period_end, report_type="daily", input_json="{}", output_json="{}", plain_text="", ) if settings.llm.store_llm_outputs: session.add(report) output = llm.daily_summary(payload) plain = f"{output.headline}\n\n{output.summary}\n\nActions: " + "; ".join(output.action_items) if settings.llm.store_llm_outputs: report.input_json = json.dumps(payload, sort_keys=True, default=str) report.output_json = output.model_dump_json() report.plain_text = plain generated.append(report) send_digest_email(settings, f"DMARC Sentinel daily summary for {domain}", plain) logger.info("Daily summaries generated") return generated def generate_weekly_summaries(settings: Settings) -> list[LLMReport]: if not settings.llm.generate_weekly_summary: logger.info("Weekly summaries skipped because weekly LLM summaries are disabled") return [] end = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) start = end - timedelta(days=7) llm = LLMClient(settings) generated: list[LLMReport] = [] with session_scope() as session: domains = session.execute(select(Report.domain).distinct()).scalars().all() for domain in domains: records = _domain_records(session, domain, start, end) existing = session.scalar( select(LLMReport).where( LLMReport.domain == domain, LLMReport.report_type == "weekly", LLMReport.period_start == start, LLMReport.period_end == end, ) ) if existing: continue total = sum(row.count for row in records) pass_count = sum(row.count for row in records if row.dmarc_pass) payload = { "task": "weekly_dmarc_summary", "domain": domain, "period": {"start": start.isoformat(), "end": end.isoformat()}, "metrics": { "total_messages": total, "dmarc_pass_rate": round(pass_count / total * 100, 2) if total else 0, "new_senders": len({row.source_ip for row in records if not row.is_known_sender and row.dmarc_pass}), "persistent_failures": len({row.source_ip for row in records if not row.dmarc_pass}), "critical_known_sender_failures": len({row.known_sender_id for row in records if row.is_known_sender and not row.dmarc_pass}), }, "instruction": ( "Include high-level posture, trend changes, new senders, persistent failures, whether DMARC policy " "posture looks safe, and recommended operational actions. Only say consider stricter policy if the " "metrics support it." ), } output = llm.weekly_summary(payload) plain = f"{output.headline}\n\n{output.summary}\n\nActions: " + "; ".join(output.action_items) report = LLMReport( domain=domain, period_start=start, period_end=end, report_type="weekly", input_json=json.dumps(payload, sort_keys=True), output_json=output.model_dump_json(), plain_text=plain, ) if settings.llm.store_llm_outputs: session.add(report) generated.append(report) send_digest_email(settings, f"DMARC Sentinel weekly summary for {domain}", plain) logger.info("Weekly summaries generated") return generated def start_scheduler(settings: Settings) -> BackgroundScheduler: global scheduler tz = ZoneInfo(settings.app.timezone) scheduler = BackgroundScheduler(timezone=tz) scheduler.add_job(poll_all, "interval", minutes=settings.app.poll_interval_minutes, args=[settings], id="poll", replace_existing=True) if settings.llm.generate_daily_summary: scheduler.add_job(generate_daily_summaries, "cron", hour=7, minute=0, args=[settings], id="daily", replace_existing=True) if settings.llm.generate_weekly_summary: scheduler.add_job(generate_weekly_summaries, "cron", day_of_week="mon", hour=7, minute=30, args=[settings], id="weekly", replace_existing=True) scheduler.start() return scheduler def scheduler_ok() -> bool: return bool(scheduler and scheduler.running)