from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from ipaddress import ip_address from typing import Iterable from defusedxml import ElementTree as ET class DMARCParseError(Exception): pass @dataclass class ParsedAuthResult: auth_type: str domain: str | None = None selector: str | None = None scope: str | None = None result: str | None = None human_result: str | None = None @dataclass class ParsedRecord: source_ip: str count: int disposition: str | None policy_dkim: str | None policy_spf: str | None dkim_aligned: bool spf_aligned: bool dmarc_pass: bool header_from: str | None reason_type: str | None reason_comment: str | None auth_results: list[ParsedAuthResult] = field(default_factory=list) @dataclass class ParsedReport: org_name: str | None org_email: str | None extra_contact_info: str | None report_id: str | None date_begin: datetime | None date_end: datetime | None domain: str adkim: str | None aspf: str | None policy_p: str | None policy_sp: str | None policy_pct: int | None fo: str | None records: list[ParsedRecord] def _strip_namespace(tag: str) -> str: return tag.rsplit("}", 1)[-1] if "}" in tag else tag def _children(element: ET.Element, name: str) -> Iterable[ET.Element]: for child in list(element): if _strip_namespace(child.tag) == name: yield child def _child(element: ET.Element, path: str) -> ET.Element | None: current = element for piece in path.split("/"): found = None for child in _children(current, piece): found = child break if found is None: return None current = found return current def _text(element: ET.Element, path: str) -> str | None: found = _child(element, path) if found is None or found.text is None: return None value = found.text.strip() return value or None def _int(value: str | None) -> int | None: if value in (None, ""): return None try: return int(value) except ValueError: return None def _dt(value: str | None) -> datetime | None: number = _int(value) if number is None: return None return datetime.fromtimestamp(number, tz=timezone.utc) def _validate_report_dates(date_begin: datetime | None, date_end: datetime | None, max_future_days: int, max_past_days: int) -> None: now = datetime.now(timezone.utc) earliest = now - timedelta(days=max_past_days) latest = now + timedelta(days=max_future_days) for label, value in {"begin": date_begin, "end": date_end}.items(): if value is None: continue if value < earliest: raise DMARCParseError(f"Report {label} date is older than {max_past_days} days") if value > latest: raise DMARCParseError(f"Report {label} date is more than {max_future_days} days in the future") if date_begin and date_end and date_begin > date_end: raise DMARCParseError("Report begin date is after end date") def parse_dmarc_xml( payload: bytes, *, max_records: int | None = None, max_record_count: int | None = None, max_future_days: int = 3, max_past_days: int = 3650, ) -> ParsedReport: try: root = ET.fromstring(payload) except Exception as exc: raise DMARCParseError(f"Invalid XML: {exc}") from exc if _strip_namespace(root.tag) != "feedback": raise DMARCParseError("Root element is not feedback") metadata = _child(root, "report_metadata") policy = _child(root, "policy_published") if metadata is None or policy is None: raise DMARCParseError("Missing report_metadata or policy_published") domain = _text(policy, "domain") if not domain: raise DMARCParseError("Missing policy domain") date_begin = _dt(_text(metadata, "date_range/begin")) date_end = _dt(_text(metadata, "date_range/end")) _validate_report_dates(date_begin, date_end, max_future_days, max_past_days) parsed_records: list[ParsedRecord] = [] for record in _children(root, "record"): if max_records is not None and len(parsed_records) >= max_records: raise DMARCParseError(f"Report exceeds record limit of {max_records}") row = _child(record, "row") if row is None: continue policy_eval = _child(row, "policy_evaluated") source_ip = _text(row, "source_ip") count = _int(_text(row, "count")) or 0 if not source_ip: continue try: ip_address(source_ip) except ValueError as exc: raise DMARCParseError(f"Invalid source IP: {source_ip}") from exc if count < 0: raise DMARCParseError(f"Negative message count for source {source_ip}") if max_record_count is not None and count > max_record_count: raise DMARCParseError(f"Record count {count} exceeds limit of {max_record_count}") policy_dkim = _text(policy_eval, "dkim") if policy_eval is not None else None policy_spf = _text(policy_eval, "spf") if policy_eval is not None else None dkim_aligned = policy_dkim == "pass" spf_aligned = policy_spf == "pass" reason = _child(policy_eval, "reason") if policy_eval is not None else None auth_results: list[ParsedAuthResult] = [] auth = _child(record, "auth_results") if auth is not None: for dkim in _children(auth, "dkim"): auth_results.append( ParsedAuthResult( auth_type="dkim", domain=_text(dkim, "domain"), selector=_text(dkim, "selector"), result=_text(dkim, "result"), human_result=_text(dkim, "human_result"), ) ) for spf in _children(auth, "spf"): auth_results.append( ParsedAuthResult( auth_type="spf", domain=_text(spf, "domain"), scope=_text(spf, "scope"), result=_text(spf, "result"), ) ) parsed_records.append( ParsedRecord( source_ip=source_ip, count=count, disposition=_text(policy_eval, "disposition") if policy_eval is not None else None, policy_dkim=policy_dkim, policy_spf=policy_spf, dkim_aligned=dkim_aligned, spf_aligned=spf_aligned, dmarc_pass=dkim_aligned or spf_aligned, header_from=_text(record, "identifiers/header_from"), reason_type=_text(reason, "type") if reason is not None else None, reason_comment=_text(reason, "comment") if reason is not None else None, auth_results=auth_results, ) ) if not parsed_records: raise DMARCParseError("No valid DMARC records found") return ParsedReport( org_name=_text(metadata, "org_name"), org_email=_text(metadata, "email"), extra_contact_info=_text(metadata, "extra_contact_info"), report_id=_text(metadata, "report_id"), date_begin=date_begin, date_end=date_end, domain=domain, adkim=_text(policy, "adkim"), aspf=_text(policy, "aspf"), policy_p=_text(policy, "p"), policy_sp=_text(policy, "sp"), policy_pct=_int(_text(policy, "pct")), fo=_text(policy, "fo"), records=parsed_records, )