#!/usr/bin/env python3
"""Replay hook event JSONL, upgrading old rows to the latest schema.

Reads --input JSONL, normalizes each row through the collector's allowlist,
stamps the current collector schema_version, and writes to --output.
Preserves the original ``ts`` so downstream time-series consumers (Phase 2B
effectiveness, 3B causal probe, Phase 4 federation) keep their ordering.
"""
from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path

from collect_hook_event import (
    assert_regular_file_destination,
    bounded,
    normalize_event,
)


def parse_args():
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--input", required=True, type=Path)
    p.add_argument("--output", required=True, type=Path)
    p.add_argument("--skip-malformed", action="store_true",
                   help="Skip lines that fail JSON decode instead of erroring")
    p.add_argument("--dry-run", action="store_true",
                   help="Report what would be written; do not write output")
    return p.parse_args()


def iter_rows(path: Path, skip_malformed: bool):
    with path.open() as fh:
        for lineno, line in enumerate(fh, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                if skip_malformed:
                    print(f"skip lineno={lineno}", file=sys.stderr)
                    continue
                raise


def replay_normalize(row):
    """Re-normalize via live allowlist, preserving original ts if present.

    ``normalize_event`` stamps ``ts = now()`` unconditionally. For replay we
    want to retain the input's ts so downstream time-series consumers don't
    see every row collapsed to replay-wall-clock. The restored ts is routed
    through ``bounded()`` — the same validator the collector applies to every
    other string field — so a secret-shaped or oversize ts string drops the
    field instead of leaking the value.
    """
    original_ts = row.get("ts") if isinstance(row, dict) else None
    normalized = normalize_event(row)
    if isinstance(original_ts, str) and original_ts:
        scrubbed = bounded(original_ts)
        if scrubbed:
            normalized["ts"] = scrubbed
    return normalized


def main():
    args = parse_args()
    if not args.input.is_file():
        print(f"input not a regular file: {args.input}", file=sys.stderr)
        return 2

    if args.dry_run:
        count = sum(1 for _ in iter_rows(args.input, args.skip_malformed))
        print(f"would_write_rows={count}")
        return 0

    try:
        assert_regular_file_destination(args.output, label="Replay output")
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2

    fd = os.open(
        str(args.output),
        os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
        0o600,
    )
    with os.fdopen(fd, "w", encoding="utf-8") as out:
        for raw in iter_rows(args.input, args.skip_malformed):
            out.write(json.dumps(replay_normalize(raw), sort_keys=True) + "\n")
    return 0


if __name__ == "__main__":
    sys.exit(main())
