#!/usr/bin/env python3
"""Invoke archived agents and emit spawn telemetry events.

Behavior is intentionally thin: validate the agent frontmatter, emit start/end events,
and dispatch either to mock Claude/Codex adapters or a sandboxed fallback.
"""

from __future__ import annotations

import argparse
import contextlib
import hashlib
import json
import os
import pathlib
import re
import shlex
import subprocess
import sys
import time
from dataclasses import dataclass
from typing import Any

import time

try:
    from state_handle import StateHandle
except Exception:  # pragma: no cover
    from bin.state_handle import StateHandle

try:
    from alc_apply_contracts import validate_agent_frontmatter
except Exception:  # pragma: no cover
    from bin.alc_apply_contracts import validate_agent_frontmatter

try:
    from event_writer import write_event
except Exception:  # pragma: no cover
    from bin.event_writer import write_event


REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
CONTRACT_PATH = REPO_ROOT / "data-contracts" / "manifests" / "u12-invoke.json"

MAX_TOKEN_CEILING = 4096
ALLOWED_CACHE_BACKENDS = {"claude", "codex", "exec"}


@dataclass(frozen=True)
class DispatchResult:
    output: str
    duration_ms: int
    tokens_in: int
    tokens_out: int
    cost_usd: float | None
    backend: str


@contextlib.contextmanager
def _event_writer_state(state: StateHandle):
    # event_writer resolves AGENT_LEARNING_STATE_DIR → events.jsonl at <state_dir>/events.jsonl;
    # StateHandle.events_jsonl == repo_state_dir/events.jsonl, so we point at repo_state_dir
    # (not state_root) to land in the indexed surface alc_query/index_events read.
    previous = os.environ.get("AGENT_LEARNING_STATE_DIR")
    os.environ["AGENT_LEARNING_STATE_DIR"] = str(state.repo_state_dir)
    try:
        yield
    finally:
        if previous is None:
            os.environ.pop("AGENT_LEARNING_STATE_DIR", None)
        else:
            os.environ["AGENT_LEARNING_STATE_DIR"] = previous


def _read_frontmatter(content: str) -> tuple[dict[str, Any], str]:
    if not content.startswith("---"):
        return {}, content
    parts = content.split("---", 2)
    if len(parts) < 3:
        return {}, content
    front_text = parts[1].strip()
    body = parts[2]

    # Use YAML if available; otherwise do a small, permissive fallback parser.
    try:
        import yaml  # type: ignore

        loaded = yaml.safe_load(front_text)
        if isinstance(loaded, dict):
            return {str(k): v for k, v in loaded.items()}, body
    except Exception:
        pass

    frontmatter: dict[str, Any] = {}
    for line in front_text.splitlines():
        if ":" not in line:
            continue
        key, value = line.split(":", 1)
        frontmatter[key.strip()] = value.strip().strip("\"'")
    return frontmatter, body


def _resolve_state_handle(repo: pathlib.Path) -> StateHandle:
    return StateHandle.for_repo(repo)


def _personal_archive_roots(state: StateHandle) -> list[pathlib.Path]:
    roots: list[pathlib.Path] = []
    env_personal = os.environ.get("AGENT_LEARNING_PERSONAL")
    if env_personal:
        roots.append(pathlib.Path(env_personal).expanduser().resolve() / "alc-agents")
        roots.append(pathlib.Path(env_personal).expanduser().resolve() / "alc-agents" / "personal")
    roots.append(state.state_root / "alc-agents" / "personal")
    roots.append(state.state_root / "alc-agents")

    ordered: list[pathlib.Path] = []
    for root in roots:
        if root not in ordered:
            ordered.append(root)
    return ordered


def _normalise_agent_ref(agent_ref: str) -> str:
    return str(pathlib.Path(agent_ref).as_posix().lstrip("./"))


