#!/usr/bin/env python3
"""Refresh compact agent-learning state and queue candidate skill improvements."""

from __future__ import annotations

import argparse
import contextlib
import datetime as dt
import fcntl
import hashlib
import json
import os
import pathlib
import re
import sys
import time
from typing import Any

from build_repo_baseline import build as build_baseline
from collect_hook_event import assert_regular_file_destination
from evaluate_gate_effectiveness import evaluate as evaluate_gates
from evaluate_gate_effectiveness import load_sessions
from evaluate_skill_impact import evaluate as evaluate_impact
from export_skill_context import write_context
from extract_skill_usage import build_usage, read_events
from map_active_skills import build_map
from propose_domain_rules import parse_chunks, score_terms
from queue_dedup import find_duplicates, order_by_keep
from scrub_secrets import scrub
from state_paths import atomic_rewrite, atomic_write_text, repo_state_dir, resolve_state_dir


def write_json(path: pathlib.Path, payload: dict[str, Any]) -> pathlib.Path:
    atomic_write_text(path, json.dumps(payload, indent=2, sort_keys=True) + "\n")
    return path


def stable_queue_id(row: dict[str, Any]) -> str:
    key = "|".join(
        [
            str(row.get("skill", "")),
            str(row.get("impact_signal", "")),
            str(row.get("candidate_adjustment", "")),
        ]
    )
    return hashlib.sha256(key.encode("utf-8")).hexdigest()[:16]


def queue_candidate_adjustments(path: pathlib.Path, impact: dict[str, Any]) -> dict[str, int]:
    path.parent.mkdir(parents=True, exist_ok=True)
    assert_regular_file_destination(path, label="Improvement queue")
    queued = 0
    suppressed_needs_review = 0
    suppressed_redacted = 0
    now = dt.datetime.now(dt.timezone.utc).isoformat()
    # Open + flock the queue file so concurrent refreshes cannot interleave
    # writes or each see a stale view of existing ids. We re-read ids AFTER
    # acquiring the exclusive lock so a sibling that already appended is
    # visible here.
    with path.open("a+", encoding="utf-8") as handle:
        fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
        try:
            handle.seek(0)
            seen: set[str] = set()
            for line in handle.read().splitlines():
                if not line.strip():
                    continue
                try:
                    payload = json.loads(line)
                except json.JSONDecodeError:
                    continue
                if payload.get("id"):
                    seen.add(str(payload["id"]))
            handle.seek(0, 2)  # back to end for appends
            for row in impact.get("skills", []):
                signal = str(row.get("impact_signal", ""))
                candidate = str(row.get("candidate_adjustment", "")).strip()
                if signal == "needs_review":
                    suppressed_needs_review += 1
                    continue
                if not candidate or signal == "correlated_with_success":
                    continue
                item = {
                    "schema_version": 1,
                    "id": stable_queue_id(row),
                    "queued_at": now,
                    "status": "candidate",
                    "source": "refresh_learning_state.py",
                    "skill": row.get("skill"),
                    "impact_signal": signal,
                    "confidence": row.get("confidence"),
                    "candidate_adjustment": candidate,
                    "evidence": {
                        "expected_sessions": row.get("expected_sessions", 0),
                        "loaded_sessions": row.get("loaded_sessions", 0),
                        "missed_sessions": row.get("missed_sessions", 0),
                        "corrections_after_loaded": row.get("corrections_after_loaded", 0),
                        "corrections_after_missed": row.get("corrections_after_missed", 0),
                    },
                }
                if item["id"] in seen:
                    continue
                rendered = scrub(json.dumps(item, sort_keys=True, separators=(",", ":")))
                if "[REDACTED" in rendered:
                    suppressed_redacted += 1
                    continue
                handle.write(rendered + "\n")
                seen.add(str(item["id"]))
                queued += 1
        finally:
            fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
    return {
        "queued": queued,
        "suppressed_needs_review": suppressed_needs_review,
        "suppressed_redacted": suppressed_redacted,
    }


