xref: /llvm-project/mlir/utils/generate-test-checks.py (revision 4e7c0a37c9c92baa655d244f5bfde91d52b138d0)
1#!/usr/bin/env python3
2"""A script to generate FileCheck statements for mlir unit tests.
3
4This script is a utility to add FileCheck patterns to an mlir file.
5
6NOTE: The input .mlir is expected to be the output from the parser, not a
7stripped down variant.
8
9Example usage:
10$ generate-test-checks.py foo.mlir
11$ mlir-opt foo.mlir -transformation | generate-test-checks.py
12$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
15
16The script will heuristically generate CHECK/CHECK-LABEL commands for each line
17within the file. By default this script will also try to insert string
18substitution blocks for all SSA value names. If --source file is specified, the
19script will attempt to insert the generated CHECKs to the source file by looking
20for line positions matched by --source_delim_regex.
21
22The script is designed to make adding checks to a test case fast, it is *not*
23designed to be authoritative about what constitutes a good test!
24"""
25
26# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
27# See https://llvm.org/LICENSE.txt for license information.
28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
29
30import argparse
31import os  # Used to advertise this file's name ("autogenerated_note").
32import re
33import sys
34
35ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by '
36ADVERT_END = """
37// The script is designed to make adding checks to
38// a test case fast, it is *not* designed to be authoritative
39// about what constitutes a good test! The CHECK should be
40// minimized and named to reflect the test intent.
41"""
42
43
44# Regex command to match an SSA identifier.
45SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*'
46SSA_RE = re.compile(SSA_RE_STR)
47
48
49# Class used to generate and manage string substitution blocks for SSA value
50# names.
51class SSAVariableNamer:
52
53  def __init__(self):
54    self.scopes = []
55    self.name_counter = 0
56
57  # Generate a substitution name for the given ssa value name.
58  def generate_name(self, ssa_name):
59    variable = 'VAL_' + str(self.name_counter)
60    self.name_counter += 1
61    self.scopes[-1][ssa_name] = variable
62    return variable
63
64  # Push a new variable name scope.
65  def push_name_scope(self):
66    self.scopes.append({})
67
68  # Pop the last variable name scope.
69  def pop_name_scope(self):
70    self.scopes.pop()
71
72  # Return the level of nesting (number of pushed scopes).
73  def num_scopes(self):
74    return len(self.scopes)
75
76  # Reset the counter.
77  def clear_counter(self):
78    self.name_counter = 0
79
80
81# Process a line of input that has been split at each SSA identifier '%'.
82def process_line(line_chunks, variable_namer):
83  output_line = ''
84
85  # Process the rest that contained an SSA value name.
86  for chunk in line_chunks:
87    m = SSA_RE.match(chunk)
88    ssa_name = m.group(0)
89
90    # Check if an existing variable exists for this name.
91    variable = None
92    for scope in variable_namer.scopes:
93      variable = scope.get(ssa_name)
94      if variable is not None:
95        break
96
97    # If one exists, then output the existing name.
98    if variable is not None:
99      output_line += '%[[' + variable + ']]'
100    else:
101      # Otherwise, generate a new variable.
102      variable = variable_namer.generate_name(ssa_name)
103      output_line += '%[[' + variable + ':.*]]'
104
105    # Append the non named group.
106    output_line += chunk[len(ssa_name):]
107
108  return output_line.rstrip() + '\n'
109
110
111# Process the source file lines. The source file doesn't have to be .mlir.
112def process_source_lines(source_lines, note, args):
113  source_split_re = re.compile(args.source_delim_regex)
114
115  source_segments = [[]]
116  for line in source_lines:
117    # Remove previous note.
118    if line == note:
119      continue
120    # Remove previous CHECK lines.
121    if line.find(args.check_prefix) != -1:
122      continue
123    # Segment the file based on --source_delim_regex.
124    if source_split_re.search(line):
125      source_segments.append([])
126
127    source_segments[-1].append(line + '\n')
128  return source_segments
129
130
131# Pre-process a line of input to remove any character sequences that will be
132# problematic with FileCheck.
133def preprocess_line(line):
134  # Replace any double brackets, '[[' with escaped replacements. '[['
135  # corresponds to variable names in FileCheck.
136  output_line = line.replace('[[', '{{\\[\\[}}')
137
138  # Replace any single brackets that are followed by an SSA identifier, the
139  # identifier will be replace by a variable; Creating the same situation as
140  # above.
141  output_line = output_line.replace('[%', '{{\\[}}%')
142
143  return output_line
144
145
146def main():
147  parser = argparse.ArgumentParser(
148      description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
149  parser.add_argument(
150      '--check-prefix', default='CHECK', help='Prefix to use from check file.')
151  parser.add_argument(
152      '-o',
153      '--output',
154      nargs='?',
155      type=argparse.FileType('w'),
156      default=None)
157  parser.add_argument(
158      'input',
159      nargs='?',
160      type=argparse.FileType('r'),
161      default=sys.stdin)
162  parser.add_argument(
163      '--source', type=str,
164      help='Print each CHECK chunk before each delimeter line in the source'
165           'file, respectively. The delimeter lines are identified by '
166           '--source_delim_regex.')
167  parser.add_argument('--source_delim_regex', type=str, default='func @')
168  parser.add_argument(
169      '--starts_from_scope', type=int, default=1,
170      help='Omit the top specified level of content. For example, by default '
171           'it omits "module {"')
172  parser.add_argument('-i', '--inplace', action='store_true', default=False)
173
174  args = parser.parse_args()
175
176  # Open the given input file.
177  input_lines = [l.rstrip() for l in args.input]
178  args.input.close()
179
180  # Generate a note used for the generated check file.
181  script_name = os.path.basename(__file__)
182  autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END)
183
184  source_segments = None
185  if args.source:
186    source_segments = process_source_lines(
187        [l.rstrip() for l in open(args.source, 'r')],
188        autogenerated_note,
189        args
190    )
191
192  if args.inplace:
193    assert args.output is None
194    output = open(args.source, 'w')
195  elif args.output is None:
196    output = sys.stdout
197  else:
198    output = args.output
199
200  output_segments = [[]]
201  # A map containing data used for naming SSA value names.
202  variable_namer = SSAVariableNamer()
203  for input_line in input_lines:
204    if not input_line:
205      continue
206    lstripped_input_line = input_line.lstrip()
207
208    # Lines with blocks begin with a ^. These lines have a trailing comment
209    # that needs to be stripped.
210    is_block = lstripped_input_line[0] == '^'
211    if is_block:
212      input_line = input_line.rsplit('//', 1)[0].rstrip()
213
214    cur_level = variable_namer.num_scopes()
215
216    # If the line starts with a '}', pop the last name scope.
217    if lstripped_input_line[0] == '}':
218      variable_namer.pop_name_scope()
219      cur_level = variable_namer.num_scopes()
220
221    # If the line ends with a '{', push a new name scope.
222    if input_line[-1] == '{':
223      variable_namer.push_name_scope()
224      if cur_level == args.starts_from_scope:
225        output_segments.append([])
226
227    # Omit lines at the near top level e.g. "module {".
228    if cur_level < args.starts_from_scope:
229      continue
230
231    if len(output_segments[-1]) == 0:
232      variable_namer.clear_counter()
233
234    # Preprocess the input to remove any sequences that may be problematic with
235    # FileCheck.
236    input_line = preprocess_line(input_line)
237
238    # Split the line at the each SSA value name.
239    ssa_split = input_line.split('%')
240
241    # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
242    if len(output_segments[-1]) != 0 or not ssa_split[0]:
243      output_line = '// ' + args.check_prefix + ': '
244      # Pad to align with the 'LABEL' statements.
245      output_line += (' ' * len('-LABEL'))
246
247      # Output the first line chunk that does not contain an SSA name.
248      output_line += ssa_split[0]
249
250      # Process the rest of the input line.
251      output_line += process_line(ssa_split[1:], variable_namer)
252
253    else:
254      # Output the first line chunk that does not contain an SSA name for the
255      # label.
256      output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n'
257
258      # Process the rest of the input line on separate check lines.
259      for argument in ssa_split[1:]:
260        output_line += '// ' + args.check_prefix + '-SAME:  '
261
262        # Pad to align with the original position in the line.
263        output_line += ' ' * len(ssa_split[0])
264
265        # Process the rest of the line.
266        output_line += process_line([argument], variable_namer)
267
268    # Append the output line.
269    output_segments[-1].append(output_line)
270
271  output.write(autogenerated_note + '\n')
272
273  # Write the output.
274  if source_segments:
275    assert len(output_segments) == len(source_segments)
276    for check_segment, source_segment in zip(output_segments, source_segments):
277      for line in check_segment:
278        output.write(line)
279      for line in source_segment:
280        output.write(line)
281  else:
282    for segment in output_segments:
283      output.write('\n')
284      for output_line in segment:
285        output.write(output_line)
286    output.write('\n')
287  output.close()
288
289
290if __name__ == '__main__':
291  main()
292