#!/usr/bin/env python3
"""Tiered sandbox execution primitive for bounded code execution."""

from __future__ import annotations

import argparse
import json
import os
import pathlib
import shlex
import re
import signal
import subprocess
import sys
import time
import uuid
from dataclasses import dataclass
from enum import Enum
from typing import Any

from event_writer import write_event

try:
    from scrub_secrets import scrub
except ImportError:
    from bin.scrub_secrets import scrub

try:
    from state_handle import StateHandle
except ImportError:
    from bin.state_handle import StateHandle

try:
    from exec_sandbox_profiles import SCOPES
except ImportError:
    from bin.exec_sandbox_profiles import SCOPES

try:
    from sandbox_run_state import RunStateTracker, _cleanup_worktree
except ImportError:
    from bin.sandbox_run_state import RunStateTracker, _cleanup_worktree


WORKER_MAX_DEPTH = 2
RECOVER_TIMEOUT_SEC = 1
EVENT_OUTPUT_BYTES_CAP = 100 * 1024
EVENT_PAYLOAD_BYTES_CAP = 200
FORBIDDEN_EXIT_CODE = 3
_REDACTED_TOKEN_RE = re.compile(r"\[REDACTED:[^\]]+\]")


ALLOWED_ACTOR_KINDS = {
    "main_agent",
    "subagent",
    "background_agent",
    "mcp_server",
    "hook",
    "operator",
    "judge",
    "recommender",
    "arkiv_agent",
    "eval_judge",
}


class ExecScope(str, Enum):
    READ = "read"
    WORKTREE = "worktree"
    EVAL = "eval"


@dataclass(frozen=True)
class ExecResult:
    exit_code: int
    stdout_path: pathlib.Path
    stderr_path: pathlib.Path
    duration_ms: int
    event_id: str
    worktree_dir: pathlib.Path | None = None
    timed_out: bool = False
    run_id: str | None = None


def _coerce_scope(value: str | ExecScope) -> ExecScope:
    if isinstance(value, ExecScope):
        return value
    try:
        return ExecScope(value)
    except ValueError as exc:
        raise ValueError(f"unsupported scope: {value}") from exc


def _coerce_actor(actor: dict[str, Any] | None) -> dict[str, str]:
    if actor is None:
        return {"kind": "operator", "name": "operator"}
    if not isinstance(actor, dict):
        raise ValueError("actor must be a dict")
    kind = actor.get("kind", "operator")
    name = actor.get("name")
    if not isinstance(kind, str) or kind not in ALLOWED_ACTOR_KINDS:
        raise ValueError(f"invalid actor kind: {kind}")
    if not isinstance(name, str) or not name.strip():
        raise ValueError("actor.name is required")
    result: dict[str, str] = {"kind": kind, "name": name.strip()}
    model = actor.get("model")
    if model is not None:
        if not isinstance(model, str):
            raise ValueError("actor.model must be a string")
        result["model"] = model.strip()
    return result


def _run_id() -> str:
    return uuid.uuid4().hex


def _stateful_handle(repo: pathlib.Path, state_dir: str | None = None) -> StateHandle:
    if state_dir:
        os.environ["AGENT_LEARNING_STATE_DIR"] = state_dir
    return StateHandle.for_repo(repo)


def _resolve_timeout(scope: ExecScope, timeout_s: int | None) -> int:
    profile = SCOPES[scope.value]
    value = timeout_s if timeout_s is not None else profile["default_timeout_s"]
    if not isinstance(value, int) or value <= 0:
        return profile["default_timeout_s"]
    return min(value, profile["max_timeout_s"])


def _run_tokens(command: str) -> list[str]:
    try:
        return shlex.split(command)
    except ValueError:
        return [command]


# Shell metacharacters that allow command chaining / substitution / redirection /
# backgrounding / globbing. Read tier runs the original string via shell=True so
# allowlisting the leading tokens is not enough — any of these in the command
# enables an escape. Single `&` covers `&&` (job-background also a chain); single
# `|` covers `||` (pipe also alternation). Parens enable subshells. Glob chars
# enable filename expansion outside the allowlist. Tilde + backslash enable home
# expansion + escape sequences. Newlines split commands.
_READ_TIER_SHELL_METACHARS = (
    ";", "&", "|", "$(", "`", ">", "<",
    "(", ")", "*", "?", "[", "]", "{", "}", "~", "\\",
    "\n", "\r",
)