def _post_dedup(
    queue_path: pathlib.Path,
    backend: str = "trigram",
    threshold: float = 0.80,
) -> int:
    """Dedup queue in place using trigram-Dice.

    Called after queue_candidate_adjustments so duplicate near-paraphrases
    from prior runs are collapsed in-process. Holds the sidecar lock for
    queue_path across read+compute+commit so a sibling refresh cannot
    interleave appends, and writes the new content via tmp+fsync+rename
    so SIGKILL mid-write leaves the queue at its prior content rather
    than truncating it.

    Persisted queue rows written by refresh use ``candidate_adjustment``
    rather than ``text``, so we synthesize a view-row carrying ``text``
    for the dedup algorithm without mutating what we write back.
    """
    if not queue_path.exists():
        return 0
    removed = 0
    with atomic_rewrite(queue_path) as (current_text, commit):
        lines = [ln for ln in current_text.splitlines() if ln]
        rows = []
        for ln in lines:
            try:
                rows.append(json.loads(ln))
            except json.JSONDecodeError:
                continue
        if len(rows) < 2:
            return 0
        view = [
            dict(r, text=r.get("text") or str(r.get("candidate_adjustment", "")))
            for r in rows
        ]
        priority = order_by_keep(view, "oldest")
        reordered_view = [view[i] for i in priority]
        reordered_rows = [rows[i] for i in priority]
        drop = find_duplicates(reordered_view, backend, threshold)
        if not drop:
            return 0
        kept = [r for i, r in enumerate(reordered_rows) if i not in drop]
        commit("\n".join(json.dumps(r, sort_keys=True) for r in kept) + "\n")
        removed = len(drop)
    return removed


def _inherited_gates(gates_md_path: pathlib.Path) -> dict[str, str]:
    """Return dict gate_id -> derived_from value for gates with a derived_from line.

    Parses ``latest-approved-gates.md`` block-by-block; a block is considered
    inherited when it carries both a ``gate_id:`` and a ``derived_from:`` line.
    Returns an empty dict if the file is missing.
    """
    if not gates_md_path.is_file():
        return {}
    text = gates_md_path.read_text(encoding="utf-8")
    inherited: dict[str, str] = {}
    # CRLF-tolerant splitter (matches the pattern in gates_inherit and
    # export_gates). The previous text.split("\n- domain:") returned the
    # whole file as one block on CRLF files or files that start directly
    # with "- domain:", and the per-block walk below would cross block
    # boundaries and report a stale derived_from.
    blocks = re.split(r"(?m)^-\s+domain:\s*", text)
    for block in blocks[1:]:
        gate_id = None
        derived_from = None
        for raw_line in block.splitlines():
            line = raw_line.strip()
            if line.startswith("gate_id:"):
                gate_id = line.split(":", 1)[1].strip()
            elif line.startswith("derived_from:"):
                derived_from = line.split(":", 1)[1].strip()
        if not gate_id:
            continue
        # Fail closed: a gate_id without derived_from is treated as
        # inherited-with-unknown-origin. Pre-fix this block was silently
        # dropped from inherited_map, and the retirement filter then
        # queued it as gate_retirement_candidate instead of
        # inherited_gate_demote_candidate -- an operator acting on the
        # queue would retire a gate that may affect sibling repos. The
        # "unknown" sentinel makes the missing provenance visible in the
        # queue row text.
        inherited[gate_id] = derived_from or "unknown"
    return inherited


_RETIRE_CAUSAL_OK = ("causal_correlated_with_failure", "causal_no_signal")


def _retirement_row_id(gate_id: str, kind: str) -> str:
    """Stable per-(gate_id, kind) row id. Pre-fix the row id embedded the
    refresh's wall-clock second and a within-batch counter, so two refreshes
    one second apart produced different ids for the same candidate and the
    downstream trigram dedup could not always collapse the drift (delta
    varies between refreshes, the text differs, trigram-Dice can drop below
    the 0.80 threshold). Stable ids let queue_candidate_adjustments-style
    `seen` membership checks prevent duplicates outright.
    """
    return hashlib.sha256(f"{gate_id}|{kind}".encode("utf-8")).hexdigest()[:16]


