xref: /llvm-project/llvm/utils/revert_checker.py (revision 0a53f43c0c7e33cde07b24169e8f45db7eba2fea)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# ===----------------------------------------------------------------------===##
4#
5# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6# See https://llvm.org/LICENSE.txt for license information.
7# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8#
9# ===----------------------------------------------------------------------===##
10"""Checks for reverts of commits across a given git commit.
11
12To clarify the meaning of 'across' with an example, if we had the following
13commit history (where `a -> b` notes that `b` is a direct child of `a`):
14
15123abc -> 223abc -> 323abc -> 423abc -> 523abc
16
17And where 423abc is a revert of 223abc, this revert is considered to be 'across'
18323abc. More generally, a revert A of a parent commit B is considered to be
19'across' a commit C if C is a parent of A and B is a parent of C.
20
21Please note that revert detection in general is really difficult, since merge
22conflicts/etc always introduce _some_ amount of fuzziness. This script just
23uses a bundle of heuristics, and is bound to ignore / incorrectly flag some
24reverts. The hope is that it'll easily catch the vast majority (>90%) of them,
25though.
26
27This is designed to be used in one of two ways: an import in Python, or run
28directly from a shell. If you want to import this, the `find_reverts`
29function is the thing to look at. If you'd rather use this from a shell, have a
30usage example:
31
32```
33./revert_checker.py c47f97169 origin/main origin/release/12.x
34```
35
36This checks for all reverts from the tip of origin/main to c47f97169, which are
37across the latter. It then does the same for origin/release/12.x to c47f97169.
38Duplicate reverts discovered when walking both roots (origin/main and
39origin/release/12.x) are deduplicated in output.
40"""
41
42import argparse
43import collections
44import logging
45import re
46import subprocess
47import sys
48from typing import Dict, Generator, Iterable, List, NamedTuple, Optional, Tuple
49
50assert sys.version_info >= (3, 6), "Only Python 3.6+ is supported."
51
52# People are creative with their reverts, and heuristics are a bit difficult.
53# At a glance, most reverts have "This reverts commit ${full_sha}". Many others
54# have `Reverts llvm/llvm-project#${PR_NUMBER}`.
55#
56# By their powers combined, we should be able to automatically catch something
57# like 80% of reverts with reasonable confidence. At some point, human
58# intervention will always be required (e.g., I saw
59# ```
60# This reverts commit ${commit_sha_1} and
61# also ${commit_sha_2_shorthand}
62# ```
63# during my sample)
64
65_CommitMessageReverts = NamedTuple(
66    "_CommitMessageReverts",
67    [
68        ("potential_shas", List[str]),
69        ("potential_pr_numbers", List[int]),
70    ],
71)
72
73
74def _try_parse_reverts_from_commit_message(
75    commit_message: str,
76) -> _CommitMessageReverts:
77    """Tries to parse revert SHAs and LLVM PR numbers form the commit message.
78
79    Returns:
80        A namedtuple containing:
81        - A list of potentially reverted SHAs
82        - A list of potentially reverted LLVM PR numbers
83    """
84    if not commit_message:
85        return _CommitMessageReverts([], [])
86
87    sha_reverts = re.findall(
88        r"This reverts commit ([a-f0-9]{40})\b",
89        commit_message,
90    )
91
92    first_line = commit_message.splitlines()[0]
93    initial_revert = re.match(r'Revert ([a-f0-9]{6,}) "', first_line)
94    if initial_revert:
95        sha_reverts.append(initial_revert.group(1))
96
97    pr_numbers = [
98        int(x)
99        for x in re.findall(
100            r"Reverts llvm/llvm-project#(\d+)",
101            commit_message,
102        )
103    ]
104
105    return _CommitMessageReverts(
106        potential_shas=sha_reverts,
107        potential_pr_numbers=pr_numbers,
108    )
109
110
111def _stream_stdout(
112    command: List[str], cwd: Optional[str] = None
113) -> Generator[str, None, None]:
114    with subprocess.Popen(
115        command,
116        cwd=cwd,
117        stdout=subprocess.PIPE,
118        encoding="utf-8",
119        errors="replace",
120    ) as p:
121        assert p.stdout is not None  # for mypy's happiness.
122        yield from p.stdout
123
124
125def _resolve_sha(git_dir: str, sha: str) -> str:
126    if len(sha) == 40:
127        return sha
128
129    return subprocess.check_output(
130        ["git", "-C", git_dir, "rev-parse", sha],
131        encoding="utf-8",
132        stderr=subprocess.DEVNULL,
133    ).strip()
134
135
136_LogEntry = NamedTuple(
137    "_LogEntry",
138    [
139        ("sha", str),
140        ("commit_message", str),
141    ],
142)
143
144
145def _log_stream(git_dir: str, root_sha: str, end_at_sha: str) -> Iterable[_LogEntry]:
146    sep = 50 * "<>"
147    log_command = [
148        "git",
149        "-C",
150        git_dir,
151        "log",
152        "^" + end_at_sha,
153        root_sha,
154        "--format=" + sep + "%n%H%n%B%n",
155    ]
156
157    stdout_stream = iter(_stream_stdout(log_command))
158
159    # Find the next separator line. If there's nothing to log, it may not exist.
160    # It might not be the first line if git feels complainy.
161    found_commit_header = False
162    for line in stdout_stream:
163        if line.rstrip() == sep:
164            found_commit_header = True
165            break
166
167    while found_commit_header:
168        sha = next(stdout_stream, None)
169        assert sha is not None, "git died?"
170        sha = sha.rstrip()
171
172        commit_message = []
173
174        found_commit_header = False
175        for line in stdout_stream:
176            line = line.rstrip()
177            if line.rstrip() == sep:
178                found_commit_header = True
179                break
180            commit_message.append(line)
181
182        yield _LogEntry(sha, "\n".join(commit_message).rstrip())
183
184
185def _shas_between(git_dir: str, base_ref: str, head_ref: str) -> Iterable[str]:
186    rev_list = [
187        "git",
188        "-C",
189        git_dir,
190        "rev-list",
191        "--first-parent",
192        f"{base_ref}..{head_ref}",
193    ]
194    return (x.strip() for x in _stream_stdout(rev_list))
195
196
197def _rev_parse(git_dir: str, ref: str) -> str:
198    return subprocess.check_output(
199        ["git", "-C", git_dir, "rev-parse", ref],
200        encoding="utf-8",
201    ).strip()
202
203
204Revert = NamedTuple(
205    "Revert",
206    [
207        ("sha", str),
208        ("reverted_sha", str),
209    ],
210)
211
212
213def _find_common_parent_commit(git_dir: str, ref_a: str, ref_b: str) -> str:
214    """Finds the closest common parent commit between `ref_a` and `ref_b`."""
215    return subprocess.check_output(
216        ["git", "-C", git_dir, "merge-base", ref_a, ref_b],
217        encoding="utf-8",
218    ).strip()
219
220
221def _load_pr_commit_mappings(
222    git_dir: str, root: str, min_ref: str
223) -> Dict[int, List[str]]:
224    git_log = ["git", "log", "--format=%H %s", f"{min_ref}..{root}"]
225    results = collections.defaultdict(list)
226    pr_regex = re.compile(r"\s\(#(\d+)\)$")
227    for line in _stream_stdout(git_log, cwd=git_dir):
228        m = pr_regex.search(line)
229        if not m:
230            continue
231
232        pr_number = int(m.group(1))
233        sha = line.split(None, 1)[0]
234        # N.B., these are kept in log (read: reverse chronological) order,
235        # which is what's expected by `find_reverts`.
236        results[pr_number].append(sha)
237    return results
238
239
240# N.B., max_pr_lookback's default of 20K commits is arbitrary, but should be
241# enough for the 99% case of reverts: rarely should someone land a cleanish
242# revert of a >6 month old change...
243def find_reverts(
244    git_dir: str, across_ref: str, root: str, max_pr_lookback: int = 20000
245) -> List[Revert]:
246    """Finds reverts across `across_ref` in `git_dir`, starting from `root`.
247
248    These reverts are returned in order of oldest reverts first.
249
250    Args:
251        git_dir: git directory to find reverts in.
252        across_ref: the ref to find reverts across.
253        root: the 'main' ref to look for reverts on.
254        max_pr_lookback: this function uses heuristics to map PR numbers to
255            SHAs. These heuristics require that commit history from `root` to
256            `some_parent_of_root` is loaded in memory. `max_pr_lookback` is how
257            many commits behind `across_ref` should be loaded in memory.
258    """
259    across_sha = _rev_parse(git_dir, across_ref)
260    root_sha = _rev_parse(git_dir, root)
261
262    common_ancestor = _find_common_parent_commit(git_dir, across_sha, root_sha)
263    if common_ancestor != across_sha:
264        raise ValueError(
265            f"{across_sha} isn't an ancestor of {root_sha} "
266            "(common ancestor: {common_ancestor})"
267        )
268
269    intermediate_commits = set(_shas_between(git_dir, across_sha, root_sha))
270    assert across_sha not in intermediate_commits
271
272    logging.debug(
273        "%d commits appear between %s and %s",
274        len(intermediate_commits),
275        across_sha,
276        root_sha,
277    )
278
279    all_reverts = []
280    # Lazily load PR <-> commit mappings, since it can be expensive.
281    pr_commit_mappings = None
282    for sha, commit_message in _log_stream(git_dir, root_sha, across_sha):
283        reverts, pr_reverts = _try_parse_reverts_from_commit_message(
284            commit_message,
285        )
286        if pr_reverts:
287            if pr_commit_mappings is None:
288                logging.info(
289                    "Loading PR <-> commit mappings. This may take a moment..."
290                )
291                pr_commit_mappings = _load_pr_commit_mappings(
292                    git_dir, root_sha, f"{across_sha}~{max_pr_lookback}"
293                )
294                logging.info(
295                    "Loaded %d PR <-> commit mappings", len(pr_commit_mappings)
296                )
297
298            for reverted_pr_number in pr_reverts:
299                reverted_shas = pr_commit_mappings.get(reverted_pr_number)
300                if not reverted_shas:
301                    logging.warning(
302                        "No SHAs for reverted PR %d (commit %s)",
303                        reverted_pr_number,
304                        sha,
305                    )
306                    continue
307                logging.debug(
308                    "Inferred SHAs %s for reverted PR %d (commit %s)",
309                    reverted_shas,
310                    reverted_pr_number,
311                    sha,
312                )
313                reverts.extend(reverted_shas)
314
315        if not reverts:
316            continue
317
318        resolved_reverts = sorted(set(_resolve_sha(git_dir, x) for x in reverts))
319        for reverted_sha in resolved_reverts:
320            if reverted_sha in intermediate_commits:
321                logging.debug(
322                    "Commit %s reverts %s, which happened after %s",
323                    sha,
324                    reverted_sha,
325                    across_sha,
326                )
327                continue
328
329            try:
330                object_type = subprocess.check_output(
331                    ["git", "-C", git_dir, "cat-file", "-t", reverted_sha],
332                    encoding="utf-8",
333                    stderr=subprocess.DEVNULL,
334                ).strip()
335            except subprocess.CalledProcessError:
336                logging.warning(
337                    "Failed to resolve reverted object %s (claimed to be reverted "
338                    "by sha %s)",
339                    reverted_sha,
340                    sha,
341                )
342                continue
343
344            if object_type == "commit":
345                all_reverts.append(Revert(sha, reverted_sha))
346                continue
347
348            logging.error(
349                "%s claims to revert %s -- which isn't a commit -- %s",
350                sha,
351                object_type,
352                reverted_sha,
353            )
354
355    # Since `all_reverts` contains reverts in log order (e.g., newer comes before
356    # older), we need to reverse this to keep with our guarantee of older =
357    # earlier in the result.
358    all_reverts.reverse()
359    return all_reverts
360
361
362def _main() -> None:
363    parser = argparse.ArgumentParser(
364        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
365    )
366    parser.add_argument("base_ref", help="Git ref or sha to check for reverts around.")
367    parser.add_argument("-C", "--git_dir", default=".", help="Git directory to use.")
368    parser.add_argument("root", nargs="+", help="Root(s) to search for commits from.")
369    parser.add_argument("--debug", action="store_true")
370    parser.add_argument(
371        "-u",
372        "--review_url",
373        action="store_true",
374        help="Format SHAs as llvm review URLs",
375    )
376    opts = parser.parse_args()
377
378    logging.basicConfig(
379        format="%(asctime)s: %(levelname)s: %(filename)s:%(lineno)d: %(message)s",
380        level=logging.DEBUG if opts.debug else logging.INFO,
381    )
382
383    # `root`s can have related history, so we want to filter duplicate commits
384    # out. The overwhelmingly common case is also to have one root, and it's way
385    # easier to reason about output that comes in an order that's meaningful to
386    # git.
387    seen_reverts = set()
388    all_reverts = []
389    for root in opts.root:
390        for revert in find_reverts(opts.git_dir, opts.base_ref, root):
391            if revert not in seen_reverts:
392                seen_reverts.add(revert)
393                all_reverts.append(revert)
394
395    sha_prefix = (
396        "https://github.com/llvm/llvm-project/commit/" if opts.review_url else ""
397    )
398    for revert in all_reverts:
399        sha_fmt = f"{sha_prefix}{revert.sha}"
400        reverted_sha_fmt = f"{sha_prefix}{revert.reverted_sha}"
401        print(f"{sha_fmt} claims to revert {reverted_sha_fmt}")
402
403
404if __name__ == "__main__":
405    _main()
406