#!/usr/bin/env python3
"""Validate data-contracts manifests and writers."""

from __future__ import annotations

import argparse
import ast
import fnmatch
import json
import os
import pathlib
import sys
from typing import Any

import artifact_writer


def load_contract_entries(base_dir: pathlib.Path | None = None) -> list[dict[str, Any]]:
    return artifact_writer.load_contracts(base_dir)


def merge_contracts(base_dir: pathlib.Path | None = None) -> tuple[dict[str, dict[str, Any]], dict[str, list[str]]]:
    entries = load_contract_entries(base_dir)
    merged: dict[str, dict[str, Any]] = {}
    sources: dict[str, list[str]] = {}
    for entry in entries:
        artifact_id = str(entry["id"]).strip()
        sources.setdefault(artifact_id, []).append(str(entry["_source"]))
        if artifact_id in merged:
            continue
        normalized = dict(entry)
        normalized.pop("_source", None)
        merged[artifact_id] = normalized
    return merged, sources


def _template_to_paths(template: str, state_dir: pathlib.Path) -> list[pathlib.Path]:
    template = template.lstrip("/")
    pattern = str((state_dir / template))
    return [pathlib.Path(path) for path in state_dir.glob(template)] if any(ch in template for ch in "*?[") else [pathlib.Path(pattern)]


def _has_wildcard(template: str) -> bool:
    return any(ch in template for ch in "*?[")


def _matches_entry(path: pathlib.Path, artifact: dict[str, Any], state_dir: pathlib.Path) -> bool:
    template = artifact["path_template"].lstrip("/")
    candidate = pathlib.Path(path)
    if candidate.is_relative_to(state_dir):
        relative = str(candidate.relative_to(state_dir))
    else:
        relative = str(candidate)
    return fnmatch.fnmatch(relative.replace("\\", "/"), template)


def _collect_scan_roots(artifact: dict[str, Any], state_dir: pathlib.Path) -> set[pathlib.Path]:
    template = str(artifact["path_template"]).lstrip("/")
    if _has_wildcard(template):
        marker_index = min((template.index(ch) for ch in ("*", "?", "[") if ch in template))
        prefix = template[:marker_index]
    else:
        prefix = pathlib.Path(template).parent.as_posix()

    candidate = (state_dir / prefix).resolve()
    if prefix and prefix != "." and candidate.exists() and candidate.is_dir():
        return {candidate}
    return {state_dir}


def check_contracts(state_dir: pathlib.Path, base_dir: pathlib.Path | None = None) -> list[str]:
    merged, _ = merge_contracts(base_dir)
    state_dir = pathlib.Path(state_dir).resolve()

    expected: set[pathlib.Path] = set()
    roots: set[pathlib.Path] = {state_dir}

    for artifact in merged.values():
        template = str(artifact["path_template"]).lstrip("/")
        resolved_template = template
        for path in _template_to_paths(resolved_template, state_dir):
            if path.is_file():
                expected.add(path.resolve())
        roots.update(_collect_scan_roots(artifact, state_dir))

    observed: set[pathlib.Path] = set()
    for root in sorted(roots):
        if not root.exists() or not root.is_dir():
            continue
        for path in root.rglob("*"):
            if not path.is_file():
                continue
            name = path.name
            if name in {".gitkeep", "README.md"}:
                continue
            if name.startswith("."):
                continue
            if any(name.endswith(suffix) for suffix in (".lock", ".bak", ".tmp", ".pending")):
                continue
            observed.add(path.resolve())

    orphans = sorted(path for path in observed if not any(_matches_entry(path, artifact, state_dir) for artifact in merged.values()))

    # keep only files that could plausibly be contract-tracked artifacts
    filtered_orphans = [
        file_path for file_path in orphans
        if any(str(file_path).endswith(ext) for ext in (".json", ".jsonl", ".sqlite", ".md", ".txt", ".yaml", ".yml"))
    ]

    return [f"orphan artifact file: {path.relative_to(state_dir)}" for path in filtered_orphans]


def _find_cycles(graph: dict[str, list[str]]) -> list[list[str]]:
    visiting: set[str] = set()
    visited: set[str] = set()
    stack: list[str] = []
    cycles: list[list[str]] = []

    def dfs(node: str) -> None:
        visiting.add(node)
        stack.append(node)
        for nxt in graph.get(node, []):
            if nxt not in graph:
                continue
            if nxt not in visiting:
                if nxt not in visited:
                    dfs(nxt)
            elif nxt in stack:
                start = stack.index(nxt)
                cycles.append(stack[start:] + [nxt])
        visiting.remove(node)
        visited.add(node)
        stack.pop()

    for node in sorted(graph):
        if node not in visited:
            dfs(node)
    return cycles


