#!/usr/bin/env python3
"""Analyst anomaly script: z-score and IQR outlier detection over events.sqlite."""

from __future__ import annotations

import argparse
import datetime as dt
import math
import pathlib
import sys
from typing import Any

from artifact_writer import write_artifact
from analyst_queries import load_samples_rows, open_events_db
from state_handle import StateHandle


def _state_handle(repo: pathlib.Path | None, state: pathlib.Path | None):
    if repo:
        return StateHandle.for_repo(repo)
    if not state:
        raise ValueError("either --repo or --state required")

    class _State:
        pass

    handle = _State()
    handle.repo_state_dir = state.resolve()
    handle.events_sqlite = handle.repo_state_dir / "events.sqlite"
    return handle


def _bucket_measurements(rows: list[dict[str, Any]], keys: tuple[str, ...]) -> dict[tuple[str, ...], list[dict[str, Any]]]:
    buckets: dict[tuple[str, ...], list[dict[str, Any]]] = {}
    for row in rows:
        bucket = tuple(str(row.get(key) or "") for key in keys)
        if any(not item for item in bucket):
            continue
        buckets.setdefault(bucket, []).append(row)
    return buckets


def _zscore_anomalies(rows: list[dict[str, Any]], min_n: int, threshold: float) -> list[dict[str, Any]]:
    output: list[dict[str, Any]] = []
    for bucket_key, bucket_rows in _bucket_measurements(rows, ("actor_name", "event")).items():
        if len(bucket_rows) < min_n:
            continue
        values = [float(r["value"]) for r in bucket_rows if r.get("value") is not None]
        if len(values) < min_n:
            continue
        mean = sum(values) / len(values)
        variance = sum((value - mean) ** 2 for value in values) / len(values)
        stdev = math.sqrt(max(0.0, variance))
        if not stdev:
            continue

        outlier_ids = [
            r["event_id"]
            for r in bucket_rows
            if r.get("value") is not None and abs((float(r["value"]) - mean) / stdev) >= threshold
        ]
        if not outlier_ids:
            continue

        output.append(
            {
                "kind": "duration_zscore",
                "actor_name": bucket_key[0],
                "event": bucket_key[1],
                "sample_count": len(bucket_rows),
                "mean_duration_ms": mean,
                "stdev_duration_ms": stdev,
                "z_threshold": threshold,
                "supporting_event_ids": outlier_ids,
                "evidence": {"event_ids": outlier_ids},
            }
        )
    return output


