#!/usr/bin/env python3
"""Incremental indexer for ``events.jsonl`` into ``events.sqlite``."""

from __future__ import annotations

import argparse
import json
import os
import pathlib
import sqlite3
import sys
from dataclasses import asdict
from typing import Any

try:
    from state_handle import StateHandle
except ImportError:  # pragma: no cover
    from bin.state_handle import StateHandle

try:
    from event_schema import EventV4
except ImportError:  # pragma: no cover
    from bin.event_schema import EventV4


CURSOR_FILE = "events.sqlite.cursor"
META_TABLE = "events_meta"
EXPECTED_SCHEMA_VERSION = 4


def _resolve_state_path(
    repo: pathlib.Path | None = None,
    state: pathlib.Path | None = None,
) -> pathlib.Path:
    if state is not None:
        return state.expanduser().resolve()
    if repo is not None:
        return StateHandle.for_repo(repo).repo_state_dir
    if env_state := os.environ.get("AGENT_LEARNING_STATE_DIR"):
        return pathlib.Path(env_state).expanduser().resolve()
    return pathlib.Path.cwd().resolve()


def _read_cursor(path: pathlib.Path) -> int:
    try:
        return int(path.read_text(encoding="utf-8").strip())
    except (OSError, ValueError):
        return 0


def _write_cursor(path: pathlib.Path, value: int) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(f"{value}\n", encoding="utf-8")


def _extract_session_id(raw: dict[str, Any]) -> str | None:
    for key in ("session_id", "sessionId", "session"):
        value = raw.get(key)
        if isinstance(value, str) and value.strip():
            return value

    for container_key in ("payload", "context"):
        container = raw.get(container_key)
        if not isinstance(container, dict):
            continue
        for key in ("session_id", "sessionId", "session"):
            value = container.get(key)
            if isinstance(value, str) and value.strip():
                return value
    return None


def _parse_event(raw: dict[str, Any]) -> EventV4:
    schema_version = raw.get("schema_version", 4)
    try:
        schema_number = int(schema_version)
    except (TypeError, ValueError):
        schema_number = 4

    if schema_number == 3:
        return EventV4.upgrade_from(raw)
    if schema_number and schema_number < 4:
        return EventV4.upgrade_from(raw)
    return EventV4.from_dict(raw)


def _event_to_row(event: EventV4, session_id: str | None) -> tuple[Any, ...]:
    telemetry = asdict(event.telemetry)
    interrupted = telemetry.get("interrupted")
    telemetry_interrupted = (
        1 if interrupted is True else 0 if interrupted is False else None
    )
    return (
        event.event_id,
        event.ts,
        event.event,
        event.schema_version,
        event.actor.kind,
        event.actor.name,
        event.actor.model,
        event.actor.parent_actor_id,
        telemetry.get("duration_ms"),
        telemetry.get("tokens_in"),
        telemetry.get("tokens_out"),
        telemetry.get("cache_read_tokens"),
        telemetry.get("cache_creation_tokens"),
        telemetry.get("cost_usd"),
        telemetry_interrupted,
        json.dumps([asdict(link) for link in event.correlation_chain], sort_keys=True, separators=(",", ":")),
        event.parent_event_id,
        event.tool_server,
        event.error_class,
        session_id,
    )


def _ensure_schema(conn: sqlite3.Connection) -> bool:
    # sqlite_ddl returns CREATE TABLE + CREATE INDEX (multi-statement) — use executescript.
    conn.executescript(EventV4.sqlite_ddl())
    conn.execute(
        f"CREATE TABLE IF NOT EXISTS {META_TABLE}(key TEXT PRIMARY KEY, value TEXT NOT NULL)"
    )
    row = conn.execute(f"SELECT value FROM {META_TABLE} WHERE key='schema_version'").fetchone()

    if row is None:
        conn.execute(
            f"INSERT INTO {META_TABLE}(key, value) VALUES (?, ?)",
            ("schema_version", str(EXPECTED_SCHEMA_VERSION)),
        )
        conn.commit()

    else:
        try:
            schema_value = int(row[0])
        except (TypeError, ValueError):
            schema_value = -1
        if schema_value != EXPECTED_SCHEMA_VERSION:
            raise RuntimeError(
                f"events.sqlite schema mismatch: expected schema_version={EXPECTED_SCHEMA_VERSION}, "
                f"found {row[0]}; re-index required"
            )

    columns = {row["name"] for row in conn.execute("PRAGMA table_info(events)").fetchall()}
    has_session_id = "session_id" in columns
    if not has_session_id:
        conn.execute("ALTER TABLE events ADD COLUMN session_id TEXT")
        has_session_id = True
    return has_session_id


