xref: /llvm-project/llvm/utils/update_any_test_checks.py (revision 33786b62bae4cec1c921f7c017aa6cdc644ced39)
1#!/usr/bin/env python3
2
3"""Dispatch to update_*_test_checks.py scripts automatically in bulk
4
5Given a list of test files, this script will invoke the correct
6update_test_checks-style script, skipping any tests which have not previously
7had assertions autogenerated. If test name starts with '@' it's treated as
8a name of file containing test list.
9"""
10
11from __future__ import print_function
12
13import argparse
14import os
15import re
16import subprocess
17import sys
18from concurrent.futures import ThreadPoolExecutor
19
20RE_ASSERTIONS = re.compile(
21    r"NOTE: Assertions have been autogenerated by ([^\s]+)( UTC_ARGS:.*)?$"
22)
23
24
25def find_utc_tool(search_path, utc_name):
26    """
27    Return the path to the given UTC tool in the search path, or None if not
28    found.
29    """
30    for path in search_path:
31        candidate = os.path.join(path, utc_name)
32        if os.path.isfile(candidate):
33            return candidate
34    return None
35
36
37def run_utc_tool(utc_name, utc_tool, testname):
38    result = subprocess.run(
39        [utc_tool, testname], stdout=subprocess.PIPE, stderr=subprocess.PIPE
40    )
41    return (result.returncode, result.stdout, result.stderr)
42
43
44def read_arguments_from_file(filename):
45    try:
46        with open(filename, "r") as file:
47            return [line.rstrip() for line in file.readlines()]
48    except FileNotFoundError:
49        print(f"Error: File '{filename}' not found.")
50        sys.exit(1)
51
52
53def expand_listfile_args(arg_list):
54    exp_arg_list = []
55    for arg in arg_list:
56        if arg.startswith("@"):
57            exp_arg_list += read_arguments_from_file(arg[1:])
58        else:
59            exp_arg_list.append(arg)
60    return exp_arg_list
61
62
63def main():
64    from argparse import RawTextHelpFormatter
65
66    parser = argparse.ArgumentParser(
67        description=__doc__, formatter_class=RawTextHelpFormatter
68    )
69    parser.add_argument(
70        "--jobs",
71        "-j",
72        default=1,
73        type=int,
74        help="Run the given number of jobs in parallel",
75    )
76    parser.add_argument(
77        "--utc-dir",
78        nargs="*",
79        help="Additional directories to scan for update_*_test_checks scripts",
80    )
81    parser.add_argument("tests", nargs="+")
82    config = parser.parse_args()
83
84    if config.utc_dir:
85        utc_search_path = config.utc_dir[:]
86    else:
87        utc_search_path = []
88    script_name = os.path.abspath(__file__)
89    utc_search_path.append(os.path.join(os.path.dirname(script_name), os.path.pardir))
90
91    not_autogenerated = []
92    utc_tools = {}
93    have_error = False
94
95    tests = expand_listfile_args(config.tests)
96
97    with ThreadPoolExecutor(max_workers=config.jobs) as executor:
98        jobs = []
99
100        for testname in tests:
101            with open(testname, "r") as f:
102                header = f.readline().strip()
103                m = RE_ASSERTIONS.search(header)
104                if m is None:
105                    not_autogenerated.append(testname)
106                    continue
107
108                utc_name = m.group(1)
109                if utc_name not in utc_tools:
110                    utc_tools[utc_name] = find_utc_tool(utc_search_path, utc_name)
111                    if not utc_tools[utc_name]:
112                        print(
113                            f"{utc_name}: not found (used in {testname})",
114                            file=sys.stderr,
115                        )
116                        have_error = True
117                        continue
118
119                future = executor.submit(
120                    run_utc_tool, utc_name, utc_tools[utc_name], testname
121                )
122                jobs.append((testname, future))
123
124        for testname, future in jobs:
125            return_code, stdout, stderr = future.result()
126
127            print(f"Update {testname}")
128            stdout = stdout.decode(errors="replace")
129            if stdout:
130                print(stdout, end="")
131                if not stdout.endswith("\n"):
132                    print()
133
134            stderr = stderr.decode(errors="replace")
135            if stderr:
136                print(stderr, end="")
137                if not stderr.endswith("\n"):
138                    print()
139            if return_code != 0:
140                print(f"Return code: {return_code}")
141                have_error = True
142
143    if have_error:
144        sys.exit(1)
145
146    if not_autogenerated:
147        print("Tests without autogenerated assertions:")
148        for testname in not_autogenerated:
149            print(f"  {testname}")
150
151
152if __name__ == "__main__":
153    main()
154