#!/usr/bin/env python3
"""Analyst pattern script: actor frequency, DAG co-occurrence, and time patterns."""

from __future__ import annotations

import argparse
import datetime as dt
import json
import pathlib
import sys
from collections import defaultdict
from typing import Any

from artifact_writer import write_artifact
from analyst_queries import load_samples_rows, open_events_db, query_q6_time_of_day_sessions
from state_handle import StateHandle


def _state_handle(repo: pathlib.Path | None, state: pathlib.Path | None):
    if repo:
        return StateHandle.for_repo(repo)
    if not state:
        raise ValueError("either --repo or --state required")

    class _State:
        pass

    root = state.resolve()
    handle = _State()
    handle.repo_state_dir = root
    handle.events_sqlite = root / "events.sqlite"
    return handle


def _safe_parse_ts(value: str | None) -> dt.datetime | None:
    if not value:
        return None
    try:
        return dt.datetime.fromisoformat(value.replace("Z", "+00:00"))
    except (TypeError, ValueError):
        return None


def _frequency_rows(conn) -> list[dict[str, Any]]:
    rows = [
        {
            "skill": row["skill"],
            "actor_kind": row["actor_kind"],
            "sample_count": int(row["sample_count"]),
            "evidence": {"event_ids": [item for item in str(row["event_ids"]).split(",") if item]},
        }
        for row in conn.execute(
            """
            SELECT
                actor_name AS skill,
                actor_kind,
                COUNT(*) AS sample_count,
                GROUP_CONCAT(event_id) AS event_ids
            FROM events
            WHERE actor_name IS NOT NULL
              AND actor_kind IS NOT NULL
            GROUP BY actor_name, actor_kind
            ORDER BY sample_count DESC
            """
        )
    ]
    return rows


def _co_occurrence_pairs(conn) -> list[dict[str, Any]]:
    rows = [dict(row) for row in conn.execute("SELECT event_id, actor_name, ts, correlation_chain FROM events")]
    by_event: dict[str, dict[str, Any]] = {
        row["event_id"]: {"actor": row["actor_name"], "ts": _safe_parse_ts(row["ts"])}
        for row in rows
        if row["actor_name"] and row["event_id"]
    }

    buckets: dict[tuple[str, str], list[str]] = defaultdict(list)
    for row in rows:
        child_id = row.get("event_id")
        child_actor = row.get("actor_name")
        child_ts = _safe_parse_ts(row.get("ts"))
        chain_raw = row.get("correlation_chain")
        if not child_id or not child_actor or child_ts is None or not chain_raw:
            continue

        try:
            chain = json.loads(chain_raw)
        except (TypeError, json.JSONDecodeError):
            continue
        if not isinstance(chain, list):
            continue

        for link in chain:
            if not isinstance(link, dict):
                continue
            parent_id = str(link.get("id", "")).strip()
            if not parent_id:
                continue
            parent = by_event.get(parent_id)
            if not parent or parent["actor"] is None:
                continue
            parent_ts = parent["ts"]
            if parent_ts is None:
                continue
            if child_ts < parent_ts:
                continue
            if (child_ts - parent_ts).total_seconds() > 10:
                continue

            buckets[(parent["actor"], child_actor)].append(child_id)

    output: list[dict[str, Any]] = []
    for (parent_actor, child_actor), pair_ids in sorted(
        buckets.items(), key=lambda item: len(item[1]), reverse=True
    ):
        output.append(
            {
                "parent_actor": parent_actor,
                "child_actor": child_actor,
                "pair_count": len(pair_ids),
                "evidence": {"event_ids": sorted(set(pair_ids))},
            }
        )
    return output


def _time_of_day_rows(conn) -> list[dict[str, Any]]:
    rows = query_q6_time_of_day_sessions(conn)
    for row in rows:
        row["evidence"] = {"event_ids": []}
    return rows


def _fallback_payload(state_handle: Any) -> dict[str, Any]:
    try:
        samples = load_samples_rows(state_handle)
    except ValueError:
        samples = []
    return {
        "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
        "fallback_mode": True,
        "fallback_samples_count": len(samples),
        "frequency_by_skill": [],
        "co_occurrence_pairs": [],
        "time_of_day_patterns": [],
    }


def run(state_handle: Any) -> dict[str, Any]:
    try:
        conn = open_events_db(state_handle)
    except FileNotFoundError:
        return _fallback_payload(state_handle)

    with conn:
        return {
            "generated_at": dt.datetime.now(dt.timezone.utc).isoformat(),
            "fallback_mode": False,
            "fallback_samples_count": 0,
            "frequency_by_skill": _frequency_rows(conn),
            "co_occurrence_pairs": _co_occurrence_pairs(conn),
            "time_of_day_patterns": _time_of_day_rows(conn),
        }


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--repo", type=pathlib.Path)
    parser.add_argument("--state", "--state-dir", dest="state", type=pathlib.Path)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        state = _state_handle(args.repo, args.state)
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2

    try:
        payload = run(state)
        write_artifact("patterns", payload, state)
    except Exception as exc:
        print(f"failed to write patterns artifact: {exc}", file=sys.stderr)
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
