#!/usr/bin/env python3
"""Analyst score script: rank recommendation candidates from anomaly + correlation signals."""

from __future__ import annotations

import argparse
import datetime as dt
import json
import math
import pathlib
import sys
from typing import Any

from artifact_writer import write_artifact
from analyst_queries import (
    load_samples_rows,
    open_events_db,
    query_cache_hit_ratio,
    query_dag_parent_child_cost,
    query_gate_with_evidence,
    query_q1_longest_by_skill,
)
from state_handle import StateHandle


def _state_handle(repo: pathlib.Path | None, state: pathlib.Path | None):
    if repo:
        return StateHandle.for_repo(repo)
    if not state:
        raise ValueError("either --repo or --state required")

    class _State:
        pass

    handle = _State()
    handle.repo_state_dir = state.resolve()
    handle.events_sqlite = handle.repo_state_dir / "events.sqlite"
    handle.outcomes_json = handle.repo_state_dir / "outcomes.json"
    return handle


def _evidence_strength(sample_count: int) -> float:
    return math.log(max(1, sample_count))


def _load_outcome_weights(state_handle: Any) -> dict[str, float]:
    try:
        payload = json.loads(state_handle.outcomes_json.read_text(encoding="utf-8"))
    except (OSError, json.JSONDecodeError):
        return {}

    if not isinstance(payload, list):
        return {}

    totals: dict[str, dict[str, int]] = {}
    for row in payload:
        if not isinstance(row, dict):
            continue
        kind = row.get("kind")
        if not isinstance(kind, str) or not kind:
            continue
        verdict = str(row.get("verdict") or row.get("result") or "").strip().lower()
        bucket = totals.setdefault(kind, {"p": 0, "n": 0})
        if verdict in {"positive", "pass", "approve", "accepted", "ok", "yes", "true"}:
            bucket["p"] += 1
        elif verdict in {"negative", "reject", "failed", "fail", "bad", "no", "false"}:
            bucket["n"] += 1

    return {
        kind: (1 + counts["p"] - counts["n"]) / (1 + counts["p"] + counts["n"])
        for kind, counts in totals.items()
    }


def _score_rows(recommendations: list[dict[str, Any]], weights: dict[str, float]) -> list[dict[str, Any]]:
    scored: list[dict[str, Any]] = []
    for rec in recommendations:
        kind = str(rec.get("kind") or "")
        outcome_weight = float(weights.get(kind, 1.0))
        sample_count = int(rec.get("supporting_events", 0))
        evidence_strength = _evidence_strength(sample_count)

        score = (
            float(rec.get("impact", 0.0))
            * float(rec.get("confidence", 0.0))
            * outcome_weight
            * evidence_strength
        )

        rec = dict(rec)
        rec["outcome_weight"] = outcome_weight
        rec["evidence_strength"] = evidence_strength
        rec["score"] = score
        scored.append(rec)
    return sorted(scored, key=lambda item: item["score"], reverse=True)