def _queue_retirement_candidates(
    queue_path: pathlib.Path,
    events_path: pathlib.Path,
    min_n_retire: int = 20,
    inherited: dict[str, str] | None = None,
) -> tuple[int, int]:
    """Run effectiveness scoring; append retirement / demote candidates.

    A candidate gate is one where:
      - label is correlated_with_failure OR no_signal
      - n_loaded >= min_n_retire AND n_absent >= min_n_retire (stricter
        than the standalone scorer's min-N=10 because retirement is more
        disruptive; M5 closes the asymmetry where n_absent could be as
        low as 10 while n_loaded was capped at min_n_retire).
      - causal_signal is causal_correlated_with_failure OR causal_no_signal.
        H2 requires causal evidence before the action that actually
        disrupts production. needs_review (no probe data) and
        causal_correlated_with_success are excluded -- if the causal probe
        says loading the gate helps, retiring it would be the wrong move.

    For each candidate, if the gate appears in ``inherited`` (i.e. carries a
    ``derived_from:`` line in approved-gates.md), queue an
    ``inherited_gate_demote_candidate`` row carrying the derived_from value.
    Otherwise queue the original ``gate_retirement_candidate``. Each
    candidate produces exactly one row (demote OR retire, never both), and
    re-running across refreshes produces idempotent rows keyed on the
    stable (gate_id, kind) hash -- not the wall-clock timestamp.

    Returns (retirement_count, demote_count).
    """
    if not events_path.is_file():
        return (0, 0)
    inherited_map = inherited or {}
    sessions = load_sessions(events_path)
    result = evaluate_gates(sessions, min_n=10)
    candidates = [
        row for row in result["gates"]
        if row["label"] in ("correlated_with_failure", "no_signal")
        and row["n_loaded"] >= min_n_retire
        and row["n_absent"] >= min_n_retire
        and row.get("causal_signal") in _RETIRE_CAUSAL_OK
    ]
    if not candidates:
        return (0, 0)
    retire_count = 0
    demote_count = 0
    ts = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
    queue_path.parent.mkdir(parents=True, exist_ok=True)
    assert_regular_file_destination(queue_path, label="Improvement queue")
    queue_path.touch(exist_ok=True)
    # Open r+ so we can read the existing rows for `seen` membership AND
    # append new ones, all under the same LOCK_EX. Matches the pattern in
    # queue_candidate_adjustments.
    with queue_path.open("r+", encoding="utf-8") as handle:
        fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
        try:
            handle.seek(0)
            seen: set[str] = set()
            for line in handle.read().splitlines():
                if not line.strip():
                    continue
                try:
                    payload = json.loads(line)
                except json.JSONDecodeError:
                    continue
                if payload.get("id"):
                    seen.add(str(payload["id"]))
            handle.seek(0, 2)  # to end for appends
            for c in candidates:
                gate_id = c["gate_id"]
                if gate_id in inherited_map:
                    kind = "inherited_gate_demote_candidate"
                else:
                    kind = "gate_retirement_candidate"
                row_id = _retirement_row_id(gate_id, kind)
                if row_id in seen:
                    continue
                if kind == "inherited_gate_demote_candidate":
                    derived_from = inherited_map[gate_id]
                    row = {
                        "id": row_id,
                        "kind": kind,
                        "text": (
                            f"Demote inherited gate {gate_id} (origin={derived_from}): "
                            f"delta={c['delta']:.3f} after n_loaded={c['n_loaded']}, "
                            f"label={c['label']}."
                        ),
                        "gate_id": gate_id,
                        "derived_from": derived_from,
                        "evidence": {
                            "n_loaded": c["n_loaded"],
                            "n_absent": c["n_absent"],
                            "delta": c["delta"],
                            "label": c["label"],
                            "causal_signal": c.get("causal_signal"),
                        },
                        "ts": ts,
                    }
                    handle.write(json.dumps(row, sort_keys=True) + "\n")
                    seen.add(row_id)
                    demote_count += 1
                else:
                    row = {
                        "id": row_id,
                        "kind": kind,
                        "text": (
                            f"Retire low-impact gate {gate_id}: "
                            f"delta={c['delta']:.3f} after n_loaded={c['n_loaded']}, "
                            f"label={c['label']}."
                        ),
                        "gate_id": gate_id,
                        "evidence": {
                            "n_loaded": c["n_loaded"],
                            "n_absent": c["n_absent"],
                            "delta": c["delta"],
                            "label": c["label"],
                            "causal_signal": c.get("causal_signal"),
                        },
                        "ts": ts,
                    }
                    handle.write(json.dumps(row, sort_keys=True) + "\n")
                    seen.add(row_id)
                    retire_count += 1
        finally:
            fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
    return (retire_count, demote_count)