def _resolve_agent_path(state: StateHandle, agent_ref: str) -> pathlib.Path:
    normalised = _normalise_agent_ref(agent_ref)
    request = pathlib.Path(normalised)
    parts = request.parts

    roots: dict[str, pathlib.Path] = {
        "dev": state.alc_agents_dirs["dev"],
        "test": state.alc_agents_dirs["test"],
        "evals": state.alc_agents_dirs["evals"],
    }
    if parts and parts[0] == "alc-agents" and len(parts) >= 3:
        request = pathlib.Path(*parts[1:])
        parts = request.parts

    if parts and parts[0] in roots and len(parts) >= 2:
        candidate = roots[parts[0]] / pathlib.Path(*parts[1:])
        if candidate.is_file():
            return candidate

    # Personal archive candidates (any repo).
    for personal_root in _personal_archive_roots(state):
        candidate = personal_root / request
        if candidate.is_file():
            return candidate
        if candidate.name == request.name:
            with_name = personal_root / "personal" / request.name
            if with_name.is_file():
                return with_name

    # Unqualified agent name: try each archive category in order.
    for root in (state.alc_agents_dirs["dev"], state.alc_agents_dirs["test"], state.alc_agents_dirs["evals"], *_personal_archive_roots(state)):
        candidate = root / request
        if candidate.is_file():
            return candidate
    raise FileNotFoundError(f"agent not found at {agent_ref}")


def _relative_agent_path(state: StateHandle, resolved: pathlib.Path) -> str:
    for root in (
        state.alc_agents_dirs["dev"],
        state.alc_agents_dirs["test"],
        state.alc_agents_dirs["evals"],
        *_personal_archive_roots(state),
    ):
        try:
            return str(resolved.relative_to(root))
        except ValueError:
            continue
    return resolved.name


