#!/usr/bin/env python3
"""Append a bounded, scrubbed hook event to agent-learning JSONL telemetry."""

from __future__ import annotations

import argparse
import contextlib
import datetime as dt
import fcntl
import json
import os
import pathlib
import re
import stat
import sys
from typing import Any

from agent_dispatch import (
    DEFAULT_TELEMETRY_CONFIG,
    bounded,
    normalize_agent_dispatch,
    normalize_path,
    telemetry_config_from_payload,
)
from scrub_secrets import scrub
from state_paths import repo_state_dir, resolve_state_dir


DEFAULT_MAX_HOOK_EVENT_BYTES = 5_000_000
MAX_HOOK_EVENT_BACKUPS = 3

# Bumped when bounded optional fields are added. Phase 2B (gate effectiveness)
# and Phase 3B (causal probe) depend on the discriminator; downstream consumers
# branch on schema_version.
SCHEMA_VERSION = 3

# Caps for the list-valued passthrough fields. Tight bounds keep telemetry
# bounded even under accidental fan-out. Per-member cap drops (not truncates)
# oversize gate ids — a truncated id would collide with neighbours, which is
# worse than the id being absent.
MAX_GATE_LOADED_IDS = 64
MAX_GATE_LOADED_ID_LEN = 64
MAX_CORRELATION_ID_LEN = 128

# Probe-decision passthrough (added P3B-B). Mirrors the gate_loaded_ids
# bound: 64 entries max, each gate_id capped at MAX_GATE_LOADED_ID_LEN.
# Decision values are a closed set; out-of-set values drop the entry.
MAX_PROBE_DECISIONS = 64
_VALID_PROBE_DECISIONS = {"load", "skip"}

# normalize_event uses an allowlist (constructs a fresh dict with only
# vetted keys), so raw-payload keys like prompt/tool_output/transcript are
# dropped by construction. No blocklist is needed at the entry point.
UUID_RE = re.compile(r"\b[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}\b", re.I)


def snake(value: str) -> str:
    value = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", value)
    value = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", value)
    value = re.sub(r"[^A-Za-z0-9]+", "_", value)
    return value.strip("_").lower() or "unknown"


def slug(value: str) -> str:
    return re.sub(r"[^a-z0-9]+", "-", value.lower()).strip("-") or "unknown"


def safe_session(value: Any) -> str:
    text = str(value or "session")
    if UUID_RE.search(text):
        return "session"
    return slug(text)[:80]


def parse_json_or_label(raw: str, source: str) -> dict[str, Any]:
    try:
        value = json.loads(raw)
    except json.JSONDecodeError:
        return {
            "event": "MalformedJson",
            "runtime": "collect_hook_event",
            "label": bounded(f"malformed_json_{source}", 80),
        }
    if isinstance(value, dict):
        return value
    return {
        "event": "MalformedJson",
        "runtime": "collect_hook_event",
        "label": bounded(f"non_object_json_{source}", 80),
    }


def assert_regular_file_destination(path: pathlib.Path, *, label: str) -> None:
    try:
        mode = path.lstat().st_mode
    except FileNotFoundError:
        return
    if stat.S_ISLNK(mode):
        raise ValueError(f"{label} is a symlink; refusing to write output to {path}")
    if not stat.S_ISREG(mode):
        raise ValueError(f"{label} is not a regular file; refusing to write output to {path}")


def skill_from_path(path_value: str | None) -> str | None:
    if not path_value:
        return None
    parts = pathlib.PurePath(path_value).parts
    for index, part in enumerate(parts):
        if part == "skills" and index + 1 < len(parts):
            return parts[index + 1]
    return None


