#!/usr/bin/env python3
import argparse
import json
import re
import sys
from dataclasses import dataclass
from pathlib import Path


IDENTIFIER = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
ITEM_KINDS = {"fn": "fn", "struct": "struct", "enum": "enum", "trait": "trait"}
ITEM_MODIFIERS = {"async", "const", "unsafe", "extern"}
BLOCK_ITEM_KEYWORDS = {
    "fn",
    "struct",
    "enum",
    "trait",
    "impl",
    "mod",
    "const",
    "static",
    "type",
    "union",
    "macro_rules",
}


@dataclass(frozen=True)
class Token:
    value: str
    start: int
    end: int
    line: int


@dataclass(frozen=True)
class PublicDecl:
    kind: str
    name: str
    pub_index: int
    keyword_index: int
    name_index: int


@dataclass(frozen=True)
class UseExport:
    path: tuple[str, ...]
    alias: str | None = None
    wildcard: bool = False


def load_module_index(root: Path) -> dict:
    path = root / ".ai/generated/module-index.json"
    if not path.is_file():
        raise SystemExit("Missing .ai/generated/module-index.json; run .ai/tools/generate-module-index first.")
    return json.loads(path.read_text())


def raw_string_end(source: str, index: int) -> int | None:
    for prefix in ("br", "rb", "r"):
        if not source.startswith(prefix, index):
            continue
        cursor = index + len(prefix)
        hashes = 0
        while cursor < len(source) and source[cursor] == "#":
            hashes += 1
            cursor += 1
        if cursor >= len(source) or source[cursor] != '"':
            continue
        delimiter = '"' + ("#" * hashes)
        end = source.find(delimiter, cursor + 1)
        return len(source) if end == -1 else end + len(delimiter)
    return None


def quoted_literal_end(source: str, index: int) -> int:
    quote = source[index]
    cursor = index + 1
    escaped = False
    while cursor < len(source):
        char = source[cursor]
        if escaped:
            escaped = False
        elif char == "\\":
            escaped = True
        elif char == quote:
            return cursor + 1
        cursor += 1
    return len(source)


def block_comment_end(source: str, index: int) -> int:
    cursor = index + 2
    depth = 1
    while cursor < len(source) and depth:
        if source.startswith("/*", cursor):
            depth += 1
            cursor += 2
        elif source.startswith("*/", cursor):
            depth -= 1
            cursor += 2
        else:
            cursor += 1
    return cursor


def scan_tokens(source: str) -> list[Token]:
    tokens: list[Token] = []
    cursor = 0
    line = 1

    while cursor < len(source):
        char = source[cursor]

        if char.isspace():
            if char == "\n":
                line += 1
            cursor += 1
            continue

        if source.startswith("//", cursor):
            end = source.find("\n", cursor + 2)
            cursor = len(source) if end == -1 else end
            continue

        if source.startswith("/*", cursor):
            end = block_comment_end(source, cursor)
            line += source[cursor:end].count("\n")
            cursor = end
            continue

        raw_end = raw_string_end(source, cursor)
        if raw_end is not None:
            line += source[cursor:raw_end].count("\n")
            cursor = raw_end
            continue

        if char == '"' or (char == "b" and cursor + 1 < len(source) and source[cursor + 1] == '"'):
            start = cursor + 1 if char == "b" else cursor
            end = quoted_literal_end(source, start)
            line += source[cursor:end].count("\n")
            cursor = end
            continue

        if char == "'" and cursor + 1 < len(source) and not source[cursor + 1].isalpha():
            end = quoted_literal_end(source, cursor)
            line += source[cursor:end].count("\n")
            cursor = end
            continue

        match = IDENTIFIER.match(source, cursor)
        if match:
            tokens.append(Token(match.group(0), cursor, match.end(), line))
            cursor = match.end()
            continue

        tokens.append(Token(char, cursor, cursor + 1, line))
        cursor += 1

    return tokens


def remove_comments_preserve_literals(text: str) -> str:
    output: list[str] = []
    cursor = 0
    while cursor < len(text):
        if text.startswith("//", cursor):
            end = text.find("\n", cursor + 2)
            if end == -1:
                break
            output.append("\n")
            cursor = end + 1
            continue

        if text.startswith("/*", cursor):
            end = block_comment_end(text, cursor)
            output.append("\n" * text[cursor:end].count("\n"))
            cursor = end
            continue

        raw_end = raw_string_end(text, cursor)
        if raw_end is not None:
            output.append(text[cursor:raw_end])
            cursor = raw_end
            continue

        char = text[cursor]
        if char == '"' or (char == "b" and cursor + 1 < len(text) and text[cursor + 1] == '"'):
            start = cursor + 1 if char == "b" else cursor
            end = quoted_literal_end(text, start)
            output.append(text[cursor:end])
            cursor = end
            continue

        if char == "'" and cursor + 1 < len(text) and not text[cursor + 1].isalpha():
            end = quoted_literal_end(text, cursor)
            output.append(text[cursor:end])
            cursor = end
            continue

        output.append(char)
        cursor += 1

    return "".join(output)


