#!/usr/bin/env python3
"""Export a compact gate registry from a validated agent-learning report."""

from __future__ import annotations

import argparse
import dataclasses
import datetime as dt
import hashlib
import json
import pathlib
import re
import sys

from state_paths import atomic_rewrite
from validate_outputs import validate


def _gate_id(domain: str, category: str, gate_text: str) -> str:
    """Deterministic 12-hex identifier for (domain, category, gate text)."""
    h = hashlib.sha256()
    h.update(f"{domain}|{category}|{gate_text.strip()}".encode("utf-8"))
    return h.hexdigest()[:12]


SECTION_RE = re.compile(r"^##\s+(.+?)\s*$", re.M)
DOMAIN_RE = re.compile(r"^###\s+domain:\s*(.+?)\s*$", re.M | re.I)


@dataclasses.dataclass(frozen=True)
class Gate:
    domain: str
    gate_category: str
    gate: str
    level: str | None = None
    evidence_unit: str | None = None
    evidence_count: str | None = None


def strip_value(value: str) -> str:
    value = value.strip()
    if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
        return value[1:-1].strip()
    return value


def section_text(report: str, heading: str) -> str:
    matches = list(SECTION_RE.finditer(report))
    for index, match in enumerate(matches):
        if match.group(1).strip().lower() == heading.lower():
            start = match.end()
            end = matches[index + 1].start() if index + 1 < len(matches) else len(report)
            return report[start:end].strip()
    return ""


def find_level(text: str) -> str | None:
    patterns = (
        r"^\s*-\s+\*\*level:\*\*\s*(.+?)\s*$",
        r"^\s*-\s+level:\s*(.+?)\s*$",
        r"^\s*level:\s*(.+?)\s*$",
    )
    for pattern in patterns:
        match = re.search(pattern, text, re.M | re.I)
        if match:
            return strip_value(match.group(1))
    return None


def find_evidence_count(text: str) -> tuple[str, str] | tuple[None, None]:
    patterns = (
        ("matching_lines", r"\bmatching_lines:\s*(\d+)\b"),
        ("baseline_sources", r"\bbaseline_sources:\s*(\d+)\b"),
        ("matching_lines", r"\bevidence:\s*(\d+)\s+matching user lines\b"),
        ("sessions", r"\bevidence:\s*(\d+)\s+of\s+\d+\s+sessions\b"),
    )
    for unit, pattern in patterns:
        match = re.search(pattern, text, re.I)
        if match:
            return unit, match.group(1)
    return None, None


def parse_gate_lines(domain: str, text: str) -> list[Gate]:
    gates: list[Gate] = []
    level = find_level(text)
    evidence_unit, evidence_count = find_evidence_count(text)
    pending_category: str | None = None

    for line in text.splitlines():
        category_match = re.match(r"^\s*-?\s*(?:gate_)?category:\s*(.+?)\s*$", line, re.I)
        if category_match:
            pending_category = strip_value(category_match.group(1))
            continue

        gate_match = re.match(r"^\s*-?\s*gate:\s*(.+?)\s*$", line, re.I)
        if gate_match and pending_category:
            gates.append(
                Gate(
                    domain=domain,
                    gate_category=pending_category,
                    gate=strip_value(gate_match.group(1)),
                    level=level,
                    evidence_unit=evidence_unit,
                    evidence_count=evidence_count,
                )
            )
            pending_category = None

    return gates


def parse_domain_sections(agent_compensation: str) -> list[Gate]:
    matches = list(DOMAIN_RE.finditer(agent_compensation))
    gates: list[Gate] = []
    for index, match in enumerate(matches):
        domain = strip_value(match.group(1))
        start = match.end()
        end = matches[index + 1].start() if index + 1 < len(matches) else len(agent_compensation)
        gates.extend(parse_gate_lines(domain, agent_compensation[start:end]))
    return gates