def normalize_event(
    raw: dict[str, Any],
    repo_arg: pathlib.Path | None = None,
    telemetry_config: dict[str, bool] | None = None,
) -> dict[str, Any]:
    repo = repo_arg or pathlib.Path(str(raw.get("repo") or raw.get("cwd") or ".")).expanduser()
    telemetry = telemetry_config or DEFAULT_TELEMETRY_CONFIG
    path = normalize_path(raw.get("path") or raw.get("file") or raw.get("skill_path"), repo)
    event = {
        "ts": dt.datetime.now(dt.timezone.utc).isoformat(),
        "event": snake(str(raw.get("event") or raw.get("type") or raw.get("hook_event") or "unknown")),
        "runtime": slug(str(raw.get("runtime") or raw.get("source") or "unknown")),
        "repo": str(repo.resolve()) if repo else None,
        "session_id": safe_session(raw.get("session_id") or raw.get("session") or raw.get("conversation_id")),
    }
    skill = bounded(raw.get("skill") or raw.get("skill_name")) or skill_from_path(path)
    if skill:
        event["skill"] = skill
    tool = bounded(raw.get("tool") or raw.get("tool_name"))
    if tool:
        event["tool"] = tool
    outcome = bounded(raw.get("outcome") or raw.get("status"))
    if outcome:
        event["outcome"] = outcome
    if path:
        event["path"] = path
    scope = bounded(raw.get("scope"))
    if scope:
        event["scope"] = scope
    label = bounded(raw.get("label") or raw.get("reason"))
    if label:
        event["label"] = label
    command = bounded(raw.get("command"), 80)
    if command:
        event["command_class"] = command.split()[0]

    event.update(normalize_agent_dispatch(raw, repo, telemetry))

    # v2/v3 passthrough fields. Validation drops malformed values silently
    # rather than crashing, so a misshapen payload still emits a row.
    raw_correlation = raw.get("correlation_id")
    if isinstance(raw_correlation, str) and 0 < len(raw_correlation) <= MAX_CORRELATION_ID_LEN:
        # Run through bounded() so a secret-shaped correlation_id drops only
        # the field instead of triggering the whole-row scrubber reject.
        scrubbed = bounded(raw_correlation, MAX_CORRELATION_ID_LEN)
        if scrubbed:
            event["correlation_id"] = scrubbed

    gate_loaded_ids = raw.get("gate_loaded_ids")
    if isinstance(gate_loaded_ids, list):
        coerced: list[str] = []
        for member in gate_loaded_ids:
            if not isinstance(member, (str, int)):
                continue
            text = str(member)
            # Per-member length cap: drop oversized ids rather than truncating,
            # since a truncated id might collide with a real one.
            if len(text) > MAX_GATE_LOADED_ID_LEN:
                continue
            coerced.append(text)
            if len(coerced) >= MAX_GATE_LOADED_IDS:
                break
        if coerced:
            event["gate_loaded_ids"] = coerced

    # probe_decisions: list of {gate_id, decision} dicts naming which cohort
    # this session falls into per gate. Malformed entries drop silently —
    # consistent with the gate_loaded_ids policy above. Whole-field is dropped
    # only when the payload itself is not a list.
    probe_decisions = raw.get("probe_decisions")
    if isinstance(probe_decisions, list):
        coerced_decisions: list[dict[str, str]] = []
        for member in probe_decisions:
            if not isinstance(member, dict):
                continue
            gate_id_raw = member.get("gate_id")
            decision_raw = member.get("decision")
            if not isinstance(gate_id_raw, (str, int)):
                continue
            gate_id_text = str(gate_id_raw)
            if not gate_id_text or len(gate_id_text) > MAX_GATE_LOADED_ID_LEN:
                continue
            gate_id_bounded = bounded(gate_id_text, MAX_GATE_LOADED_ID_LEN)
            if not gate_id_bounded:
                continue
            if not isinstance(decision_raw, str):
                continue
            decision_bounded = bounded(decision_raw, 16)
            if decision_bounded not in _VALID_PROBE_DECISIONS:
                continue
            coerced_decisions.append({"gate_id": gate_id_bounded, "decision": decision_bounded})
            if len(coerced_decisions) >= MAX_PROBE_DECISIONS:
                break
        if coerced_decisions:
            event["probe_decisions"] = coerced_decisions

    event["schema_version"] = SCHEMA_VERSION
    return {key: value for key, value in event.items() if value is not None}


def default_output(repo: pathlib.Path, state_dir: str | None, personal: str | None) -> pathlib.Path:
    return repo_state_dir(repo, state_dir, personal) / "hook-events.jsonl"


def _retention_max_from_payload(data: dict[str, Any]) -> int | None:
    retention = data.get("retention") or {}
    value = retention.get("max_hook_event_bytes") if isinstance(retention, dict) else None
    if isinstance(value, int) and value > 0:
        return value
    return None


