#!/usr/bin/env python3
"""Backfill transcript events from CLAUDE/Codex transcript logs."""

from __future__ import annotations

import argparse
import datetime as dt
import os
import pathlib
import sys

from state_paths import resolve_state_dir
import event_writer
from transcript_parser import deterministic_event_id, parse_claude_transcript, parse_codex_transcript


def _state_root(state_dir: str | None) -> pathlib.Path:
    return resolve_state_dir(state_dir)


def _parse_since(raw: str) -> dt.timedelta:
    raw = raw.strip().lower()
    unit = raw[-1:]
    value = raw[:-1]
    if not value.isdigit():
        raise ValueError("invalid --since format")
    amount = int(value)
    if amount < 0:
        raise ValueError("negative --since")
    if unit == "d":
        return dt.timedelta(days=amount)
    if unit == "h":
        return dt.timedelta(hours=amount)
    if unit == "m":
        return dt.timedelta(minutes=amount)
    if unit == "s":
        return dt.timedelta(seconds=amount)
    raise ValueError(f"invalid --since unit {unit!r}")


def _transcript_files(root: pathlib.Path) -> list[pathlib.Path]:
    if not root.exists():
        return []
    if root.is_file():
        return [root]
    files = [*root.rglob("*.jsonl"), *root.rglob("*.json")]
    deduped = []
    seen = set()
    for candidate in files:
        resolved = candidate.resolve()
        if resolved in seen:
            continue
        seen.add(resolved)
        deduped.append(candidate)
    return sorted(deduped, key=lambda path: (path.stat().st_mtime, str(path)))


def _collect_rows(path: pathlib.Path, source: str) -> list[dict]:
    if source == "claude":
        return list(parse_claude_transcript(path))
    return list(parse_codex_transcript(path))


def backfill(since: str, *, claude_dir: str, codex_dir: str, state_dir: str | None) -> int:
    root = dt.datetime.now(dt.timezone.utc)
    cutoff = root - _parse_since(since)

    rows: list[dict] = []
    for file in _transcript_files(pathlib.Path(claude_dir).expanduser()):
        if dt.datetime.fromtimestamp(file.stat().st_mtime, tz=dt.timezone.utc) < cutoff:
            continue
        for offset, row in enumerate(_collect_rows(file, source="claude")):
            row["event_id"] = deterministic_event_id(
                row.get("event", "transcript_event"),
                file,
                offset,
            )
            rows.append(row)

    for file in _transcript_files(pathlib.Path(codex_dir).expanduser()):
        if dt.datetime.fromtimestamp(file.stat().st_mtime, tz=dt.timezone.utc) < cutoff:
            continue
        for offset, row in enumerate(_collect_rows(file, source="codex")):
            row["event_id"] = deterministic_event_id(
                row.get("event", "transcript_event"),
                file,
                offset,
            )
            rows.append(row)

    if not rows:
        return 0
    event_writer.write_events_batch(rows, source="transcript")
    return len(rows)


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--since", required=True, help="Duration window, e.g. 7d, 30d, 12h")
    parser.add_argument("--claude-dir", default=os.path.expanduser("~/.claude/projects"))
    parser.add_argument("--codex-dir", default=os.path.expanduser("~/.codex/sessions"))
    parser.add_argument("--state-dir", default=None, help="Override AGENT_LEARNING_STATE_DIR")
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        _parse_since(args.since)
    except ValueError as exc:
        print(f"invalid --since: {exc}", file=sys.stderr)
        return 2

    # Validate state dir early and keep behavior aligned with event_writer writes.
    state = _state_root(args.state_dir)
    if not state:
        return 1
    os.environ["AGENT_LEARNING_STATE_DIR"] = str(state)

    try:
        count = backfill(args.since, claude_dir=args.claude_dir, codex_dir=args.codex_dir, state_dir=args.state_dir)
        print(f"wrote {count} events")
        return 0
    except Exception as exc:  # noqa: BLE001
        print(f"backfill_transcripts failed: {exc}", file=sys.stderr)
        return 1


if __name__ == "__main__":
    raise SystemExit(main())
