#!/usr/bin/env python3
"""Derive DAG edges from persisted events.

Reads ``events.jsonl`` and emits lightweight correlation rows via
``event_writer.write_events_batch(...)`` with ``source="correlation"``.
"""

from __future__ import annotations

import argparse
import datetime as dt
import json
import pathlib
import re
import sys
from collections import defaultdict, deque

try:
    from state_handle import StateHandle
except ImportError:
    from bin.state_handle import StateHandle

from event_writer import write_events_batch
from event_schema import CORRELATION_CHAIN_MAX, EventV4


STATE_EVENTS_FILE = "events.jsonl"
SOURCE = "correlation"


def _normalize_event_name(value: object) -> str:
    text = str(value or "")
    text = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", text)
    text = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", text)
    text = re.sub(r"[^A-Za-z0-9]+", "_", text)
    return text.lower().strip("_")


def _parse_ts(value: object) -> dt.datetime | None:
    if not isinstance(value, str):
        return None
    try:
        parsed = dt.datetime.fromisoformat(value.replace("Z", "+00:00"))
    except ValueError:
        return None
    if parsed.tzinfo is None:
        return None
    return parsed.astimezone(dt.timezone.utc)


def _to_int(value: object) -> int | None:
    if isinstance(value, bool) or value is None:
        return None
    if isinstance(value, int):
        return value
    if isinstance(value, float):
        return int(value)
    return None


def _to_float(value: object) -> float | None:
    if isinstance(value, bool) or value is None:
        return None
    if isinstance(value, (int, float)):
        return float(value)
    return None


def _sum_optional_int(values: list[object]) -> int | None:
    total: int | None = None
    for value in values:
        casted = _to_int(value)
        if casted is None:
            continue
        total = (total or 0) + casted
    return total


def _sum_optional_float(values: list[object]) -> float | None:
    total: float | None = None
    for value in values:
        casted = _to_float(value)
        if casted is None:
            continue
        total = (total or 0.0) + casted
    return total


def _coerce_chain(raw_chain: object) -> list[dict[str, str]]:
    if not isinstance(raw_chain, list):
        return []
    chain: list[dict[str, str]] = []
    for member in raw_chain:
        if not isinstance(member, dict):
            continue
        role = str(member.get("role") or "").strip()
        link_id = str(member.get("id") or "").strip()
        if not role or not link_id:
            continue
        chain.append({"role": role[:200], "id": link_id[:128]})
    return chain


def _correlation_chain(seed_chain: object, actor_parent_id: object) -> list[dict[str, str]]:
    chain = _coerce_chain(seed_chain)
    parent_text = str(actor_parent_id or "").strip()
    if parent_text:
        chain.append({"role": "actor_parent", "id": parent_text[:128]})
    return chain[:CORRELATION_CHAIN_MAX]


def _telemetry_rollup(start: dict, end: dict) -> dict[str, object]:
    start_telemetry = start.get("telemetry") if isinstance(start.get("telemetry"), dict) else {}
    end_telemetry = end.get("telemetry") if isinstance(end.get("telemetry"), dict) else {}
    start_ts = _parse_ts(start.get("ts"))
    end_ts = _parse_ts(end.get("ts"))
    if not start_ts or not end_ts:
        return {}
    result = {
        "duration_ms": max(0, int((end_ts - start_ts).total_seconds() * 1000)),
    }
    token_fields = ["tokens_in", "tokens_out", "cache_read_tokens", "cache_creation_tokens"]
    for field in token_fields:
        rolled = _sum_optional_int([start_telemetry.get(field), end_telemetry.get(field)])
        if rolled is not None:
            result[field] = rolled
    cost = _sum_optional_float([start_telemetry.get("cost_usd"), end_telemetry.get("cost_usd")])
    if cost is not None:
        result["cost_usd"] = cost
    start_interrupted = bool(start_telemetry.get("interrupted"))
    end_interrupted = bool(end_telemetry.get("interrupted"))
    if start_interrupted or end_interrupted:
        result["interrupted"] = True
    return result


def _correlation_key(event: dict[str, object]) -> tuple[str, str, str, str, str]:
    actor = event.get("actor") if isinstance(event.get("actor"), dict) else {}
    actor_kind = str(actor.get("kind") or "main_agent")
    actor_name = str(actor.get("name") or "")
    parent_actor_id = str(actor.get("parent_actor_id") or "")
    correlation_id = str(
        event.get("correlation_id")
        or event.get("session_id")
        or event.get("dispatch_id")
        or event.get("agent_id")
        or ""
    )
    tool = str(event.get("tool") or event.get("command_class") or "")
    return (actor_kind, actor_name, parent_actor_id, correlation_id, tool)