def check_manifest_merge(base_dir: pathlib.Path | None = None) -> list[str]:
    merged, sources = merge_contracts(base_dir)
    errors: list[str] = []

    for artifact_id, source_list in sorted(sources.items()):
        if len(source_list) > 1:
            errors.append(f"artifact '{artifact_id}' declared in multiple manifests: {', '.join(source_list)}")
        if len(source_list) > 1:
            artifact_defs = [entry for entry in load_contract_entries(base_dir) if str(entry["id"]).strip() == artifact_id]
            lifecycle_values = {
                json.dumps(entry.get("lifecycle", {}), sort_keys=True, separators=(",", ":"))
                for entry in artifact_defs
            }
            if len(lifecycle_values) > 1:
                conflicts = sorted({
                    json.dumps(entry.get("lifecycle", {}), sort_keys=True, separators=(",", ":"))
                    for entry in artifact_defs
                })
                errors.append(
                    f"artifact '{artifact_id}' has conflicting lifecycle declarations: "
                    f"{', '.join(conflicts)}"
                )

    graph = {artifact_id: [c for c in artifact.get("consumers", []) if isinstance(c, str)] for artifact_id, artifact in merged.items()}
    cycles = _find_cycles(graph)
    for cycle in cycles:
        errors.append("producer-consumer cycle: " + " -> ".join(cycle))

    return errors


def _extract_writer_ids(module_path: pathlib.Path) -> list[str]:
    tree = ast.parse(module_path.read_text(encoding="utf-8"))
    values: set[str] = set()
    for node in ast.walk(tree):
        if not isinstance(node, ast.Call):
            continue
        func = node.func
        func_name = None
        if isinstance(func, ast.Name):
            func_name = func.id
        elif isinstance(func, ast.Attribute):
            func_name = func.attr
        if func_name != "write_artifact":
            continue
        if not node.args:
            continue
        first = node.args[0]
        if isinstance(first, ast.Constant) and isinstance(first.value, str):
            values.add(first.value)
    return sorted(values)


def check_pending_writes(writer_module: str, base_dir: pathlib.Path | None = None) -> list[str]:
    path = pathlib.Path(writer_module)
    if not path.exists():
        return [f"writer module not found: {path}"]
    if path.suffix != ".py":
        return [f"writer module must be a .py file: {path}"]

    registry = set(merge_contracts(base_dir)[0].keys())
    artifact_ids = _extract_writer_ids(path)
    if not artifact_ids:
        return [f"no write_artifact(...) calls found in {path}"]

    return [
        f"writer references unregistered artifact id: {artifact_id}"
        for artifact_id in artifact_ids
        if artifact_id not in registry
    ]


def show_registry(base_dir: pathlib.Path | None = None) -> None:
    merged, _ = merge_contracts(base_dir)
    print(json.dumps(merged, indent=2, sort_keys=True))


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--state-dir", help="State directory for post-hoc contract checks")
    parser.add_argument("--check-contracts", action="store_true", help="check runtime state against manifest")
    parser.add_argument("--check-manifest-merge", action="store_true", help="pre-merge manifest validation")
    parser.add_argument("--check-pending-writes", metavar="WRITER", help="check writer module registrations")
    parser.add_argument("--show-registry", action="store_true", help="print merged registry")
    parser.add_argument("--contracts-dir", help="override contracts directory")
    args = parser.parse_args(argv)

    data_dir = pathlib.Path(args.contracts_dir) if args.contracts_dir else None
    if args.contracts_dir:
        os.environ["ALC_DATA_CONTRACTS_DIR"] = args.contracts_dir

    if args.check_contracts:
        if not args.state_dir:
            print("--check-contracts requires --state-dir", file=sys.stderr)
            return 2
        errors = check_contracts(pathlib.Path(args.state_dir), data_dir)
        for error in errors:
            print(error, file=sys.stderr)
        if errors:
            return 1

    if args.check_manifest_merge:
        errors = check_manifest_merge(data_dir)
        for error in errors:
            print(error, file=sys.stderr)
        if errors:
            return 1

    if args.check_pending_writes:
        errors = check_pending_writes(args.check_pending_writes, data_dir)
        for error in errors:
            print(error, file=sys.stderr)
        if errors:
            return 1

    if args.show_registry:
        show_registry(data_dir)

    if not any(
        (
            args.check_contracts,
            args.check_manifest_merge,
            args.check_pending_writes,
            args.show_registry,
        )
    ):
        parser.print_usage()
        return 2

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
