#!/usr/bin/env python3
"""Semantically dedup the improvement-queue.jsonl using trigrams or embeddings.

Default backend: character-trigram Sørensen-Dice (stdlib only).
Optional backend: sentence-transformers (gracefully falls back if missing).
"""
from __future__ import annotations

import argparse
import fcntl
import json
import sys
from pathlib import Path

from collect_hook_event import assert_regular_file_destination


def trigrams(text: str) -> set:
    text = text.lower()
    return {text[i:i + 3] for i in range(len(text) - 2)} if len(text) >= 3 else {text}


def dice(a: set, b: set) -> float:
    """Sørensen-Dice coefficient over trigram sets.

    Chosen over Jaccard because Dice yields a more tractable threshold
    (~0.80) for near-paraphrase detection: Jaccard penalises every
    disjoint trigram twice (numerator and denominator), so realistic
    paraphrase pairs sit around 0.65–0.70 under it.
    """
    if not a or not b:
        return 0.0
    return 2 * len(a & b) / (len(a) + len(b))


def embed_backend(texts):
    try:
        from sentence_transformers import SentenceTransformer
    except ImportError:
        print("queue_dedup: sentence-transformers not installed; falling back to trigram",
              file=sys.stderr)
        return None
    model = SentenceTransformer("BAAI/bge-small-en-v1.5")
    return model.encode(texts, normalize_embeddings=True)


def cosine(u, v) -> float:
    return float(sum(a * b for a, b in zip(u, v)))


def _bucket_key(row):
    """Bucket by (kind, domain, skill) so dedup never collapses across them.

    A near-identical ``text`` can be a legitimately distinct candidate in a
    different domain or for a different skill (e.g. the same proposed gate
    wording posted against ``cloudflare`` vs ``rails``, or the same retirement
    candidate adjustment raised for two different skills). Use "" for absent
    keys so missing/null values still group consistently.
    """
    return (
        row.get("kind", "") or "",
        row.get("domain", "") or "",
        row.get("skill", "") or "",
    )


def find_duplicates(rows, backend, threshold):
    texts = [r.get("text", "") for r in rows]
    if backend == "embed":
        vectors = embed_backend(texts)
        if vectors is None:
            backend = "trigram"

    if backend == "trigram":
        sigs = [trigrams(t) for t in texts]
        def sim(i, j): return dice(sigs[i], sigs[j])
    else:
        def sim(i, j): return cosine(vectors[i], vectors[j])

    # Group row indices by (kind, domain, skill). Only compare rows within
    # the same bucket; rows in different buckets are never collapsed.
    buckets: dict[tuple[str, str, str], list[int]] = {}
    for idx, row in enumerate(rows):
        buckets.setdefault(_bucket_key(row), []).append(idx)

    drop = set()
    for indices in buckets.values():
        for a, i in enumerate(indices):
            if i in drop:
                continue
            for j in indices[a + 1:]:
                if j in drop:
                    continue
                if sim(i, j) >= threshold:
                    drop.add(j)
    return drop


def order_by_keep(rows, keep):
    if keep == "oldest":
        return sorted(range(len(rows)), key=lambda i: rows[i].get("ts", ""))
    return sorted(range(len(rows)), key=lambda i: rows[i].get("ts", ""), reverse=True)


def parse_args():
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--queue", required=True, type=Path)
    p.add_argument("--backend", choices=["trigram", "embed"], default="trigram")
    p.add_argument("--threshold", type=float, default=0.80)
    p.add_argument("--keep", choices=["oldest", "newest"], default="oldest")
    p.add_argument("--dry-run", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    try:
        assert_regular_file_destination(args.queue, label="Dedup queue")
    except ValueError as exc:
        print(str(exc), file=sys.stderr)
        return 2
    if not args.queue.is_file():
        print(f"queue not a regular file: {args.queue}", file=sys.stderr)
        return 2

    with args.queue.open("r+", encoding="utf-8") as fh:
        fcntl.flock(fh, fcntl.LOCK_EX)
        lines = [ln for ln in fh.read().splitlines() if ln]
        rows = [json.loads(ln) for ln in lines]
        if not rows:
            return 0

        priority = order_by_keep(rows, args.keep)
        reordered = [rows[i] for i in priority]
        drop = find_duplicates(reordered, args.backend, args.threshold)
        kept = [r for i, r in enumerate(reordered) if i not in drop]

        if args.dry_run:
            print(f"would_remove={len(drop)} kept={len(kept)} backend={args.backend}")
            return 0

        fh.seek(0)
        fh.truncate()
        for r in kept:
            fh.write(json.dumps(r, sort_keys=True) + "\n")
    return 0


if __name__ == "__main__":
    sys.exit(main())