def normalize_signature(signature: str) -> str:
    return " ".join(remove_comments_preserve_literals(signature).split())


def signature_from_source(source: str, start: int) -> str:
    cursor = start
    depth_angle = 0
    depth_paren = 0
    depth_bracket = 0

    while cursor < len(source):
        if source.startswith("//", cursor):
            end = source.find("\n", cursor + 2)
            cursor = len(source) if end == -1 else end
            continue

        if source.startswith("/*", cursor):
            cursor = block_comment_end(source, cursor)
            continue

        raw_end = raw_string_end(source, cursor)
        if raw_end is not None:
            cursor = raw_end
            continue

        char = source[cursor]
        if char == '"' or (char == "b" and cursor + 1 < len(source) and source[cursor + 1] == '"'):
            start_quote = cursor + 1 if char == "b" else cursor
            cursor = quoted_literal_end(source, start_quote)
            continue

        if char == "'" and cursor + 1 < len(source) and not source[cursor + 1].isalpha():
            cursor = quoted_literal_end(source, cursor)
            continue

        if char == "<":
            depth_angle += 1
        elif char == ">" and depth_angle:
            depth_angle -= 1
        elif char == "(":
            depth_paren += 1
        elif char == ")" and depth_paren:
            depth_paren -= 1
        elif char == "[":
            depth_bracket += 1
        elif char == "]" and depth_bracket:
            depth_bracket -= 1

        at_top_level = depth_angle == 0 and depth_paren == 0 and depth_bracket == 0
        if at_top_level and char in "{;":
            return normalize_signature(source[start:cursor])

        cursor += 1

    return normalize_signature(source[start:])


def child_module_candidates(current_file: Path, module_name: str) -> list[Path]:
    if current_file.name in {"lib.rs", "main.rs", "mod.rs"}:
        base = current_file.parent
    else:
        base = current_file.with_suffix("")
    return [
        base / f"{module_name}.rs",
        base / module_name / "mod.rs",
    ]


def is_bare_pub(tokens: list[Token], index: int) -> bool:
    return tokens[index].value == "pub" and not (
        index + 1 < len(tokens) and tokens[index + 1].value == "("
    )


def public_decl_at(tokens: list[Token], index: int) -> PublicDecl | None:
    if not is_bare_pub(tokens, index):
        return None

    cursor = index + 1
    while cursor < len(tokens) and tokens[cursor].value in ITEM_MODIFIERS:
        cursor += 1

    if cursor >= len(tokens):
        return None

    kind = tokens[cursor].value
    if kind == "mod":
        if cursor + 1 < len(tokens) and IDENTIFIER.fullmatch(tokens[cursor + 1].value):
            return PublicDecl("mod", tokens[cursor + 1].value, index, cursor, cursor + 1)
        return None

    if kind in ITEM_KINDS and cursor + 1 < len(tokens):
        name = tokens[cursor + 1].value
        if IDENTIFIER.fullmatch(name):
            return PublicDecl(ITEM_KINDS[kind], name, index, cursor, cursor + 1)

    return None


def skip_brace_block(tokens: list[Token], index: int) -> int:
    depth = 0
    cursor = index
    while cursor < len(tokens):
        if tokens[cursor].value == "{":
            depth += 1
        elif tokens[cursor].value == "}":
            depth -= 1
            if depth == 0:
                return cursor + 1
        cursor += 1
    return cursor


def skip_item(tokens: list[Token], index: int) -> int:
    cursor = index
    while cursor < len(tokens):
        if tokens[cursor].value == ";":
            return cursor + 1
        if tokens[cursor].value == "{":
            return skip_brace_block(tokens, cursor)
        cursor += 1
    return cursor


def file_module_parts(source_file: str, module: dict) -> list[str]:
    module_path = module.get("path", ".")
    prefix = "src/" if module_path == "." else f"{module_path}/src/"
    if source_file.endswith("/src/lib.rs") or source_file.endswith("/src/main.rs"):
        return []
    relative = source_file.removeprefix(prefix).removesuffix(".rs")
    return [part for part in relative.split("/") if part and part != "mod"]


def public_path(parts: list[str], item_name: str) -> str:
    return "::".join([*parts, item_name])


def join_path(parts: list[str] | tuple[str, ...]) -> str:
    return "::".join(parts)


def is_path_separator(tokens: list[Token], index: int) -> bool:
    return (
        index + 1 < len(tokens)
        and tokens[index].value == ":"
        and tokens[index + 1].value == ":"
    )


