#!/usr/bin/env python3
"""Extract text-bearing messages from agent JSONL transcripts."""

from __future__ import annotations

import argparse
import dataclasses
import json
import os
import pathlib
import re
import sys
import time

from scrub_secrets import scrub as scrub_secrets

DEFAULT_MAX_SESSIONS = 50
SAMPLING_STRATEGY = "oldest10_middle15_newest25"
UUIDISH_RE = re.compile(r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", re.I)


@dataclasses.dataclass(frozen=True)
class FileSelection:
    files: list[pathlib.Path]
    total: int
    sampled: bool
    root: pathlib.Path


def session_ref(file_path: pathlib.Path) -> str:
    raw = f"{file_path.parent.name}/{file_path.stem}"
    raw = UUIDISH_RE.sub("session", raw)
    return re.sub(r"[^A-Za-z0-9_.@/-]+", "-", raw)[-96:]


def message_text(content) -> str:
    if isinstance(content, str):
        return content
    if isinstance(content, dict):
        if isinstance(content.get("text"), str):
            return content["text"]
        if isinstance(content.get("parts"), list):
            return "\n".join(str(part) for part in content["parts"] if isinstance(part, str))
    if isinstance(content, list):
        pieces = []
        for block in content:
            if isinstance(block, dict) and block.get("type") in {"text", "input_text", "output_text"}:
                pieces.append(str(block.get("text", "")))
            elif isinstance(block, str):
                pieces.append(block)
        return "\n".join(piece for piece in pieces if piece)
    return ""


def normalize_message(event: dict) -> dict:
    if isinstance(event.get("message"), dict):
        return normalize_message(event["message"])
    payload = event.get("payload")
    if isinstance(payload, dict) and payload.get("type") == "message":
        return normalize_message(payload)
    author = event.get("author")
    if isinstance(author, dict) and author.get("role"):
        return {"role": author.get("role"), "content": event.get("content")}
    if event.get("role"):
        return {"role": event.get("role"), "content": event.get("content")}
    return {}


def transcript_files(path: pathlib.Path, days: int | None) -> list[pathlib.Path]:
    cutoff = None if days is None else time.time() - days * 86400
    if path.is_file():
        files = [path]
    else:
        files = [*path.rglob("*.jsonl"), *path.rglob("*.json")]
    selected: list[pathlib.Path] = []
    for file_path in files:
        if file_path.suffix not in {".jsonl", ".json"}:
            continue
        if cutoff is not None and file_path.stat().st_mtime < cutoff:
            continue
        selected.append(file_path)
    return sorted(selected, key=lambda item: (item.stat().st_mtime, str(item)))


def walk_json(value):
    if isinstance(value, dict):
        yield value
        for child in value.values():
            yield from walk_json(child)
    elif isinstance(value, list):
        for child in value:
            yield from walk_json(child)


def iter_events(file_path: pathlib.Path):
    if file_path.suffix == ".jsonl":
        with file_path.open("r", encoding="utf-8", errors="replace") as handle:
            for raw in handle:
                try:
                    event = json.loads(raw)
                except json.JSONDecodeError:
                    continue
                if isinstance(event, dict):
                    yield event
        return

    try:
        root = json.loads(file_path.read_text(encoding="utf-8", errors="replace"))
    except (json.JSONDecodeError, OSError):
        return
    yield from walk_json(root)


def pick_middle(files: list[pathlib.Path], count: int) -> list[pathlib.Path]:
    if count <= 0 or not files:
        return []
    if len(files) <= count:
        return files
    step = len(files) / (count + 1)
    picked: list[pathlib.Path] = []
    seen: set[pathlib.Path] = set()
    for index in range(count):
        candidate = files[min(len(files) - 1, max(0, round((index + 1) * step) - 1))]
        if candidate not in seen:
            picked.append(candidate)
            seen.add(candidate)
    cursor = 0
    while len(picked) < count and cursor < len(files):
        candidate = files[cursor]
        if candidate not in seen:
            picked.append(candidate)
            seen.add(candidate)
        cursor += 1
    return sorted(picked, key=lambda item: (item.stat().st_mtime, str(item)))


def sample_files(files: list[pathlib.Path], max_sessions: int | None) -> list[pathlib.Path]:
    if max_sessions is None or len(files) <= max_sessions:
        return files
    if max_sessions < 1:
        return []
    oldest_count = min(10, max_sessions)
    newest_count = min(25, max_sessions - oldest_count)
    middle_count = max_sessions - oldest_count - newest_count
    oldest = files[:oldest_count]
    newest = files[-newest_count:] if newest_count else []
    middle_pool = files[oldest_count : len(files) - newest_count if newest_count else len(files)]
    selected = [*oldest, *pick_middle(middle_pool, middle_count), *newest]
    deduped = sorted(dict.fromkeys(selected), key=lambda item: (item.stat().st_mtime, str(item)))
    return deduped[:max_sessions]


def file_matches_cwd(file_path: pathlib.Path, cwd: str | None) -> bool:
    if cwd is None:
        return True
    has_cwd_marker = False
    for event in iter_events(file_path):
        found_cwd = session_cwd(event)
        if found_cwd is not None:
            has_cwd_marker = True
            if found_cwd == cwd:
                return True
    if not has_cwd_marker:
        encoded = cwd.strip("/").replace("/", "-")
        return bool(encoded and encoded in str(file_path))
    return False


def select_jsonl_files(path: pathlib.Path, days: int | None, max_sessions: int | None, cwd: str | None = None) -> FileSelection:
    files = transcript_files(path, days)
    files = [file_path for file_path in files if file_matches_cwd(file_path, cwd)]
    selected = sample_files(files, max_sessions)
    return FileSelection(files=selected, total=len(files), sampled=len(selected) < len(files), root=path)


def session_cwd(event: dict) -> str | None:
    if event.get("type") == "session_meta" and isinstance(event.get("payload"), dict):
        cwd = event["payload"].get("cwd")
        return str(cwd) if cwd else None
    for key in ("cwd", "project"):
        cwd = event.get(key)
        if cwd:
            return str(cwd)
    return None


def extract(path: pathlib.Path, days: int | None, cwd: str | None = None, max_sessions: int | None = DEFAULT_MAX_SESSIONS) -> list[str]:
    rows: list[str] = []
    selection = select_jsonl_files(path, days, max_sessions, cwd)
    if selection.sampled:
        rows.append(
            "meta: sampled_sessions "
            f"root={selection.root} selected={len(selection.files)} total={selection.total} "
            f"strategy={SAMPLING_STRATEGY}"
        )
    for file_path in selection.files:
        ref = session_ref(file_path)
        seen_messages: set[tuple[str, str]] = set()
        for event in iter_events(file_path):
            event_cwd = session_cwd(event)
            if cwd is not None and event_cwd is not None and event_cwd != cwd:
                continue
            message = normalize_message(event)
            role = message.get("role")
            if role not in {"user", "assistant"}:
                continue
            text = scrub_secrets(message_text(message.get("content"))).strip()
            key = (str(role), text)
            if text and key not in seen_messages:
                seen_messages.add(key)
                rows.append(f"{role}: {text} [session_ref={ref}]")
    return rows


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--path",
        action="append",
        default=None,
        help="Transcript file or directory. Repeat for cross-agent corpora.",
    )
    parser.add_argument("--days", type=int, default=None)
    parser.add_argument("--cwd", help="Only include sessions whose session_meta.payload.cwd matches this path")
    parser.add_argument(
        "--max-sessions",
        type=int,
        default=DEFAULT_MAX_SESSIONS,
        help="Maximum JSONL transcript files to process per path. Use 0 to disable reading.",
    )
    parser.add_argument(
        "--no-sampling",
        action="store_true",
        help="Process every matching transcript file. Intended only for focused debugging.",
    )
    parser.add_argument("--output")
    args = parser.parse_args(argv)
    if args.max_sessions < 0:
        parser.error("--max-sessions must be >= 0")

    paths = args.path or [os.path.expanduser("~/.codex/sessions")]
    rows: list[str] = []
    max_sessions = None if args.no_sampling else args.max_sessions
    for path in paths:
        rows.extend(extract(pathlib.Path(path).expanduser(), args.days, args.cwd, max_sessions=max_sessions))
    output = "\n".join(rows)
    if output:
        output += "\n"
    if args.output:
        pathlib.Path(args.output).write_text(output, encoding="utf-8")
    else:
        sys.stdout.write(output)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
