from __future__ import annotations import gzip import hashlib import io import zipfile from dataclasses import dataclass from email.message import Message from pathlib import PurePosixPath class AttachmentExtractionError(Exception): pass @dataclass(frozen=True) class ExtractedReport: filename: str payload: bytes sha256: str ARCHIVE_SUFFIXES = (".zip", ".gz") XML_MIME_HINTS = {"text/xml", "application/xml", "application/dmarc+xml"} GZIP_MIME_HINTS = {"application/gzip", "application/x-gzip"} ZIP_MIME_HINTS = {"application/zip", "application/x-zip-compressed"} def _max_bytes(max_mb: int) -> int: return max_mb * 1024 * 1024 def _sha(payload: bytes) -> str: return hashlib.sha256(payload).hexdigest() def _ensure_size(payload: bytes, max_mb: int, filename: str) -> None: if len(payload) > _max_bytes(max_mb): raise AttachmentExtractionError(f"{filename} exceeds decompressed limit of {max_mb} MB") def _ensure_compressed_size(payload: bytes, max_mb: int, filename: str) -> None: if len(payload) > _max_bytes(max_mb): raise AttachmentExtractionError(f"{filename} exceeds compressed limit of {max_mb} MB") def _ensure_ratio(compressed_size: int, decompressed_size: int, max_ratio: int, filename: str) -> None: if compressed_size <= 0: return ratio = decompressed_size / compressed_size if ratio > max_ratio: raise AttachmentExtractionError(f"{filename} exceeds compression ratio limit of {max_ratio}:1") def _safe_zip_name(name: str) -> bool: path = PurePosixPath(name) return not path.is_absolute() and ".." not in path.parts def _extract_zip(filename: str, payload: bytes, max_mb: int, max_reports: int, max_ratio: int) -> list[ExtractedReport]: reports: list[ExtractedReport] = [] with zipfile.ZipFile(io.BytesIO(payload)) as archive: for info in archive.infolist(): if info.is_dir(): continue if not _safe_zip_name(info.filename): raise AttachmentExtractionError(f"{filename} contains unsafe zip path {info.filename}") lower = info.filename.lower() if lower.endswith(ARCHIVE_SUFFIXES): raise AttachmentExtractionError(f"{filename} contains nested archive {info.filename}") if not lower.endswith(".xml"): continue if len(reports) >= max_reports: raise AttachmentExtractionError(f"{filename} exceeds archive XML report limit of {max_reports}") with archive.open(info) as handle: xml = handle.read(_max_bytes(max_mb) + 1) _ensure_size(xml, max_mb, info.filename) _ensure_ratio(info.compress_size, len(xml), max_ratio, info.filename) reports.append(ExtractedReport(info.filename, xml, _sha(xml))) return reports def _extract_gzip(filename: str, payload: bytes, max_mb: int, max_ratio: int) -> list[ExtractedReport]: with gzip.GzipFile(fileobj=io.BytesIO(payload)) as gz: xml = gz.read(_max_bytes(max_mb) + 1) _ensure_size(xml, max_mb, filename) _ensure_ratio(len(payload), len(xml), max_ratio, filename) out_name = filename[:-3] if filename.lower().endswith(".gz") else f"{filename}.xml" return [ExtractedReport(out_name, xml, _sha(xml))] def extract_payload( filename: str, content_type: str | None, payload: bytes, max_mb: int, *, max_compressed_mb: int = 10, max_reports_per_archive: int = 20, max_compression_ratio: int = 100, ) -> list[ExtractedReport]: _ensure_compressed_size(payload, max_compressed_mb, filename) lower = filename.lower() mime = (content_type or "").lower() if lower.endswith(".zip") or mime in ZIP_MIME_HINTS: return _extract_zip(filename, payload, max_mb, max_reports_per_archive, max_compression_ratio) if lower.endswith(".gz") or mime in GZIP_MIME_HINTS: return _extract_gzip(filename, payload, max_mb, max_compression_ratio) if lower.endswith(".xml") or mime in XML_MIME_HINTS: _ensure_size(payload, max_mb, filename) return [ExtractedReport(filename, payload, _sha(payload))] return [] def message_has_candidate_attachment(message: Message) -> bool: for part in message.walk(): filename = part.get_filename() or "" content_type = (part.get_content_type() or "").lower() lower = filename.lower() if lower.endswith((".xml", ".xml.gz", ".gz", ".zip")): return True if content_type in XML_MIME_HINTS | GZIP_MIME_HINTS | ZIP_MIME_HINTS: return True return False def extract_dmarc_attachments( message: Message, max_mb: int, *, max_compressed_mb: int = 10, max_attachments: int = 20, max_reports_per_message: int = 20, max_reports_per_archive: int = 20, max_compression_ratio: int = 100, ) -> list[ExtractedReport]: reports: list[ExtractedReport] = [] attachment_count = 0 for part in message.walk(): if part.is_multipart(): continue filename = part.get_filename() or "attachment" payload = part.get_payload(decode=True) if not payload: continue attachment_count += 1 if attachment_count > max_attachments: raise AttachmentExtractionError(f"message exceeds attachment limit of {max_attachments}") reports.extend( extract_payload( filename, part.get_content_type(), payload, max_mb, max_compressed_mb=max_compressed_mb, max_reports_per_archive=max_reports_per_archive, max_compression_ratio=max_compression_ratio, ) ) if len(reports) > max_reports_per_message: raise AttachmentExtractionError(f"message exceeds extracted report limit of {max_reports_per_message}") return reports