def public_file_module_names(source: str) -> list[str]:
    tokens = scan_tokens(source)
    names: list[str] = []

    for index, token in enumerate(tokens):
        if token.value != "pub":
            continue
        decl = public_decl_at(tokens, index)
        if decl is None or decl.kind != "mod":
            continue
        if decl.name_index + 1 < len(tokens) and tokens[decl.name_index + 1].value == ";":
            names.append(decl.name)

    return names


def parse_pub_use_exports(source: str) -> list[UseExport]:
    tokens = scan_tokens(source)
    exports: list[UseExport] = []

    def parse_group(index: int, prefix: list[str]) -> int:
        cursor = index + 1
        while cursor < len(tokens):
            if tokens[cursor].value == "}":
                return cursor + 1
            if tokens[cursor].value == ",":
                cursor += 1
                continue
            cursor = parse_tree(cursor, prefix)
            if cursor < len(tokens) and tokens[cursor].value == ",":
                cursor += 1
        return cursor

    def parse_tree(index: int, prefix: list[str]) -> int:
        cursor = index
        parts = list(prefix)

        while cursor < len(tokens):
            token = tokens[cursor]

            if token.value == "{":
                return parse_group(cursor, parts)

            if token.value == "*":
                exports.append(UseExport(tuple(parts), wildcard=True))
                return cursor + 1

            if not IDENTIFIER.fullmatch(token.value):
                return cursor + 1

            name = token.value
            if name == "self":
                cursor += 1
                if is_path_separator(tokens, cursor):
                    cursor += 2
                    continue
                return cursor

            parts.append(name)
            cursor += 1

            if is_path_separator(tokens, cursor):
                cursor += 2
                if cursor < len(tokens) and tokens[cursor].value in {"{", "*"}:
                    continue
                continue

            alias = None
            if cursor + 1 < len(tokens) and tokens[cursor].value == "as":
                if IDENTIFIER.fullmatch(tokens[cursor + 1].value):
                    alias = tokens[cursor + 1].value
                    cursor += 2

            exports.append(UseExport(tuple(parts), alias=alias))
            return cursor

        return cursor

    index = 0
    while index < len(tokens):
        if (
            tokens[index].value == "pub"
            and is_bare_pub(tokens, index)
            and index + 1 < len(tokens)
            and tokens[index + 1].value == "use"
        ):
            cursor = parse_tree(index + 2, [])
            while cursor < len(tokens) and tokens[cursor].value != ";":
                cursor += 1
            index = cursor + 1
            continue
        index += 1

    return exports


def resolve_use_path(parts: tuple[str, ...], current_parts: list[str], definitions: dict[str, dict]) -> tuple[str, ...] | None:
    if not parts:
        return None

    if parts[0] == "crate":
        return parts[1:]

    if parts[0] == "self":
        return tuple([*current_parts, *parts[1:]])

    if parts[0] == "super":
        cursor = 0
        base = list(current_parts)
        while cursor < len(parts) and parts[cursor] == "super":
            if base:
                base.pop()
            cursor += 1
        return tuple([*base, *parts[cursor:]])

    candidate = parts
    if join_path(candidate) in definitions or any(path.startswith(f"{join_path(candidate)}::") for path in definitions):
        return candidate

    relative = tuple([*current_parts, *parts])
    if join_path(relative) in definitions or any(path.startswith(f"{join_path(relative)}::") for path in definitions):
        return relative

    return None


def reexported_items_from_source(source: str, source_file: str, module: dict, definitions: dict[str, dict]) -> list[dict]:
    items: list[dict] = []
    current_parts = file_module_parts(source_file, module)

    for export in parse_pub_use_exports(source):
        target_parts = resolve_use_path(export.path, current_parts, definitions)
        if target_parts is None:
            continue

        target_path = join_path(target_parts)
        public_base = current_parts

        if export.wildcard:
            prefix = f"{target_path}::"
            for definition_path, definition in definitions.items():
                if not definition_path.startswith(prefix):
                    continue
                alias_path = definition_path.removeprefix(prefix)
                if "::" in alias_path:
                    continue
                items.append({**definition, "path": public_path(public_base, definition["name"])})
            continue

        definition = definitions.get(target_path)
        if definition is None:
            continue
        name = export.alias or definition["name"]
        items.append({**definition, "path": public_path(public_base, name), "name": name})

    return items