def _iqr_bounds(values: list[float]) -> tuple[float, float]:
    sorted_values = sorted(values)
    n = len(sorted_values)
    if n == 0:
        return (0.0, 0.0)
    mid = n // 2
    q1 = sorted_values[(n - 1) // 4]
    q3 = sorted_values[(3 * (n - 1)) // 4]
    iqr = q3 - q1
    return q1 - 1.5 * iqr, q3 + 1.5 * iqr


def _iqr_anomalies(rows: list[dict[str, Any]], key_fields: tuple[str, ...], min_n: int, outlier_name: str) -> list[dict[str, Any]]:
    output: list[dict[str, Any]] = []
    for bucket_key, bucket_rows in _bucket_measurements(rows, key_fields).items():
        if len(bucket_rows) < min_n:
            continue
        values = [float(r["value"]) for r in bucket_rows if r.get("value") is not None]
        if len(values) < min_n:
            continue
        low, high = _iqr_bounds(values)
        outlier_ids = [
            r["event_id"]
            for r in bucket_rows
            if r.get("value") is not None and (float(r["value"]) < low or float(r["value"]) > high)
        ]
        if not outlier_ids:
            continue

        row: dict[str, Any] = {
            "outlier_type": outlier_name,
            "sample_count": len(bucket_rows),
            "supporting_event_ids": outlier_ids,
            "evidence": {"event_ids": outlier_ids},
        }
        for index, key in enumerate(key_fields):
            row[key] = bucket_key[index]
        output.append(row)
    return output


def _session_rollups(conn) -> list[dict[str, Any]]:
    rows = [
        {
            "session_id": str(row["session_id"]),
            "event_count": int(row["event_count"]),
            "total_cost_usd": float(row["total_cost_usd"] or 0.0),
            "total_tokens": float(row["total_tokens"] or 0.0),
            "total_duration_ms": float(row["total_duration_ms"] or 0.0),
            "error_count": int(row["error_count"] or 0),
            "evidence": {"event_ids": [item for item in str(row["event_ids"]).split(",") if item]},
        }
        for row in conn.execute(
            """
            SELECT
                session_id,
                COUNT(*) AS event_count,
            ROUND(SUM(COALESCE(telemetry_cost_usd, 0.0)), 3) AS total_cost_usd,
            ROUND(SUM(COALESCE(telemetry_tokens_in, 0) + COALESCE(telemetry_tokens_out, 0)), 3) AS total_tokens,
                ROUND(SUM(COALESCE(telemetry_duration_ms, 0)), 3) AS total_duration_ms,
                SUM(CASE WHEN COALESCE(telemetry_interrupted, 0) = 1 OR error_class IS NOT NULL THEN 1 ELSE 0 END) AS error_count,
                GROUP_CONCAT(event_id) AS event_ids
            FROM events
            WHERE session_id IS NOT NULL
            GROUP BY session_id
            ORDER BY total_cost_usd DESC
            """
        )
    ]
    return rows


def run(state_handle: Any, *, min_n: int = 4, z_threshold: float = 4.0) -> dict[str, Any]:
    try:
        conn = open_events_db(state_handle)
    except FileNotFoundError:
        try:
            samples = load_samples_rows(state_handle)
        except ValueError:
            samples = []
        return {
            "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
            "fallback_mode": True,
            "fallback_samples_count": len(samples),
            "duration_anomalies": [],
            "token_anomalies": [],
            "cost_anomalies": [],
            "session_rollups": [],
        }

    with conn:
        rows = [
            {
                "event_id": row["event_id"],
                "actor_name": row["actor_name"],
                "actor_model": row["actor_model"],
                "actor_kind": row["actor_kind"],
                "event": row["event"],
                "duration_ms": row["telemetry_duration_ms"],
                "token_total": (row["telemetry_tokens_in"] or 0) + (row["telemetry_tokens_out"] or 0),
                "cost_usd": row["telemetry_cost_usd"],
            }
            for row in conn.execute(
                """
                SELECT
                    event_id,
                    actor_name,
                    actor_model,
                    actor_kind,
                    event,
                    telemetry_duration_ms,
                    telemetry_tokens_in,
                    telemetry_tokens_out,
                    telemetry_cost_usd
                FROM events
                """
            )
        ]

        duration_rows = _zscore_anomalies(
            [
                {
                    "event_id": row["event_id"],
                    "actor_name": row["actor_name"],
                    "event": row["event"],
                    "value": row["duration_ms"],
                }
                for row in rows
                if row.get("duration_ms") is not None
            ],
            min_n=min_n,
            threshold=z_threshold,
        )

        token_rows = _iqr_anomalies(
            [
                {
                    "event_id": row["event_id"],
                    "actor_model": row["actor_model"],
                    "skill": row["actor_name"],
                    "event": row["event"],
                    "value": row["token_total"],
                }
                for row in rows
                if row.get("token_total") is not None and row.get("actor_model")
            ],
            key_fields=("actor_model", "skill"),
            min_n=min_n,
            outlier_name="token_iqr",
        )

        cost_rows = _iqr_anomalies(
            [
                {
                    "event_id": row["event_id"],
                    "actor_kind": row["actor_kind"],
                    "event": row["event"],
                    "value": row["cost_usd"],
                }
                for row in rows
                if row.get("cost_usd") is not None
            ],
            key_fields=("actor_kind",),
            min_n=min_n,
            outlier_name="cost_iqr",
        )

        return {
            "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
            "fallback_mode": False,
            "fallback_samples_count": 0,
            "duration_anomalies": duration_rows,
            "token_anomalies": token_rows,
            "cost_anomalies": cost_rows,
            "session_rollups": _session_rollups(conn),
        }


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--repo", type=pathlib.Path)
    parser.add_argument("--state", "--state-dir", dest="state", type=pathlib.Path)
    parser.add_argument("--min-n", type=int, default=4)
    parser.add_argument("--z-threshold", type=float, default=4.0)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        state_handle = _state_handle(args.repo, args.state)
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2

    result = run(state_handle, min_n=args.min_n, z_threshold=args.z_threshold)
    try:
        write_artifact("anomalies", result, state_handle)
    except Exception as exc:
        print(f"failed to write anomalies artifact: {exc}", file=sys.stderr)
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