def _load_json_config(path: pathlib.Path) -> dict[str, Any] | None:
    try:
        data = json.loads(path.read_text(encoding="utf-8"))
    except (OSError, ValueError, TypeError):
        return None
    return data if isinstance(data, dict) else None


def load_max_hook_event_bytes(repo: pathlib.Path, state_dir: str | None = None, personal: str | None = None) -> int:
    """Read `retention.max_hook_event_bytes` from .agent-learning.json.

    Returns DEFAULT_MAX_HOOK_EVENT_BYTES on any error so a missing or
    malformed config never breaks the hook adapter.
    """
    candidates: list[pathlib.Path] = []
    integration_config = repo / ".agent-learning.json"
    integration_payload = _load_json_config(integration_config)
    if integration_payload is not None:
        value = _retention_max_from_payload(integration_payload)
        if value is not None:
            return value
        configured_state_dir = integration_payload.get("state_dir")
        if isinstance(configured_state_dir, str) and configured_state_dir:
            candidates.append(pathlib.Path(configured_state_dir).expanduser() / "config.json")

    candidates.append(resolve_state_dir(state_dir, personal, repo) / "config.json")

    seen: set[pathlib.Path] = set()
    for candidate in candidates:
        resolved = candidate.expanduser().resolve()
        if resolved in seen:
            continue
        seen.add(resolved)
        payload = _load_json_config(resolved)
        if payload is None:
            continue
        value = _retention_max_from_payload(payload)
        if value is not None:
            return value
    return DEFAULT_MAX_HOOK_EVENT_BYTES


def load_telemetry_config(
    repo: pathlib.Path,
    state_dir: str | None = None,
    personal: str | None = None,
) -> dict[str, bool]:
    """Read bounded telemetry feature flags from repo/state config.

    Repo-local `.agent-learning.json` wins over generated state config because
    the main thread in each repo owns what its hooks may capture.
    """
    candidates: list[pathlib.Path] = []
    integration_config = repo / ".agent-learning.json"
    integration_payload = _load_json_config(integration_config)
    if integration_payload is not None:
        configured_state_dir = integration_payload.get("state_dir")
        if isinstance(configured_state_dir, str) and configured_state_dir:
            candidates.append(pathlib.Path(configured_state_dir).expanduser() / "config.json")
    candidates.append(resolve_state_dir(state_dir, personal, repo) / "config.json")

    config = dict(DEFAULT_TELEMETRY_CONFIG)
    seen: set[pathlib.Path] = set()
    for candidate in reversed(candidates):
        try:
            resolved = candidate.expanduser().resolve()
        except OSError:
            continue
        if resolved in seen:
            continue
        seen.add(resolved)
        payload = _load_json_config(resolved)
        if payload is not None:
            config.update(telemetry_config_from_payload(payload))
    if integration_payload is not None:
        config.update(telemetry_config_from_payload(integration_payload))
    return config


def _chmod_quiet(path: pathlib.Path, mode: int) -> None:
    try:
        os.chmod(path, mode)
    except OSError:
        # Best-effort: permissions changes shouldn't break the hook flow.
        pass


def _lock_path_for(output: pathlib.Path) -> pathlib.Path:
    """Sidecar lock path that both rotate and append agree on.

    Co-located with the log so all writers see the same lock target
    regardless of how the log path was resolved.
    """
    return output.with_name(f"{output.name}.lock")


@contextlib.contextmanager
def _rotation_lock(output: pathlib.Path, *, exclusive: bool):
    """flock(LOCK_EX|LOCK_SH) on the sidecar lock file.

    Pattern mirrors refresh_learning_state.py: open a long-lived fd, flock
    it for the duration of the critical section, release on exit. The lock
    file is created mode 0o600 on first use and never deleted; deletion
    would race with concurrent lockers.
    """
    lock_path = _lock_path_for(output)
    # O_CREAT so the first caller in a fresh state dir creates the lock
    # file; subsequent callers just open it. 0o600 keeps the lock file
    # from being world-visible — symmetric with the log permissions.
    fd = os.open(str(lock_path), os.O_RDWR | os.O_CREAT, 0o600)
    try:
        fcntl.flock(fd, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH)
        try:
            yield
        finally:
            fcntl.flock(fd, fcntl.LOCK_UN)
    finally:
        os.close(fd)