def run(state_dir: pathlib.Path) -> int:
    state_dir.mkdir(parents=True, exist_ok=True)
    events_jsonl = state_dir / "events.jsonl"
    events_sqlite = state_dir / "events.sqlite"
    cursor_path = state_dir / CURSOR_FILE

    conn = sqlite3.connect(events_sqlite)
    conn.row_factory = sqlite3.Row
    try:
        conn.execute("BEGIN IMMEDIATE")
        has_session_id = _ensure_schema(conn)

        if not events_jsonl.exists():
            _write_cursor(cursor_path, 0)
            return 0

        start_offset = _read_cursor(cursor_path)
        file_size = events_jsonl.stat().st_size
        if file_size < start_offset:
            start_offset = 0

        rows: list[tuple[Any, ...]] = []
        cursor = start_offset
        columns = [
            "event_id",
            "ts",
            "event",
            "schema_version",
            "actor_kind",
            "actor_name",
            "actor_model",
            "actor_parent_actor_id",
            "telemetry_duration_ms",
            "telemetry_tokens_in",
            "telemetry_tokens_out",
            "telemetry_cache_read_tokens",
            "telemetry_cache_creation_tokens",
            "telemetry_cost_usd",
            "telemetry_interrupted",
            "correlation_chain",
            "parent_event_id",
            "tool_server",
            "error_class",
        ]
        if has_session_id:
            columns.append("session_id")
        insert_sql = f"INSERT INTO events ({','.join(columns)}) VALUES ({','.join(['?'] * len(columns))})"

        skipped = 0
        with open(events_jsonl, "rb") as handle:
            handle.seek(start_offset)
            cursor = start_offset
            for line in handle:
                cursor = handle.tell()
                if not line:
                    continue
                text = line.strip()
                if not text:
                    continue
                # Quarantine bad rows: one corrupt line shouldn't block all future
                # indexing. Log to stderr + advance cursor; never let a single bad
                # row wedge the indexer (the prior behavior re-failed forever).
                # _event_to_row is INSIDE the try so an _extract_session_id /
                # asdict failure on a malformed-but-parsed event still quarantines.
                try:
                    payload = json.loads(text.decode("utf-8"))
                    if not isinstance(payload, dict):
                        raise ValueError("event row is not a JSON object")
                    event = _parse_event(payload)
                    row = _event_to_row(event, _extract_session_id(payload))
                except (json.JSONDecodeError, ValueError, UnicodeDecodeError, TypeError, KeyError) as exc:
                    print(f"warn.index_events_skipped offset={cursor} reason={exc!s}", file=sys.stderr)
                    skipped += 1
                    continue
                rows.append(row)
        if skipped:
            print(f"warn.index_events_total_skipped count={skipped}", file=sys.stderr)

        if rows:
            conn.executemany(insert_sql, rows)

        _write_cursor(cursor_path, cursor)
        conn.commit()
        return len(rows)
    finally:
        conn.close()


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--state", "--state-dir", dest="state", type=pathlib.Path, default=None)
    parser.add_argument("--repo", type=pathlib.Path, default=None)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        state_dir = _resolve_state_path(repo=args.repo, state=args.state)
        added = run(state_dir)
    except RuntimeError as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        return 1
    except Exception as exc:
        print(f"ERROR: {exc}", file=sys.stderr)
        return 2
    if added:
        print(f"indexed {added} events")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
