#!/usr/bin/env python3
#
# Git command to transform staged files according to a command that accepts file
# content on stdin and produces output on stdout. This command is useful in
# combination with `git add -p` which allows you to stage specific changes in
# a file. This command runs a formatter on the file with staged changes while
# ignoring unstaged changes.
#
# Usage: git-format-staged [OPTION]... [FILE]...
# Example: git-format-staged --formatter 'prettier --stdin-filepath {}' '*.js'
#
# Tested with Python versions 3.8 - 3.15.
#
# Original author: Jesse Hallett <jesse@sitr.us>

from __future__ import annotations
import argparse
from collections.abc import Sequence
from fnmatch import fnmatch
from gettext import gettext as _
import re
import shlex
import signal
import subprocess
import sys
from typing import NoReturn, Protocol, cast


# The string 4.0.1 is replaced during the publish process.
VERSION = "4.0.1"
PROG = sys.argv[0]


def info(msg: str):
    print(msg, file=sys.stdout)


def info_stderr(msg: str):
    print(msg, file=sys.stderr)


def warn(msg: str):
    print("{}: warning: {}".format(PROG, msg), file=sys.stderr)


def fatal(msg: str):
    print("{}: error: {}".format(PROG, msg), file=sys.stderr)
    exit(1)


def format_staged_files(
    file_patterns: Sequence[str],
    formatter: str,
    update_working_tree: bool = True,
    write: bool = True,
    verbose: bool = False,
):
    common_opts = [
        "--cached",
        "--diff-filter=AM",  # select only file additions and modifications
        "--no-renames",
        "HEAD",
    ]

    try:
        staged_files = {
            path.decode("utf-8")
            for path in subprocess.check_output(
                ["git", "diff", "--name-only"] + common_opts
            ).splitlines()
        }

        output = subprocess.check_output(["git", "diff-index"] + common_opts)
        for line in output.splitlines():
            entry = DiffIndexEntry(line.decode("utf-8"))
            if entry.dst_mode == "120000":
                # Do not process symlinks
                continue
            if not (matches_some_path(file_patterns, entry.src_path)):
                continue
            if (
                entry.src_mode is None
                and (not entry.dst_hash or object_is_empty(entry.dst_hash))
                and entry.src_path not in staged_files
            ):
                # File is not staged, it's tracked only with `--intent-to-add` and won't get committed
                continue
            if format_file_in_index(
                formatter,
                entry,
                update_working_tree=update_working_tree,
                write=write,
                verbose=verbose,
            ):
                info("Reformatted {} with {}".format(entry.src_path, formatter))
    except Exception as err:
        fatal(str(err))


# Run formatter on file in the git index. Creates a new git object with the
# result, and replaces the content of the file in the index with that object.
# Returns hash of the new object if formatting produced any changes.
def format_file_in_index(
    formatter: str,
    diff_entry: "DiffIndexEntry",
    update_working_tree: bool = True,
    write: bool = True,
    verbose: bool = False,
):
    orig_hash = diff_entry.dst_hash
    if not orig_hash:
        return None

    new_hash = format_object(formatter, orig_hash, diff_entry.src_path, verbose=verbose)

    # If the new hash is the same then the formatter did not make any changes.
    if not write or new_hash == orig_hash:
        return None

    # If the content of the new object is empty then the formatter did not
    # produce any output. We want to abort instead of replacing the file with an
    # empty one.
    if object_is_empty(new_hash):
        return None

    replace_file_in_index(diff_entry, new_hash)

    if update_working_tree:
        try:
            patch_working_file(diff_entry.src_path, orig_hash, new_hash)
        except Exception as err:
            # Errors patching working tree files are not fatal
            warn(str(err))

    return new_hash


# Match {}, and to avoid breaking quoting from shlex also match and remove surrounding quotes. This
# is important for backward compatibility because previous version of git-format-staged did not use
# shlex quoting, and required manual quoting.
file_path_placeholder = re.compile(r"(['\"]?)\{\}(\1)")