def _derive_actor(seed: dict[str, object]) -> dict[str, object]:
    actor = seed.get("actor") if isinstance(seed.get("actor"), dict) else {}
    return {
        "kind": actor.get("kind") or "main_agent",
        "name": actor.get("name") or "correlation",
        "model": actor.get("model"),
        "parent_actor_id": actor.get("parent_actor_id"),
    }


def _derive_pair_row(start: dict, end: dict, event_name: str) -> dict[str, object]:
    start_actor = start.get("actor") if isinstance(start.get("actor"), dict) else {}
    actor_kind = str(start_actor.get("kind") or "main_agent")
    pair_key = f"{start.get('event_id')}:{end.get('event_id')}"
    event_id = EventV4.deterministic_id(
        actor_kind=actor_kind,
        event_type=event_name,
        payload_key=f"derive:{pair_key}",
    )
    return {
        "event_id": event_id,
        "ts": end.get("ts"),
        "event": event_name,
        "schema_version": 4,
        "actor": _derive_actor(start),
        "telemetry": _telemetry_rollup(start, end),
        "correlation_chain": _correlation_chain(
            start.get("correlation_chain"),
            start_actor.get("parent_actor_id"),
        ),
        "parent_event_id": start.get("event_id"),
        "tool_server": start.get("tool_server") or end.get("tool_server"),
        "error_class": start.get("error_class") or end.get("error_class"),
    }


def _iter_rows(path: pathlib.Path):
    with path.open(encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if not line:
                continue
            try:
                payload = json.loads(line)
            except json.JSONDecodeError:
                print(f"warn.invalid_json ignored path={path}", file=sys.stderr)
                continue
            if not isinstance(payload, dict):
                print(f"warn.non_object ignored path={path}", file=sys.stderr)
                continue
            yield payload


def _pair_events(rows):
    pending_tools: dict[tuple[str, str, str, str, str], deque[dict]] = defaultdict(deque)
    pending_subagents: dict[tuple[str, str, str, str, str], deque[dict]] = defaultdict(deque)
    derived: list[dict[str, object]] = []

    for row in rows:
        event_name = _normalize_event_name(row.get("event"))
        ts = _parse_ts(row.get("ts"))
        if ts is None:
            print(f"warn.invalid_ts event_id={row.get('event_id')}", file=sys.stderr)
            continue
        event_id = row.get("event_id")
        if not isinstance(event_id, str):
            print("warn.missing_event_id ignored", file=sys.stderr)
            continue

        if event_name == "pre_tool_use":
            pending_tools[_correlation_key(row)].append(row)
            continue

        if event_name == "post_tool_use":
            queue = pending_tools[_correlation_key(row)]
            if queue:
                start = queue.popleft()
                derived.append(_derive_pair_row(start, row, "tool_use_pair"))
            else:
                print(f"warn.unmatched_post_tool_use event_id={event_id}", file=sys.stderr)
            continue

        if event_name == "subagent_start":
            pending_subagents[_correlation_key(row)].append(row)
            continue

        if event_name == "subagent_end":
            queue = pending_subagents[_correlation_key(row)]
            if queue:
                start = queue.popleft()
                derived.append(_derive_pair_row(start, row, "subagent_run"))
            else:
                print(f"warn.unmatched_subagent_end event_id={event_id}", file=sys.stderr)
            continue

    return derived, pending_tools, pending_subagents


def _load_existing_ids(path: pathlib.Path) -> set[str]:
    ids: set[str] = set()
    for row in _iter_rows(path):
        event_id = row.get("event_id")
        if isinstance(event_id, str):
            ids.add(event_id)
    return ids


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--state-dir", help="State root override")
    args = parser.parse_args(argv)

    state_root, _ = StateHandle.resolve_state_root(state_dir=args.state_dir)
    state_root = pathlib.Path(state_root).expanduser().resolve()
    events_path = state_root / STATE_EVENTS_FILE

    if not events_path.exists():
        print(f"warn.events_jsonl_missing path={events_path}", file=sys.stderr)
        return 0

    existing_ids = _load_existing_ids(events_path)
    rows = list(_iter_rows(events_path))
    if not rows:
        return 0

    derived, pending_tools, pending_subagents = _pair_events(rows)

    to_write = [
        row
        for row in derived
        if isinstance(row.get("event_id"), str) and row["event_id"] not in existing_ids
    ]

    if to_write:
        write_events_batch(to_write, source=SOURCE, auto_id_fallback=False)

    unmatched_starts = sum(len(queue) for queue in pending_tools.values()) + sum(len(queue) for queue in pending_subagents.values())
    if unmatched_starts:
        if any(queue for queue in pending_tools.values()):
            print("warn.subtool_start_without_post", file=sys.stderr)
        if any(queue for queue in pending_subagents.values()):
            print("warn.subagent_start_without_end", file=sys.stderr)

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
