xref: /llvm-project/llvm/utils/update_mc_test_checks.py (revision 6f973fd4ab18ff58689e83383190ed4767c2a7dd)
1#!/usr/bin/env python3
2"""
3A test update script.  This script is a utility to update LLVM 'llvm-mc' based test cases with new FileCheck patterns.
4"""
5
6from __future__ import print_function
7
8import argparse
9import functools
10import os  # Used to advertise this file's name ("autogenerated_note").
11
12from UpdateTestChecks import common
13
14import subprocess
15import re
16
17mc_LIKE_TOOLS = [
18    "llvm-mc",
19]
20ERROR_RE = re.compile(r":\d+: (warning|error): .*")
21ERROR_CHECK_RE = re.compile(r"# COM: .*")
22OUTPUT_SKIPPED_RE = re.compile(r"(.text)")
23COMMENT = {"asm": "//", "dasm": "#"}
24
25
26def invoke_tool(exe, check_rc, cmd_args, testline, verbose=False):
27    if isinstance(cmd_args, list):
28        args = [applySubstitutions(a, substitutions) for a in cmd_args]
29    else:
30        args = cmd_args
31
32    cmd = 'echo "' + testline + '" | ' + exe + " " + args
33    if verbose:
34        print("Command: ", cmd)
35
36    out = subprocess.run(
37        cmd,
38        shell=True,
39        check=check_rc,
40        stdout=subprocess.PIPE,
41        stderr=subprocess.DEVNULL,
42    ).stdout
43
44    # Fix line endings to unix CR style.
45    return out.decode().replace("\r\n", "\n")
46
47
48# create tests line-by-line, here we just filter out the check lines and comments
49# and treat all others as tests
50def isTestLine(input_line, mc_mode):
51    line = input_line.strip()
52    # Skip empty and comment lines
53    if not line or line.startswith(COMMENT[mc_mode]):
54        return False
55    # skip any CHECK lines.
56    elif common.CHECK_RE.match(input_line):
57        return False
58    return True
59
60
61def isRunLine(l):
62    return common.RUN_LINE_RE.match(l)
63
64
65def hasErr(err):
66    return err and ERROR_RE.search(err) is not None
67
68
69def getErrString(err):
70    if not err:
71        return ""
72
73    # take the first match
74    for line in err.splitlines():
75        s = ERROR_RE.search(line)
76        if s:
77            return s.group(0)
78    return ""
79
80
81def getOutputString(out):
82    if not out:
83        return ""
84    output = ""
85
86    for line in out.splitlines():
87        if OUTPUT_SKIPPED_RE.search(line):
88            continue
89        if line.strip("\t ") == "":
90            continue
91        output += line.lstrip("\t ")
92    return output
93
94
95def should_add_line_to_output(input_line, prefix_set, mc_mode):
96    # special check line
97    if mc_mode == "dasm" and ERROR_CHECK_RE.search(input_line):
98        return False
99    else:
100        return common.should_add_line_to_output(
101            input_line, prefix_set, comment_marker=COMMENT[mc_mode]
102        )
103
104
105def getStdCheckLine(prefix, output, mc_mode):
106    o = ""
107    for line in output.splitlines():
108        o += COMMENT[mc_mode] + " " + prefix + ": " + line + "\n"
109    return o
110
111
112def getErrCheckLine(prefix, output, mc_mode, line_offset=1):
113    return (
114        COMMENT[mc_mode]
115        + " "
116        + prefix
117        + ": "
118        + ":[[@LINE-{}]]".format(line_offset)
119        + output
120        + "\n"
121    )
122
123
124def main():
125    parser = argparse.ArgumentParser(description=__doc__)
126    parser.add_argument(
127        "--llvm-mc-binary",
128        default=None,
129        help='The "mc" binary to use to generate the test case',
130    )
131    parser.add_argument(
132        "--tool",
133        default=None,
134        help="Treat the given tool name as an mc-like tool for which check lines should be generated",
135    )
136    parser.add_argument(
137        "--default-march",
138        default=None,
139        help="Set a default -march for when neither triple nor arch are found in a RUN line",
140    )
141    parser.add_argument(
142        "--unique",
143        action="store_true",
144        default=False,
145        help="remove duplicated test line if found",
146    )
147    parser.add_argument(
148        "--sort",
149        action="store_true",
150        default=False,
151        help="sort testline in alphabetic order (keep run-lines on top), this option could be dangerous as it"
152        "could change the order of lines that are not expected",
153    )
154    parser.add_argument("tests", nargs="+")
155    initial_args = common.parse_commandline_args(parser)
156
157    script_name = os.path.basename(__file__)
158
159    for ti in common.itertests(
160        initial_args.tests, parser, script_name="utils/" + script_name
161    ):
162        if ti.path.endswith(".s"):
163            mc_mode = "asm"
164        elif ti.path.endswith(".txt"):
165            mc_mode = "dasm"
166
167            if ti.args.sort:
168                print("sorting with dasm(.txt) file is not supported!")
169                return -1
170
171        else:
172            common.warn("Expected .s and .txt, Skipping file : ", ti.path)
173            continue
174
175        triple_in_ir = None
176        for l in ti.input_lines:
177            m = common.TRIPLE_IR_RE.match(l)
178            if m:
179                triple_in_ir = m.groups()[0]
180                break
181
182        run_list = []
183        for l in ti.run_lines:
184            if "|" not in l:
185                common.warn("Skipping unparsable RUN line: " + l)
186                continue
187
188            commands = [cmd.strip() for cmd in l.split("|")]
189            assert len(commands) >= 2
190            mc_cmd = " | ".join(commands[:-1])
191            filecheck_cmd = commands[-1]
192
193            # special handling for negating exit status
194            # if not is used in runline, disable rc check, since
195            # the command might or might not
196            # return non-zero code on a single line run
197            check_rc = True
198            mc_cmd_args = mc_cmd.strip().split()
199            if mc_cmd_args[0] == "not":
200                check_rc = False
201                mc_tool = mc_cmd_args[1]
202                mc_cmd = mc_cmd[len(mc_cmd_args[0]) :].strip()
203            else:
204                mc_tool = mc_cmd_args[0]
205
206            triple_in_cmd = None
207            m = common.TRIPLE_ARG_RE.search(mc_cmd)
208            if m:
209                triple_in_cmd = m.groups()[0]
210
211            march_in_cmd = ti.args.default_march
212            m = common.MARCH_ARG_RE.search(mc_cmd)
213            if m:
214                march_in_cmd = m.groups()[0]
215
216            common.verify_filecheck_prefixes(filecheck_cmd)
217
218            mc_like_tools = mc_LIKE_TOOLS[:]
219            if ti.args.tool:
220                mc_like_tools.append(ti.args.tool)
221            if mc_tool not in mc_like_tools:
222                common.warn("Skipping non-mc RUN line: " + l)
223                continue
224
225            if not filecheck_cmd.startswith("FileCheck "):
226                common.warn("Skipping non-FileChecked RUN line: " + l)
227                continue
228
229            mc_cmd_args = mc_cmd[len(mc_tool) :].strip()
230            mc_cmd_args = mc_cmd_args.replace("< %s", "").replace("%s", "").strip()
231            check_prefixes = common.get_check_prefixes(filecheck_cmd)
232
233            run_list.append(
234                (
235                    check_prefixes,
236                    mc_tool,
237                    check_rc,
238                    mc_cmd_args,
239                    triple_in_cmd,
240                    march_in_cmd,
241                )
242            )
243
244        # find all test line from input
245        testlines = [l for l in ti.input_lines if isTestLine(l, mc_mode)]
246        # remove duplicated lines to save running time
247        testlines = list(dict.fromkeys(testlines))
248        common.debug("Valid test line found: ", len(testlines))
249
250        run_list_size = len(run_list)
251        testnum = len(testlines)
252
253        raw_output = []
254        raw_prefixes = []
255        for (
256            prefixes,
257            mc_tool,
258            check_rc,
259            mc_args,
260            triple_in_cmd,
261            march_in_cmd,
262        ) in run_list:
263            common.debug("Extracted mc cmd:", mc_tool, mc_args)
264            common.debug("Extracted FileCheck prefixes:", str(prefixes))
265            common.debug("Extracted triple :", str(triple_in_cmd))
266            common.debug("Extracted march:", str(march_in_cmd))
267
268            triple = triple_in_cmd or triple_in_ir
269            if not triple:
270                triple = common.get_triple_from_march(march_in_cmd)
271
272            raw_output.append([])
273            for line in testlines:
274                # get output for each testline
275                out = invoke_tool(
276                    ti.args.llvm_mc_binary or mc_tool,
277                    check_rc,
278                    mc_args,
279                    line,
280                    verbose=ti.args.verbose,
281                )
282                raw_output[-1].append(out)
283
284            common.debug("Collect raw tool lines:", str(len(raw_output[-1])))
285
286            raw_prefixes.append(prefixes)
287
288        output_lines = []
289        generated_prefixes = {}
290        used_prefixes = set()
291        prefix_set = set([prefix for p in run_list for prefix in p[0]])
292        common.debug("Rewriting FileCheck prefixes:", str(prefix_set))
293
294        for test_id in range(testnum):
295            input_line = testlines[test_id]
296
297            # a {prefix : output, [runid] } dict
298            # insert output to a prefix-key dict, and do a max sorting
299            # to select the most-used prefix which share the same output string
300            p_dict = {}
301            for run_id in range(run_list_size):
302                out = raw_output[run_id][test_id]
303
304                if hasErr(out):
305                    o = getErrString(out)
306                else:
307                    o = getOutputString(out)
308
309                prefixes = raw_prefixes[run_id]
310
311                for p in prefixes:
312                    if p not in p_dict:
313                        p_dict[p] = o, [run_id]
314                    else:
315                        if p_dict[p] == (None, []):
316                            continue
317
318                        prev_o, run_ids = p_dict[p]
319                        if o == prev_o:
320                            run_ids.append(run_id)
321                            p_dict[p] = o, run_ids
322                        else:
323                            # conflict, discard
324                            p_dict[p] = None, []
325
326            p_dict_sorted = dict(
327                sorted(p_dict.items(), key=lambda item: -len(item[1][1]))
328            )
329
330            # prefix is selected and generated with most shared output lines
331            # each run_id can only be used once
332            gen_prefix = ""
333            used_runid = set()
334
335            # line number diff between generated prefix and testline
336            line_offset = 1
337            for prefix, tup in p_dict_sorted.items():
338                o, run_ids = tup
339
340                if len(run_ids) == 0:
341                    continue
342
343                skip = False
344                for i in run_ids:
345                    if i in used_runid:
346                        skip = True
347                    else:
348                        used_runid.add(i)
349                if not skip:
350                    used_prefixes.add(prefix)
351
352                    if hasErr(o):
353                        newline = getErrCheckLine(prefix, o, mc_mode, line_offset)
354                    else:
355                        newline = getStdCheckLine(prefix, o, mc_mode)
356
357                    if newline:
358                        gen_prefix += newline
359                        line_offset += 1
360
361            generated_prefixes[input_line] = gen_prefix.rstrip("\n")
362
363        # write output
364        for input_info in ti.iterlines(output_lines):
365            input_line = input_info.line
366            if input_line in testlines:
367                output_lines.append(input_line)
368                output_lines.append(generated_prefixes[input_line])
369
370            elif should_add_line_to_output(input_line, prefix_set, mc_mode):
371                output_lines.append(input_line)
372
373        if ti.args.unique or ti.args.sort:
374            # split with double newlines
375            test_units = "\n".join(output_lines).split("\n\n")
376
377            # select the key line for each test unit
378            test_dic = {}
379            for unit in test_units:
380                lines = unit.split("\n")
381                for l in lines:
382                    # if contains multiple lines, use
383                    # the first testline or runline as key
384                    if isTestLine(l, mc_mode):
385                        test_dic[unit] = l
386                        break
387                    if isRunLine(l):
388                        test_dic[unit] = l
389                        break
390
391            # unique
392            if ti.args.unique:
393                new_test_units = []
394                written_lines = set()
395                for unit in test_units:
396                    # if not testline/runline, we just add it
397                    if unit not in test_dic:
398                        new_test_units.append(unit)
399                    else:
400                        if test_dic[unit] in written_lines:
401                            common.debug("Duplicated test skipped: ", unit)
402                            continue
403
404                        written_lines.add(test_dic[unit])
405                        new_test_units.append(unit)
406                test_units = new_test_units
407
408            # sort
409            if ti.args.sort:
410
411                def getkey(l):
412                    # find key of test unit, otherwise use first line
413                    if l in test_dic:
414                        line = test_dic[l]
415                    else:
416                        line = l.split("\n")[0]
417
418                    # runline placed on the top
419                    return (not isRunLine(line), line)
420
421                test_units = sorted(test_units, key=getkey)
422
423            # join back to be output string
424            output_lines = "\n\n".join(test_units).split("\n")
425
426        # output
427        if ti.args.gen_unused_prefix_body:
428            output_lines.extend(
429                ti.get_checks_for_unused_prefixes(run_list, used_prefixes)
430            )
431
432        common.debug("Writing %d lines to %s..." % (len(output_lines), ti.path))
433        with open(ti.path, "wb") as f:
434            f.writelines(["{}\n".format(l).encode("utf-8") for l in output_lines])
435
436
437if __name__ == "__main__":
438    main()
439