#!/usr/bin/env python3
"""Evaluate domain-classifier precision and recall against fixed fixtures."""

from __future__ import annotations

import argparse
import json
import pathlib
import sys
from collections import Counter
from typing import Any

from distill_learning import classify_events, load_domain_rules, packaged_domain_rules_path


DEFAULT_FIXTURES = (
    pathlib.Path(__file__).resolve().parents[1]
    / "fixtures"
    / "eval-fixtures"
    / "classifier_precision.json"
)

# Lazy-loaded so an unrelated import (e.g. `from evaluate_classifier import
# evaluate`) doesn't crash if the packaged tm-norge preset is missing or
# malformed. The cached value is fine because domain-rules JSON is shipped
# inside the package and doesn't change at runtime.
_FIXTURE_DOMAIN_RULES: list[dict[str, Any]] | None = None


def fixture_domain_rules() -> list[dict[str, Any]]:
    global _FIXTURE_DOMAIN_RULES
    if _FIXTURE_DOMAIN_RULES is None:
        _FIXTURE_DOMAIN_RULES = load_domain_rules(packaged_domain_rules_path("tm-norge"))
    return _FIXTURE_DOMAIN_RULES


def load_fixtures(path: pathlib.Path) -> list[dict[str, Any]]:
    payload = json.loads(path.read_text(encoding="utf-8"))
    if not isinstance(payload, list):
        raise ValueError("fixture file must contain a JSON array")
    return payload


def predicted_domains(lines: list[str]) -> set[str]:
    return {event["domain"] for event in classify_events({}, lines, fixture_domain_rules())}


def evaluate(fixtures: list[dict[str, Any]], min_precision: float, min_recall: float) -> dict[str, Any]:
    true_positive = 0
    false_positive = 0
    false_negative = 0
    false_positive_domains: Counter[str] = Counter()
    false_negative_domains: Counter[str] = Counter()
    cases: list[dict[str, Any]] = []

    for fixture in fixtures:
        lines = fixture.get("lines", [])
        if not isinstance(lines, list) or not all(isinstance(line, str) for line in lines):
            raise ValueError(f"fixture {fixture.get('id', '<unknown>')} must contain string lines")

        expected = set(fixture.get("expected_domains", []))
        predicted = predicted_domains(lines)

        true_positive += len(predicted & expected)
        fp = predicted - expected
        fn = expected - predicted
        false_positive += len(fp)
        false_negative += len(fn)
        false_positive_domains.update(fp)
        false_negative_domains.update(fn)
        cases.append(
            {
                "id": fixture.get("id"),
                "expected_domains": sorted(expected),
                "predicted_domains": sorted(predicted),
                "false_positive_domains": sorted(fp),
                "false_negative_domains": sorted(fn),
            }
        )

    precision_denominator = true_positive + false_positive
    recall_denominator = true_positive + false_negative
    precision = true_positive / precision_denominator if precision_denominator else 1.0
    recall = true_positive / recall_denominator if recall_denominator else 1.0
    passed = precision >= min_precision and recall >= min_recall

    return {
        "precision": precision,
        "recall": recall,
        "thresholds": {"min_precision": min_precision, "min_recall": min_recall},
        "passed": passed,
        "true_positive": true_positive,
        "false_positive": false_positive,
        "false_negative": false_negative,
        "false_positive_domains": sorted(false_positive_domains),
        "false_negative_domains": sorted(false_negative_domains),
        "cases": cases,
    }


def parse_args(argv: list[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--fixtures", type=pathlib.Path, default=DEFAULT_FIXTURES)
    parser.add_argument("--min-precision", type=float, default=0.85)
    parser.add_argument("--min-recall", type=float, default=0.85)
    return parser.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv or sys.argv[1:])
    result = evaluate(load_fixtures(args.fixtures), args.min_precision, args.min_recall)
    print(json.dumps(result, indent=2, sort_keys=True))
    return 0 if result["passed"] else 1


if __name__ == "__main__":
    raise SystemExit(main())
