#!/usr/bin/env python3
"""Mine corpus chunks for n-gram terms that co-occur with corrections.

Corpus format: one line per chunk, leading "[session=<id> outcome=<state>]"
header followed by chunk text. <state> is `correction` or `clean`.

Scoring: tf_correction(term) / (tf_clean(term) + 1).

Stop words and short tokens are filtered. Default n-gram range: 1..2.
Output: JSON proposals sorted by score descending, capped by --top-k.
"""
from __future__ import annotations

import argparse
import json
import re
import sys
from collections import Counter
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent))
from collect_hook_event import assert_regular_file_destination  # noqa: E402

STOP_WORDS = {
    "the", "a", "an", "and", "or", "but", "is", "are", "was", "were",
    "be", "to", "of", "in", "on", "for", "at", "by", "from", "with",
    "this", "that", "these", "those", "it", "as", "if", "then", "than",
    "into", "out", "up", "down", "over", "under", "you", "we", "they",
    "i", "me", "my", "your", "their", "our", "his", "her", "its",
    "have", "has", "had", "do", "does", "did", "can", "could", "should",
    "would", "may", "might", "must", "shall", "will",
}

HEADER_RE = re.compile(r"^\[session=(\S+)\s+outcome=(\w+)\]\s*(.*)$")
EXTRACT_SESSION_RE = re.compile(r"^(\w+):\s*(.*?)\s*\[session_ref=([^\]]+)\]\s*$")
TOKEN_RE = re.compile(r"[a-z][a-z0-9_-]{2,}")

# Heuristic correction signals — words/phrases that tend to appear in user
# turns when the agent is being corrected. Case-insensitive substring match.
CORRECTION_PATTERNS = (
    "wait,", "actually,", "no,", "wrong", "revert", "undo", "let me try",
    "fix the", "instead", "i meant", "should be", "let's redo", "stop,",
)


def parse_chunks(corpus_path: Path):
    """Yield (outcome, text) pairs.

    Accepts two input formats:

    1. Strict: each line starts with ``[session=<id> outcome=<state>] <text>``.
       Used by hand-built fixtures and the integration tests.

    2. extract_sessions output: each line is ``<role>: <text> [session_ref=<ref>]``.
       Lines from the same session_ref are grouped; a per-session outcome is
       inferred by scanning the user-turn text for correction phrases
       (see CORRECTION_PATTERNS). Sessions without any correction signal
       are tagged ``clean``.

    Format 1 takes priority; if any line matches the strict header, that
    branch is used for the entire corpus. Otherwise format 2 is attempted.
    Lines matching neither are silently skipped.
    """
    text = corpus_path.read_text(encoding="utf-8")
    lines = text.splitlines()

    strict_chunks = []
    for line in lines:
        m = HEADER_RE.match(line)
        if m:
            strict_chunks.append((m.group(2), m.group(3).lower()))
    if strict_chunks:
        yield from strict_chunks
        return

    # Fallback: extract_sessions format. Group by session_ref.
    sessions: dict[str, dict[str, list[str]]] = {}
    for line in lines:
        m = EXTRACT_SESSION_RE.match(line)
        if not m:
            continue
        role, body, ref = m.group(1), m.group(2), m.group(3)
        bucket = sessions.setdefault(ref, {"user": [], "all": []})
        bucket["all"].append(body)
        if role == "user":
            bucket["user"].append(body)

    for ref, buckets in sessions.items():
        joined_user = " ".join(buckets["user"]).lower()
        is_correction = any(p in joined_user for p in CORRECTION_PATTERNS)
        outcome = "correction" if is_correction else "clean"
        yield outcome, " ".join(buckets["all"]).lower()


def tokens(text):
    return [t for t in TOKEN_RE.findall(text) if t not in STOP_WORDS]


def ngrams(toks, n):
    if len(toks) < n:
        return []
    return [" ".join(toks[i:i + n]) for i in range(len(toks) - n + 1)]


def score_terms(chunks, n_min=1, n_max=2):
    correction = Counter()
    clean = Counter()
    for outcome, text in chunks:
        toks = tokens(text)
        for n in range(n_min, n_max + 1):
            grams = ngrams(toks, n)
            if outcome == "correction":
                correction.update(grams)
            elif outcome == "clean":
                clean.update(grams)
    scores = []
    for term, c in correction.items():
        score = c / (clean.get(term, 0) + 1)
        scores.append((term, c, clean.get(term, 0), score))
    return scores


def parse_args():
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--corpus", required=True, type=Path)
    p.add_argument("--output", required=True, type=Path)
    p.add_argument("--top-k", type=int, default=10)
    p.add_argument("--min-score", type=float, default=2.0)
    return p.parse_args()


def main():
    args = parse_args()
    if not args.corpus.is_file():
        print(f"corpus not a regular file: {args.corpus}", file=sys.stderr)
        return 2
    try:
        assert_regular_file_destination(args.output, label="Proposals output")
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2
    chunks = list(parse_chunks(args.corpus))
    scores = score_terms(chunks)
    scores = [s for s in scores if s[3] >= args.min_score]
    scores.sort(key=lambda x: (-x[3], x[0]))  # primary: score desc, secondary: term asc for stability
    scores = scores[:args.top_k]
    args.output.write_text(json.dumps({
        "proposals": [
            {"term": term, "correction_count": cc, "clean_count": cl, "score": sc}
            for term, cc, cl, sc in scores
        ]
    }, indent=2, sort_keys=True), encoding="utf-8")
    return 0


if __name__ == "__main__":
    sys.exit(main())
