1#!/usr/bin/env python3 2"""A script to generate FileCheck statements for mlir unit tests. 3 4This script is a utility to add FileCheck patterns to an mlir file. 5 6NOTE: The input .mlir is expected to be the output from the parser, not a 7stripped down variant. 8 9Example usage: 10$ generate-test-checks.py foo.mlir 11$ mlir-opt foo.mlir -transformation | generate-test-checks.py 12$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir 13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i 14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' 15 16The script will heuristically generate CHECK/CHECK-LABEL commands for each line 17within the file. By default this script will also try to insert string 18substitution blocks for all SSA value names. If --source file is specified, the 19script will attempt to insert the generated CHECKs to the source file by looking 20for line positions matched by --source_delim_regex. 21 22The script is designed to make adding checks to a test case fast, it is *not* 23designed to be authoritative about what constitutes a good test! 24""" 25 26# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 27# See https://llvm.org/LICENSE.txt for license information. 28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 29 30import argparse 31import os # Used to advertise this file's name ("autogenerated_note"). 32import re 33import sys 34 35ADVERT_BEGIN = '// NOTE: Assertions have been autogenerated by ' 36ADVERT_END = """ 37// The script is designed to make adding checks to 38// a test case fast, it is *not* designed to be authoritative 39// about what constitutes a good test! The CHECK should be 40// minimized and named to reflect the test intent. 41""" 42 43 44# Regex command to match an SSA identifier. 45SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*' 46SSA_RE = re.compile(SSA_RE_STR) 47 48 49# Class used to generate and manage string substitution blocks for SSA value 50# names. 51class SSAVariableNamer: 52 53 def __init__(self): 54 self.scopes = [] 55 self.name_counter = 0 56 57 # Generate a substitution name for the given ssa value name. 58 def generate_name(self, ssa_name): 59 variable = 'VAL_' + str(self.name_counter) 60 self.name_counter += 1 61 self.scopes[-1][ssa_name] = variable 62 return variable 63 64 # Push a new variable name scope. 65 def push_name_scope(self): 66 self.scopes.append({}) 67 68 # Pop the last variable name scope. 69 def pop_name_scope(self): 70 self.scopes.pop() 71 72 # Return the level of nesting (number of pushed scopes). 73 def num_scopes(self): 74 return len(self.scopes) 75 76 # Reset the counter. 77 def clear_counter(self): 78 self.name_counter = 0 79 80 81# Process a line of input that has been split at each SSA identifier '%'. 82def process_line(line_chunks, variable_namer): 83 output_line = '' 84 85 # Process the rest that contained an SSA value name. 86 for chunk in line_chunks: 87 m = SSA_RE.match(chunk) 88 ssa_name = m.group(0) 89 90 # Check if an existing variable exists for this name. 91 variable = None 92 for scope in variable_namer.scopes: 93 variable = scope.get(ssa_name) 94 if variable is not None: 95 break 96 97 # If one exists, then output the existing name. 98 if variable is not None: 99 output_line += '%[[' + variable + ']]' 100 else: 101 # Otherwise, generate a new variable. 102 variable = variable_namer.generate_name(ssa_name) 103 output_line += '%[[' + variable + ':.*]]' 104 105 # Append the non named group. 106 output_line += chunk[len(ssa_name):] 107 108 return output_line.rstrip() + '\n' 109 110 111# Process the source file lines. The source file doesn't have to be .mlir. 112def process_source_lines(source_lines, note, args): 113 source_split_re = re.compile(args.source_delim_regex) 114 115 source_segments = [[]] 116 for line in source_lines: 117 # Remove previous note. 118 if line == note: 119 continue 120 # Remove previous CHECK lines. 121 if line.find(args.check_prefix) != -1: 122 continue 123 # Segment the file based on --source_delim_regex. 124 if source_split_re.search(line): 125 source_segments.append([]) 126 127 source_segments[-1].append(line + '\n') 128 return source_segments 129 130 131# Pre-process a line of input to remove any character sequences that will be 132# problematic with FileCheck. 133def preprocess_line(line): 134 # Replace any double brackets, '[[' with escaped replacements. '[[' 135 # corresponds to variable names in FileCheck. 136 output_line = line.replace('[[', '{{\\[\\[}}') 137 138 # Replace any single brackets that are followed by an SSA identifier, the 139 # identifier will be replace by a variable; Creating the same situation as 140 # above. 141 output_line = output_line.replace('[%', '{{\\[}}%') 142 143 return output_line 144 145 146def main(): 147 parser = argparse.ArgumentParser( 148 description=__doc__, formatter_class=argparse.RawTextHelpFormatter) 149 parser.add_argument( 150 '--check-prefix', default='CHECK', help='Prefix to use from check file.') 151 parser.add_argument( 152 '-o', 153 '--output', 154 nargs='?', 155 type=argparse.FileType('w'), 156 default=None) 157 parser.add_argument( 158 'input', 159 nargs='?', 160 type=argparse.FileType('r'), 161 default=sys.stdin) 162 parser.add_argument( 163 '--source', type=str, 164 help='Print each CHECK chunk before each delimeter line in the source' 165 'file, respectively. The delimeter lines are identified by ' 166 '--source_delim_regex.') 167 parser.add_argument('--source_delim_regex', type=str, default='func @') 168 parser.add_argument( 169 '--starts_from_scope', type=int, default=1, 170 help='Omit the top specified level of content. For example, by default ' 171 'it omits "module {"') 172 parser.add_argument('-i', '--inplace', action='store_true', default=False) 173 174 args = parser.parse_args() 175 176 # Open the given input file. 177 input_lines = [l.rstrip() for l in args.input] 178 args.input.close() 179 180 # Generate a note used for the generated check file. 181 script_name = os.path.basename(__file__) 182 autogenerated_note = (ADVERT_BEGIN + 'utils/' + script_name + "\n" + ADVERT_END) 183 184 source_segments = None 185 if args.source: 186 source_segments = process_source_lines( 187 [l.rstrip() for l in open(args.source, 'r')], 188 autogenerated_note, 189 args 190 ) 191 192 if args.inplace: 193 assert args.output is None 194 output = open(args.source, 'w') 195 elif args.output is None: 196 output = sys.stdout 197 else: 198 output = args.output 199 200 output_segments = [[]] 201 # A map containing data used for naming SSA value names. 202 variable_namer = SSAVariableNamer() 203 for input_line in input_lines: 204 if not input_line: 205 continue 206 lstripped_input_line = input_line.lstrip() 207 208 # Lines with blocks begin with a ^. These lines have a trailing comment 209 # that needs to be stripped. 210 is_block = lstripped_input_line[0] == '^' 211 if is_block: 212 input_line = input_line.rsplit('//', 1)[0].rstrip() 213 214 cur_level = variable_namer.num_scopes() 215 216 # If the line starts with a '}', pop the last name scope. 217 if lstripped_input_line[0] == '}': 218 variable_namer.pop_name_scope() 219 cur_level = variable_namer.num_scopes() 220 221 # If the line ends with a '{', push a new name scope. 222 if input_line[-1] == '{': 223 variable_namer.push_name_scope() 224 if cur_level == args.starts_from_scope: 225 output_segments.append([]) 226 227 # Omit lines at the near top level e.g. "module {". 228 if cur_level < args.starts_from_scope: 229 continue 230 231 if len(output_segments[-1]) == 0: 232 variable_namer.clear_counter() 233 234 # Preprocess the input to remove any sequences that may be problematic with 235 # FileCheck. 236 input_line = preprocess_line(input_line) 237 238 # Split the line at the each SSA value name. 239 ssa_split = input_line.split('%') 240 241 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. 242 if len(output_segments[-1]) != 0 or not ssa_split[0]: 243 output_line = '// ' + args.check_prefix + ': ' 244 # Pad to align with the 'LABEL' statements. 245 output_line += (' ' * len('-LABEL')) 246 247 # Output the first line chunk that does not contain an SSA name. 248 output_line += ssa_split[0] 249 250 # Process the rest of the input line. 251 output_line += process_line(ssa_split[1:], variable_namer) 252 253 else: 254 # Output the first line chunk that does not contain an SSA name for the 255 # label. 256 output_line = '// ' + args.check_prefix + '-LABEL: ' + ssa_split[0] + '\n' 257 258 # Process the rest of the input line on separate check lines. 259 for argument in ssa_split[1:]: 260 output_line += '// ' + args.check_prefix + '-SAME: ' 261 262 # Pad to align with the original position in the line. 263 output_line += ' ' * len(ssa_split[0]) 264 265 # Process the rest of the line. 266 output_line += process_line([argument], variable_namer) 267 268 # Append the output line. 269 output_segments[-1].append(output_line) 270 271 output.write(autogenerated_note + '\n') 272 273 # Write the output. 274 if source_segments: 275 assert len(output_segments) == len(source_segments) 276 for check_segment, source_segment in zip(output_segments, source_segments): 277 for line in check_segment: 278 output.write(line) 279 for line in source_segment: 280 output.write(line) 281 else: 282 for segment in output_segments: 283 output.write('\n') 284 for output_line in segment: 285 output.write(output_line) 286 output.write('\n') 287 output.close() 288 289 290if __name__ == '__main__': 291 main() 292