def reachable_public_source_files(root: Path, module: dict) -> list[str]:
    indexed_sources = {
        (root / source_file).resolve(): source_file
        for source_file in module.get("files", {}).get("source", [])
    }
    module_path = module.get("path", ".")
    module_root = root if module_path == "." else root / module_path
    crate_roots = [
        module_root / "src/lib.rs",
        module_root / "src/main.rs",
    ]

    reachable: list[str] = []
    visited: set[Path] = set()

    def visit(path: Path) -> None:
        resolved = path.resolve()
        if resolved in visited or resolved not in indexed_sources:
            return
        visited.add(resolved)
        reachable.append(indexed_sources[resolved])

        for module_name in public_file_module_names(path.read_text()):
            for candidate in child_module_candidates(path, module_name):
                if candidate.resolve() in indexed_sources:
                    visit(candidate)
                    break

    for crate_root in crate_roots:
        if crate_root.resolve() in indexed_sources:
            visit(crate_root)
            break

    return reachable


def extract_items_from_source(source: str, source_file: str, module: dict) -> list[dict]:
    tokens = scan_tokens(source)
    items: list[dict] = []
    base_parts = file_module_parts(source_file, module)

    def parse_scope(start_index: int, path_parts: list[str], stop_at_brace: bool) -> int:
        cursor = start_index
        while cursor < len(tokens):
            token = tokens[cursor]

            if stop_at_brace and token.value == "}":
                return cursor + 1

            if token.value == "pub":
                decl = public_decl_at(tokens, cursor)
                if decl is not None:
                    after_name = decl.name_index + 1
                    if decl.kind == "mod":
                        if after_name < len(tokens) and tokens[after_name].value == "{":
                            cursor = parse_scope(after_name + 1, [*path_parts, decl.name], True)
                            continue
                        cursor = skip_item(tokens, cursor)
                        continue

                    items.append(
                        {
                            "path": public_path(path_parts, decl.name),
                            "kind": decl.kind,
                            "name": decl.name,
                            "source": {
                                "path": source_file,
                                "line": tokens[decl.pub_index].line,
                            },
                            "signature": signature_from_source(source, tokens[decl.pub_index].start),
                        }
                    )
                    cursor = skip_item(tokens, cursor)
                    continue

            if token.value in BLOCK_ITEM_KEYWORDS:
                cursor = skip_item(tokens, cursor)
                continue

            if token.value == "{":
                cursor = skip_brace_block(tokens, cursor)
                continue

            cursor += 1

        return cursor

    parse_scope(0, base_parts, False)
    return items


def extract_public_items(root: Path, module: dict) -> list[dict]:
    definitions: dict[str, dict] = {}
    for source_file in module.get("files", {}).get("source", []):
        path = root / source_file
        if not path.is_file():
            raise SystemExit(f"Missing source file from module index: {source_file}")
        for item in extract_items_from_source(path.read_text(), source_file, module):
            definitions.setdefault(item["path"], item)

    items: list[dict] = []

    for source_file in reachable_public_source_files(root, module):
        path = root / source_file
        if not path.is_file():
            raise SystemExit(f"Missing source file from module index: {source_file}")
        source = path.read_text()
        items.extend(extract_items_from_source(source, source_file, module))
        items.extend(reexported_items_from_source(source, source_file, module, definitions))

    seen: set[tuple[str, str, str, int]] = set()
    unique_items: list[dict] = []
    for item in items:
        key = (item["path"], item["kind"], item["source"]["path"], item["source"]["line"])
        if key in seen:
            continue
        seen.add(key)
        unique_items.append(item)

    return unique_items


def build_surface(root: Path) -> dict:
    module_index = load_module_index(root)
    modules = []

    for layer in module_index.get("layers", []):
        for module in layer.get("modules", []):
            public_items = extract_public_items(root, module)
            modules.append(
                {
                    "name": module["name"],
                    "items": public_items,
                }
            )

    return {
        "schemaVersion": 1,
        "kind": "public-surface",
        "generatedBy": ".ai/tools/generate-public-surface",
        "visibility": "crate-external",
        "modules": modules,
    }


def main() -> int:
    parser = argparse.ArgumentParser(description="Generate .ai/generated/public-surface.json from Rust source files.")
    parser.add_argument("--check", action="store_true", help="Fail if the generated public surface differs from the current file.")
    parser.add_argument("--print", action="store_true", help="Print generated JSON instead of writing it.")
    parser.add_argument("--output", default=".ai/generated/public-surface.json", help="Output path relative to the project root.")
    args = parser.parse_args()

    root = Path(__file__).resolve().parents[2]
    generated = json.dumps(build_surface(root), indent=2, sort_keys=False) + "\n"
    output_path = root / args.output

    if args.check:
        if not output_path.is_file():
            sys.stderr.write(f"Missing generated public surface: {output_path.relative_to(root)}\n")
            return 1
        current = output_path.read_text()
        if current != generated:
            sys.stderr.write(f"Stale generated public surface: {output_path.relative_to(root)}\n")
            return 1
        print("generate-public-surface check passed")
        return 0

    if args.print:
        sys.stdout.write(generated)
        return 0

    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(generated)
    print(f"wrote {output_path.relative_to(root)}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