def _is_allowed_read_command(command: str) -> bool:
    if any(meta in command for meta in _READ_TIER_SHELL_METACHARS):
        return False
    allowlist = SCOPES[ExecScope.READ.value]["allowlist_tokens"]
    tokens = _run_tokens(command)
    if not tokens:
        return False
    for tokens_allowed in allowlist:
        if tokens[: len(tokens_allowed)] == list(tokens_allowed):
            return True
    return False


def _contains_parent_path_traversal(repo_root: pathlib.Path, command: str) -> bool:
    for token in _run_tokens(command):
        if token.startswith("-"):
            continue
        if ".." not in token:
            continue
        # Skip git-style ranges and similar tokens (e.g. HEAD~1).
        if token.count(".") <= 1:
            continue
        if token.startswith("/"):
            return True
        try:
            candidate = (repo_root / token).resolve()
        except OSError:
            return True
        repo_root_resolved = repo_root.resolve()
        if not str(candidate).startswith(str(repo_root_resolved) + os.path.sep):
            return True
    return False


def _has_disallowed_path_tokens(scope: ExecScope, repo_root: pathlib.Path, command: str) -> bool:
    if scope == ExecScope.READ:
        return _contains_parent_path_traversal(repo_root, command)

    for token in _run_tokens(command):
        if token.startswith("~/"):
            return True
        if token.startswith("/"):
            candidate = pathlib.Path(token).resolve()
            if not str(candidate).startswith(str(repo_root.resolve()) + os.path.sep):
                return True
    return False


def _sanitize_env_for_sandbox() -> dict[str, str]:
    env = os.environ.copy()
    for key in (
        "HTTP_PROXY",
        "HTTPS_PROXY",
        "http_proxy",
        "https_proxy",
        "ALL_PROXY",
        "all_proxy",
        "NO_PROXY",
        "no_proxy",
    ):
        env.pop(key, None)
    env["NO_NETWORK"] = "1"
    return env


def _truncate_for_event(value: str, max_bytes: int = EVENT_OUTPUT_BYTES_CAP) -> str:
    if len(value.encode("utf-8")) <= max_bytes:
        return value
    encoded = value.encode("utf-8", errors="replace")
    return encoded[:max_bytes].decode("utf-8", errors="replace")


def _event_scrub(value: str) -> str:
    return _REDACTED_TOKEN_RE.sub("[REDACTED]", scrub(value))


def _spawn_worktree(
    handle: StateHandle,
    run_id: str,
    base_ref: str | None,
    tracker: RunStateTracker,
) -> pathlib.Path:
    worktree_dir = tracker.ensure_worktree_dir(run_id)
    tracker.set_status(
        run_id,
        status="running",
        worktree_dir=worktree_dir,
        pid=os.getpid(),
    )
    base = base_ref or "HEAD"
    cmd = ["git", "-C", str(handle.repo), "worktree", "add", "--detach", str(worktree_dir), base]
    subprocess.run(cmd, check=True, text=True, capture_output=True, env=_sanitize_env_for_sandbox())
    return worktree_dir


def _run_in_shell(
    *,
    command: str,
    cwd: pathlib.Path,
    timeout_s: int,
    env: dict[str, str],
    set_umask: bool,
) -> tuple[int, str, str, bool, int]:
    timed_out = False

    start = time.monotonic()
    kwargs: dict[str, Any] = {
        "args": command,
        "shell": True,
        "cwd": str(cwd),
        "env": env,
        "stdout": subprocess.PIPE,
        "stderr": subprocess.PIPE,
        "text": True,
        "start_new_session": True,
    }

    if set_umask:
        def _preexec() -> None:
            os.umask(0o444)
        kwargs["preexec_fn"] = _preexec

    proc = subprocess.Popen(**kwargs)
    try:
        stdout, stderr = proc.communicate(timeout=timeout_s)
        code = proc.returncode
    except subprocess.TimeoutExpired:
        timed_out = True
        try:
            os.killpg(proc.pid, signal.SIGKILL)
        except OSError:
            proc.kill()
        stdout, stderr = proc.communicate(timeout=RECOVER_TIMEOUT_SEC)
        code = 124
    finally:
        duration_ms = int((time.monotonic() - start) * 1000)

    return code or 0, stdout or "", stderr or "", timed_out, duration_ms


