#!/usr/bin/env python3
"""Watch one VCM long-running validation job in the foreground.

Usage:
  .ai/tools/watch-job <job-id> [--window <duration>] [--interval <duration>]

watch-job renews the job supervision lease while it runs. When the watch
window elapses and the job is still running, it exits 125 WITHOUT stopping
the job; call watch-job again immediately in the same turn. The job ceiling
itself is enforced by the run-long-check worker, not by watch-job.

Exit codes:
  0    success
  1    failed
  124  timeout (job hit its ceiling and was killed by the worker)
  125  still running; call watch-job again now
  4    orphaned or stale (job lost supervision and was killed, or its worker died)
  2    usage error or unknown job id
"""
import argparse
import json
import sys
import time
from pathlib import Path

MAX_WINDOW_SECONDS = 8 * 60
DEFAULT_WINDOW = "8m"
STATUS_WAIT_SECONDS = 10
TERMINAL_EXIT_CODES = {
    "success": 0,
    "failed": 1,
    "timeout": 124,
    "orphaned": 4,
    "stale": 4,
}


def root_dir() -> Path:
    return Path(__file__).resolve().parents[2]


def parse_duration(value: str) -> float:
    value = value.strip().lower()
    if value.endswith("ms"):
        return float(value[:-2]) / 1000
    if value.endswith("s"):
        return float(value[:-1])
    if value.endswith("m"):
        return float(value[:-1]) * 60
    if value.endswith("h"):
        return float(value[:-1]) * 3600
    return float(value)


def read_optional_json(path: Path) -> dict | None:
    try:
        return json.loads(path.read_text())
    except (OSError, ValueError):
        return None


def renew_lease(lease_path: Path) -> None:
    try:
        lease_path.touch()
    except OSError:
        pass


def tail(path: Path, lines: int = 40, max_bytes: int = 65536) -> str:
    try:
        size = path.stat().st_size
    except OSError:
        return ""
    try:
        with path.open("rb") as handle:
            handle.seek(max(0, size - max_bytes))
            data = handle.read(max_bytes)
    except OSError:
        return ""
    text = data.decode("utf-8", errors="replace")
    return "\n".join(text.splitlines()[-lines:])


def elapsed_seconds(status: dict) -> float | None:
    started_at = status.get("startedAt")
    if not started_at:
        return None
    from datetime import datetime, timezone

    try:
        started = datetime.fromisoformat(str(started_at).replace("Z", "+00:00"))
    except ValueError:
        return None
    return round((datetime.now(timezone.utc) - started).total_seconds(), 3)


def print_summary(job_id: str, status: dict, directory: Path) -> None:
    state = status.get("status")
    print(f"job: {job_id}")
    print(f"status: {state}")
    print(f"exitCode: {status.get('exitCode')}")
    print(f"durationSeconds: {status.get('durationSeconds')}")
    if state == "timeout":
        print(f"timeoutSeconds: {status.get('timeoutSeconds')}")
        print(f"processStopResult: {status.get('processStopResult')}")
    if state == "orphaned":
        print(f"orphanReason: {status.get('orphanReason')}")
        print(f"processStopResult: {status.get('processStopResult')}")
    if state == "stale":
        print(f"staleReason: {status.get('staleReason')}")

    if state in {"failed", "timeout", "orphaned", "stale"}:
        stdout_tail = tail(directory / "stdout.log")
        stderr_tail = tail(directory / "stderr.log")
        if stdout_tail:
            print("\nstdout tail:")
            print(stdout_tail)
        if stderr_tail:
            print("\nstderr tail:")
            print(stderr_tail)


def print_progress(job_id: str, status: dict, directory: Path, window: float) -> None:
    print(f"job: {job_id}")
    print("status: still-running")
    print(f"jobStatus: {status.get('status')}")
    elapsed = elapsed_seconds(status)
    if elapsed is not None:
        print(f"elapsedSeconds: {elapsed}")
        ceiling = status.get("timeoutSeconds")
        if isinstance(ceiling, (int, float)):
            print(f"ceilingRemainingSeconds: {max(0, round(ceiling - elapsed))}")
    print(f"watchWindowSeconds: {int(window)}")
    stdout_tail = tail(directory / "stdout.log", lines=5)
    if stdout_tail:
        print("\nstdout tail:")
        print(stdout_tail)
    print("\njob still running - call .ai/tools/watch-job again now; do not end the turn.")


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Watch a file-backed long-running validation job."
    )
    parser.add_argument("job_id")
    parser.add_argument("--window", default=DEFAULT_WINDOW)
    parser.add_argument("--interval", default="1s")
    parser.add_argument("--timeout", dest="legacy_timeout", default=None, help=argparse.SUPPRESS)
    args = parser.parse_args()

    if args.legacy_timeout is not None:
        print(
            "error: watch-job no longer takes --timeout; the job ceiling is set by"
            " run-long-check --timeout. Use --window (max 8m) and call watch-job"
            " repeatedly until it reports a terminal result.",
            file=sys.stderr,
        )
        return 2

    try:
        window = parse_duration(args.window)
        interval = parse_duration(args.interval)
    except ValueError as exc:
        print(f"error: invalid duration: {exc}", file=sys.stderr)
        return 2
    if window <= 0 or interval <= 0:
        print("error: window and interval must be positive", file=sys.stderr)
        return 2
    if window > MAX_WINDOW_SECONDS:
        print("error: window exceeds maximum 8m; call watch-job repeatedly instead", file=sys.stderr)
        return 2

    directory = root_dir() / ".ai/vcm/jobs" / args.job_id
    status_path = directory / "status.json"
    lease_path = directory / "lease"

    found_deadline = time.time() + STATUS_WAIT_SECONDS
    while not status_path.is_file():
        if time.time() >= found_deadline:
            print(f"error: unknown job id: {args.job_id} (no {status_path})", file=sys.stderr)
            return 2
        time.sleep(0.2)

    deadline = time.time() + window
    last_status: dict = {}
    while True:
        renew_lease(lease_path)
        status = read_optional_json(status_path)
        if status:
            last_status = status
            state = status.get("status")
            if state in TERMINAL_EXIT_CODES:
                print_summary(args.job_id, status, directory)
                return TERMINAL_EXIT_CODES[state]

        if time.time() >= deadline:
            print_progress(args.job_id, last_status, directory, window)
            return 125

        time.sleep(interval)


if __name__ == "__main__":
    raise SystemExit(main())
