xref: /llvm-project/mlir/utils/generate-test-checks.py (revision dbf798fa646811c03e40c25f9bb3a456267c5a73)
1#!/usr/bin/env python3
2"""A script to generate FileCheck statements for mlir unit tests.
3
4This script is a utility to add FileCheck patterns to an mlir file.
5
6NOTE: The input .mlir is expected to be the output from the parser, not a
7stripped down variant.
8
9Example usage:
10$ generate-test-checks.py foo.mlir
11$ mlir-opt foo.mlir -transformation | generate-test-checks.py
12$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
15
16The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17within the file. By default this script will also try to insert string
18substitution blocks for all SSA value names. If --source file is specified, the
19script will attempt to insert the generated CHECKs to the source file by looking
20for line positions matched by --source_delim_regex.
21
22The script is designed to make adding checks to a test case fast, it is *not*
23designed to be authoritative about what constitutes a good test!
24"""
25
26# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27# See https://llvm.org/LICENSE.txt for license information.
28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
29
30import argparse
31import os  # Used to advertise this file's name ("autogenerated_note").
32import re
33import sys
34
35ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
36ADVERT_END = """
37// The script is designed to make adding checks to
38// a test case fast, it is *not* designed to be authoritative
39// about what constitutes a good test! The CHECK should be
40// minimized and named to reflect the test intent.
41"""
42
43
44# Regex command to match an SSA identifier.
45SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
46SSA_RE = re.compile(SSA_RE_STR)
47
48# Regex matching the left-hand side of an assignment
49SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
50SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
51
52# Regex matching attributes
53ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
54ATTR_RE = re.compile(ATTR_RE_STR)
55
56# Regex matching the left-hand side of an attribute definition
57ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
58ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
59
60
61# Class used to generate and manage string substitution blocks for SSA value
62# names.
63class VariableNamer:
64    def __init__(self, variable_names):
65        self.scopes = []
66        self.name_counter = 0
67
68        # Number of variable names to still generate in parent scope
69        self.generate_in_parent_scope_left = 0
70
71        # Parse variable names
72        self.variable_names = [name.upper() for name in variable_names.split(',')]
73        self.used_variable_names = set()
74
75    # Generate the following 'n' variable names in the parent scope.
76    def generate_in_parent_scope(self, n):
77        self.generate_in_parent_scope_left = n
78
79    # Generate a substitution name for the given ssa value name.
80    def generate_name(self, source_variable_name):
81
82        # Compute variable name
83        variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
84        if variable_name == '':
85            variable_name = "VAL_" + str(self.name_counter)
86            self.name_counter += 1
87
88        # Scope where variable name is saved
89        scope = len(self.scopes) - 1
90        if self.generate_in_parent_scope_left > 0:
91            self.generate_in_parent_scope_left -= 1
92            scope = len(self.scopes) - 2
93        assert(scope >= 0)
94
95        # Save variable
96        if variable_name in self.used_variable_names:
97            raise RuntimeError(variable_name + ': duplicate variable name')
98        self.scopes[scope][source_variable_name] = variable_name
99        self.used_variable_names.add(variable_name)
100
101        return variable_name
102
103    # Push a new variable name scope.
104    def push_name_scope(self):
105        self.scopes.append({})
106
107    # Pop the last variable name scope.
108    def pop_name_scope(self):
109        self.scopes.pop()
110
111    # Return the level of nesting (number of pushed scopes).
112    def num_scopes(self):
113        return len(self.scopes)
114
115    # Reset the counter and used variable names.
116    def clear_names(self):
117        self.name_counter = 0
118        self.used_variable_names = set()
119
120class AttributeNamer:
121
122    def __init__(self, attribute_names):
123        self.name_counter = 0
124        self.attribute_names = [name.upper() for name in attribute_names.split(',')]
125        self.map = {}
126        self.used_attribute_names = set()
127
128    # Generate a substitution name for the given attribute name.
129    def generate_name(self, source_attribute_name):
130
131        # Compute FileCheck name
132        attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
133        if attribute_name == '':
134            attribute_name = "ATTR_" + str(self.name_counter)
135            self.name_counter += 1
136
137        # Prepend global symbol
138        attribute_name = '$' + attribute_name
139
140        # Save attribute
141        if attribute_name in self.used_attribute_names:
142            raise RuntimeError(attribute_name + ': duplicate attribute name')
143        self.map[source_attribute_name] = attribute_name
144        self.used_attribute_names.add(attribute_name)
145        return attribute_name
146
147    # Get the saved substitution name for the given attribute name. If no name
148    # has been generated for the given attribute yet, the source attribute name
149    # itself is returned.
150    def get_name(self, source_attribute_name):
151        return self.map[source_attribute_name] if source_attribute_name in self.map else '?'
152
153# Return the number of SSA results in a line of type
154#   %0, %1, ... = ...
155# The function returns 0 if there are no results.
156def get_num_ssa_results(input_line):
157    m = SSA_RESULTS_RE.match(input_line)
158    return m.group().count('%') if m else 0
159
160
161# Process a line of input that has been split at each SSA identifier '%'.
162def process_line(line_chunks, variable_namer):
163    output_line = ""
164
165    # Process the rest that contained an SSA value name.
166    for chunk in line_chunks:
167        m = SSA_RE.match(chunk)
168        ssa_name = m.group(0) if m is not None else ''
169
170        # Check if an existing variable exists for this name.
171        variable = None
172        for scope in variable_namer.scopes:
173            variable = scope.get(ssa_name)
174            if variable is not None:
175                break
176
177        # If one exists, then output the existing name.
178        if variable is not None:
179            output_line += "%[[" + variable + "]]"
180        else:
181            # Otherwise, generate a new variable.
182            variable = variable_namer.generate_name(ssa_name)
183            output_line += "%[[" + variable + ":.*]]"
184
185        # Append the non named group.
186        output_line += chunk[len(ssa_name) :]
187
188    return output_line.rstrip() + "\n"
189
190
191# Process the source file lines. The source file doesn't have to be .mlir.
192def process_source_lines(source_lines, note, args):
193    source_split_re = re.compile(args.source_delim_regex)
194
195    source_segments = [[]]
196    for line in source_lines:
197        # Remove previous note.
198        if line == note:
199            continue
200        # Remove previous CHECK lines.
201        if line.find(args.check_prefix) != -1:
202            continue
203        # Segment the file based on --source_delim_regex.
204        if source_split_re.search(line):
205            source_segments.append([])
206
207        source_segments[-1].append(line + "\n")
208    return source_segments
209
210def process_attribute_definition(line, attribute_namer, output):
211    m = ATTR_DEF_RE.match(line)
212    if m:
213        attribute_name = attribute_namer.generate_name(m.group(1))
214        line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n'
215        output.write(line)
216
217def process_attribute_references(line, attribute_namer):
218
219    output_line = ''
220    components = ATTR_RE.split(line)
221    for component in components:
222        m = ATTR_RE.match(component)
223        if m:
224            output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]'
225            output_line += component[len(m.group()):]
226        else:
227            output_line += component
228    return output_line
229
230# Pre-process a line of input to remove any character sequences that will be
231# problematic with FileCheck.
232def preprocess_line(line):
233    # Replace any double brackets, '[[' with escaped replacements. '[['
234    # corresponds to variable names in FileCheck.
235    output_line = line.replace("[[", "{{\\[\\[}}")
236
237    # Replace any single brackets that are followed by an SSA identifier, the
238    # identifier will be replace by a variable; Creating the same situation as
239    # above.
240    output_line = output_line.replace("[%", "{{\\[}}%")
241
242    return output_line
243
244
245def main():
246    parser = argparse.ArgumentParser(
247        description=__doc__, formatter_class=argparse.RawTextHelpFormatter
248    )
249    parser.add_argument(
250        "--check-prefix", default="CHECK", help="Prefix to use from check file."
251    )
252    parser.add_argument(
253        "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None
254    )
255    parser.add_argument(
256        "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
257    )
258    parser.add_argument(
259        "--source",
260        type=str,
261        help="Print each CHECK chunk before each delimeter line in the source"
262        "file, respectively. The delimeter lines are identified by "
263        "--source_delim_regex.",
264    )
265    parser.add_argument("--source_delim_regex", type=str, default="func @")
266    parser.add_argument(
267        "--starts_from_scope",
268        type=int,
269        default=1,
270        help="Omit the top specified level of content. For example, by default "
271        'it omits "module {"',
272    )
273    parser.add_argument("-i", "--inplace", action="store_true", default=False)
274    parser.add_argument(
275        "--variable_names",
276        type=str,
277        default='',
278        help="Names to be used in FileCheck regular expression to represent SSA "
279        "variables in the order they are encountered. Separate names with commas, "
280        "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
281    parser.add_argument(
282        "--attribute_names",
283        type=str,
284        default='',
285        help="Names to be used in FileCheck regular expression to represent "
286        "attributes in the order they are defined. Separate names with commas,"
287        "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
288
289    args = parser.parse_args()
290
291    # Open the given input file.
292    input_lines = [l.rstrip() for l in args.input]
293    args.input.close()
294
295    # Generate a note used for the generated check file.
296    script_name = os.path.basename(__file__)
297    autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END
298
299    source_segments = None
300    if args.source:
301        source_segments = process_source_lines(
302            [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args
303        )
304
305    if args.inplace:
306        assert args.output is None
307        output = open(args.source, "w")
308    elif args.output is None:
309        output = sys.stdout
310    else:
311        output = args.output
312
313    output_segments = [[]]
314
315    # Namers
316    variable_namer = VariableNamer(args.variable_names)
317    attribute_namer = AttributeNamer(args.attribute_names)
318
319    # Process lines
320    for input_line in input_lines:
321        if not input_line:
322            continue
323
324        # Check if this is an attribute definition and process it
325        process_attribute_definition(input_line, attribute_namer, output)
326
327        # Lines with blocks begin with a ^. These lines have a trailing comment
328        # that needs to be stripped.
329        lstripped_input_line = input_line.lstrip()
330        is_block = lstripped_input_line[0] == "^"
331        if is_block:
332            input_line = input_line.rsplit("//", 1)[0].rstrip()
333
334        cur_level = variable_namer.num_scopes()
335
336        # If the line starts with a '}', pop the last name scope.
337        if lstripped_input_line[0] == "}":
338            variable_namer.pop_name_scope()
339            cur_level = variable_namer.num_scopes()
340
341        # If the line ends with a '{', push a new name scope.
342        if input_line[-1] == "{":
343            variable_namer.push_name_scope()
344            if cur_level == args.starts_from_scope:
345                output_segments.append([])
346
347            # Result SSA values must still be pushed to parent scope
348            num_ssa_results = get_num_ssa_results(input_line)
349            variable_namer.generate_in_parent_scope(num_ssa_results)
350
351        # Omit lines at the near top level e.g. "module {".
352        if cur_level < args.starts_from_scope:
353            continue
354
355        if len(output_segments[-1]) == 0:
356            variable_namer.clear_names()
357
358        # Preprocess the input to remove any sequences that may be problematic with
359        # FileCheck.
360        input_line = preprocess_line(input_line)
361
362        # Process uses of attributes in this line
363        input_line = process_attribute_references(input_line, attribute_namer)
364
365        # Split the line at the each SSA value name.
366        ssa_split = input_line.split("%")
367
368        # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
369        if len(output_segments[-1]) != 0 or not ssa_split[0]:
370            output_line = "// " + args.check_prefix + ": "
371            # Pad to align with the 'LABEL' statements.
372            output_line += " " * len("-LABEL")
373
374            # Output the first line chunk that does not contain an SSA name.
375            output_line += ssa_split[0]
376
377            # Process the rest of the input line.
378            output_line += process_line(ssa_split[1:], variable_namer)
379
380        else:
381            # Output the first line chunk that does not contain an SSA name for the
382            # label.
383            output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"
384
385            # Process the rest of the input line on separate check lines.
386            for argument in ssa_split[1:]:
387                output_line += "// " + args.check_prefix + "-SAME:  "
388
389                # Pad to align with the original position in the line.
390                output_line += " " * len(ssa_split[0])
391
392                # Process the rest of the line.
393                output_line += process_line([argument], variable_namer)
394
395        # Append the output line.
396        output_segments[-1].append(output_line)
397
398    output.write(autogenerated_note + "\n")
399
400    # Write the output.
401    if source_segments:
402        assert len(output_segments) == len(source_segments)
403        for check_segment, source_segment in zip(output_segments, source_segments):
404            for line in check_segment:
405                output.write(line)
406            for line in source_segment:
407                output.write(line)
408    else:
409        for segment in output_segments:
410            output.write("\n")
411            for output_line in segment:
412                output.write(output_line)
413        output.write("\n")
414    output.close()
415
416
417if __name__ == "__main__":
418    main()
419