def parse_yaml_style_blocks(agent_compensation: str) -> list[Gate]:
    gates: list[Gate] = []
    blocks = re.split(r"(?m)^\s*-\s+domain:\s*", agent_compensation)
    for block in blocks[1:]:
        first_line, _, rest = block.partition("\n")
        domain = strip_value(first_line)
        category_match = re.search(r"^\s*category:\s*(.+?)\s*$", rest, re.M | re.I)
        gate_match = re.search(r"^\s*gate:\s*(.+?)\s*$", rest, re.M | re.I)
        if not category_match or not gate_match:
            continue
        evidence_unit, evidence_count = find_evidence_count(rest)
        gates.append(
            Gate(
                domain=domain,
                gate_category=strip_value(category_match.group(1)),
                gate=strip_value(gate_match.group(1)),
                level=find_level(rest),
                evidence_unit=evidence_unit,
                evidence_count=evidence_count,
            )
        )
    return gates


def parse_gates(report: str) -> list[Gate]:
    agent_compensation = section_text(report, "agent_compensation")
    gates = parse_domain_sections(agent_compensation)
    gates.extend(parse_yaml_style_blocks(agent_compensation))

    seen: set[tuple[str, str, str]] = set()
    deduped: list[Gate] = []
    for gate in gates:
        key = (gate.domain, gate.gate_category, gate.gate)
        if key in seen:
            continue
        seen.add(key)
        deduped.append(gate)
    return deduped


def _load_probes(path: pathlib.Path | None) -> dict:
    """Return the registered probes dict, or {} when no path is supplied or
    the file is missing. Operator-controlled local state, so a malformed
    file surfaces a JSONDecodeError rather than being silently swallowed."""
    if path is None:
        return {}
    if not path.exists():
        return {}
    data = json.loads(path.read_text(encoding="utf-8"))
    return data if isinstance(data, dict) else {}


def _select_for_render(gates: list[Gate], max_domains: int | None) -> list[Gate]:
    """Apply the --max-domains cap. Pulled out so main() can identify which
    gate_ids the next render will actually cover (for inherited-block dedup)
    without re-implementing the selection rule."""
    selected_domains: list[str] = []
    selected: list[Gate] = []
    for gate in gates:
        if gate.domain not in selected_domains:
            if max_domains is not None and len(selected_domains) >= max_domains:
                continue
            selected_domains.append(gate.domain)
        selected.append(gate)
    return selected


def preserved_inherited_blocks(existing_text: str) -> dict[str, str]:
    """Return {gate_id: block_text} for blocks in `existing_text` that carry a
    `derived_from:` line. Local (non-inherited) blocks are not preserved
    because they get re-rendered from the report on every export.

    The returned block text starts with `- domain:` and contains its
    indented field lines, no trailing newline. Callers append them after
    the freshly-rendered local blocks.
    """
    if not existing_text:
        return {}
    blocks_by_id: dict[str, str] = {}
    # Split BEFORE each "- domain:" marker (zero-width lookahead) so the
    # marker stays attached to the block. Tolerates CRLF and a file with
    # no leading newline before the first block, matching the pattern
    # gates_promote/gates_inherit use elsewhere.
    parts = re.split(r"(?m)^(?=-\s+domain:)", existing_text)
    for part in parts:
        if not re.match(r"-\s+domain:", part):
            continue
        # Anchored to the canonical indented field position so a local
        # gate whose `gate:` text legitimately contains the literal
        # substring "derived_from:" (e.g. an instruction about checking
        # the derived_from field) is not silently treated as inherited
        # and pinned across re-exports.
        if not re.search(r"^\s+derived_from:\s*\S", part, re.MULTILINE):
            continue
        match = re.search(r"^\s*gate_id:\s*([a-f0-9]{12})\s*$", part, re.MULTILINE)
        if not match:
            continue
        blocks_by_id[match.group(1)] = part.rstrip()
    return blocks_by_id


