#!/usr/bin/env python3
"""Evaluate recent recommendations with an archived judge agent."""

from __future__ import annotations

import argparse
import contextlib
import datetime as dt
import json
import os
import pathlib
import sys
from typing import Any

THIS_DIR = pathlib.Path(__file__).resolve().parent
if str(THIS_DIR) not in sys.path:
    sys.path.insert(0, str(THIS_DIR))

try:
    from alc_invoke import invoke
    from alc_query import get_recommendations
    from event_schema import EventV4
    from event_writer import write_event
    from state_handle import StateHandle
except Exception:  # pragma: no cover
    from bin.alc_invoke import invoke
    from bin.alc_query import get_recommendations
    from bin.event_schema import EventV4
    from bin.event_writer import write_event
    from bin.state_handle import StateHandle


VERDICTS = {"approve", "reject", "modify"}


@contextlib.contextmanager
def _event_writer_state(state: StateHandle):
    # event_writer resolves AGENT_LEARNING_STATE_DIR → events.jsonl at <state_dir>/events.jsonl;
    # StateHandle.events_jsonl == repo_state_dir/events.jsonl, so we point at repo_state_dir
    # (not state_root) to land in the indexed surface alc_query/index_events read.
    previous = os.environ.get("AGENT_LEARNING_STATE_DIR")
    os.environ["AGENT_LEARNING_STATE_DIR"] = str(state.repo_state_dir)
    try:
        yield
    finally:
        if previous is None:
            os.environ.pop("AGENT_LEARNING_STATE_DIR", None)
        else:
            os.environ["AGENT_LEARNING_STATE_DIR"] = previous


def _parse_window(value: str) -> dt.datetime:
    unit = value[-1:].lower()
    amount = int(value[:-1]) if value[:-1].isdigit() else -1
    if amount < 0 or unit not in {"d", "w"}:
        raise ValueError("--window must look like 7d or 2w")
    days = amount * (7 if unit == "w" else 1)
    return dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=days)


def _resolve_judge(state: StateHandle, judge_ref: str) -> pathlib.Path | None:
    ref = pathlib.Path(judge_ref.lstrip("./"))
    if ref.parts and ref.parts[0] == "evals" and len(ref.parts) >= 2:
        candidate = state.alc_agents_dirs["evals"] / pathlib.Path(*ref.parts[1:])
        if candidate.suffix != ".md":
            candidate = candidate.with_suffix(".md")
        return candidate if candidate.is_file() else None
    candidate = state.alc_agents_dirs["evals"] / ref
    if candidate.suffix != ".md":
        candidate = candidate.with_suffix(".md")
    return candidate if candidate.is_file() else None


def _rec_id(rec: dict[str, Any]) -> str:
    for key in ("rec_id", "recommendation_id", "id", "patch_id"):
        value = rec.get(key)
        if value:
            return str(value)
    return EventV4.deterministic_id("recommendation", "synthetic", json.dumps(rec, sort_keys=True))


def _rec_ts(rec: dict[str, Any]) -> dt.datetime | None:
    value = rec.get("ts") or rec.get("created_at") or rec.get("timestamp")
    if not value:
        return None
    try:
        return dt.datetime.fromisoformat(str(value).replace("Z", "+00:00"))
    except ValueError:
        return None


def _recent_recommendations(state: StateHandle, window: str, limit: int) -> list[dict[str, Any]]:
    if limit == 0:
        return []
    cutoff = _parse_window(window)
    rows = []
    for rec in get_recommendations(state):
        ts = _rec_ts(rec)
        if ts is None or ts >= cutoff:
            rows.append(rec)
    return rows if limit < 0 else rows[:limit]


def _existing_event_ids(state: StateHandle) -> set[str]:
    path = state.events_jsonl
    if not path.is_file():
        return set()
    ids: set[str] = set()
    for line in path.read_text(encoding="utf-8").splitlines():
        try:
            value = json.loads(line)
        except json.JSONDecodeError:
            continue
        # JSONL rows are normally dicts, but a stray null/list/number/string
        # line shouldn't crash dedup — without isinstance, .get on those raises
        # AttributeError and the whole eval batch silently aborts.
        if not isinstance(value, dict):
            continue
        event_id = value.get("event_id")
        if event_id:
            ids.add(str(event_id))
    return ids


def _parse_verdict(output: str) -> tuple[str, str, bool]:
    try:
        payload = json.loads(output)
        verdict = str(payload.get("verdict", "")).strip().lower()
        reason = str(payload.get("judge_reason") or payload.get("reason") or "").strip()
        if verdict in VERDICTS:
            return verdict, reason[:180] or "No judge reason supplied.", False
    except Exception:
        pass
    return "modify", "Judge returned malformed JSON; defaulted to modify.", True


def _task_for(rec: dict[str, Any]) -> str:
    return json.dumps({"evaluate_recommendation": rec}, sort_keys=True, separators=(",", ":"))


def run(*, repo: pathlib.Path, window: str, limit: int, judge: str) -> tuple[int, int]:
    state = StateHandle.for_repo(repo)
    if _resolve_judge(state, judge) is None:
        print(f"judge agent not found: {judge}", file=sys.stderr)
        return 1, 0

    written = 0
    existing = _existing_event_ids(state)
    for rec in _recent_recommendations(state, window, limit):
        rec_id = _rec_id(rec)
        verdict_ts = str(rec.get("ts") or rec.get("created_at") or "")
        event_id = EventV4.deterministic_id("eval_judge", "eval_verdict", f"{rec_id}:{verdict_ts}")
        if event_id in existing:
            continue

        response = invoke(repo=repo, agent_ref=judge, task=_task_for(rec))
        verdict, reason, malformed = _parse_verdict(str(response.get("output", "")))
        if malformed:
            print(f"warning: malformed judge JSON for {rec_id}; using modify", file=sys.stderr)

        row = {
            "event_id": event_id,
            "event": "eval_verdict",
            "actor": {"kind": "eval_judge", "name": "rec-quality-judge"},
            "correlation_chain": [{"role": "evaluated_rec", "id": rec_id}],
            "payload": {
                "verdict": verdict,
                "judge_reason": reason,
                "recommendation_kind": str(rec.get("kind") or rec.get("recommendation_kind") or "unknown"),
            },
        }
        with _event_writer_state(state):
            write_event(row, source="eval", auto_id_fallback=False)
        existing.add(event_id)
        written += 1
    return 0, written


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--window", required=True, help="Recommendation lookback window, e.g. 7d")
    parser.add_argument("--limit", type=int, default=20, help="Maximum recommendations to evaluate")
    parser.add_argument("--judge", default="evals/rec-quality-judge", help="Archive-relative judge agent")
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    if args.limit < 0:
        print("--limit must be >=0", file=sys.stderr)
        return 2
    try:
        code, written = run(repo=pathlib.Path.cwd().resolve(), window=args.window, limit=args.limit, judge=args.judge)
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2
    if code == 0:
        print(json.dumps({"eval_verdict_events": written}, sort_keys=True))
    return code


if __name__ == "__main__":
    raise SystemExit(main())
