xref: /llvm-project/llvm/utils/update_mir_test_checks.py (revision 102838d3f6e15a5c510257c2c70fe7faca5b59d6)
1#!/usr/bin/env python3
2
3"""Updates FileCheck checks in MIR tests.
4
5This script is a utility to update MIR based tests with new FileCheck
6patterns.
7
8The checks added by this script will cover the entire body of each
9function it handles. Virtual registers used are given names via
10FileCheck patterns, so if you do want to check a subset of the body it
11should be straightforward to trim out the irrelevant parts. None of
12the YAML metadata will be checked, other than function names, and fixedStack
13if the --print-fixed-stack option is used.
14
15If there are multiple llc commands in a test, the full set of checks
16will be repeated for each different check pattern. Checks for patterns
17that are common between different commands will be left as-is by
18default, or removed if the --remove-common-prefixes flag is provided.
19"""
20
21from __future__ import print_function
22
23import argparse
24import collections
25import glob
26import os
27import re
28import subprocess
29import sys
30
31from UpdateTestChecks import common
32
33MIR_FUNC_NAME_RE = re.compile(r" *name: *(?P<func>[A-Za-z0-9_.-]+)")
34MIR_BODY_BEGIN_RE = re.compile(r" *body: *\|")
35MIR_BASIC_BLOCK_RE = re.compile(r" *bb\.[0-9]+.*:$")
36VREG_RE = re.compile(r"(%[0-9]+)(?:\.[a-z0-9_]+)?(?::[a-z0-9_]+)?(?:\([<>a-z0-9 ]+\))?")
37MI_FLAGS_STR = (
38    r"(frame-setup |frame-destroy |nnan |ninf |nsz |arcp |contract |afn "
39    r"|reassoc |nuw |nsw |exact |nofpexcept |nomerge )*"
40)
41VREG_DEF_FLAGS_STR = r"(?:dead |undef )*"
42VREG_DEF_RE = re.compile(
43    r"^ *(?P<vregs>{2}{0}(?:, {2}{0})*) = "
44    r"{1}(?P<opcode>[A-Zt][A-Za-z0-9_]+)".format(
45        VREG_RE.pattern, MI_FLAGS_STR, VREG_DEF_FLAGS_STR
46    )
47)
48MIR_PREFIX_DATA_RE = re.compile(r"^ *(;|bb.[0-9].*: *$|[a-z]+:( |$)|$)")
49
50IR_FUNC_NAME_RE = re.compile(
51    r"^\s*define\s+(?:internal\s+)?[^@]*@(?P<func>[A-Za-z0-9_.]+)\s*\("
52)
53IR_PREFIX_DATA_RE = re.compile(r"^ *(;|$)")
54
55MIR_FUNC_RE = re.compile(
56    r"^---$"
57    r"\n"
58    r"^ *name: *(?P<func>[A-Za-z0-9_.-]+)$"
59    r".*?"
60    r"(?:^ *fixedStack: *(\[\])? *\n"
61    r"(?P<fixedStack>.*?)\n?"
62    r"^ *stack:"
63    r".*?)?"
64    r"^ *body: *\|\n"
65    r"(?P<body>.*?)\n"
66    r"^\.\.\.$",
67    flags=(re.M | re.S),
68)
69
70
71class LLC:
72    def __init__(self, bin):
73        self.bin = bin
74
75    def __call__(self, args, ir):
76        if ir.endswith(".mir"):
77            args = "{} -x mir".format(args)
78        with open(ir) as ir_file:
79            stdout = subprocess.check_output(
80                "{} {}".format(self.bin, args), shell=True, stdin=ir_file
81            )
82            if sys.version_info[0] > 2:
83                stdout = stdout.decode()
84            # Fix line endings to unix CR style.
85            stdout = stdout.replace("\r\n", "\n")
86        return stdout
87
88
89class Run:
90    def __init__(self, prefixes, cmd_args, triple):
91        self.prefixes = prefixes
92        self.cmd_args = cmd_args
93        self.triple = triple
94
95    def __getitem__(self, index):
96        return [self.prefixes, self.cmd_args, self.triple][index]
97
98
99def log(msg, verbose=True):
100    if verbose:
101        print(msg, file=sys.stderr)
102
103
104def find_triple_in_ir(lines, verbose=False):
105    for l in lines:
106        m = common.TRIPLE_IR_RE.match(l)
107        if m:
108            return m.group(1)
109    return None
110
111
112def build_run_list(test, run_lines, verbose=False):
113    run_list = []
114    all_prefixes = []
115    for l in run_lines:
116        if "|" not in l:
117            common.warn("Skipping unparsable RUN line: " + l)
118            continue
119
120        commands = [cmd.strip() for cmd in l.split("|", 1)]
121        llc_cmd = commands[0]
122        filecheck_cmd = commands[1] if len(commands) > 1 else ""
123        common.verify_filecheck_prefixes(filecheck_cmd)
124
125        if not llc_cmd.startswith("llc "):
126            common.warn("Skipping non-llc RUN line: {}".format(l), test_file=test)
127            continue
128        if not filecheck_cmd.startswith("FileCheck "):
129            common.warn(
130                "Skipping non-FileChecked RUN line: {}".format(l), test_file=test
131            )
132            continue
133
134        triple = None
135        m = common.TRIPLE_ARG_RE.search(llc_cmd)
136        if m:
137            triple = m.group(1)
138        # If we find -march but not -mtriple, use that.
139        m = common.MARCH_ARG_RE.search(llc_cmd)
140        if m and not triple:
141            triple = "{}--".format(m.group(1))
142
143        cmd_args = llc_cmd[len("llc") :].strip()
144        cmd_args = cmd_args.replace("< %s", "").replace("%s", "").strip()
145        check_prefixes = common.get_check_prefixes(filecheck_cmd)
146        all_prefixes += check_prefixes
147
148        run_list.append(Run(check_prefixes, cmd_args, triple))
149
150    # Sort prefixes that are shared between run lines before unshared prefixes.
151    # This causes us to prefer printing shared prefixes.
152    for run in run_list:
153        run.prefixes.sort(key=lambda prefix: -all_prefixes.count(prefix))
154
155    return run_list
156
157
158def find_functions_with_one_bb(lines, verbose=False):
159    result = []
160    cur_func = None
161    bbs = 0
162    for line in lines:
163        m = MIR_FUNC_NAME_RE.match(line)
164        if m:
165            if bbs == 1:
166                result.append(cur_func)
167            cur_func = m.group("func")
168            bbs = 0
169        m = MIR_BASIC_BLOCK_RE.match(line)
170        if m:
171            bbs += 1
172    if bbs == 1:
173        result.append(cur_func)
174    return result
175
176
177class FunctionInfo:
178    def __init__(self, body, fixedStack):
179        self.body = body
180        self.fixedStack = fixedStack
181
182    def __eq__(self, other):
183        if not isinstance(other, FunctionInfo):
184            return False
185        return self.body == other.body and self.fixedStack == other.fixedStack
186
187
188def build_function_info_dictionary(
189    test, raw_tool_output, triple, prefixes, func_dict, verbose
190):
191    for m in MIR_FUNC_RE.finditer(raw_tool_output):
192        func = m.group("func")
193        fixedStack = m.group("fixedStack")
194        body = m.group("body")
195        if verbose:
196            log("Processing function: {}".format(func))
197            for l in body.splitlines():
198                log("  {}".format(l))
199
200        # Vreg mangling
201        mangled = []
202        vreg_map = {}
203        for func_line in body.splitlines(keepends=True):
204            m = VREG_DEF_RE.match(func_line)
205            if m:
206                for vreg in VREG_RE.finditer(m.group("vregs")):
207                    if vreg.group(1) in vreg_map:
208                        name = vreg_map[vreg.group(1)]
209                    else:
210                        name = mangle_vreg(m.group("opcode"), vreg_map.values())
211                        vreg_map[vreg.group(1)] = name
212                    func_line = func_line.replace(
213                        vreg.group(1), "[[{}:%[0-9]+]]".format(name), 1
214                    )
215            for number, name in vreg_map.items():
216                func_line = re.sub(
217                    r"{}\b".format(number), "[[{}]]".format(name), func_line
218                )
219            mangled.append(func_line)
220        body = "".join(mangled)
221
222        for prefix in prefixes:
223            info = FunctionInfo(body, fixedStack)
224            if func in func_dict[prefix]:
225                if func_dict[prefix][func] != info:
226                    func_dict[prefix][func] = None
227            else:
228                func_dict[prefix][func] = info
229
230
231def add_checks_for_function(
232    test, output_lines, run_list, func_dict, func_name, single_bb, args
233):
234    printed_prefixes = set()
235    for run in run_list:
236        for prefix in run.prefixes:
237            if prefix in printed_prefixes:
238                break
239            if not func_dict[prefix][func_name]:
240                continue
241            if printed_prefixes:
242                # Add some space between different check prefixes.
243                indent = len(output_lines[-1]) - len(output_lines[-1].lstrip(" "))
244                output_lines.append(" "*indent + ";")
245            printed_prefixes.add(prefix)
246            log("Adding {} lines for {}".format(prefix, func_name), args.verbose)
247            add_check_lines(
248                test,
249                output_lines,
250                prefix,
251                func_name,
252                single_bb,
253                func_dict[prefix][func_name],
254                args,
255            )
256            break
257        else:
258            common.warn(
259                "Found conflicting asm for function: {}".format(func_name),
260                test_file=test,
261            )
262    return output_lines
263
264
265def add_check_lines(
266    test, output_lines, prefix, func_name, single_bb, func_info: FunctionInfo, args
267):
268    func_body = func_info.body.splitlines()
269    if single_bb:
270        # Don't bother checking the basic block label for a single BB
271        func_body.pop(0)
272
273    if not func_body:
274        common.warn(
275            "Function has no instructions to check: {}".format(func_name),
276            test_file=test,
277        )
278        return
279
280    first_line = func_body[0]
281    indent = len(first_line) - len(first_line.lstrip(" "))
282    # A check comment, indented the appropriate amount
283    check = "{:>{}}; {}".format("", indent, prefix)
284
285    output_lines.append("{}-LABEL: name: {}".format(check, func_name))
286
287    if args.print_fixed_stack:
288        output_lines.append("{}: fixedStack:".format(check))
289        for stack_line in func_info.fixedStack.splitlines():
290            filecheck_directive = check + "-NEXT"
291            output_lines.append("{}: {}".format(filecheck_directive, stack_line))
292
293    first_check = True
294    for func_line in func_body:
295        if not func_line.strip():
296            # The mir printer prints leading whitespace so we can't use CHECK-EMPTY:
297            output_lines.append(check + "-NEXT: {{" + func_line + "$}}")
298            continue
299        filecheck_directive = check if first_check else check + "-NEXT"
300        first_check = False
301        check_line = "{}: {}".format(filecheck_directive, func_line[indent:]).rstrip()
302        output_lines.append(check_line)
303
304
305def mangle_vreg(opcode, current_names):
306    base = opcode
307    # Simplify some common prefixes and suffixes
308    if opcode.startswith("G_"):
309        base = base[len("G_") :]
310    if opcode.endswith("_PSEUDO"):
311        base = base[: len("_PSEUDO")]
312    # Shorten some common opcodes with long-ish names
313    base = dict(
314        IMPLICIT_DEF="DEF",
315        GLOBAL_VALUE="GV",
316        CONSTANT="C",
317        FCONSTANT="C",
318        MERGE_VALUES="MV",
319        UNMERGE_VALUES="UV",
320        INTRINSIC="INT",
321        INTRINSIC_W_SIDE_EFFECTS="INT",
322        INSERT_VECTOR_ELT="IVEC",
323        EXTRACT_VECTOR_ELT="EVEC",
324        SHUFFLE_VECTOR="SHUF",
325    ).get(base, base)
326    # Avoid ambiguity when opcodes end in numbers
327    if len(base.rstrip("0123456789")) < len(base):
328        base += "_"
329
330    i = 0
331    for name in current_names:
332        if name.rstrip("0123456789") == base:
333            i += 1
334    if i:
335        return "{}{}".format(base, i)
336    return base
337
338
339def should_add_line_to_output(input_line, prefix_set):
340    # Skip any check lines that we're handling as well as comments
341    m = common.CHECK_RE.match(input_line)
342    if (m and m.group(1) in prefix_set) or re.search("^[ \t]*;", input_line):
343        return False
344    return True
345
346
347def update_test_file(args, test, autogenerated_note):
348    with open(test) as fd:
349        input_lines = [l.rstrip() for l in fd]
350
351    triple_in_ir = find_triple_in_ir(input_lines, args.verbose)
352    run_lines = common.find_run_lines(test, input_lines)
353    run_list = build_run_list(test, run_lines, args.verbose)
354
355    simple_functions = find_functions_with_one_bb(input_lines, args.verbose)
356
357    func_dict = {}
358    for run in run_list:
359        for prefix in run.prefixes:
360            func_dict.update({prefix: dict()})
361    for prefixes, llc_args, triple_in_cmd in run_list:
362        log("Extracted LLC cmd: llc {}".format(llc_args), args.verbose)
363        log("Extracted FileCheck prefixes: {}".format(prefixes), args.verbose)
364
365        raw_tool_output = args.llc_binary(llc_args, test)
366        if not triple_in_cmd and not triple_in_ir:
367            common.warn("No triple found: skipping file", test_file=test)
368            return
369
370        build_function_info_dictionary(
371            test,
372            raw_tool_output,
373            triple_in_cmd or triple_in_ir,
374            prefixes,
375            func_dict,
376            args.verbose,
377        )
378
379    state = "toplevel"
380    func_name = None
381    prefix_set = set([prefix for run in run_list for prefix in run.prefixes])
382    log("Rewriting FileCheck prefixes: {}".format(prefix_set), args.verbose)
383
384    output_lines = []
385    output_lines.append(autogenerated_note)
386
387    for input_line in input_lines:
388        if input_line == autogenerated_note:
389            continue
390
391        if state == "toplevel":
392            m = IR_FUNC_NAME_RE.match(input_line)
393            if m:
394                state = "ir function prefix"
395                func_name = m.group("func")
396            if input_line.rstrip("| \r\n") == "---":
397                state = "document"
398            output_lines.append(input_line)
399        elif state == "document":
400            m = MIR_FUNC_NAME_RE.match(input_line)
401            if m:
402                state = "mir function metadata"
403                func_name = m.group("func")
404            if input_line.strip() == "...":
405                state = "toplevel"
406                func_name = None
407            if should_add_line_to_output(input_line, prefix_set):
408                output_lines.append(input_line)
409        elif state == "mir function metadata":
410            if should_add_line_to_output(input_line, prefix_set):
411                output_lines.append(input_line)
412            m = MIR_BODY_BEGIN_RE.match(input_line)
413            if m:
414                if func_name in simple_functions:
415                    # If there's only one block, put the checks inside it
416                    state = "mir function prefix"
417                    continue
418                state = "mir function body"
419                add_checks_for_function(
420                    test,
421                    output_lines,
422                    run_list,
423                    func_dict,
424                    func_name,
425                    single_bb=False,
426                    args=args,
427                )
428        elif state == "mir function prefix":
429            m = MIR_PREFIX_DATA_RE.match(input_line)
430            if not m:
431                state = "mir function body"
432                add_checks_for_function(
433                    test,
434                    output_lines,
435                    run_list,
436                    func_dict,
437                    func_name,
438                    single_bb=True,
439                    args=args,
440                )
441
442            if should_add_line_to_output(input_line, prefix_set):
443                output_lines.append(input_line)
444        elif state == "mir function body":
445            if input_line.strip() == "...":
446                state = "toplevel"
447                func_name = None
448            if should_add_line_to_output(input_line, prefix_set):
449                output_lines.append(input_line)
450        elif state == "ir function prefix":
451            m = IR_PREFIX_DATA_RE.match(input_line)
452            if not m:
453                state = "ir function body"
454                add_checks_for_function(
455                    test,
456                    output_lines,
457                    run_list,
458                    func_dict,
459                    func_name,
460                    single_bb=False,
461                    args=args,
462                )
463
464            if should_add_line_to_output(input_line, prefix_set):
465                output_lines.append(input_line)
466        elif state == "ir function body":
467            if input_line.strip() == "}":
468                state = "toplevel"
469                func_name = None
470            if should_add_line_to_output(input_line, prefix_set):
471                output_lines.append(input_line)
472
473    log("Writing {} lines to {}...".format(len(output_lines), test), args.verbose)
474
475    with open(test, "wb") as fd:
476        fd.writelines(["{}\n".format(l).encode("utf-8") for l in output_lines])
477
478
479def main():
480    parser = argparse.ArgumentParser(
481        description=__doc__, formatter_class=argparse.RawTextHelpFormatter
482    )
483    parser.add_argument(
484        "--llc-binary",
485        default="llc",
486        type=LLC,
487        help='The "llc" binary to generate the test case with',
488    )
489    parser.add_argument(
490        "--print-fixed-stack",
491        action="store_true",
492        help="Add check lines for fixedStack",
493    )
494    parser.add_argument("tests", nargs="+")
495    args = common.parse_commandline_args(parser)
496
497    script_name = os.path.basename(__file__)
498    for ti in common.itertests(args.tests, parser, script_name="utils/" + script_name):
499        try:
500            update_test_file(ti.args, ti.path, ti.test_autogenerated_note)
501        except Exception:
502            common.warn("Error processing file", test_file=ti.path)
503            raise
504
505
506if __name__ == "__main__":
507    main()
508