def _build_event_payload(
    *,
    run_id: str,
    scope: ExecScope,
    command: str,
    exit_code: int,
    duration_ms: int,
    timed_out: bool,
    worktree_dir: pathlib.Path | None,
    stdout: str,
    stderr: str,
) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "run_id": run_id,
        "scope": scope.value,
        "command": _event_scrub(command),
        "exit_code": exit_code,
        "duration_ms": duration_ms,
        "timeout": timed_out,
        "stdout_bytes": len(stdout.encode("utf-8", errors="replace")),
        "stderr_bytes": len(stderr.encode("utf-8", errors="replace")),
        "stdout_excerpt": _truncate_for_event(_event_scrub(stdout), max_bytes=EVENT_PAYLOAD_BYTES_CAP),
        "stderr_excerpt": _truncate_for_event(_event_scrub(stderr), max_bytes=EVENT_PAYLOAD_BYTES_CAP),
    }
    if worktree_dir is not None:
        payload["worktree_dir"] = str(worktree_dir)
    return {key: value for key, value in payload.items() if value is not None}


def _write_outputs(run_dir: pathlib.Path, *, stdout: str, stderr: str, exit_code: int) -> tuple[pathlib.Path, pathlib.Path]:
    stdout_path = run_dir / "stdout"
    stderr_path = run_dir / "stderr"
    run_dir.mkdir(parents=True, exist_ok=True)
    stdout_path.write_text(stdout, encoding="utf-8")
    stderr_path.write_text(stderr, encoding="utf-8")
    (run_dir / "exit_code").write_text(f"{exit_code}\n", encoding="utf-8")
    return stdout_path, stderr_path