def _queue_domain_rule_candidates(
    queue_path: pathlib.Path,
    corpus_path: pathlib.Path,
    top_k: int = 10,
    min_score: float = 2.0,
) -> int:
    """Mine corpus for correction-correlated terms; queue them as domain_rule_candidate.

    Returns count appended. Corpus may not exist if extract_sessions hasn't run yet.
    """
    if not corpus_path.is_file():
        return 0
    chunks = list(parse_chunks(corpus_path))
    if not chunks:
        return 0
    scores = score_terms(chunks)
    scores = [s for s in scores if s[3] >= min_score]
    scores.sort(key=lambda x: (-x[3], x[0]))
    scores = scores[:top_k]
    if not scores:
        return 0
    ts = dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
    now_unix = int(time.time())
    queue_path.parent.mkdir(parents=True, exist_ok=True)
    assert_regular_file_destination(queue_path, label="Improvement queue")
    with queue_path.open("a", encoding="utf-8") as fh:
        fcntl.flock(fh.fileno(), fcntl.LOCK_EX)
        try:
            for i, (term, cc, cl, score) in enumerate(scores):
                row = {
                    "id": f"domain-{term.replace(' ', '_')}-{now_unix}-{i}",
                    "kind": "domain_rule_candidate",
                    "text": (
                        f"Consider adding domain seed for '{term}' "
                        f"(correction_count={cc}, clean_count={cl}, score={score:.2f})."
                    ),
                    "term": term,
                    "evidence": {
                        "correction_count": cc,
                        "clean_count": cl,
                        "score": score,
                    },
                    "ts": ts,
                }
                fh.write(json.dumps(row, sort_keys=True) + "\n")
        finally:
            fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
    return len(scores)


def has_event_rows(path: pathlib.Path) -> bool:
    if not path.exists():
        return False
    return any(line.strip() for line in path.read_text(encoding="utf-8").splitlines())


