#!/usr/bin/env python3
"""Start one VCM long-running validation job with a worker-enforced ceiling.

Usage:
  .ai/tools/run-long-check [--timeout <duration>] -- <command> [args...]

The command runs in a detached worker process group. The worker itself
enforces the job ceiling (--timeout, max 60m) and a supervision lease: when
no foreground watcher renews .ai/vcm/jobs/<job-id>/lease, the worker kills
the command process group and records the job as orphaned. Watch the job
with .ai/tools/watch-job in the same turn.

Only one validation job may be active at a time.
"""
import json
import os
import signal
import subprocess
import sys
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path

MAX_TIMEOUT_SECONDS = 60 * 60
DEFAULT_TIMEOUT_SECONDS = MAX_TIMEOUT_SECONDS
QUEUED_STALE_SECONDS = 60
ACTIVE_STATUSES = {"queued", "starting", "running"}
DEFAULT_LEASE_START_GRACE_SECONDS = 300
DEFAULT_LEASE_RENEW_GRACE_SECONDS = 120


def root_dir() -> Path:
    return Path(__file__).resolve().parents[2]


def jobs_root() -> Path:
    return root_dir() / ".ai/vcm/jobs"


def now_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")


def write_json(path: Path, data: dict) -> None:
    tmp = path.with_suffix(path.suffix + ".tmp")
    tmp.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n")
    tmp.replace(path)


def read_optional_json(path: Path) -> dict | None:
    try:
        return json.loads(path.read_text())
    except (OSError, ValueError):
        return None


def parse_duration(value: str) -> float:
    value = value.strip().lower()
    if value.endswith("ms"):
        return float(value[:-2]) / 1000
    if value.endswith("s"):
        return float(value[:-1])
    if value.endswith("m"):
        return float(value[:-1]) * 60
    if value.endswith("h"):
        return float(value[:-1]) * 3600
    return float(value)


def env_seconds(name: str, default: float) -> float:
    raw = os.environ.get(name)
    if not raw:
        return default
    try:
        return max(1.0, float(raw))
    except ValueError:
        return default


def job_id() -> str:
    timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    return f"{timestamp}-{uuid.uuid4().hex[:8]}"


def process_exists(pid: int) -> bool:
    try:
        os.kill(pid, 0)
        return True
    except ProcessLookupError:
        return False
    except PermissionError:
        return True


def stop_command(process: subprocess.Popen) -> str:
    try:
        os.killpg(process.pid, signal.SIGTERM)
        mode = "process-group"
    except ProcessLookupError:
        return "not-running"
    except PermissionError:
        process.terminate()
        mode = "process"

    try:
        process.wait(timeout=2)
        return f"terminated-{mode}"
    except subprocess.TimeoutExpired:
        pass

    try:
        os.killpg(process.pid, signal.SIGKILL)
    except (ProcessLookupError, PermissionError):
        process.kill()
    return f"killed-{mode}"


def lease_age_seconds(lease_path: Path) -> float | None:
    try:
        return max(0.0, time.time() - lease_path.stat().st_mtime)
    except OSError:
        return None


def lease_iso(lease_path: Path) -> str | None:
    try:
        mtime = lease_path.stat().st_mtime
    except OSError:
        return None
    return datetime.fromtimestamp(mtime, timezone.utc).isoformat().replace("+00:00", "Z")


def mark_stale(status_path: Path, status: dict, reason: str) -> None:
    stale = dict(status)
    stale.update({"status": "stale", "finishedAt": now_iso(), "staleReason": reason})
    write_json(status_path, stale)


def find_active_job() -> dict | None:
    root = jobs_root()
    if not root.is_dir():
        return None
    for directory in sorted(root.iterdir()):
        if not directory.is_dir():
            continue
        status_path = directory / "status.json"
        status = read_optional_json(status_path)
        if not status or status.get("status") not in ACTIVE_STATUSES:
            continue
        pid = status.get("processId") or status.get("workerPid")
        if isinstance(pid, int):
            if process_exists(pid):
                return status
            mark_stale(status_path, status, "job process not running")
            continue
        try:
            age = time.time() - status_path.stat().st_mtime
        except OSError:
            continue
        if age < QUEUED_STALE_SECONDS:
            return status
        mark_stale(status_path, status, "queued job never started")
    return None


def start_job(command: list[str], timeout_seconds: float) -> int:
    active = find_active_job()
    if active:
        print(
            f"error: validation job {active.get('jobId')} is already {active.get('status')}",
            file=sys.stderr,
        )
        print(f"watch it: .ai/tools/watch-job {active.get('jobId')}", file=sys.stderr)
        print("VCM allows one validation job at a time.", file=sys.stderr)
        return 3

    job = job_id()
    directory = jobs_root() / job
    directory.mkdir(parents=True, exist_ok=False)

    (directory / "command.json").write_text(
        json.dumps(
            {
                "command": command,
                "cwd": ".",
                "timeoutSeconds": timeout_seconds,
                "createdAt": now_iso(),
            },
            indent=2,
            sort_keys=True,
        )
        + "\n"
    )
    write_json(
        directory / "status.json",
        {
            "jobId": job,
            "status": "queued",
            "command": command,
            "cwd": ".",
            "timeoutSeconds": timeout_seconds,
            "startedAt": None,
            "finishedAt": None,
            "exitCode": None,
            "durationSeconds": None,
            "workerPid": None,
            "processId": None,
        },
    )

    subprocess.Popen(
        [sys.executable, str(Path(__file__).resolve()), "--worker", job],
        cwd=root_dir(),
        stdin=subprocess.DEVNULL,
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
        start_new_session=True,
    )

    print(f"job: {job}")
    print(f"ceiling: {int(timeout_seconds)}s (worker-enforced)")
    print(f"watch: .ai/tools/watch-job {job}")
    return 0