def run(
    *,
    scope: ExecScope | str,
    command: str,
    repo: pathlib.Path,
    base_ref: str | None = None,
    timeout_s: int | None = None,
    actor: dict[str, Any] | None = None,
    parent_event_id: str | None = None,
    depth: int = 0,
) -> ExecResult:
    scope = _coerce_scope(scope)
    if depth >= WORKER_MAX_DEPTH:
        raise RuntimeError(f"max sandbox depth reached: {depth}")

    handle = _stateful_handle(repo.resolve(), os.environ.get("AGENT_LEARNING_STATE_DIR"))
    actor_payload = _coerce_actor(actor)

    tracker = RunStateTracker(handle)
    tracker.recover_stale(actor_payload)

    run_id = _run_id()
    timeout = _resolve_timeout(scope, timeout_s)
    run_dir = handle.repo_state_dir / "sandbox-runs" / run_id

    command_for_exec = command.strip()
    use_worktree = SCOPES[scope.value].get("require_worktree", False)
    started_at = time.monotonic()
    worktree_dir: pathlib.Path | None = None
    timed_out = False
    exit_code = 0
    stdout_value = ""
    stderr_value = ""
    event_id = ""

    try:
        if scope == ExecScope.READ:
            if not _is_allowed_read_command(command_for_exec):
                raise RuntimeError("command is not in scope allowlist")
            if _has_disallowed_path_tokens(scope=scope, repo_root=handle.repo, command=command_for_exec):
                raise RuntimeError("path leaves execution root")
            env = _sanitize_env_for_sandbox()
            exit_code, stdout_value, stderr_value, timed_out, _ = _run_in_shell(
                command=command_for_exec,
                cwd=handle.repo,
                timeout_s=timeout,
                env=env,
                set_umask=True,
            )
        else:
            worktree_dir = _spawn_worktree(handle, run_id, base_ref, tracker)
            env = _sanitize_env_for_sandbox()
            exit_code, stdout_value, stderr_value, timed_out, _ = _run_in_shell(
                command=command_for_exec,
                cwd=worktree_dir,
                timeout_s=timeout,
                env=env,
                set_umask=False,
            )
    except Exception as exc:
        duration_ms = int((time.monotonic() - started_at) * 1000)
        timed_out = False
        exit_code = FORBIDDEN_EXIT_CODE
        stdout_value = ""
        if not stderr_value:
            stderr_value = f"{exc}\n"
    finally:
        duration_ms = int((time.monotonic() - started_at) * 1000)
        stdout_path, stderr_path = _write_outputs(run_dir, stdout=stdout_value, stderr=stderr_value, exit_code=exit_code)

        if use_worktree and worktree_dir is not None:
            tracker.set_status(run_id, status="finished", worktree_dir=worktree_dir, pid=0)
            _cleanup_worktree(handle.repo, worktree_dir)
            tracker.delete_status(run_id)

        payload = _build_event_payload(
            run_id=run_id,
            scope=scope,
            command=command_for_exec,
            exit_code=exit_code,
            duration_ms=duration_ms,
            timed_out=timed_out,
            worktree_dir=worktree_dir,
            stdout=stdout_value,
            stderr=stderr_value,
        )
        correlation_chain: list[dict[str, str]] = []
        if parent_event_id:
            correlation_chain.append({"role": "triggered_by", "id": parent_event_id})

        chain_payload: dict[str, Any] = {
            "event": "exec_sandbox_run",
            "actor": actor_payload,
            "payload": payload,
            "telemetry": {
                "duration_ms": duration_ms,
                "tokens_in": None,
                "tokens_out": None,
            },
            "correlation_chain": correlation_chain,
        }
        chain_payload["payload"]["stdout_bytes"] = len(stdout_value.encode("utf-8", errors="replace"))
        chain_payload["payload"]["stderr_bytes"] = len(stderr_value.encode("utf-8", errors="replace"))

        if timed_out:
            chain_payload["payload"]["timeout"] = True
        if parent_event_id is not None:
            chain_payload["parent_event_id"] = parent_event_id
        event_id = write_event(chain_payload, source="eval")

        tracker.close()

    return ExecResult(
        exit_code=exit_code,
        stdout_path=stdout_path,
        stderr_path=stderr_path,
        duration_ms=duration_ms,
        event_id=event_id,
        worktree_dir=worktree_dir,
        timed_out=timed_out,
        run_id=run_id,
    )


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--scope", required=True, choices=tuple(SCOPES))
    parser.add_argument("--cmd", required=True, help="Shell command to execute")
    parser.add_argument("--repo", required=True, type=pathlib.Path)
    parser.add_argument("--base-ref", dest="base_ref", default=None)
    parser.add_argument("--timeout", type=int, default=None)
    parser.add_argument("--depth", type=int, default=0)
    parser.add_argument("--parent-event-id", dest="parent_event_id", default=None)
    parser.add_argument("--actor-kind", default="operator")
    parser.add_argument("--actor-name", default="operator")
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    try:
        actor = {"kind": args.actor_kind, "name": args.actor_name}
        result = run(
            scope=args.scope,
            command=args.cmd,
            repo=args.repo,
            base_ref=args.base_ref,
            timeout_s=args.timeout,
            actor=actor,
            parent_event_id=args.parent_event_id,
            depth=args.depth,
        )
    except RuntimeError as exc:
        print(f"error: {exc}", file=sys.stderr)
        return 1
    except ValueError as exc:
        print(f"error: {exc}", file=sys.stderr)
        return 2

    print(
        json.dumps(
                {
                "event_id": result.event_id,
                "exit_code": result.exit_code,
                "run_id": result.run_id,
                "stdout": str(result.stdout_path),
                "stderr": str(result.stderr_path),
                "worktree_dir": str(result.worktree_dir) if result.worktree_dir else None,
            }
        )
    )
    return result.exit_code


if __name__ == "__main__":
    raise SystemExit(main())