def rotate_if_needed(output: pathlib.Path, max_bytes: int) -> None:
    """If output exceeds max_bytes, rotate to a timestamped .bak file.

    Holds LOCK_EX on the sidecar lock for the entire stat+rename window
    so concurrent appenders (which take LOCK_SH around their open+write)
    cannot keep writing into the renamed inode.
    """
    with _rotation_lock(output, exclusive=True):
        try:
            size = output.stat().st_size
        except OSError:
            return
        if size <= max_bytes:
            return
        stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
        backup = output.with_name(f"{output.name}.{stamp}.bak")
        try:
            output.rename(backup)
        except OSError:
            return
        _chmod_quiet(backup, 0o600)
        # Cap old backups: keep the newest MAX_HOOK_EVENT_BACKUPS.
        # The glob excludes the sidecar lock file because *.lock does not
        # match *.bak.
        try:
            backups = sorted(
                output.parent.glob(f"{output.name}.*.bak"),
                key=lambda p: p.name,
            )
            for stale in backups[:-MAX_HOOK_EVENT_BACKUPS]:
                try:
                    stale.unlink()
                except OSError:
                    pass
        except OSError:
            pass


def load_input(args: argparse.Namespace) -> dict[str, Any]:
    if args.event:
        return parse_json_or_label(args.event, "event")
    if args.input:
        try:
            payload_text = pathlib.Path(args.input).read_text(encoding="utf-8")
        except OSError as error:
            raise ValueError(f"cannot read --input {args.input}: {error}") from error
        return parse_json_or_label(payload_text, "input")
    data = sys.stdin.read().strip()
    return parse_json_or_label(data or "{}", "stdin")


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--event", help="JSON object to append. Defaults to stdin.")
    parser.add_argument("--input", help="Path containing one JSON object.")
    parser.add_argument("--repo", default=".")
    parser.add_argument("--state-dir")
    parser.add_argument("--personal")
    parser.add_argument("--output")
    args = parser.parse_args(argv)

    repo = pathlib.Path(args.repo).expanduser().resolve()
    output = pathlib.Path(args.output) if args.output else default_output(repo, args.state_dir, args.personal)
    output.parent.mkdir(parents=True, exist_ok=True)
    try:
        event = normalize_event(load_input(args), repo, load_telemetry_config(repo, args.state_dir, args.personal))
        assert_regular_file_destination(output, label="Hook event output")
    except ValueError as error:
        print(str(error), file=sys.stderr)
        return 1
    rendered = scrub(json.dumps(event, sort_keys=True, separators=(",", ":")))
    if "[REDACTED" in rendered:
        print("event contains secret-like content after normalization", file=sys.stderr)
        return 1
    max_bytes = load_max_hook_event_bytes(repo, args.state_dir, args.personal)
    rotate_if_needed(output, max_bytes)
    # os.open with mode=0o600 narrows the file-creation window so a new log
    # file is never readable by group/other, even briefly. mode is ignored on
    # existing files; _chmod_quiet below upgrades pre-existing logs that
    # predate this hardening.
    #
    # LOCK_EX around the open+write window: cooperates with rotate_if_needed
    # (which also takes LOCK_EX), AND serializes concurrent appenders. The
    # previous LOCK_SH permitted concurrent appenders, relying on O_APPEND
    # being atomic up to PIPE_BUF (4096 on Linux). With a full v2 event --
    # MAX_GATE_LOADED_IDS * MAX_GATE_LOADED_ID_LEN = 4096B for
    # gate_loaded_ids alone, plus probe_decisions and headers -- one line
    # routinely exceeds PIPE_BUF, two concurrent appenders interleave bytes
    # mid-line, evaluate_gate_effectiveness.load_sessions sees torn JSON,
    # and (combined with the unguarded json.loads pre-C7) the entire
    # scoring pass aborts.
    try:
        with _rotation_lock(output, exclusive=True):
            fd = os.open(str(output), os.O_WRONLY | os.O_APPEND | os.O_CREAT, 0o600)
            try:
                os.write(fd, (rendered + "\n").encode("utf-8"))
            finally:
                os.close(fd)
    except OSError as error:
        print(f"failed to write hook event log to {output}: {error}", file=sys.stderr)
        return 1
    _chmod_quiet(output, 0o600)
    print(f"appended hook event to {output}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