def _hash_prompt(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()[:16]


def _emit_start(
    *,
    state: StateHandle,
    actor: dict[str, Any],
    agent_path: str,
    task: str,
) -> str:
    payload = {
        "agent_path": agent_path,
        "task_prompt_hash": _hash_prompt(task),
    }
    row = {
        "event": "subagent_invoke_start",
        "actor": actor,
        "payload": payload,
        "telemetry": {
            "duration_ms": 0,
            "tokens_in": 0,
            "tokens_out": 0,
            "cost_usd": None,
        },
    }
    with _event_writer_state(state):
        return write_event(row, source="background", auto_id_fallback=True)


def _parse_age_days(value: str) -> int | None:
    if not value:
        return None
    m = re.fullmatch(r"(\d+)([smhdw])", value.strip(), re.I)
    if not m:
        return None
    amount = int(m.group(1))
    unit = m.group(2).lower()
    if unit == "s":
        return max(1, int(amount / 86400) + (1 if amount % 86400 else 0))
    if unit == "m":
        return max(1, int(amount / 1440) + (1 if amount % 1440 else 0))
    if unit == "h":
        return max(1, int(amount / 24) + (1 if amount % 24 else 0))
    if unit == "d":
        return amount
    if unit == "w":
        return amount * 7
    return None


def _load_dev_agent_retention_days() -> int:
    if not CONTRACT_PATH.is_file():
        return 30
    try:
        payload = json.loads(CONTRACT_PATH.read_text(encoding="utf-8"))
        for artifact in payload.get("artifacts", []):
            if artifact.get("id") != "u12-invoke-dev-agent-archive":
                continue
            life = artifact.get("lifecycle", {})
            age = str(life.get("max_age", "")).strip()
            if age:
                parsed = _parse_age_days(age)
                if parsed:
                    return parsed
    except Exception:
        return 30
    return 30


def _cleanup_stale_dev_agents(state: StateHandle) -> None:
    max_age_days = _load_dev_agent_retention_days()
    cutoff = time.time() - (max_age_days * 24 * 3600)
    for path in sorted(state.alc_agents_dirs["dev"].glob("*.md")):
        try:
            if path.stat().st_mtime <= cutoff:
                path.unlink()
        except OSError:
            continue


def _emit_end(
    *,
    state: StateHandle,
    actor: dict[str, Any],
    parent_event_id: str,
    outcome: dict[str, Any],
    telemetry: dict[str, Any],
) -> str:
    row = {
        "event": "subagent_invoke_end",
        "actor": actor,
        "telemetry": telemetry,
        "parent_event_id": parent_event_id,
        "payload": {
            "outcome": outcome,
        },
    }
    with _event_writer_state(state):
        return write_event(row, source="background", auto_id_fallback=True)


def _estimate_tokens(text: str) -> int:
    return max(1, len(text) // 4)


def _mock_dispatch(backend: str, *, agent_name: str, model: str, task_prompt_hash: str) -> DispatchResult:
    output = f"{backend}:mock:{agent_name}:{model}:{task_prompt_hash}"
    return DispatchResult(
        output=output,
        duration_ms=4,
        tokens_in=0,
        tokens_out=_estimate_tokens(output),
        cost_usd=0.0,
        backend=backend,
    )

def _exec_sandbox_dispatch(
    *,
    state: StateHandle,
    repo: pathlib.Path,
    agent_name: str,
    model: str,
    task_prompt_hash: str,
    sandbox_depth: int,
) -> DispatchResult:
    start = time.perf_counter()
    depth = max(0, sandbox_depth) + 1
    exec_path = pathlib.Path(__file__).with_name("exec_sandbox")
    payload = {
        "system_prompt": "Arkiv mock dispatch",
        "agent": agent_name,
        "model": model,
        "hash": task_prompt_hash,
    }
    prompt_path = repo / ".agent-learning" / f".alc_invoke_prompt_{task_prompt_hash}.txt"
    prompt_path.parent.mkdir(parents=True, exist_ok=True)
    prompt_path.write_text(json.dumps(payload), encoding="utf-8")
    command = "cat " + shlex.quote(prompt_path.as_posix())
    proc = subprocess.run(
        [
            sys.executable,
            str(exec_path),
            "--scope",
            "read",
            "--cmd",
            command,
            "--repo",
            str(repo),
            "--depth",
            str(depth),
            "--actor-kind",
            "arkiv_agent",
            "--actor-name",
            agent_name,
        ],
        text=True,
        capture_output=True,
    )
    try:
        prompt_path.unlink()
    except OSError:
        pass
    duration_ms = int((time.perf_counter() - start) * 1000)
    if proc.returncode != 0:
        raise RuntimeError(f"subprocess failed: {proc.stderr.strip() or proc.stdout.strip()}")

    data = json.loads(proc.stdout.strip()) if proc.stdout else {}
    output_text = ""
    stdout_path = data.get("stdout")
    if isinstance(stdout_path, str) and stdout_path:
        path = pathlib.Path(stdout_path)
        if path.exists():
            try:
                output_text = path.read_text(encoding="utf-8").strip()
            except OSError:
                output_text = ""
    if not output_text:
        output_text = json.dumps(data)

    return DispatchResult(
        output=output_text,
        duration_ms=duration_ms,
        tokens_in=max(0, _estimate_tokens(task_prompt_hash)),
        tokens_out=max(0, _estimate_tokens(output_text)),
        cost_usd=0.0,
        backend="exec",
    )


def _dispatch(
    *,
    state: StateHandle,
    repo: pathlib.Path,
    agent_name: str,
    model: str,
    task: str,
    task_prompt_hash: str,
    sandbox_depth: int,
) -> DispatchResult:
    if os.environ.get("CLAUDE_PLUGIN_ROOT"):
        return _mock_dispatch("claude", agent_name=agent_name, model=model, task_prompt_hash=task_prompt_hash)

    if os.environ.get("CODEX_PLUGIN_ROOT") or os.environ.get("CODEX_HOME"):
        return _mock_dispatch("codex", agent_name=agent_name, model=model, task_prompt_hash=task_prompt_hash)

    return _exec_sandbox_dispatch(
        state=state,
        repo=repo,
        agent_name=agent_name,
        model=model,
        task_prompt_hash=task_prompt_hash,
        sandbox_depth=sandbox_depth,
    )


def invoke(
    *,
    repo: pathlib.Path,
    agent_ref: str,
    task: str,
    output_path: str | None = None,
    model_override: str | None = None,
    sandbox_depth: int = 0,
) -> dict[str, Any]:
    state = _resolve_state_handle(repo)
    state.repo_state_dir.mkdir(parents=True, exist_ok=True)
    _cleanup_stale_dev_agents(state)

    try:
        agent_path = _resolve_agent_path(state, agent_ref)
    except FileNotFoundError as exc:
        raise RuntimeError(f"agent not found at {agent_ref}") from exc

    content = agent_path.read_text(encoding="utf-8")
    errors = validate_agent_frontmatter(content)
    if errors:
        raise RuntimeError("; ".join(errors))

    frontmatter, _ = _read_frontmatter(content)
    agent_name = str(frontmatter.get("name") or "").strip()
    if not agent_name:
        raise RuntimeError("agent name missing in frontmatter")

    model = model_override or str(frontmatter.get("model", "inherit")).strip() or "inherit"
    actor = {
        "kind": "arkiv_agent",
        "name": agent_name,
        "model": model,
    }

    agent_path_for_event = _relative_agent_path(state, agent_path)
    start_event_id = _emit_start(
        state=state,
        actor=actor,
        agent_path=agent_path_for_event,
        task=task,
    )

    overall_start = time.perf_counter()
    try:
        result = _dispatch(
            state=state,
            repo=state.repo,
            agent_name=agent_name,
            model=model,
            task=task,
            task_prompt_hash=_hash_prompt(task),
            sandbox_depth=sandbox_depth,
        )
        outcome = {
            "status": "ok",
            "backend": result.backend,
            "output": result.output,
        }
        cost = result.cost_usd
    except Exception as exc:
        message = str(exc)
        result = DispatchResult(
            output=message[:128],
            duration_ms=0,
            tokens_in=0,
            tokens_out=0,
            cost_usd=None,
            backend="error",
        )
        outcome = {
            "status": "error",
            "backend": "error",
            "message": message[:128],
        }
        cost = None
    finally:
        total_ms = int((time.perf_counter() - overall_start) * 1000)

    end_event_id = _emit_end(
        state=state,
        actor=actor,
        parent_event_id=start_event_id,
        outcome=outcome,
        telemetry={
            "duration_ms": total_ms,
            "tokens_in": result.tokens_in,
            "tokens_out": result.tokens_out,
            "cost_usd": cost,
        },
    )

    if outcome.get("status") != "ok":
        raise RuntimeError(outcome.get("message") or "invoke failed")

    if output_path:
        output_file = pathlib.Path(output_path)
        output_file.parent.mkdir(parents=True, exist_ok=True)
        output_file.write_text(result.output, encoding="utf-8")

    response: dict[str, Any] = {
        "agent": agent_name,
        "task": task,
        "output": result.output,
        "duration_s": total_ms / 1000.0,
        "model_used": model,
        "event_ids": [start_event_id, end_event_id],
    }
    if cost is not None:
        response["cost"] = cost
    return response


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--agent", required=True, help="Archive-relative agent path")
    parser.add_argument("--task", required=True, help="Prompt for the agent")
    parser.add_argument("--output", default=None, help="Optional path for copied output")
    parser.add_argument("--model", default=None, help="Override agent model")
    parser.add_argument(
        "--alc-sandbox-depth",
        dest="alc_sandbox_depth",
        default=0,
        type=int,
        help="Depth forwarded to nested exec_sandbox as +1",
    )
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    if args.alc_sandbox_depth < 0:
        print("--alc-sandbox-depth must be >=0", file=sys.stderr)
        return 2

    try:
        result = invoke(
            repo=pathlib.Path.cwd().resolve(),
            agent_ref=args.agent,
            task=args.task,
            output_path=args.output,
            model_override=args.model,
            sandbox_depth=args.alc_sandbox_depth,
        )
    except RuntimeError as exc:
        message = str(exc)
        if message.startswith("agent not found at"):
            print(message, file=sys.stderr)
        else:
            print(f"error: {message}", file=sys.stderr)
        return 1

    print(json.dumps(result, sort_keys=True))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