def run_worker(job: str) -> int:
    directory = jobs_root() / job
    command_path = directory / "command.json"
    status_path = directory / "status.json"
    lease_path = directory / "lease"

    payload = json.loads(command_path.read_text())
    command = payload["command"]
    timeout_seconds = min(
        float(payload.get("timeoutSeconds") or DEFAULT_TIMEOUT_SECONDS),
        MAX_TIMEOUT_SECONDS,
    )
    start_grace = env_seconds("VCM_JOB_LEASE_START_GRACE_SECONDS", DEFAULT_LEASE_START_GRACE_SECONDS)
    renew_grace = env_seconds("VCM_JOB_LEASE_RENEW_GRACE_SECONDS", DEFAULT_LEASE_RENEW_GRACE_SECONDS)

    worker_started = time.time()
    base = {
        "jobId": job,
        "command": command,
        "cwd": ".",
        "timeoutSeconds": timeout_seconds,
        "workerPid": os.getpid(),
    }
    write_json(
        status_path,
        {
            **base,
            "status": "starting",
            "startedAt": None,
            "finishedAt": None,
            "exitCode": None,
            "durationSeconds": None,
            "processId": None,
        },
    )

    started = time.time()
    started_at = now_iso()
    verdict = None
    with (directory / "stdout.log").open("wb") as stdout, (directory / "stderr.log").open("wb") as stderr:
        process = subprocess.Popen(
            command,
            cwd=root_dir(),
            stdout=stdout,
            stderr=stderr,
            start_new_session=True,
        )
        running = {
            **base,
            "status": "running",
            "startedAt": started_at,
            "finishedAt": None,
            "exitCode": None,
            "durationSeconds": None,
            "processId": process.pid,
        }
        write_json(status_path, running)

        while True:
            exit_code = process.poll()
            if exit_code is not None:
                break

            if time.time() - started >= timeout_seconds:
                stop_result = stop_command(process)
                verdict = (
                    "timeout",
                    {"processStopResult": stop_result},
                )
                exit_code = process.wait()
                break

            lease_age = lease_age_seconds(lease_path)
            if lease_age is None:
                if time.time() - worker_started >= start_grace:
                    stop_result = stop_command(process)
                    verdict = (
                        "orphaned",
                        {
                            "orphanReason": f"no watcher within {int(start_grace)}s",
                            "lastWatchedAt": None,
                            "processStopResult": stop_result,
                        },
                    )
                    exit_code = process.wait()
                    break
            elif lease_age >= renew_grace:
                stop_result = stop_command(process)
                verdict = (
                    "orphaned",
                    {
                        "orphanReason": f"lease not renewed for {int(lease_age)}s",
                        "lastWatchedAt": lease_iso(lease_path),
                        "processStopResult": stop_result,
                    },
                )
                exit_code = process.wait()
                break

            time.sleep(1)

    duration = round(time.time() - started, 3)
    current = read_optional_json(status_path) or {}
    if current.get("status") in {"timeout", "orphaned", "stale"}:
        current["processExitCode"] = exit_code
        current["processFinishedAt"] = now_iso()
        current["processDurationSeconds"] = duration
        write_json(status_path, current)
        return 0

    final = {
        **base,
        "startedAt": started_at,
        "finishedAt": now_iso(),
        "durationSeconds": duration,
        "processId": process.pid,
    }
    if verdict is None:
        final.update({"status": "success" if exit_code == 0 else "failed", "exitCode": exit_code})
    else:
        kind, extra = verdict
        final.update({"status": kind, "exitCode": None, "processExitCode": exit_code, **extra})
    write_json(status_path, final)
    return 0


def main() -> int:
    argv = sys.argv[1:]
    if len(argv) >= 2 and argv[0] == "--worker":
        return run_worker(argv[1])

    timeout_seconds = DEFAULT_TIMEOUT_SECONDS
    index = 0
    while index < len(argv) and argv[index] != "--":
        arg = argv[index]
        raw = None
        if arg == "--timeout":
            if index + 1 >= len(argv):
                print("error: --timeout needs a duration", file=sys.stderr)
                return 2
            raw = argv[index + 1]
            index += 2
        elif arg.startswith("--timeout="):
            raw = arg[len("--timeout="):]
            index += 1
        else:
            print(f"error: unknown option: {arg}", file=sys.stderr)
            return 2
        try:
            timeout_seconds = parse_duration(raw)
        except ValueError:
            print(f"error: invalid duration: {raw}", file=sys.stderr)
            return 2

    if index >= len(argv) or argv[index] != "--" or index + 1 >= len(argv):
        print(
            "Usage: .ai/tools/run-long-check [--timeout <duration>] -- <command> [args...]",
            file=sys.stderr,
        )
        return 2
    if timeout_seconds <= 0:
        print("error: timeout must be positive", file=sys.stderr)
        return 2
    if timeout_seconds > MAX_TIMEOUT_SECONDS:
        print(
            "error: timeout exceeds maximum 60m; split the work or get user approval",
            file=sys.stderr,
        )
        return 2

    return start_job(argv[index + 1:], timeout_seconds)


if __name__ == "__main__":
    raise SystemExit(main())
