xref: /llvm-project/llvm/utils/update_mca_test_checks.py (revision 891d511fdbb4654667533c832049deeea0c715c7)
1#!/usr/bin/env python3
2
3"""A test case update script.
4
5This script is a utility to update LLVM 'llvm-mca' based test cases with new
6FileCheck patterns.
7"""
8
9import argparse
10from collections import defaultdict
11import glob
12import os
13import sys
14import warnings
15
16from UpdateTestChecks import common
17
18
19COMMENT_CHAR = "#"
20ADVERT_PREFIX = "{} NOTE: Assertions have been autogenerated by ".format(COMMENT_CHAR)
21ADVERT = "{}utils/{}".format(ADVERT_PREFIX, os.path.basename(__file__))
22
23
24class Error(Exception):
25    """Generic Error that can be raised without printing a traceback."""
26
27    pass
28
29
30def _warn(msg):
31    """Log a user warning to stderr."""
32    warnings.warn(msg, Warning, stacklevel=2)
33
34
35def _configure_warnings(args):
36    warnings.resetwarnings()
37    if args.w:
38        warnings.simplefilter("ignore")
39    if args.Werror:
40        warnings.simplefilter("error")
41
42
43def _showwarning(message, category, filename, lineno, file=None, line=None):
44    """Version of warnings.showwarning that won't attempt to print out the
45    line at the location of the warning if the line text is not explicitly
46    specified.
47    """
48    if file is None:
49        file = sys.stderr
50    if line is None:
51        line = ""
52    file.write(warnings.formatwarning(message, category, filename, lineno, line))
53
54
55def _get_parser():
56    parser = argparse.ArgumentParser(description=__doc__)
57    parser.add_argument("-w", action="store_true", help="suppress warnings")
58    parser.add_argument(
59        "-Werror", action="store_true", help="promote warnings to errors"
60    )
61    parser.add_argument(
62        "--llvm-mca-binary",
63        metavar="<path>",
64        default="llvm-mca",
65        help="the binary to use to generate the test case " "(default: llvm-mca)",
66    )
67    parser.add_argument("tests", metavar="<test-path>", nargs="+")
68    return parser
69
70def _get_run_infos(run_lines, args):
71    run_infos = []
72    for run_line in run_lines:
73        try:
74            (tool_cmd, filecheck_cmd) = tuple(
75                [cmd.strip() for cmd in run_line.split("|", 1)]
76            )
77        except ValueError:
78            _warn("could not split tool and filecheck commands: {}".format(run_line))
79            continue
80
81        common.verify_filecheck_prefixes(filecheck_cmd)
82        tool_basename = os.path.splitext(os.path.basename(args.llvm_mca_binary))[0]
83
84        if not tool_cmd.startswith(tool_basename + " "):
85            _warn("skipping non-{} RUN line: {}".format(tool_basename, run_line))
86            continue
87
88        if not filecheck_cmd.startswith("FileCheck "):
89            _warn("skipping non-FileCheck RUN line: {}".format(run_line))
90            continue
91
92        tool_cmd_args = tool_cmd[len(tool_basename) :].strip()
93        tool_cmd_args = tool_cmd_args.replace("< %s", "").replace("%s", "").strip()
94
95        check_prefixes = common.get_check_prefixes(filecheck_cmd)
96
97        run_infos.append((check_prefixes, tool_cmd_args))
98
99    return run_infos
100
101
102def _break_down_block(block_info, common_prefix):
103    """Given a block_info, see if we can analyze it further to let us break it
104    down by prefix per-line rather than per-block.
105    """
106    texts = block_info.keys()
107    prefixes = list(block_info.values())
108    # Split the lines from each of the incoming block_texts and zip them so that
109    # each element contains the corresponding lines from each text.  E.g.
110    #
111    # block_text_1: A   # line 1
112    #               B   # line 2
113    #
114    # block_text_2: A   # line 1
115    #               C   # line 2
116    #
117    # would become:
118    #
119    # [(A, A),   # line 1
120    #  (B, C)]   # line 2
121    #
122    line_tuples = list(zip(*list((text.splitlines() for text in texts))))
123
124    # To simplify output, we'll only proceed if the very first line of the block
125    # texts is common to each of them.
126    if len(set(line_tuples[0])) != 1:
127        return []
128
129    result = []
130    lresult = defaultdict(list)
131    for i, line in enumerate(line_tuples):
132        if len(set(line)) == 1:
133            # We're about to output a line with the common prefix.  This is a sync
134            # point so flush any batched-up lines one prefix at a time to the output
135            # first.
136            for prefix in sorted(lresult):
137                result.extend(lresult[prefix])
138            lresult = defaultdict(list)
139
140            # The line is common to each block so output with the common prefix.
141            result.append((common_prefix, line[0]))
142        else:
143            # The line is not common to each block, or we don't have a common prefix.
144            # If there are no prefixes available, warn and bail out.
145            if not prefixes[0]:
146                _warn(
147                    "multiple lines not disambiguated by prefixes:\n{}\n"
148                    "Some blocks may be skipped entirely as a result.".format(
149                        "\n".join("  - {}".format(l) for l in line)
150                    )
151                )
152                return []
153
154            # Iterate through the line from each of the blocks and add the line with
155            # the corresponding prefix to the current batch of results so that we can
156            # later output them per-prefix.
157            for i, l in enumerate(line):
158                for prefix in prefixes[i]:
159                    lresult[prefix].append((prefix, l))
160
161    # Flush any remaining batched-up lines one prefix at a time to the output.
162    for prefix in sorted(lresult):
163        result.extend(lresult[prefix])
164    return result
165
166
167def _get_useful_prefix_info(run_infos):
168    """Given the run_infos, calculate any prefixes that are common to every one,
169    and the length of the longest prefix string.
170    """
171    try:
172        all_sets = [set(s) for s in list(zip(*run_infos))[0]]
173        common_to_all = set.intersection(*all_sets)
174        longest_prefix_len = max(len(p) for p in set.union(*all_sets))
175    except IndexError:
176        common_to_all = []
177        longest_prefix_len = 0
178    else:
179        if len(common_to_all) > 1:
180            _warn("Multiple prefixes common to all RUN lines: {}".format(common_to_all))
181        if common_to_all:
182            common_to_all = sorted(common_to_all)[0]
183    return common_to_all, longest_prefix_len
184
185
186def _align_matching_blocks(all_blocks, farthest_indexes):
187    """Some sub-sequences of blocks may be common to multiple lists of blocks,
188    but at different indexes in each one.
189
190    For example, in the following case, A,B,E,F, and H are common to both
191    sets, but only A and B would be identified as such due to the indexes
192    matching:
193
194    index | 0 1 2 3 4 5 6
195    ------+--------------
196    setA  | A B C D E F H
197    setB  | A B E F G H
198
199    This function attempts to align the indexes of matching blocks by
200    inserting empty blocks into the block list. With this approach, A, B, E,
201    F, and H would now be able to be identified as matching blocks:
202
203    index | 0 1 2 3 4 5 6 7
204    ------+----------------
205    setA  | A B C D E F   H
206    setB  | A B     E F G H
207    """
208
209    # "Farthest block analysis": essentially, iterate over all blocks and find
210    # the highest index into a block list for the first instance of each block.
211    # This is relatively expensive, but we're dealing with small numbers of
212    # blocks so it doesn't make a perceivable difference to user time.
213    for blocks in all_blocks.values():
214        for block in blocks:
215            if not block:
216                continue
217
218            index = blocks.index(block)
219
220            if index > farthest_indexes[block]:
221                farthest_indexes[block] = index
222
223    # Use the results of the above analysis to identify any blocks that can be
224    # shunted along to match the farthest index value.
225    for blocks in all_blocks.values():
226        for index, block in enumerate(blocks):
227            if not block:
228                continue
229
230            changed = False
231            # If the block has not already been subject to alignment (i.e. if the
232            # previous block is not empty) then insert empty blocks until the index
233            # matches the farthest index identified for that block.
234            if (index > 0) and blocks[index - 1]:
235                while index < farthest_indexes[block]:
236                    blocks.insert(index, "")
237                    index += 1
238                    changed = True
239
240            if changed:
241                # Bail out.  We'll need to re-do the farthest block analysis now that
242                # we've inserted some blocks.
243                return True
244
245    return False
246
247
248def _get_block_infos(run_infos, test_path, args, common_prefix):  # noqa
249    """For each run line, run the tool with the specified args and collect the
250    output. We use the concept of 'blocks' for uniquing, where a block is
251    a series of lines of text with no more than one newline character between
252    each one.  For example:
253
254    This
255    is
256    one
257    block
258
259    This is
260    another block
261
262    This is yet another block
263
264    We then build up a 'block_infos' structure containing a dict where the
265    text of each block is the key and a list of the sets of prefixes that may
266    generate that particular block.  This then goes through a series of
267    transformations to minimise the amount of CHECK lines that need to be
268    written by taking advantage of common prefixes.
269    """
270
271    def _block_key(tool_args, prefixes):
272        """Get a hashable key based on the current tool_args and prefixes."""
273        return " ".join([tool_args] + prefixes)
274
275    all_blocks = {}
276    max_block_len = 0
277
278    # A cache of the furthest-back position in any block list of the first
279    # instance of each block, indexed by the block itself.
280    farthest_indexes = defaultdict(int)
281
282    # Run the tool for each run line to generate all of the blocks.
283    for prefixes, tool_args in run_infos:
284        key = _block_key(tool_args, prefixes)
285        raw_tool_output = common.invoke_tool(args.llvm_mca_binary, tool_args, test_path)
286
287        # Replace any lines consisting of purely whitespace with empty lines.
288        raw_tool_output = "\n".join(
289            line if line.strip() else "" for line in raw_tool_output.splitlines()
290        )
291
292        # Split blocks, stripping all trailing whitespace, but keeping preceding
293        # whitespace except for newlines so that columns will line up visually.
294        all_blocks[key] = [
295            b.lstrip("\n").rstrip() for b in raw_tool_output.split("\n\n")
296        ]
297        max_block_len = max(max_block_len, len(all_blocks[key]))
298
299        # Attempt to align matching blocks until no more changes can be made.
300        made_changes = True
301        while made_changes:
302            made_changes = _align_matching_blocks(all_blocks, farthest_indexes)
303
304    # If necessary, pad the lists of blocks with empty blocks so that they are
305    # all the same length.
306    for key in all_blocks:
307        len_to_pad = max_block_len - len(all_blocks[key])
308        all_blocks[key] += [""] * len_to_pad
309
310    # Create the block_infos structure where it is a nested dict in the form of:
311    # block number -> block text -> list of prefix sets
312    block_infos = defaultdict(lambda: defaultdict(list))
313    for prefixes, tool_args in run_infos:
314        key = _block_key(tool_args, prefixes)
315        for block_num, block_text in enumerate(all_blocks[key]):
316            block_infos[block_num][block_text].append(set(prefixes))
317
318    # Now go through the block_infos structure and attempt to smartly prune the
319    # number of prefixes per block to the minimal set possible to output.
320    for block_num in range(len(block_infos)):
321        # When there are multiple block texts for a block num, remove any
322        # prefixes that are common to more than one of them.
323        # E.g. [ [{ALL,FOO}] , [{ALL,BAR}] ] -> [ [{FOO}] , [{BAR}] ]
324        all_sets = [s for s in block_infos[block_num].values()]
325        pruned_sets = []
326
327        for i, setlist in enumerate(all_sets):
328            other_set_values = set(
329                [
330                    elem
331                    for j, setlist2 in enumerate(all_sets)
332                    for set_ in setlist2
333                    for elem in set_
334                    if i != j
335                ]
336            )
337            pruned_sets.append([s - other_set_values for s in setlist])
338
339        for i, block_text in enumerate(block_infos[block_num]):
340
341            # When a block text matches multiple sets of prefixes, try removing any
342            # prefixes that aren't common to all of them.
343            # E.g. [ {ALL,FOO} , {ALL,BAR} ] -> [{ALL}]
344            common_values = set.intersection(*pruned_sets[i])
345            if common_values:
346                pruned_sets[i] = [common_values]
347
348            # Everything should be uniqued as much as possible by now.  Apply the
349            # newly pruned sets to the block_infos structure.
350            # If there are any blocks of text that still match multiple prefixes,
351            # output a warning.
352            current_set = set()
353            for s in pruned_sets[i]:
354                s = sorted(list(s))
355                if s:
356                    current_set.add(s[0])
357                    if len(s) > 1:
358                        _warn(
359                            "Multiple prefixes generating same output: {} "
360                            "(discarding {})".format(",".join(s), ",".join(s[1:]))
361                        )
362
363            if block_text and not current_set:
364                raise Error(
365                    "block not captured by existing prefixes:\n\n{}".format(block_text)
366                )
367            block_infos[block_num][block_text] = sorted(list(current_set))
368
369        # If we have multiple block_texts, try to break them down further to avoid
370        # the case where we have very similar block_texts repeated after each
371        # other.
372        if common_prefix and len(block_infos[block_num]) > 1:
373            # We'll only attempt this if each of the block_texts have the same number
374            # of lines as each other.
375            same_num_Lines = (
376                len(set(len(k.splitlines()) for k in block_infos[block_num].keys()))
377                == 1
378            )
379            if same_num_Lines:
380                breakdown = _break_down_block(block_infos[block_num], common_prefix)
381                if breakdown:
382                    block_infos[block_num] = breakdown
383
384    return block_infos
385
386
387def _write_block(output, block, not_prefix_set, common_prefix, prefix_pad):
388    """Write an individual block, with correct padding on the prefixes.
389    Returns a set of all of the prefixes that it has written.
390    """
391    end_prefix = ":     "
392    previous_prefix = None
393    num_lines_of_prefix = 0
394    written_prefixes = set()
395
396    for prefix, line in block:
397        if prefix in not_prefix_set:
398            _warn(
399                'not writing for prefix {0} due to presence of "{0}-NOT:" '
400                "in input file.".format(prefix)
401            )
402            continue
403
404        # If the previous line isn't already blank and we're writing more than one
405        # line for the current prefix output a blank line first, unless either the
406        # current of previous prefix is common to all.
407        num_lines_of_prefix += 1
408        if prefix != previous_prefix:
409            if output and output[-1]:
410                if num_lines_of_prefix > 1 or any(
411                    p == common_prefix for p in (prefix, previous_prefix)
412                ):
413                    output.append("")
414            num_lines_of_prefix = 0
415            previous_prefix = prefix
416
417        written_prefixes.add(prefix)
418        output.append(
419            "{} {}{}{} {}".format(
420                COMMENT_CHAR, prefix, end_prefix, " " * (prefix_pad - len(prefix)), line
421            ).rstrip()
422        )
423        end_prefix = "-NEXT:"
424
425    output.append("")
426    return written_prefixes
427
428
429def _write_output(
430    test_path,
431    input_lines,
432    prefix_list,
433    block_infos,  # noqa
434    args,
435    common_prefix,
436    prefix_pad,
437):
438    prefix_set = set([prefix for prefixes, _ in prefix_list for prefix in prefixes])
439    not_prefix_set = set()
440
441    output_lines = []
442    for input_line in input_lines:
443        if input_line.startswith(ADVERT_PREFIX):
444            continue
445
446        if input_line.startswith(COMMENT_CHAR):
447            m = common.CHECK_RE.match(input_line)
448            try:
449                prefix = m.group(1)
450            except AttributeError:
451                prefix = None
452
453            if "{}-NOT:".format(prefix) in input_line:
454                not_prefix_set.add(prefix)
455
456            if prefix not in prefix_set or prefix in not_prefix_set:
457                output_lines.append(input_line)
458                continue
459
460        if common.should_add_line_to_output(input_line, prefix_set):
461            # This input line of the function body will go as-is into the output.
462            # Except make leading whitespace uniform: 2 spaces.
463            input_line = common.SCRUB_LEADING_WHITESPACE_RE.sub(r"  ", input_line)
464
465            # Skip empty lines if the previous output line is also empty.
466            if input_line or output_lines[-1]:
467                output_lines.append(input_line)
468        else:
469            continue
470
471    # Add a blank line before the new checks if required.
472    if len(output_lines) > 0 and output_lines[-1]:
473        output_lines.append("")
474
475    output_check_lines = []
476    used_prefixes = set()
477    for block_num in range(len(block_infos)):
478        if type(block_infos[block_num]) is list:
479            # The block is of the type output from _break_down_block().
480            used_prefixes |= _write_block(
481                output_check_lines,
482                block_infos[block_num],
483                not_prefix_set,
484                common_prefix,
485                prefix_pad,
486            )
487        else:
488            # _break_down_block() was unable to do do anything so output the block
489            # as-is.
490
491            # Rather than writing out each block as soon we encounter it, save it
492            # indexed by prefix so that we can write all of the blocks out sorted by
493            # prefix at the end.
494            output_blocks = defaultdict(list)
495
496            for block_text in sorted(block_infos[block_num]):
497
498                if not block_text:
499                    continue
500
501                lines = block_text.split("\n")
502                for prefix in block_infos[block_num][block_text]:
503                    assert prefix not in output_blocks
504                    used_prefixes |= _write_block(
505                        output_blocks[prefix],
506                        [(prefix, line) for line in lines],
507                        not_prefix_set,
508                        common_prefix,
509                        prefix_pad,
510                    )
511
512            for prefix in sorted(output_blocks):
513                output_check_lines.extend(output_blocks[prefix])
514
515    unused_prefixes = (prefix_set - not_prefix_set) - used_prefixes
516    if unused_prefixes:
517        raise Error("unused prefixes: {}".format(sorted(unused_prefixes)))
518
519    if output_check_lines:
520        output_lines.insert(0, ADVERT)
521        output_lines.extend(output_check_lines)
522
523    # The file should not end with two newlines. It creates unnecessary churn.
524    while len(output_lines) > 0 and output_lines[-1] == "":
525        output_lines.pop()
526
527    if input_lines == output_lines:
528        sys.stderr.write("            [unchanged]\n")
529        return
530    sys.stderr.write("      [{} lines total]\n".format(len(output_lines)))
531
532    common.debug("Writing", len(output_lines), "lines to", test_path, "..\n\n")
533
534    with open(test_path, "wb") as f:
535        f.writelines(["{}\n".format(l).encode("utf-8") for l in output_lines])
536
537
538def update_test_file(args, test_path, autogenerated_note):
539    sys.stderr.write("Test: {}\n".format(test_path))
540
541    # Call this per test. By default each warning will only be written once
542    # per source location. Reset the warning filter so that now each warning
543    # will be written once per source location per test.
544    _configure_warnings(args)
545
546    with open(test_path) as f:
547        input_lines = [l.rstrip() for l in f]
548
549    run_lines = common.find_run_lines(test_path, input_lines)
550    run_infos = _get_run_infos(run_lines, args)
551    common_prefix, prefix_pad = _get_useful_prefix_info(run_infos)
552    block_infos = _get_block_infos(run_infos, test_path, args, common_prefix)
553    _write_output(
554        test_path,
555        input_lines,
556        run_infos,
557        block_infos,
558        args,
559        common_prefix,
560        prefix_pad,
561    )
562
563def main():
564    script_name = "utils/" + os.path.basename(__file__)
565    parser = _get_parser()
566    args = common.parse_commandline_args(parser)
567    if not args.llvm_mca_binary:
568        raise Error("--llvm-mca-binary value cannot be empty string")
569
570    if "llvm-mca" not in os.path.basename(args.llvm_mca_binary):
571        _warn("unexpected binary name: {}".format(args.llvm_mca_binary))
572
573    for ti in common.itertests(args.tests, parser, script_name=script_name):
574        try:
575            update_test_file(ti.args, ti.path, ti.test_autogenerated_note)
576        except Exception:
577            common.warn("Error processing file", test_file=ti.path)
578            raise
579    return 0
580
581if __name__ == "__main__":
582    try:
583        warnings.showwarning = _showwarning
584        sys.exit(main())
585    except Error as e:
586        sys.stdout.write("error: {}\n".format(e))
587        sys.exit(1)
588