#!/usr/bin/env python3
"""Compute correlation-only effectiveness signals per gate_id.

Reads hook events JSONL. Pairs instructions_loaded with session_end via
correlation_id. For each gate_id, computes correction_rate among sessions
that loaded the gate (cohort A) and among sessions that did not (cohort B).
delta = correction_rate(B) - correction_rate(A); positive delta means
loading the gate correlates with fewer corrections.

Never reports causality. Below --min-n in either cohort, label is
needs_review.
"""
from __future__ import annotations

import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent))
from collect_hook_event import assert_regular_file_destination  # noqa: E402


def load_sessions(events_path: Path, *, include_rotated: bool = True):
    """Returns dict cid -> {gates: set[str], outcome: str|None,
                            probe_decisions: dict[gate_id, 'load'|'skip']}.

    A session's per-gate probe decision is whichever value was most recently
    seen for that gate_id across the session's instructions_loaded events.

    Reads the live hook events file plus, when ``include_rotated`` is true,
    every ``<name>.*.bak`` sibling produced by ``collect_hook_event.rotate_if_needed``.
    Sessions are keyed by ``correlation_id``, which is stable across files,
    so a session that straddles a rotation (instructions_loaded in a .bak
    and session_end in the live file) still pairs up. Without this merge,
    every rotation throws away the cohort window down to whatever fits in
    ``DEFAULT_MAX_HOOK_EVENT_BYTES`` and gates with N>min_n yesterday flip
    to ``needs_review`` overnight.

    Per-line resilience:
      - Malformed JSON lines are skipped (matches the policy in
        ``refresh_learning_state.queue_candidate_adjustments``). A single
        torn line in hook-events.jsonl previously aborted the entire
        scoring pass and, transitively, the entire refresh.
      - v1 rows (``schema_version < 2``) are skipped: they predate
        ``gate_loaded_ids`` and ``probe_decisions``, and counting them
        toward cohorts would silently inflate ``n_absent`` for every gate
        and bias delta toward correlated_with_failure during a v1->v2
        migration window.
    """
    sessions = defaultdict(lambda: {"gates": set(), "outcome": None, "probe_decisions": {}})
    paths = [events_path]
    if include_rotated:
        paths.extend(sorted(events_path.parent.glob(f"{events_path.name}.*.bak")))
    for path in paths:
        if not path.exists():
            continue
        with path.open(encoding="utf-8") as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                try:
                    row = json.loads(line)
                except json.JSONDecodeError:
                    continue
                if row.get("schema_version", 1) < 2:
                    continue
                cid = row.get("correlation_id")
                if not cid:
                    continue
                evt = row.get("event")
                if evt == "instructions_loaded":
                    for gid in row.get("gate_loaded_ids", []) or []:
                        sessions[cid]["gates"].add(gid)
                    for entry in row.get("probe_decisions", []) or []:
                        if not isinstance(entry, dict):
                            continue
                        gid = entry.get("gate_id")
                        decision = entry.get("decision")
                        if isinstance(gid, str) and decision in ("load", "skip"):
                            sessions[cid]["probe_decisions"][gid] = decision
                elif evt == "session_end":
                    sessions[cid]["outcome"] = row.get("outcome")
    return dict(sessions)


PROBE_COHORT_MIN_N = 5


def evaluate(sessions, min_n=10):
    """For each gate_id, compute cohort stats. Returns dict with 'gates' list.

    Additionally splits by probe_decision when present, emitting a
    causal_signal field. Probe cohorts use a lower min-N (5) than the
    correlational cohorts (10): probes are deliberately N-bounded by the
    operator-controlled skip rate, so the same threshold would suppress
    causal signal indefinitely.
    """
    all_gate_ids = set()
    for s in sessions.values():
        all_gate_ids.update(s["gates"])
        # Include gates that were probe-skipped this session too — they
        # belong to the gate's universe even though the session itself
        # never had the gate loaded.
        all_gate_ids.update(s.get("probe_decisions", {}).keys())

    def rate(outs):
        if not outs:
            return None
        return sum(1 for o in outs if o == "correction") / len(outs)

    rows = []
    for gid in sorted(all_gate_ids):
        loaded_outcomes = [s["outcome"] for s in sessions.values()
                           if gid in s["gates"] and s["outcome"]]
        absent_outcomes = [s["outcome"] for s in sessions.values()
                           if gid not in s["gates"] and s["outcome"]]

        # Probe cohorts: per-gate decisions, regardless of whether the gate
        # was loaded in the absence-of-probe case. The probe decision *is*
        # the contract — load=loaded, skip=not loaded.
        probe_loaded_outcomes = [
            s["outcome"] for s in sessions.values()
            if s.get("probe_decisions", {}).get(gid) == "load" and s["outcome"]
        ]
        probe_skipped_outcomes = [
            s["outcome"] for s in sessions.values()
            if s.get("probe_decisions", {}).get(gid) == "skip" and s["outcome"]
        ]

        a = rate(loaded_outcomes)
        b = rate(absent_outcomes)
        n_loaded = len(loaded_outcomes)
        n_absent = len(absent_outcomes)

        if n_loaded < min_n or n_absent < min_n or a is None or b is None:
            label = "needs_review"
            delta = None
        else:
            delta = b - a
            if delta >= 0.20:
                label = "correlated_with_success"
            elif delta <= -0.10:
                label = "correlated_with_failure"
            else:
                label = "no_signal"

        # Causal signal from probe cohorts. Uses a smaller min-N (5) because
        # probes are intentionally rate-limited; demanding 10 would gate the
        # signal indefinitely on low-traffic gates.
        probe_loaded_rate = rate(probe_loaded_outcomes)
        probe_skipped_rate = rate(probe_skipped_outcomes)
        n_probe_loaded = len(probe_loaded_outcomes)
        n_probe_skipped = len(probe_skipped_outcomes)

        if (n_probe_loaded < PROBE_COHORT_MIN_N
                or n_probe_skipped < PROBE_COHORT_MIN_N
                or probe_loaded_rate is None
                or probe_skipped_rate is None):
            causal_signal = "needs_review"
        else:
            probe_delta = probe_skipped_rate - probe_loaded_rate
            if probe_delta >= 0.20:
                causal_signal = "causal_correlated_with_success"
            elif probe_delta <= -0.10:
                causal_signal = "causal_correlated_with_failure"
            else:
                causal_signal = "causal_no_signal"

        rows.append({
            "gate_id": gid,
            "n_loaded": n_loaded,
            "n_absent": n_absent,
            "correction_rate_loaded": a,
            "correction_rate_absent": b,
            "delta": delta,
            "label": label,
            "causal_signal": causal_signal,
        })
    return {"gates": rows}


def parse_args():
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--events", required=True, type=Path)
    p.add_argument("--output", required=True, type=Path)
    p.add_argument("--min-n", type=int, default=10)
    return p.parse_args()


def main():
    args = parse_args()
    if not args.events.is_file():
        print(f"events not a regular file: {args.events}", file=sys.stderr)
        return 2
    try:
        assert_regular_file_destination(args.output, label="Effectiveness output")
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2
    sessions = load_sessions(args.events)
    result = evaluate(sessions, min_n=args.min_n)
    args.output.write_text(json.dumps(result, indent=2, sort_keys=True), encoding="utf-8")
    return 0


if __name__ == "__main__":
    sys.exit(main())