# Run formatter on a git blob identified by its hash. Writes output to a new git
# blob, and returns the hash of the new blob.
def format_object(
    formatter: str, object_hash: str, file_path: str, verbose: bool = False
) -> str:
    get_content = subprocess.Popen(
        ["git", "cat-file", "-p", object_hash], stdout=subprocess.PIPE
    )

    command = re.sub(file_path_placeholder, shlex.quote(file_path), formatter)
    if verbose:
        info_stderr(command)
    format_content = subprocess.Popen(
        command, shell=True, stdin=get_content.stdout, stdout=subprocess.PIPE
    )

    write_object = subprocess.Popen(
        ["git", "hash-object", "-w", "--stdin"],
        stdin=format_content.stdout,
        stdout=subprocess.PIPE,
    )

    # Close the parent process reference to stdout, leaving only references in the child processes.
    # This way if the downstream process terminates while format_content is still running,
    # format_content will be terminated with a SIGPIPE signal.
    format_content.stdout.close()  # pyright: ignore[reportOptionalMemberAccess]

    # On the other hand we don't close get_content.stdout() so that we can check if there is unread
    # data left after the formatter has finished.

    # Read output from the last process in the pipe, and block until that process has completed.
    # It's important to block on the last process completing before waiting for the other sub
    # processes to finish.
    new_hash, _err = write_object.communicate()

    # The first two pipe processes should have completed by now. Block to verify that we have exit
    # statuses from them.
    try:
        # Use communicate() to check for any unread output from get_content.
        get_content_unread_stdout, _ = get_content.communicate(timeout=5)
        get_content.stdout.close()  # pyright: ignore[reportOptionalMemberAccess]
        format_content_exit_status = format_content.wait(timeout=5)
    except subprocess.TimeoutExpired as exception:
        raise Exception(
            "the formatter command did not terminate as expected"
        ) from exception

    # An error from format_content is most relevant to the user, so prioritize displaying this error
    # message in case multiple things went wrong.
    if format_content_exit_status != 0:
        raise Exception(
            f"formatter exited with non-zero status ({format_content_exit_status}) while processing {file_path}"
        )

    # If the formatter exited before reading all input then get_content might have been terminated
    # by a SIGPIPE signal. This is probably incorrect behavior from the formatter command, but is
    # allowed by design (for now). So we emit a warning, but will not fail.

    # If there was unread output from get_content that's an indication that the formatter command is
    # probably not configured correctly. But program design allows the formatter command do do what
    # it wants. So this is a warning, not a hard error.
    #
    # A SIGPIPE termination to get_content would also indicate unread output. This should not
    # happen, but it doesn't hurt to check.
    if len(get_content_unread_stdout) > 0 or get_content.returncode == -signal.SIGPIPE:
        warn(
            f"the formatter command exited before reading all content from {file_path}"
        )

    if get_content.returncode != 0 and get_content.returncode != -signal.SIGPIPE:
        if verbose:
            info_stderr(
                f"non-zero exit status from `git cat-file -p {object_hash}`\n"
                + f"exit status: {get_content.returncode}\n"
                + f"file path: {file_path}\n"
            )
        raise ValueError(
            f"unable to read file content for {file_path} from object database."
        )

    if write_object.returncode != 0:
        if verbose:
            info_stderr(
                f"non-zero exit status from `git hash-object -w --stdin`\n"
                + f"exit status: {write_object.returncode}\n"
                + f"file path: {file_path}\n"
            )
        raise Exception("unable to write formatted content to object database")

    return new_hash.decode("utf-8").rstrip()


def object_is_empty(object_hash: str) -> bool:
    get_content = subprocess.Popen(
        ["git", "cat-file", "-p", object_hash], stdout=subprocess.PIPE
    )
    content, _err = get_content.communicate()

    if get_content.returncode != 0:
        raise Exception("unable to verify content of formatted object")

    return not content


def replace_file_in_index(diff_entry: "DiffIndexEntry", new_object_hash: str):
    _ = subprocess.check_call(
        [
            "git",
            "update-index",
            "--cacheinfo",
            "{},{},{}".format(
                diff_entry.dst_mode, new_object_hash, diff_entry.src_path
            ),
        ]
    )


