from __future__ import annotations import re from dataclasses import dataclass, field from typing import Callable TxtLookup = Callable[[str], list[str]] MxLookup = Callable[[str], list[str]] @dataclass class ParsedDmarcRecord: raw: str | None = None p: str | None = None sp: str | None = None pct: int | None = None adkim: str | None = None aspf: str | None = None fo: str | None = None rua: str | None = None ruf: str | None = None @dataclass class ParsedSpfRecord: raw: str | None = None includes: list[str] = field(default_factory=list) all_mechanism: str | None = None @dataclass class DkimRecord: selector: str domain: str query_name: str record: str | None = None error: str | None = None @dataclass class DomainDnsPolicy: domain: str dmarc: ParsedDmarcRecord = field(default_factory=ParsedDmarcRecord) spf: ParsedSpfRecord = field(default_factory=ParsedSpfRecord) dkim: list[DkimRecord] = field(default_factory=list) mx_records: list[str] = field(default_factory=list) errors: list[str] = field(default_factory=list) def _default_txt_lookup(name: str) -> list[str]: try: import dns.resolver except ImportError as exc: raise RuntimeError("dnspython is not installed") from exc answers = dns.resolver.resolve(name, "TXT", lifetime=10) values = [] for answer in answers: parts = getattr(answer, "strings", None) if parts is None: values.append(str(answer).strip('"')) else: values.append("".join(part.decode("utf-8", errors="replace") for part in parts)) return values def _default_mx_lookup(name: str) -> list[str]: try: import dns.resolver except ImportError as exc: raise RuntimeError("dnspython is not installed") from exc answers = dns.resolver.resolve(name, "MX", lifetime=10) return [f"{answer.preference} {str(answer.exchange).rstrip('.')}" for answer in answers] def _tag_map(record: str) -> dict[str, str]: tags: dict[str, str] = {} for part in record.split(";"): if "=" not in part: continue key, value = part.split("=", 1) key = key.strip().lower() value = value.strip() if key: tags[key] = value return tags def _int(value: str | None) -> int | None: if not value: return None try: return int(value) except ValueError: return None def parse_dmarc_records(records: list[str]) -> tuple[ParsedDmarcRecord, list[str]]: dmarc_records = [record.strip() for record in records if record.strip().lower().startswith("v=dmarc1")] errors = [] if not dmarc_records: return ParsedDmarcRecord(), ["DMARC TXT record not found"] if len(dmarc_records) > 1: errors.append("Multiple DMARC TXT records found") raw = dmarc_records[0] tags = _tag_map(raw) return ( ParsedDmarcRecord( raw=raw, p=tags.get("p"), sp=tags.get("sp"), pct=_int(tags.get("pct")), adkim=tags.get("adkim"), aspf=tags.get("aspf"), fo=tags.get("fo"), rua=tags.get("rua"), ruf=tags.get("ruf"), ), errors, ) def parse_spf_records(records: list[str]) -> tuple[ParsedSpfRecord, list[str]]: spf_records = [record.strip() for record in records if record.strip().lower().startswith("v=spf1")] errors = [] if not spf_records: return ParsedSpfRecord(), ["SPF TXT record not found"] if len(spf_records) > 1: errors.append("Multiple SPF TXT records found") raw = spf_records[0] tokens = raw.split() includes = [token.split(":", 1)[1] for token in tokens if token.lower().startswith("include:") and ":" in token] all_mechanism = next((token for token in tokens if re.fullmatch(r"[+?~-]?all", token, flags=re.IGNORECASE)), None) return ParsedSpfRecord(raw=raw, includes=includes, all_mechanism=all_mechanism), errors def collect_domain_dns_policy( domain: str, *, selectors: list[str | tuple[str, str]] | None = None, txt_lookup: TxtLookup | None = None, mx_lookup: MxLookup | None = None, ) -> DomainDnsPolicy: txt_lookup = txt_lookup or _default_txt_lookup mx_lookup = mx_lookup or _default_mx_lookup domain = domain.lower().rstrip(".") policy = DomainDnsPolicy(domain=domain) try: policy.dmarc, errors = parse_dmarc_records(txt_lookup(f"_dmarc.{domain}")) policy.errors.extend(errors) except Exception as exc: policy.errors.append(f"DMARC lookup failed: {exc}") try: policy.spf, errors = parse_spf_records(txt_lookup(domain)) policy.errors.extend(errors) except Exception as exc: policy.errors.append(f"SPF lookup failed: {exc}") try: policy.mx_records = mx_lookup(domain) except Exception as exc: policy.errors.append(f"MX lookup failed: {exc}") selector_domains: set[tuple[str, str]] = set() for item in selectors or []: if isinstance(item, tuple): selector, dkim_domain = item else: selector, dkim_domain = item, domain selector = (selector or "").strip().lower() dkim_domain = (dkim_domain or domain).strip().lower().rstrip(".") if selector and dkim_domain: selector_domains.add((selector, dkim_domain)) for selector, dkim_domain in sorted(selector_domains): query_name = f"{selector}._domainkey.{dkim_domain}" try: records = txt_lookup(query_name) dkim_records = [record for record in records if record.strip().lower().startswith("v=dkim1")] policy.dkim.append( DkimRecord( selector=selector, domain=dkim_domain, query_name=query_name, record=dkim_records[0] if dkim_records else None, ) ) if not dkim_records: policy.errors.append(f"DKIM record not found for selector {selector} on {dkim_domain}") except Exception as exc: policy.dkim.append(DkimRecord(selector=selector, domain=dkim_domain, query_name=query_name, error=str(exc))) policy.errors.append(f"DKIM lookup failed for selector {selector} on {dkim_domain}: {exc}") return policy