def render_registry(
    report_path: pathlib.Path,
    gates: list[Gate],
    max_domains: int | None = None,
    probes: dict | None = None,
    preserve_blocks: list[str] | None = None,
) -> str:
    now = dt.datetime.now(dt.timezone.utc)
    selected_gates = _select_for_render(gates, max_domains)
    selected_domains: list[str] = []
    for gate in selected_gates:
        if gate.domain not in selected_domains:
            selected_domains.append(gate.domain)

    lines = [
        "# Approved Agent Gates",
        "",
        f"- generated_at: {now.isoformat()}",
        f"- date: {now.date().isoformat()}",
        f"- source_report: {report_path}",
        f"- domains: {', '.join(selected_domains) if selected_domains else 'none'}",
        "",
        "## gates",
        "",
    ]
    probes = probes or {}
    for gate in selected_gates:
        gate_id = _gate_id(gate.domain, gate.gate_category, gate.gate)
        lines.extend(
            [
                f"- domain: {gate.domain}",
                f"  gate_id: {gate_id}",
                f"  gate_category: {gate.gate_category}",
                f"  gate: {gate.gate}",
            ]
        )
        probe = probes.get(gate_id)
        if isinstance(probe, dict) and "rate" in probe:
            # Emit probe metadata so downstream loaders know this gate is
            # under an A/B probe and at what skip rate. Unregistered gates
            # stay silent (status omitted entirely).
            lines.append("  probe_status: active")
            lines.append(f"  probe_rate: {probe['rate']}")
        if gate.level:
            lines.append(f"  level: {gate.level}")
        if gate.evidence_unit and gate.evidence_count:
            lines.append(f"  {gate.evidence_unit}: {gate.evidence_count}")
    # Preserve federation-inherited blocks across re-exports. Without
    # this, every export wipes blocks containing `derived_from:` (written
    # by gates_inherit) and refresh._inherited_gates() sees them as gone,
    # queueing federated gates for retirement instead of demotion.
    # Inherited blocks intentionally live outside --max-domains: they were
    # chosen by a sibling repo's operator, not by this repo's report.
    if preserve_blocks:
        for block in preserve_blocks:
            lines.extend(block.rstrip().split("\n"))
    lines.append("")
    return "\n".join(lines)


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--report", required=True, help="Validated agent-learning report.md to export from.")
    parser.add_argument("--output", required=True, help="Markdown registry output path.")
    parser.add_argument("--max-domains", type=int, default=None, help="Maximum number of domains to export.")
    parser.add_argument(
        "--probes",
        default=None,
        help="Optional path to causal_probe probes.json; emits probe_status/probe_rate "
        "lines for any gate_id registered there.",
    )
    args = parser.parse_args(argv)

    if args.max_domains is not None and args.max_domains < 1:
        parser.error("--max-domains must be at least 1")

    report_path = pathlib.Path(args.report).resolve()
    output_path = pathlib.Path(args.output)
    report = report_path.read_text(encoding="utf-8")

    errors = validate(report)
    if errors:
        for error in errors:
            print(error, file=sys.stderr)
        return 1

    gates = parse_gates(report)
    if not gates:
        print("no agent_compensation gates found", file=sys.stderr)
        return 1

    probes_path = pathlib.Path(args.probes) if args.probes else None
    probes = _load_probes(probes_path)

    # Read + preserve + write under one sidecar lock so a concurrent
    # gates_inherit cannot land an inherited block between the read and
    # the write -- pre-B-5 export would render from a pre-inherit snapshot
    # and silently overwrite the freshly-appended block. gates_inherit
    # uses the same `atomic_rewrite` on the same path, so the lockfile
    # `<output>.lock` is the shared mutex.
    with atomic_rewrite(output_path) as (existing_text, commit):
        inherited = preserved_inherited_blocks(existing_text)
        new_gate_ids = {
            _gate_id(g.domain, g.gate_category, g.gate)
            for g in _select_for_render(gates, args.max_domains)
        }
        preserve = [block for gid, block in inherited.items() if gid not in new_gate_ids]
        commit(render_registry(
            report_path,
            gates,
            args.max_domains,
            probes=probes,
            preserve_blocks=preserve,
        ))
    print(f"exported {len(gates)} gates to {output_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