def refresh(
    repo: pathlib.Path,
    state_dir: str | pathlib.Path | None = None,
    personal: str | pathlib.Path | None = None,
    events: pathlib.Path | None = None,
    queue: pathlib.Path | None = None,
    corpus: pathlib.Path | None = None,
) -> dict[str, Any]:
    repo = repo.expanduser().resolve()
    repo_state = repo_state_dir(repo, state_dir, personal)
    reports = repo_state / "reports"
    repo_state.mkdir(parents=True, exist_ok=True)
    reports.mkdir(parents=True, exist_ok=True)

    start_event_id: str | None = None
    try:
        from event_emit import event_emit as emit_background_event

        start_event_id = emit_background_event(
            kind="refresh_start",
            actor_name="refresh_learning_state",
            actor_kind="background_agent",
        )
    except Exception:
        start_event_id = None

    started_at = time.time()

    # Top-level mutex so two concurrent refresh() invocations on the same
    # repo serialize end-to-end. Per-file LOCK_EX guards individual writes,
    # but releases between them, leaving a window where a sibling refresh
    # could land a baseline.json from its run interleaved with this run's
    # skill-map.json. The .refresh.lock sidecar is never renamed, so its
    # lock is a stable mutex across the writes that this refresh performs.
    refresh_lock_path = repo_state / ".refresh.lock"
    refresh_lock_fd = os.open(str(refresh_lock_path), os.O_RDWR | os.O_CREAT, 0o644)
    error: BaseException | None = None
    result: dict[str, Any]
    # Nested try so the outer finally closes the fd even if fcntl.flock
    # itself raises (EINTR loop, EDEADLK, etc.) -- the previous shape
    # held the fd outside the try and would leak it on any error
    # between os.open and the flock that ran on the next line.
    try:
        fcntl.flock(refresh_lock_fd, fcntl.LOCK_EX)
        try:
            result = _refresh_locked(
                repo,
                repo_state,
                reports,
                state_dir,
                personal,
                events,
                queue,
                corpus,
            )
        except BaseException as exc:
            error = exc
        finally:
            try:
                fcntl.flock(refresh_lock_fd, fcntl.LOCK_UN)
            except OSError:
                pass
    finally:
        os.close(refresh_lock_fd)

    if error is not None:
        raise error

    duration_ms = int((time.time() - started_at) * 1000)
    try:
        from event_emit import event_emit as emit_background_event

        emit_background_event(
            kind="refresh_end",
            actor_name="refresh_learning_state",
            actor_kind="background_agent",
            parent_event_id=start_event_id,
            payload={"telemetry": {"duration_ms": duration_ms}},
        )
    except Exception:
        pass

    return result