def patch_working_file(path: str, orig_object_hash: str, new_object_hash: str):
    patch = subprocess.check_output(
        [
            "git",
            "diff",
            "--no-ext-diff",
            "--color=never",
            orig_object_hash,
            new_object_hash,
        ]
    )

    # Substitute object hashes in patch header with path to working tree file
    patch_b = patch.replace(orig_object_hash.encode(), path.encode()).replace(
        new_object_hash.encode(), path.encode()
    )

    apply_patch = subprocess.Popen(
        ["git", "apply", "-"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )

    _output, _err = apply_patch.communicate(input=patch_b)

    if apply_patch.returncode != 0:
        raise Exception(
            "could not apply formatting changes to working tree file {}".format(path)
        )


# Format: src_mode dst_mode src_hash dst_hash status/score? src_path dst_path?
diff_pat = re.compile(
    r"^:(?P<src_mode>\d+) (?P<dst_mode>\d+) (?P<src_hash>[a-f0-9]+) (?P<dst_hash>[a-f0-9]+) (?P<status>[A-Z])(?P<score>\d+)?\t(?P<src_path>[^\t]+)(?:\t(?P<dst_path>[^\t]+))?$"
)


class DiffIndexEntry:
    src_mode: str | None
    dst_mode: str | None
    src_hash: str | None
    dst_hash: str | None
    status: str
    score: int | None
    src_path: str
    dst_path: str | None

    def __init__(self, diff_index_output_line: str):
        "Parse a line of output from `git diff-index`"
        m = diff_pat.match(diff_index_output_line)
        if not m:
            raise ValueError(
                "Failed to parse diff-index line: " + diff_index_output_line
            )
        self.src_mode = unless_zeroed(m.group("src_mode"))
        self.dst_mode = unless_zeroed(m.group("dst_mode"))
        self.src_hash = unless_zeroed(m.group("src_hash"))
        self.dst_hash = unless_zeroed(m.group("dst_hash"))
        self.status = m.group("status")
        self.score = int(m.group("score")) if m.group("score") else None
        self.src_path = m.group("src_path")
        self.dst_path = m.group("dst_path")


zeroed_pat = re.compile(r"^0+$")


# Returns the argument unless the argument is a string of zeroes, in which case
# returns `None`
def unless_zeroed(s: str) -> str | None:
    return s if not zeroed_pat.match(s) else None


def matches_some_path(patterns: Sequence[str], target: str) -> bool:
    is_match = False
    for signed_pattern in patterns:
        (is_pattern_positive, pattern) = from_signed_pattern(signed_pattern)
        if fnmatch(target, pattern):
            is_match = is_pattern_positive
    return is_match


# Checks for a '!' as the first character of a pattern, returns the rest of the
# pattern in a tuple. The tuple takes the form (is_pattern_positive, pattern).
# For example:
#     from_signed_pattern('!pat') == (False, 'pat')
#     from_signed_pattern('pat') == (True, 'pat')
def from_signed_pattern(pattern: str) -> tuple[bool, str]:
    if pattern[0] == "!":
        return (False, pattern[1:])
    else:
        return (True, pattern)


class CustomArgumentParser(argparse.ArgumentParser):
    def error(  # pyright: ignore[reportImplicitOverride]
        self, message: str
    ) -> NoReturn:
        if message.startswith("unrecognized arguments:"):
            message += " Do you need to quote your formatter command?"
        super().error(message)


class Args(Protocol):
    formatter: str
    no_update_working_tree: bool
    no_write: bool
    verbose: bool
    files: Sequence[str]


if __name__ == "__main__":
    parser = CustomArgumentParser(
        description="Transform staged files using a formatting command that accepts content via stdin and produces a result via stdout.",
        epilog='Example: %(prog)s --formatter "prettier --stdin-filepath {}" "src/*.js" "test/*.js"',
    )
    _ = parser.add_argument(
        "--formatter",
        "-f",
        required=True,
        help='Shell command to format files, will run once per file. Occurrences of the placeholder `{}` will be replaced with a path to the file being formatted (with appropriate quoting). (Example: "prettier --stdin-filepath {}")',
    )
    _ = parser.add_argument(
        "--no-update-working-tree",
        action="store_true",
        help="By default formatting changes made to staged file content will also be applied to working tree files via a patch. This option disables that behavior, leaving working tree files untouched.",
    )
    _ = parser.add_argument(
        "--no-write",
        action="store_true",
        help='Prevents %(prog)s from modifying staged or working tree files. You can use this option to check staged changes with a linter instead of formatting. With this option stdout from the formatter command is ignored. Example: %(prog)s --no-write -f "eslint --stdin --stdin-filename {} >&2" "*.js"',
    )
    _ = parser.add_argument(
        "--version",
        action="version",
        version="%(prog)s version {}".format(VERSION),
        help="Display version of %(prog)s",
    )
    _ = parser.add_argument(
        "--verbose",
        help="Show the formatting commands that are running",
        action="store_true",
    )
    _ = parser.add_argument(
        "files",
        nargs="+",
        help='Patterns that specify files to format. The formatter will only transform staged files that are given here. Patterns may be literal file paths, or globs which will be tested against staged file paths using Python\'s fnmatch function. Patterns must be relative to the git repository root. For example "src/*.js" will match all files with a .js extension in src/ and its subdirectories. Patterns may be negated to exclude files using a "!" character. Patterns are evaluated left-to-right. (Example: "main.js" "src/*.js" "test/*.js" "!test/todo/*")',
    )
    args = cast(Args, parser.parse_args())  # pyright: ignore[reportInvalidCast]
    files = args.files
    format_staged_files(
        file_patterns=files,
        formatter=args.formatter,
        update_working_tree=not args.no_update_working_tree,
        write=not args.no_write,
        verbose=args.verbose,
    )
