xref: /llvm-project/clang/tools/clang-format/git-clang-format (revision e25c556abeb9ae5f82da42cd26b9dae8462a7197)
1#!/usr/bin/env python3
2#
3# ===- git-clang-format - ClangFormat Git Integration -------*- python -*--=== #
4#
5# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6# See https://llvm.org/LICENSE.txt for license information.
7# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8#
9# ===----------------------------------------------------------------------=== #
10
11r"""
12clang-format git integration
13============================
14
15This file provides a clang-format integration for git. Put it somewhere in your
16path and ensure that it is executable. Then, "git clang-format" will invoke
17clang-format on the changes in current files or a specific commit.
18
19For further details, run:
20git clang-format -h
21
22Requires Python 2.7 or Python 3
23"""
24
25from __future__ import absolute_import, division, print_function
26import argparse
27import collections
28import contextlib
29import errno
30import os
31import re
32import subprocess
33import sys
34
35usage = (
36    "git clang-format [OPTIONS] [<commit>] [<commit>|--staged] [--] [<file>...]"
37)
38
39desc = """
40If zero or one commits are given, run clang-format on all lines that differ
41between the working directory and <commit>, which defaults to HEAD.  Changes are
42only applied to the working directory, or in the stage/index.
43
44Examples:
45  To format staged changes, i.e everything that's been `git add`ed:
46    git clang-format
47
48  To also format everything touched in the most recent commit:
49    git clang-format HEAD~1
50
51  If you're on a branch off main, to format everything touched on your branch:
52    git clang-format main
53
54If two commits are given (requires --diff), run clang-format on all lines in the
55second <commit> that differ from the first <commit>.
56
57The following git-config settings set the default of the corresponding option:
58  clangFormat.binary
59  clangFormat.commit
60  clangFormat.extensions
61  clangFormat.style
62"""
63
64# Name of the temporary index file in which save the output of clang-format.
65# This file is created within the .git directory.
66temp_index_basename = "clang-format-index"
67
68
69Range = collections.namedtuple("Range", "start, count")
70
71
72def main():
73    config = load_git_config()
74
75    # In order to keep '--' yet allow options after positionals, we need to
76    # check for '--' ourselves.  (Setting nargs='*' throws away the '--', while
77    # nargs=argparse.REMAINDER disallows options after positionals.)
78    argv = sys.argv[1:]
79    try:
80        idx = argv.index("--")
81    except ValueError:
82        dash_dash = []
83    else:
84        dash_dash = argv[idx:]
85        argv = argv[:idx]
86
87    default_extensions = ",".join(
88        [
89            # From clang/lib/Frontend/FrontendOptions.cpp, all lower case
90            "c",
91            "h",  # C
92            "m",  # ObjC
93            "mm",  # ObjC++
94            "cc",
95            "cp",
96            "cpp",
97            "c++",
98            "cxx",
99            "hh",
100            "hpp",
101            "hxx",
102            "inc",  # C++
103            "ccm",
104            "cppm",
105            "cxxm",
106            "c++m",  # C++ Modules
107            "cu",
108            "cuh",  # CUDA
109            # Other languages that clang-format supports
110            "proto",
111            "protodevel",  # Protocol Buffers
112            "java",  # Java
113            "js",
114            "mjs",
115            "cjs",  # JavaScript
116            "ts",  # TypeScript
117            "cs",  # C Sharp
118            "json",  # Json
119            "sv",
120            "svh",
121            "v",
122            "vh",  # Verilog
123            "td",  # TableGen
124            "txtpb",
125            "textpb",
126            "pb.txt",
127            "textproto",
128            "asciipb",  # TextProto
129        ]
130    )
131
132    p = argparse.ArgumentParser(
133        usage=usage,
134        formatter_class=argparse.RawDescriptionHelpFormatter,
135        description=desc,
136    )
137    p.add_argument(
138        "--binary",
139        default=config.get("clangformat.binary", "clang-format"),
140        help="path to clang-format",
141    ),
142    p.add_argument(
143        "--commit",
144        default=config.get("clangformat.commit", "HEAD"),
145        help="default commit to use if none is specified",
146    ),
147    p.add_argument(
148        "--diff",
149        action="store_true",
150        help="print a diff instead of applying the changes",
151    )
152    p.add_argument(
153        "--diffstat",
154        action="store_true",
155        help="print a diffstat instead of applying the changes",
156    )
157    p.add_argument(
158        "--extensions",
159        default=config.get("clangformat.extensions", default_extensions),
160        help=(
161            "comma-separated list of file extensions to format, "
162            "excluding the period and case-insensitive"
163        ),
164    ),
165    p.add_argument(
166        "-f",
167        "--force",
168        action="store_true",
169        help="allow changes to unstaged files",
170    )
171    p.add_argument(
172        "-p", "--patch", action="store_true", help="select hunks interactively"
173    )
174    p.add_argument(
175        "-q",
176        "--quiet",
177        action="count",
178        default=0,
179        help="print less information",
180    )
181    p.add_argument(
182        "--staged",
183        "--cached",
184        action="store_true",
185        help="format lines in the stage instead of the working dir",
186    )
187    p.add_argument(
188        "--style",
189        default=config.get("clangformat.style", None),
190        help="passed to clang-format",
191    ),
192    p.add_argument(
193        "-v",
194        "--verbose",
195        action="count",
196        default=0,
197        help="print extra information",
198    )
199    p.add_argument(
200        "--diff_from_common_commit",
201        action="store_true",
202        help=(
203            "diff from the last common commit for commits in "
204            "separate branches rather than the exact point of the "
205            "commits"
206        ),
207    )
208    # We gather all the remaining positional arguments into 'args' since we need
209    # to use some heuristics to determine whether or not <commit> was present.
210    # However, to print pretty messages, we make use of metavar and help.
211    p.add_argument(
212        "args",
213        nargs="*",
214        metavar="<commit>",
215        help="revision from which to compute the diff",
216    )
217    p.add_argument(
218        "ignored",
219        nargs="*",
220        metavar="<file>...",
221        help="if specified, only consider differences in these files",
222    )
223    opts = p.parse_args(argv)
224
225    opts.verbose -= opts.quiet
226    del opts.quiet
227
228    commits, files = interpret_args(opts.args, dash_dash, opts.commit)
229    if len(commits) > 2:
230        die("at most two commits allowed; %d given" % len(commits))
231    if len(commits) == 2:
232        if opts.staged:
233            die("--staged is not allowed when two commits are given")
234        if not opts.diff:
235            die("--diff is required when two commits are given")
236    elif opts.diff_from_common_commit:
237        die(
238            "--diff_from_common_commit is only allowed when two commits are "
239            "given"
240        )
241
242    if os.path.dirname(opts.binary):
243        opts.binary = os.path.abspath(opts.binary)
244
245    changed_lines = compute_diff_and_extract_lines(
246        commits, files, opts.staged, opts.diff_from_common_commit
247    )
248    if opts.verbose >= 1:
249        ignored_files = set(changed_lines)
250    filter_by_extension(changed_lines, opts.extensions.lower().split(","))
251    # The computed diff outputs absolute paths, so we must cd before accessing
252    # those files.
253    cd_to_toplevel()
254    filter_symlinks(changed_lines)
255    filter_ignored_files(changed_lines, binary=opts.binary)
256    if opts.verbose >= 1:
257        ignored_files.difference_update(changed_lines)
258        if ignored_files:
259            print(
260                "Ignoring the following files (wrong extension, symlink, or "
261                "ignored by clang-format):"
262            )
263            for filename in ignored_files:
264                print("    %s" % filename)
265        if changed_lines:
266            print("Running clang-format on the following files:")
267            for filename in changed_lines:
268                print("    %s" % filename)
269
270    if not changed_lines:
271        if opts.verbose >= 0:
272            print("no modified files to format")
273        return 0
274
275    if len(commits) > 1:
276        old_tree = commits[1]
277        revision = old_tree
278    elif opts.staged:
279        old_tree = create_tree_from_index(changed_lines)
280        revision = ""
281    else:
282        old_tree = create_tree_from_workdir(changed_lines)
283        revision = None
284    new_tree = run_clang_format_and_save_to_tree(
285        changed_lines, revision, binary=opts.binary, style=opts.style
286    )
287    if opts.verbose >= 1:
288        print("old tree: %s" % old_tree)
289        print("new tree: %s" % new_tree)
290
291    if old_tree == new_tree:
292        if opts.verbose >= 0:
293            print("clang-format did not modify any files")
294        return 0
295
296    if opts.diff:
297        return print_diff(old_tree, new_tree)
298    if opts.diffstat:
299        return print_diffstat(old_tree, new_tree)
300
301    changed_files = apply_changes(
302        old_tree, new_tree, force=opts.force, patch_mode=opts.patch
303    )
304    if (opts.verbose >= 0 and not opts.patch) or opts.verbose >= 1:
305        print("changed files:")
306        for filename in changed_files:
307            print("    %s" % filename)
308
309    return 1
310
311
312def load_git_config(non_string_options=None):
313    """Return the git configuration as a dictionary.
314
315    All options are assumed to be strings unless in `non_string_options`, in
316    which is a dictionary mapping option name (in lower case) to either "--bool"
317    or "--int"."""
318    if non_string_options is None:
319        non_string_options = {}
320    out = {}
321    for entry in run("git", "config", "--list", "--null").split("\0"):
322        if entry:
323            if "\n" in entry:
324                name, value = entry.split("\n", 1)
325            else:
326                # A setting with no '=' ('\n' with --null) is implicitly 'true'
327                name = entry
328                value = "true"
329            if name in non_string_options:
330                value = run("git", "config", non_string_options[name], name)
331            out[name] = value
332    return out
333
334
335def interpret_args(args, dash_dash, default_commit):
336    """Interpret `args` as "[commits] [--] [files]" and return (commits, files).
337
338    It is assumed that "--" and everything that follows has been removed from
339    args and placed in `dash_dash`.
340
341    If "--" is present (i.e., `dash_dash` is non-empty), the arguments to its
342    left (if present) are taken as commits.  Otherwise, the arguments are
343    checked from left to right if they are commits or files.  If commits are not
344    given, a list with `default_commit` is used."""
345    if dash_dash:
346        if len(args) == 0:
347            commits = [default_commit]
348        else:
349            commits = args
350        for commit in commits:
351            object_type = get_object_type(commit)
352            if object_type not in ("commit", "tag"):
353                if object_type is None:
354                    die("'%s' is not a commit" % commit)
355                else:
356                    die(
357                        "'%s' is a %s, but a commit was expected"
358                        % (commit, object_type)
359                    )
360        files = dash_dash[1:]
361    elif args:
362        commits = []
363        while args:
364            if not disambiguate_revision(args[0]):
365                break
366            commits.append(args.pop(0))
367        if not commits:
368            commits = [default_commit]
369        files = args
370    else:
371        commits = [default_commit]
372        files = []
373    return commits, files
374
375
376def disambiguate_revision(value):
377    """Returns True if `value` is a revision, False if it is a file, or dies."""
378    # If `value` is ambiguous (neither a commit nor a file), the following
379    # command will die with an appropriate error message.
380    run("git", "rev-parse", value, verbose=False)
381    object_type = get_object_type(value)
382    if object_type is None:
383        return False
384    if object_type in ("commit", "tag"):
385        return True
386    die(
387        "`%s` is a %s, but a commit or filename was expected"
388        % (value, object_type)
389    )
390
391
392def get_object_type(value):
393    """Returns a string description of an object's type, or None if it is not
394    a valid git object."""
395    cmd = ["git", "cat-file", "-t", value]
396    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
397    stdout, stderr = p.communicate()
398    if p.returncode != 0:
399        return None
400    return convert_string(stdout.strip())
401
402
403def compute_diff_and_extract_lines(commits, files, staged, diff_common_commit):
404    """Calls compute_diff() followed by extract_lines()."""
405    diff_process = compute_diff(commits, files, staged, diff_common_commit)
406    changed_lines = extract_lines(diff_process.stdout)
407    diff_process.stdout.close()
408    diff_process.wait()
409    if diff_process.returncode != 0:
410        # Assume error was already printed to stderr.
411        sys.exit(2)
412    return changed_lines
413
414
415def compute_diff(commits, files, staged, diff_common_commit):
416    """Return a subprocess object producing the diff from `commits`.
417
418    The return value's `stdin` file object will produce a patch with the
419    differences between the working directory (or stage if --staged is used) and
420    the first commit if a single one was specified, or the difference between
421    both specified commits, filtered on `files` (if non-empty).
422    Zero context lines are used in the patch."""
423    git_tool = "diff-index"
424    extra_args = []
425    if len(commits) == 2:
426        git_tool = "diff-tree"
427        if diff_common_commit:
428            commits = [f"{commits[0]}...{commits[1]}"]
429    elif staged:
430        extra_args += ["--cached"]
431
432    cmd = ["git", git_tool, "-p", "-U0"] + extra_args + commits + ["--"]
433    cmd.extend(files)
434    p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
435    p.stdin.close()
436    return p
437
438
439def extract_lines(patch_file):
440    """Extract the changed lines in `patch_file`.
441
442    The return value is a dictionary mapping filename to a list of (start_line,
443    line_count) pairs.
444
445    The input must have been produced with ``-U0``, meaning unidiff format with
446    zero lines of context.  The return value is a dict mapping filename to a
447    list of line `Range`s."""
448    matches = {}
449    for line in patch_file:
450        line = convert_string(line)
451        match = re.search(r"^\+\+\+\ [^/]+/(.*)", line)
452        if match:
453            filename = match.group(1).rstrip("\r\n\t")
454        match = re.search(r"^@@ -[0-9,]+ \+(\d+)(,(\d+))?", line)
455        if match:
456            start_line = int(match.group(1))
457            line_count = 1
458            if match.group(3):
459                line_count = int(match.group(3))
460            if line_count == 0:
461                line_count = 1
462            if start_line == 0:
463                continue
464            matches.setdefault(filename, []).append(
465                Range(start_line, line_count)
466            )
467    return matches
468
469
470def filter_by_extension(dictionary, allowed_extensions):
471    """Delete every key in `dictionary` that doesn't have an allowed extension.
472
473    `allowed_extensions` must be a collection of lowercase file extensions,
474    excluding the period."""
475    allowed_extensions = frozenset(allowed_extensions)
476    for filename in list(dictionary.keys()):
477        base_ext = filename.rsplit(".", 1)
478        if len(base_ext) == 1 and "" in allowed_extensions:
479            continue
480        if len(base_ext) == 1 or base_ext[1].lower() not in allowed_extensions:
481            del dictionary[filename]
482
483
484def filter_symlinks(dictionary):
485    """Delete every key in `dictionary` that is a symlink."""
486    for filename in list(dictionary.keys()):
487        if os.path.islink(filename):
488            del dictionary[filename]
489
490
491def filter_ignored_files(dictionary, binary):
492    """Delete every key in `dictionary` that is ignored by clang-format."""
493    ignored_files = run(binary, "-list-ignored", *dictionary.keys())
494    if not ignored_files:
495        return
496    ignored_files = ignored_files.split("\n")
497    for filename in ignored_files:
498        del dictionary[filename]
499
500
501def cd_to_toplevel():
502    """Change to the top level of the git repository."""
503    toplevel = run("git", "rev-parse", "--show-toplevel")
504    os.chdir(toplevel)
505
506
507def create_tree_from_workdir(filenames):
508    """Create a new git tree with the given files from the working directory.
509
510    Returns the object ID (SHA-1) of the created tree."""
511    return create_tree(filenames, "--stdin")
512
513
514def create_tree_from_index(filenames):
515    # Copy the environment, because the files have to be read from the original
516    # index.
517    env = os.environ.copy()
518
519    def index_contents_generator():
520        for filename in filenames:
521            git_ls_files_cmd = [
522                "git",
523                "ls-files",
524                "--stage",
525                "-z",
526                "--",
527                filename,
528            ]
529            git_ls_files = subprocess.Popen(
530                git_ls_files_cmd,
531                env=env,
532                stdin=subprocess.PIPE,
533                stdout=subprocess.PIPE,
534            )
535            stdout = git_ls_files.communicate()[0]
536            yield convert_string(stdout.split(b"\0")[0])
537
538    return create_tree(index_contents_generator(), "--index-info")
539
540
541def run_clang_format_and_save_to_tree(
542    changed_lines, revision=None, binary="clang-format", style=None
543):
544    """Run clang-format on each file and save the result to a git tree.
545
546    Returns the object ID (SHA-1) of the created tree."""
547    # Copy the environment when formatting the files in the index, because the
548    # files have to be read from the original index.
549    env = os.environ.copy() if revision == "" else None
550
551    def iteritems(container):
552        try:
553            return container.iteritems()  # Python 2
554        except AttributeError:
555            return container.items()  # Python 3
556
557    def index_info_generator():
558        for filename, line_ranges in iteritems(changed_lines):
559            if revision is not None:
560                if len(revision) > 0:
561                    git_metadata_cmd = [
562                        "git",
563                        "ls-tree",
564                        "%s:%s" % (revision, os.path.dirname(filename)),
565                        os.path.basename(filename),
566                    ]
567                else:
568                    git_metadata_cmd = [
569                        "git",
570                        "ls-files",
571                        "--stage",
572                        "--",
573                        filename,
574                    ]
575                git_metadata = subprocess.Popen(
576                    git_metadata_cmd,
577                    env=env,
578                    stdin=subprocess.PIPE,
579                    stdout=subprocess.PIPE,
580                )
581                stdout = git_metadata.communicate()[0]
582                mode = oct(int(stdout.split()[0], 8))
583            else:
584                mode = oct(os.stat(filename).st_mode)
585            # Adjust python3 octal format so that it matches what git expects
586            if mode.startswith("0o"):
587                mode = "0" + mode[2:]
588            blob_id = clang_format_to_blob(
589                filename,
590                line_ranges,
591                revision=revision,
592                binary=binary,
593                style=style,
594                env=env,
595            )
596            yield "%s %s\t%s" % (mode, blob_id, filename)
597
598    return create_tree(index_info_generator(), "--index-info")
599
600
601def create_tree(input_lines, mode):
602    """Create a tree object from the given input.
603
604    If mode is '--stdin', it must be a list of filenames.  If mode is
605    '--index-info' is must be a list of values suitable for "git update-index
606    --index-info", such as "<mode> <SP> <sha1> <TAB> <filename>".  Any other
607    mode is invalid."""
608    assert mode in ("--stdin", "--index-info")
609    cmd = ["git", "update-index", "--add", "-z", mode]
610    with temporary_index_file():
611        p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
612        for line in input_lines:
613            p.stdin.write(to_bytes("%s\0" % line))
614        p.stdin.close()
615        if p.wait() != 0:
616            die("`%s` failed" % " ".join(cmd))
617        tree_id = run("git", "write-tree")
618        return tree_id
619
620
621def clang_format_to_blob(
622    filename,
623    line_ranges,
624    revision=None,
625    binary="clang-format",
626    style=None,
627    env=None,
628):
629    """Run clang-format on the given file and save the result to a git blob.
630
631    Runs on the file in `revision` if not None, or on the file in the working
632    directory if `revision` is None. Revision can be set to an empty string to
633    run clang-format on the file in the index.
634
635    Returns the object ID (SHA-1) of the created blob."""
636    clang_format_cmd = [binary]
637    if style:
638        clang_format_cmd.extend(["--style=" + style])
639    clang_format_cmd.extend(
640        [
641            "--lines=%s:%s" % (start_line, start_line + line_count - 1)
642            for start_line, line_count in line_ranges
643        ]
644    )
645    if revision is not None:
646        clang_format_cmd.extend(["--assume-filename=" + filename])
647        git_show_cmd = [
648            "git",
649            "cat-file",
650            "blob",
651            "%s:%s" % (revision, filename),
652        ]
653        git_show = subprocess.Popen(
654            git_show_cmd, env=env, stdin=subprocess.PIPE, stdout=subprocess.PIPE
655        )
656        git_show.stdin.close()
657        clang_format_stdin = git_show.stdout
658    else:
659        clang_format_cmd.extend([filename])
660        git_show = None
661        clang_format_stdin = subprocess.PIPE
662    try:
663        clang_format = subprocess.Popen(
664            clang_format_cmd, stdin=clang_format_stdin, stdout=subprocess.PIPE
665        )
666        if clang_format_stdin == subprocess.PIPE:
667            clang_format_stdin = clang_format.stdin
668    except OSError as e:
669        if e.errno == errno.ENOENT:
670            die('cannot find executable "%s"' % binary)
671        else:
672            raise
673    clang_format_stdin.close()
674    hash_object_cmd = [
675        "git",
676        "hash-object",
677        "-w",
678        "--path=" + filename,
679        "--stdin",
680    ]
681    hash_object = subprocess.Popen(
682        hash_object_cmd, stdin=clang_format.stdout, stdout=subprocess.PIPE
683    )
684    clang_format.stdout.close()
685    stdout = hash_object.communicate()[0]
686    if hash_object.returncode != 0:
687        die("`%s` failed" % " ".join(hash_object_cmd))
688    if clang_format.wait() != 0:
689        die("`%s` failed" % " ".join(clang_format_cmd))
690    if git_show and git_show.wait() != 0:
691        die("`%s` failed" % " ".join(git_show_cmd))
692    return convert_string(stdout).rstrip("\r\n")
693
694
695@contextlib.contextmanager
696def temporary_index_file(tree=None):
697    """Context manager for setting GIT_INDEX_FILE to a temporary file and
698    deleting the file afterward."""
699    index_path = create_temporary_index(tree)
700    old_index_path = os.environ.get("GIT_INDEX_FILE")
701    os.environ["GIT_INDEX_FILE"] = index_path
702    try:
703        yield
704    finally:
705        if old_index_path is None:
706            del os.environ["GIT_INDEX_FILE"]
707        else:
708            os.environ["GIT_INDEX_FILE"] = old_index_path
709        os.remove(index_path)
710
711
712def create_temporary_index(tree=None):
713    """Create a temporary index file and return the created file's path.
714
715    If `tree` is not None, use that as the tree to read in.  Otherwise, an
716    empty index is created."""
717    gitdir = run("git", "rev-parse", "--git-dir")
718    path = os.path.join(gitdir, temp_index_basename)
719    if tree is None:
720        tree = "--empty"
721    run("git", "read-tree", "--index-output=" + path, tree)
722    return path
723
724
725def print_diff(old_tree, new_tree):
726    """Print the diff between the two trees to stdout."""
727    # We use the porcelain 'diff' and not plumbing 'diff-tree' because the
728    # output is expected to be viewed by the user, and only the former does nice
729    # things like color and pagination.
730    #
731    # We also only print modified files since `new_tree` only contains the files
732    # that were modified, so unmodified files would show as deleted without the
733    # filter.
734    return subprocess.run(
735        ["git", "diff", "--diff-filter=M", "--exit-code", old_tree, new_tree]
736    ).returncode
737
738
739def print_diffstat(old_tree, new_tree):
740    """Print the diffstat between the two trees to stdout."""
741    # We use the porcelain 'diff' and not plumbing 'diff-tree' because the
742    # output is expected to be viewed by the user, and only the former does nice
743    # things like color and pagination.
744    #
745    # We also only print modified files since `new_tree` only contains the files
746    # that were modified, so unmodified files would show as deleted without the
747    # filter.
748    return subprocess.run(
749        [
750            "git",
751            "diff",
752            "--diff-filter=M",
753            "--exit-code",
754            "--stat",
755            old_tree,
756            new_tree,
757        ]
758    ).returncode
759
760
761def apply_changes(old_tree, new_tree, force=False, patch_mode=False):
762    """Apply the changes in `new_tree` to the working directory.
763
764    Bails if there are local changes in those files and not `force`.  If
765    `patch_mode`, runs `git checkout --patch` to select hunks interactively."""
766    changed_files = (
767        run(
768            "git",
769            "diff-tree",
770            "--diff-filter=M",
771            "-r",
772            "-z",
773            "--name-only",
774            old_tree,
775            new_tree,
776        )
777        .rstrip("\0")
778        .split("\0")
779    )
780    if not force:
781        unstaged_files = run(
782            "git", "diff-files", "--name-status", *changed_files
783        )
784        if unstaged_files:
785            print(
786                "The following files would be modified but have unstaged "
787                "changes:",
788                file=sys.stderr,
789            )
790            print(unstaged_files, file=sys.stderr)
791            print("Please commit, stage, or stash them first.", file=sys.stderr)
792            sys.exit(2)
793    if patch_mode:
794        # In patch mode, we could just as well create an index from the new tree
795        # and checkout from that, but then the user will be presented with a
796        # message saying "Discard ... from worktree".  Instead, we use the old
797        # tree as the index and checkout from new_tree, which gives the slightly
798        # better message, "Apply ... to index and worktree".  This is not quite
799        # right, since it won't be applied to the user's index, but oh well.
800        with temporary_index_file(old_tree):
801            subprocess.run(["git", "checkout", "--patch", new_tree], check=True)
802        index_tree = old_tree
803    else:
804        with temporary_index_file(new_tree):
805            run("git", "checkout-index", "-f", "--", *changed_files)
806    return changed_files
807
808
809def run(*args, **kwargs):
810    stdin = kwargs.pop("stdin", "")
811    verbose = kwargs.pop("verbose", True)
812    strip = kwargs.pop("strip", True)
813    for name in kwargs:
814        raise TypeError("run() got an unexpected keyword argument '%s'" % name)
815    p = subprocess.Popen(
816        args,
817        stdout=subprocess.PIPE,
818        stderr=subprocess.PIPE,
819        stdin=subprocess.PIPE,
820    )
821    stdout, stderr = p.communicate(input=stdin)
822
823    stdout = convert_string(stdout)
824    stderr = convert_string(stderr)
825
826    if p.returncode == 0:
827        if stderr:
828            if verbose:
829                print(
830                    "`%s` printed to stderr:" % " ".join(args), file=sys.stderr
831                )
832            print(stderr.rstrip(), file=sys.stderr)
833        if strip:
834            stdout = stdout.rstrip("\r\n")
835        return stdout
836    if verbose:
837        print(
838            "`%s` returned %s" % (" ".join(args), p.returncode), file=sys.stderr
839        )
840    if stderr:
841        print(stderr.rstrip(), file=sys.stderr)
842    sys.exit(2)
843
844
845def die(message):
846    print("error:", message, file=sys.stderr)
847    sys.exit(2)
848
849
850def to_bytes(str_input):
851    # Encode to UTF-8 to get binary data.
852    if isinstance(str_input, bytes):
853        return str_input
854    return str_input.encode("utf-8")
855
856
857def to_string(bytes_input):
858    if isinstance(bytes_input, str):
859        return bytes_input
860    return bytes_input.encode("utf-8")
861
862
863def convert_string(bytes_input):
864    try:
865        return to_string(bytes_input.decode("utf-8"))
866    except AttributeError:  # 'str' object has no attribute 'decode'.
867        return str(bytes_input)
868    except UnicodeError:
869        return str(bytes_input)
870
871
872if __name__ == "__main__":
873    sys.exit(main())
874