def _refresh_locked(
    repo: pathlib.Path,
    repo_state: pathlib.Path,
    reports: pathlib.Path,
    state_dir: str | pathlib.Path | None,
    personal: str | pathlib.Path | None,
    events: pathlib.Path | None,
    queue: pathlib.Path | None,
    corpus: pathlib.Path | None,
) -> dict[str, Any]:

    # Resolve runtime from the state-root config.json the installer writes.
    # Falls back silently to "auto" if config or runtime key is absent.
    runtime = "auto"
    try:
        state_root = resolve_state_dir(state_dir, personal, repo)
        config_path = state_root / "config.json"
        if config_path.exists():
            cfg = json.loads(config_path.read_text(encoding="utf-8"))
            if isinstance(cfg, dict) and isinstance(cfg.get("runtime"), str):
                runtime = cfg["runtime"]
    except (OSError, json.JSONDecodeError):
        runtime = "auto"

    baseline = build_baseline(repo, runtime=runtime)
    skill_map = build_map(repo, runtime=runtime)
    event_log = events or repo_state / "hook-events.jsonl"
    event_log_present = has_event_rows(event_log)
    if event_log_present:
        usage = build_usage(read_events(event_log), skill_map)
        impact = evaluate_impact(usage)
    else:
        usage = {}
        impact = {}
        print(
            "refresh: hook-events.jsonl is empty; skill-context will be empty",
            file=sys.stderr,
        )

    touched = [
        write_json(repo_state / "baseline.json", baseline),
        write_json(repo_state / "skill-map.json", skill_map),
        write_json(repo_state / "skill-usage.json", usage),
        write_json(repo_state / "skill-impact.json", impact),
        write_context(reports / "latest-skill-context.md", skill_map, usage, impact),
    ]
    queue_path = queue or repo_state / "improvement-queue.jsonl"
    queue_stats = queue_candidate_adjustments(queue_path, impact)
    queue_path.touch(exist_ok=True)
    # Run dedup after the append-section releases its lock. The helper
    # acquires its own LOCK_EX, so it's safe to call here.
    dedup_removed = _post_dedup(queue_path)
    # Score gate effectiveness and append retirement / demote candidates for
    # low-impact gates (correlated_with_failure / no_signal with
    # n_loaded >= 20). Inherited gates (those with a derived_from: line in
    # latest-approved-gates.md) are queued as inherited_gate_demote_candidate
    # instead of gate_retirement_candidate so operators can demote them in
    # this repo without affecting siblings that inherited the same gate.
    # Re-run dedup so any duplicate rows are collapsed alongside other
    # candidates.
    gates_md_path = reports / "latest-approved-gates.md"
    inherited_map = _inherited_gates(gates_md_path)
    retirement_count, demote_count = _queue_retirement_candidates(
        queue_path, event_log, inherited=inherited_map
    )
    # Mine the session corpus for correction-correlated n-grams and queue
    # them as domain_rule_candidate rows for operator review. The corpus is
    # expected at repo_state/"session-corpus.txt" unless overridden; if
    # absent, the helper is a no-op. extract_sessions does not currently
    # emit the [session=<id> outcome=<state>] header format the proposer
    # parses; the wiring is forward-looking and degrades gracefully.
    # TODO: future integration test should seed a corpus.txt with proper
    # headers and assert domain_rule_candidate rows appear in the queue.
    corpus_path = corpus or repo_state / "session-corpus.txt"
    domain_count = _queue_domain_rule_candidates(queue_path, corpus_path)
    if retirement_count or demote_count or domain_count:
        dedup_removed += _post_dedup(queue_path)
    touched.append(queue_path)

    queued = queue_stats["queued"]
    suppressed = queue_stats["suppressed_needs_review"]
    suppressed_redacted = queue_stats["suppressed_redacted"]
    if suppressed:
        print(f"refresh: suppressed {suppressed} needs_review rows", file=sys.stderr)
    if suppressed_redacted:
        print(
            f"refresh: suppressed {suppressed_redacted} rows containing secret-like content",
            file=sys.stderr,
        )
    if dedup_removed:
        print(f"refresh: dedup_removed={dedup_removed} near-duplicate rows", file=sys.stderr)
    if retirement_count:
        print(
            f"refresh: queued {retirement_count} gate_retirement_candidate row(s)",
            file=sys.stderr,
        )
    if demote_count:
        print(
            f"refresh: queued {demote_count} inherited_gate_demote_candidate row(s)",
            file=sys.stderr,
        )
    if domain_count:
        print(
            f"refresh: queued {domain_count} domain_rule_candidate row(s)",
            file=sys.stderr,
        )

    return {
        "repo": str(repo),
        "repo_state_dir": str(repo_state),
        "event_log": str(event_log),
        "event_log_present": event_log_present,
        "queued_candidates": queued,
        "suppressed_needs_review": suppressed,
        "suppressed_redacted": suppressed_redacted,
        "dedup_removed": dedup_removed,
        "retirement_candidates_queued": retirement_count,
        "inherited_demote_candidates_queued": demote_count,
        "domain_rule_candidates_queued": domain_count,
        "touched": [str(path) for path in touched],
    }


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--repo", default=".")
    parser.add_argument("--state-dir")
    parser.add_argument("--personal")
    parser.add_argument("--events")
    parser.add_argument("--queue")
    parser.add_argument("--corpus")
    parser.add_argument("--output")
    args = parser.parse_args(argv)

    try:
        result = refresh(
            pathlib.Path(args.repo),
            args.state_dir,
            args.personal,
            pathlib.Path(args.events) if args.events else None,
            pathlib.Path(args.queue) if args.queue else None,
            pathlib.Path(args.corpus) if args.corpus else None,
        )
    except ValueError as error:
        print(str(error), file=sys.stderr)
        return 1
    rendered = json.dumps(result, indent=2, sort_keys=True) + "\n"
    if args.output:
        pathlib.Path(args.output).write_text(rendered, encoding="utf-8")
    else:
        sys.stdout.write(rendered)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
