#!/usr/bin/env python3
"""Analyst correlation script: gate effectiveness and DAG-derived patterns."""

from __future__ import annotations

import argparse
import datetime as dt
import pathlib
import sys
from typing import Any

from artifact_writer import write_artifact
from analyst_queries import (
    load_samples_rows,
    open_events_db,
    query_cache_hit_ratio,
    query_dag_parent_child_cost,
    query_gate_with_evidence,
    query_q10_eval_verdict_roi,
    query_q8_frustration_pairs,
)
from state_handle import StateHandle


def _state_handle(repo: pathlib.Path | None, state: pathlib.Path | None):
    if repo:
        return StateHandle.for_repo(repo)
    if not state:
        raise ValueError("either --repo or --state required")

    class _State:
        pass

    handle = _State()
    handle.repo_state_dir = state.resolve()
    handle.events_sqlite = handle.repo_state_dir / "events.sqlite"
    return handle


def _attach_event_ids(row: dict[str, Any]) -> None:
    if "event_ids" not in row:
        return
    raw = row.get("event_ids")
    if isinstance(raw, str):
        row["evidence"] = {"event_ids": [item for item in raw.split(",") if item]}
    elif isinstance(raw, list):
        row["evidence"] = {"event_ids": raw}
    else:
        row["evidence"] = {"event_ids": []}
    row.pop("event_ids", None)


def run(state_handle: Any) -> dict[str, Any]:
    try:
        conn = open_events_db(state_handle)
    except FileNotFoundError:
        try:
            samples = load_samples_rows(state_handle)
        except ValueError:
            samples = []
        return {
            "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
            "fallback_mode": True,
            "fallback_samples_count": len(samples),
            "gate_effectiveness": [],
            "dag_cost_attribution": [],
            "cache_hit_ratio": [],
            "time_to_stop_patterns": [],
            "eval_roi": [],
            "time_of_day_sessions": [],
        }

    with conn:
        gate_rows = query_gate_with_evidence(conn)
        for row in gate_rows:
            _attach_event_ids(row)
            if "event_ids" not in row["evidence"]:
                row["evidence"] = {"event_ids": []}

        dag_rows = query_dag_parent_child_cost(conn)
        for row in dag_rows:
            _attach_event_ids(row)

        return {
            "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
            "fallback_mode": False,
            "fallback_samples_count": 0,
            "gate_effectiveness": gate_rows,
            "dag_cost_attribution": dag_rows,
            "cache_hit_ratio": query_cache_hit_ratio(conn),
            "time_to_stop_patterns": query_q8_frustration_pairs(conn),
            "eval_roi": query_q10_eval_verdict_roi(conn),
            "time_of_day_sessions": [],
        }


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--repo", type=pathlib.Path)
    parser.add_argument("--state", "--state-dir", dest="state", type=pathlib.Path)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        state_handle = _state_handle(args.repo, args.state)
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2

    result = run(state_handle)
    try:
        write_artifact("correlations", result, state_handle)
    except Exception as exc:
        print(f"failed to write correlations artifact: {exc}", file=sys.stderr)
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
