#!/usr/bin/env python3
import argparse
import fnmatch
import json
import subprocess
import sys
from pathlib import Path

try:
    import tomllib
except ModuleNotFoundError:
    tomllib = None


def relative_path(path: Path, root: Path) -> str:
    try:
        rel = path.absolute().relative_to(root.absolute())
    except ValueError:
        rel = path.resolve().relative_to(root.resolve())
    return "." if str(rel) == "." else rel.as_posix()


def find_cargo_root(project_root: Path) -> Path:
    if (project_root / "Cargo.toml").is_file():
        return project_root

    for child in sorted(project_root.iterdir(), key=lambda path: path.name):
        if child.is_dir() and (child / "Cargo.toml").is_file():
            return child

    raise SystemExit(f"Could not find Cargo.toml in {project_root} or its direct child directories.")


def load_toml(path: Path) -> dict:
    if tomllib is None:
        raise SystemExit("Python 3.11+ is required to parse Cargo.toml files.")
    return tomllib.loads(path.read_text())


def load_cargo_metadata(cargo_root: Path) -> dict | None:
    result = subprocess.run(
        ["cargo", "metadata", "--format-version", "1", "--no-deps"],
        cwd=cargo_root,
        check=False,
        text=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    if result.returncode != 0:
        return None
    return json.loads(result.stdout)


def path_is_excluded(relative: str, excludes: list[str]) -> bool:
    normalized = relative.rstrip("/")
    for pattern in excludes:
        normalized_pattern = pattern.rstrip("/")
        if (
            normalized == normalized_pattern
            or normalized.startswith(f"{normalized_pattern}/")
            or fnmatch.fnmatch(normalized, normalized_pattern)
        ):
            return True
    return False


def valid_workspace_member_manifests(cargo_root: Path) -> list[Path]:
    data = load_toml(cargo_root / "Cargo.toml")
    workspace = data.get("workspace", {})
    members = workspace.get("members", [])
    excludes = workspace.get("exclude", [])

    manifests: list[Path] = []
    seen: set[Path] = set()

    if data.get("package", {}).get("name"):
        manifests.append(cargo_root / "Cargo.toml")
        seen.add(cargo_root / "Cargo.toml")

    for member in members:
        candidates = sorted(cargo_root.glob(member)) if "*" in member else [cargo_root / member]
        for candidate in candidates:
            if not candidate.is_dir():
                continue
            member_rel = relative_path(candidate, cargo_root)
            if path_is_excluded(member_rel, excludes):
                continue
            manifest = candidate / "Cargo.toml"
            if not manifest.is_file() or manifest in seen:
                continue
            manifests.append(manifest)
            seen.add(manifest)

    return manifests


def package_name(manifest: Path) -> str | None:
    package = load_toml(manifest).get("package", {})
    name = package.get("name")
    return name if isinstance(name, str) else None


def dependency_names_from_table(table: dict) -> set[str]:
    names: set[str] = set()
    for name, value in table.items():
        if isinstance(value, dict):
            package = value.get("package")
            names.add(package if isinstance(package, str) else name)
        else:
            names.add(name)
    return names


def manifest_dependency_names(manifest: Path) -> set[str]:
    data = load_toml(manifest)
    names: set[str] = set()
    for key in ("dependencies", "dev-dependencies", "build-dependencies"):
        table = data.get(key, {})
        if isinstance(table, dict):
            names.update(dependency_names_from_table(table))

    target = data.get("target", {})
    if isinstance(target, dict):
        for target_table in target.values():
            if not isinstance(target_table, dict):
                continue
            for key in ("dependencies", "dev-dependencies", "build-dependencies"):
                table = target_table.get(key, {})
                if isinstance(table, dict):
                    names.update(dependency_names_from_table(table))
    return names


def infer_layer(module_cargo_path: str, cargo_root: Path, project_root: Path) -> tuple[str, str]:
    if module_cargo_path == ".":
        return ("root", relative_path(cargo_root, project_root))
    first = module_cargo_path.split("/", 1)[0]
    return (first, relative_path(cargo_root / first, project_root))


def rust_files_under(module_dir: Path, project_root: Path, child: str) -> list[str]:
    root = module_dir / child
    if not root.is_dir():
        return []
    return sorted(
        relative_path(path, project_root)
        for path in root.rglob("*.rs")
        if path.is_file()
    )


def package_records_from_metadata(project_root: Path, cargo_root: Path, metadata: dict) -> list[dict]:
    workspace_root = Path(metadata["workspace_root"])
    workspace_member_ids = set(metadata["workspace_members"])
    member_order = {member_id: index for index, member_id in enumerate(metadata["workspace_members"])}

    packages = [
        package
        for package in metadata["packages"]
        if package["id"] in workspace_member_ids
    ]
    packages.sort(key=lambda package: member_order[package["id"]])

    workspace_names = {package["name"] for package in packages}
    records = []
    for package in packages:
        dependency_names = {
            dependency["name"]
            for dependency in package.get("dependencies", [])
            if dependency.get("name") in workspace_names
        }
        records.append(
            {
                "name": package["name"],
                "manifest": Path(package["manifest_path"]),
                "workspaceDependencies": sorted(dependency_names),
            }
        )
    return records


def package_records_from_manifests(cargo_root: Path) -> list[dict]:
    manifests = valid_workspace_member_manifests(cargo_root)
    names_by_manifest = {
        manifest: name
        for manifest in manifests
        if (name := package_name(manifest)) is not None
    }
    workspace_names = set(names_by_manifest.values())
    records = []

    for manifest, name in names_by_manifest.items():
        dependency_names = manifest_dependency_names(manifest) & workspace_names
        records.append(
            {
                "name": name,
                "manifest": manifest,
                "workspaceDependencies": sorted(dependency_names),
            }
        )
    return records


def build_index(project_root: Path) -> dict:
    cargo_root = find_cargo_root(project_root)
    metadata = load_cargo_metadata(cargo_root)
    records = (
        package_records_from_metadata(project_root, cargo_root, metadata)
        if metadata is not None
        else package_records_from_manifests(cargo_root)
    )

    layers: list[dict] = []
    layer_by_name: dict[str, dict] = {}

    for record in records:
        manifest_path = record["manifest"]
        module_dir = manifest_path.parent
        module_path = relative_path(module_dir, project_root)
        manifest_rel = relative_path(manifest_path, project_root)
        module_cargo_path = relative_path(module_dir, cargo_root)
        layer_name, layer_path = infer_layer(module_cargo_path, cargo_root, project_root)

        if layer_name not in layer_by_name:
            layer = {
                "name": layer_name,
                "path": layer_path,
                "modules": [],
            }
            layer_by_name[layer_name] = layer
            layers.append(layer)

        layer_by_name[layer_name]["modules"].append(
            {
                "name": record["name"],
                "path": module_path,
                "manifest": manifest_rel,
                "architectureDoc": f"{module_path}/ARCHITECTURE.md"
                if module_path != "."
                else "ARCHITECTURE.md",
                "workspaceDependencies": record["workspaceDependencies"],
                "files": {
                    "source": rust_files_under(module_dir, project_root, "src"),
                    "tests": rust_files_under(module_dir, project_root, "tests"),
                },
            }
        )

    return {
        "schemaVersion": 1,
        "kind": "module-index",
        "generatedBy": ".ai/tools/generate-module-index",
        "workspace": {
            "root": relative_path(cargo_root, project_root),
            "manifest": relative_path(cargo_root / "Cargo.toml", project_root),
        },
        "layerInference": {
            "method": "path-first-segment-from-cargo-root",
            "rootPackageLayer": "root",
        },
        "layers": layers,
    }


def main() -> int:
    parser = argparse.ArgumentParser(description="Generate .ai/generated/module-index.json from cargo metadata.")
    parser.add_argument("--check", action="store_true", help="Fail if the generated module index differs from the current file.")
    parser.add_argument("--print", action="store_true", help="Print generated JSON instead of writing it.")
    parser.add_argument("--output", default=".ai/generated/module-index.json", help="Output path relative to the project root.")
    args = parser.parse_args()

    root = Path(__file__).resolve().parents[2]
    generated = json.dumps(build_index(root), indent=2, sort_keys=False) + "\n"
    output_path = root / args.output

    if args.check:
        if not output_path.is_file():
            sys.stderr.write(f"Missing generated module index: {output_path.relative_to(root)}\n")
            return 1
        current = output_path.read_text()
        if current != generated:
            sys.stderr.write(f"Stale generated module index: {output_path.relative_to(root)}\n")
            return 1
        print("generate-module-index check passed")
        return 0

    if args.print:
        sys.stdout.write(generated)
        return 0

    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(generated)
    print(f"wrote {output_path.relative_to(root)}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