def _run_from_events(state_handle: Any, limit: int) -> dict[str, Any]:
    conn = open_events_db(state_handle)
    with conn:
        recommendations: list[dict[str, Any]] = []

        for row in query_q1_longest_by_skill(conn, min_events=2):
            sample_count = int(row.get("sample_count", 0))
            recommendations.append(
                {
                    "recommendation_id": f"rec-duration-{row.get('skill')}-{row.get('event_name')}",
                    "kind": "anomaly_duration_spike",
                    "title": f"Duration spike for skill {row.get('skill')}",
                    "skill": row.get("skill"),
                    "event": row.get("event_name"),
                    "impact": float(row.get("avg_duration_ms") or 0.0),
                    "confidence": min(1.0, 0.2 + sample_count * 0.05),
                    "supporting_events": sample_count,
                    "evidence": {"event_ids": []},
                }
            )

        for row in query_gate_with_evidence(conn):
            sample_count = int(row.get("n_loaded_tools", 0))
            if sample_count < 2:
                continue
            recommendations.append(
                {
                    "recommendation_id": f"rec-gate-{row.get('gate_id')}",
                    "kind": "correlation_gate_effectiveness",
                    "title": f"Gate effectiveness regression for {row.get('gate_id')}",
                    "gate_id": row.get("gate_id"),
                    "impact": 1.0 - float(row.get("pass_rate") or 0.0),
                    "confidence": min(1.0, int(row.get("n_error", 0) or 0) / max(1, sample_count)),
                    "supporting_events": sample_count,
                    "evidence": {"event_ids": row.get("event_ids", [])},
                }
            )

        for row in query_dag_parent_child_cost(conn):
            sample_count = int(row.get("event_count", 0))
            recommendations.append(
                {
                    "recommendation_id": f"rec-dag-{row.get('parent_actor')}-{row.get('child_actor')}",
                    "kind": "correlation_dag_cost",
                    "title": f"High-cost parent-child pair {row.get('parent_actor')} -> {row.get('child_actor')}",
                    "parent_actor": row.get("parent_actor"),
                    "child_actor": row.get("child_actor"),
                    "impact": float(row.get("total_child_cost_usd") or 0.0) / 10.0,
                    "confidence": min(1.0, 0.2 + sample_count * 0.1),
                    "supporting_events": sample_count,
                    "evidence": {"event_ids": row.get("event_ids", [])},
                }
            )

        for row in query_cache_hit_ratio(conn):
            sample_count = int(row.get("sample_count", 0))
            cache_hit_ratio = float(row.get("cache_hit_ratio") or 0.0)
            if cache_hit_ratio >= 0.9:
                continue
            recommendations.append(
                {
                    "recommendation_id": f"rec-cache-{row.get('session_id')}",
                    "kind": "correlation_cache_hit",
                    "title": f"Improve cache-hit ratio for session {row.get('session_id')}",
                    "session_id": row.get("session_id"),
                    "impact": max(0.0, 1.0 - cache_hit_ratio),
                    "confidence": min(1.0, sample_count / 20.0),
                    "supporting_events": sample_count,
                    "evidence": {"event_ids": []},
                }
            )

    weights = _load_outcome_weights(state_handle)
    scored = _score_rows(recommendations, weights)

    return {
        "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
        "fallback_mode": False,
        "fallback_samples_count": 0,
        "outcome_weights": weights,
        "recommendations": scored[:limit],
    }


def _run_fallback(state_handle: Any, limit: int) -> dict[str, Any]:
    try:
        samples = load_samples_rows(state_handle)
    except ValueError:
        samples = []
    if not samples:
        return {
            "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
            "fallback_mode": True,
            "fallback_samples_count": 0,
            "outcome_weights": {},
            "recommendations": [],
        }

    sample_rows: list[dict[str, Any]] = []
    for idx, row in enumerate(samples[:limit], start=1):
        sample_rows.append(
            {
                "recommendation_id": f"rec-sample-{idx}",
                "kind": "sample_aggregate",
                "title": "Fallback recommendation from samples",
                "supporting_events": 1,
                "impact": 0.2,
                "confidence": 0.2,
                "evidence": {"event_ids": []},
            }
        )

    return {
        "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
        "fallback_mode": True,
        "fallback_samples_count": len(samples),
        "outcome_weights": {},
        "recommendations": sample_rows,
    }


def run(state_handle: Any, *, limit: int = 25) -> dict[str, Any]:
    try:
        return _run_from_events(state_handle, limit=limit)
    except FileNotFoundError:
        return _run_fallback(state_handle, limit=limit)


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--repo", type=pathlib.Path)
    parser.add_argument("--state", "--state-dir", dest="state", type=pathlib.Path)
    parser.add_argument("--limit", type=int, default=25)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        state_handle = _state_handle(args.repo, args.state)
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2

    payload = run(state_handle, limit=args.limit)
    try:
        write_artifact("recommendations", payload, state_handle)
    except Exception as exc:
        print(f"failed to write recommendations artifact: {exc}", file=sys.stderr)
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
