1#!/usr/bin/env python3 2"""A script to generate FileCheck statements for mlir unit tests. 3 4This script is a utility to add FileCheck patterns to an mlir file. 5 6NOTE: The input .mlir is expected to be the output from the parser, not a 7stripped down variant. 8 9Example usage: 10$ generate-test-checks.py foo.mlir 11$ mlir-opt foo.mlir -transformation | generate-test-checks.py 12$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir 13$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i 14$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' 15 16The script will heuristically generate CHECK/CHECK-LABEL commands for each line 17within the file. By default this script will also try to insert string 18substitution blocks for all SSA value names. If --source file is specified, the 19script will attempt to insert the generated CHECKs to the source file by looking 20for line positions matched by --source_delim_regex. 21 22The script is designed to make adding checks to a test case fast, it is *not* 23designed to be authoritative about what constitutes a good test! 24""" 25 26# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 27# See https://llvm.org/LICENSE.txt for license information. 28# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 29 30import argparse 31import os # Used to advertise this file's name ("autogenerated_note"). 32import re 33import sys 34 35ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " 36ADVERT_END = """ 37// The script is designed to make adding checks to 38// a test case fast, it is *not* designed to be authoritative 39// about what constitutes a good test! The CHECK should be 40// minimized and named to reflect the test intent. 41""" 42 43 44# Regex command to match an SSA identifier. 45SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" 46SSA_RE = re.compile(SSA_RE_STR) 47 48# Regex matching the left-hand side of an assignment 49SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' 50SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) 51 52# Regex matching attributes 53ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)' 54ATTR_RE = re.compile(ATTR_RE_STR) 55 56# Regex matching the left-hand side of an attribute definition 57ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*=' 58ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) 59 60 61# Class used to generate and manage string substitution blocks for SSA value 62# names. 63class VariableNamer: 64 def __init__(self, variable_names): 65 self.scopes = [] 66 self.name_counter = 0 67 68 # Number of variable names to still generate in parent scope 69 self.generate_in_parent_scope_left = 0 70 71 # Parse variable names 72 self.variable_names = [name.upper() for name in variable_names.split(',')] 73 self.used_variable_names = set() 74 75 # Generate the following 'n' variable names in the parent scope. 76 def generate_in_parent_scope(self, n): 77 self.generate_in_parent_scope_left = n 78 79 # Generate a substitution name for the given ssa value name. 80 def generate_name(self, source_variable_name): 81 82 # Compute variable name 83 variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' 84 if variable_name == '': 85 variable_name = "VAL_" + str(self.name_counter) 86 self.name_counter += 1 87 88 # Scope where variable name is saved 89 scope = len(self.scopes) - 1 90 if self.generate_in_parent_scope_left > 0: 91 self.generate_in_parent_scope_left -= 1 92 scope = len(self.scopes) - 2 93 assert(scope >= 0) 94 95 # Save variable 96 if variable_name in self.used_variable_names: 97 raise RuntimeError(variable_name + ': duplicate variable name') 98 self.scopes[scope][source_variable_name] = variable_name 99 self.used_variable_names.add(variable_name) 100 101 return variable_name 102 103 # Push a new variable name scope. 104 def push_name_scope(self): 105 self.scopes.append({}) 106 107 # Pop the last variable name scope. 108 def pop_name_scope(self): 109 self.scopes.pop() 110 111 # Return the level of nesting (number of pushed scopes). 112 def num_scopes(self): 113 return len(self.scopes) 114 115 # Reset the counter and used variable names. 116 def clear_names(self): 117 self.name_counter = 0 118 self.used_variable_names = set() 119 120class AttributeNamer: 121 122 def __init__(self, attribute_names): 123 self.name_counter = 0 124 self.attribute_names = [name.upper() for name in attribute_names.split(',')] 125 self.map = {} 126 self.used_attribute_names = set() 127 128 # Generate a substitution name for the given attribute name. 129 def generate_name(self, source_attribute_name): 130 131 # Compute FileCheck name 132 attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else '' 133 if attribute_name == '': 134 attribute_name = "ATTR_" + str(self.name_counter) 135 self.name_counter += 1 136 137 # Prepend global symbol 138 attribute_name = '$' + attribute_name 139 140 # Save attribute 141 if attribute_name in self.used_attribute_names: 142 raise RuntimeError(attribute_name + ': duplicate attribute name') 143 self.map[source_attribute_name] = attribute_name 144 self.used_attribute_names.add(attribute_name) 145 return attribute_name 146 147 # Get the saved substitution name for the given attribute name. If no name 148 # has been generated for the given attribute yet, the source attribute name 149 # itself is returned. 150 def get_name(self, source_attribute_name): 151 return self.map[source_attribute_name] if source_attribute_name in self.map else '?' 152 153# Return the number of SSA results in a line of type 154# %0, %1, ... = ... 155# The function returns 0 if there are no results. 156def get_num_ssa_results(input_line): 157 m = SSA_RESULTS_RE.match(input_line) 158 return m.group().count('%') if m else 0 159 160 161# Process a line of input that has been split at each SSA identifier '%'. 162def process_line(line_chunks, variable_namer): 163 output_line = "" 164 165 # Process the rest that contained an SSA value name. 166 for chunk in line_chunks: 167 m = SSA_RE.match(chunk) 168 ssa_name = m.group(0) if m is not None else '' 169 170 # Check if an existing variable exists for this name. 171 variable = None 172 for scope in variable_namer.scopes: 173 variable = scope.get(ssa_name) 174 if variable is not None: 175 break 176 177 # If one exists, then output the existing name. 178 if variable is not None: 179 output_line += "%[[" + variable + "]]" 180 else: 181 # Otherwise, generate a new variable. 182 variable = variable_namer.generate_name(ssa_name) 183 output_line += "%[[" + variable + ":.*]]" 184 185 # Append the non named group. 186 output_line += chunk[len(ssa_name) :] 187 188 return output_line.rstrip() + "\n" 189 190 191# Process the source file lines. The source file doesn't have to be .mlir. 192def process_source_lines(source_lines, note, args): 193 source_split_re = re.compile(args.source_delim_regex) 194 195 source_segments = [[]] 196 for line in source_lines: 197 # Remove previous note. 198 if line == note: 199 continue 200 # Remove previous CHECK lines. 201 if line.find(args.check_prefix) != -1: 202 continue 203 # Segment the file based on --source_delim_regex. 204 if source_split_re.search(line): 205 source_segments.append([]) 206 207 source_segments[-1].append(line + "\n") 208 return source_segments 209 210def process_attribute_definition(line, attribute_namer, output): 211 m = ATTR_DEF_RE.match(line) 212 if m: 213 attribute_name = attribute_namer.generate_name(m.group(1)) 214 line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' 215 output.write(line) 216 217def process_attribute_references(line, attribute_namer): 218 219 output_line = '' 220 components = ATTR_RE.split(line) 221 for component in components: 222 m = ATTR_RE.match(component) 223 if m: 224 output_line += '#[[' + attribute_namer.get_name(m.group(1)) + ']]' 225 output_line += component[len(m.group()):] 226 else: 227 output_line += component 228 return output_line 229 230# Pre-process a line of input to remove any character sequences that will be 231# problematic with FileCheck. 232def preprocess_line(line): 233 # Replace any double brackets, '[[' with escaped replacements. '[[' 234 # corresponds to variable names in FileCheck. 235 output_line = line.replace("[[", "{{\\[\\[}}") 236 237 # Replace any single brackets that are followed by an SSA identifier, the 238 # identifier will be replace by a variable; Creating the same situation as 239 # above. 240 output_line = output_line.replace("[%", "{{\\[}}%") 241 242 return output_line 243 244 245def main(): 246 parser = argparse.ArgumentParser( 247 description=__doc__, formatter_class=argparse.RawTextHelpFormatter 248 ) 249 parser.add_argument( 250 "--check-prefix", default="CHECK", help="Prefix to use from check file." 251 ) 252 parser.add_argument( 253 "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None 254 ) 255 parser.add_argument( 256 "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin 257 ) 258 parser.add_argument( 259 "--source", 260 type=str, 261 help="Print each CHECK chunk before each delimeter line in the source" 262 "file, respectively. The delimeter lines are identified by " 263 "--source_delim_regex.", 264 ) 265 parser.add_argument("--source_delim_regex", type=str, default="func @") 266 parser.add_argument( 267 "--starts_from_scope", 268 type=int, 269 default=1, 270 help="Omit the top specified level of content. For example, by default " 271 'it omits "module {"', 272 ) 273 parser.add_argument("-i", "--inplace", action="store_true", default=False) 274 parser.add_argument( 275 "--variable_names", 276 type=str, 277 default='', 278 help="Names to be used in FileCheck regular expression to represent SSA " 279 "variables in the order they are encountered. Separate names with commas, " 280 "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')") 281 parser.add_argument( 282 "--attribute_names", 283 type=str, 284 default='', 285 help="Names to be used in FileCheck regular expression to represent " 286 "attributes in the order they are defined. Separate names with commas," 287 "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')") 288 289 args = parser.parse_args() 290 291 # Open the given input file. 292 input_lines = [l.rstrip() for l in args.input] 293 args.input.close() 294 295 # Generate a note used for the generated check file. 296 script_name = os.path.basename(__file__) 297 autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END 298 299 source_segments = None 300 if args.source: 301 source_segments = process_source_lines( 302 [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args 303 ) 304 305 if args.inplace: 306 assert args.output is None 307 output = open(args.source, "w") 308 elif args.output is None: 309 output = sys.stdout 310 else: 311 output = args.output 312 313 output_segments = [[]] 314 315 # Namers 316 variable_namer = VariableNamer(args.variable_names) 317 attribute_namer = AttributeNamer(args.attribute_names) 318 319 # Process lines 320 for input_line in input_lines: 321 if not input_line: 322 continue 323 324 # Check if this is an attribute definition and process it 325 process_attribute_definition(input_line, attribute_namer, output) 326 327 # Lines with blocks begin with a ^. These lines have a trailing comment 328 # that needs to be stripped. 329 lstripped_input_line = input_line.lstrip() 330 is_block = lstripped_input_line[0] == "^" 331 if is_block: 332 input_line = input_line.rsplit("//", 1)[0].rstrip() 333 334 cur_level = variable_namer.num_scopes() 335 336 # If the line starts with a '}', pop the last name scope. 337 if lstripped_input_line[0] == "}": 338 variable_namer.pop_name_scope() 339 cur_level = variable_namer.num_scopes() 340 341 # If the line ends with a '{', push a new name scope. 342 if input_line[-1] == "{": 343 variable_namer.push_name_scope() 344 if cur_level == args.starts_from_scope: 345 output_segments.append([]) 346 347 # Result SSA values must still be pushed to parent scope 348 num_ssa_results = get_num_ssa_results(input_line) 349 variable_namer.generate_in_parent_scope(num_ssa_results) 350 351 # Omit lines at the near top level e.g. "module {". 352 if cur_level < args.starts_from_scope: 353 continue 354 355 if len(output_segments[-1]) == 0: 356 variable_namer.clear_names() 357 358 # Preprocess the input to remove any sequences that may be problematic with 359 # FileCheck. 360 input_line = preprocess_line(input_line) 361 362 # Process uses of attributes in this line 363 input_line = process_attribute_references(input_line, attribute_namer) 364 365 # Split the line at the each SSA value name. 366 ssa_split = input_line.split("%") 367 368 # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. 369 if len(output_segments[-1]) != 0 or not ssa_split[0]: 370 output_line = "// " + args.check_prefix + ": " 371 # Pad to align with the 'LABEL' statements. 372 output_line += " " * len("-LABEL") 373 374 # Output the first line chunk that does not contain an SSA name. 375 output_line += ssa_split[0] 376 377 # Process the rest of the input line. 378 output_line += process_line(ssa_split[1:], variable_namer) 379 380 else: 381 # Output the first line chunk that does not contain an SSA name for the 382 # label. 383 output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" 384 385 # Process the rest of the input line on separate check lines. 386 for argument in ssa_split[1:]: 387 output_line += "// " + args.check_prefix + "-SAME: " 388 389 # Pad to align with the original position in the line. 390 output_line += " " * len(ssa_split[0]) 391 392 # Process the rest of the line. 393 output_line += process_line([argument], variable_namer) 394 395 # Append the output line. 396 output_segments[-1].append(output_line) 397 398 output.write(autogenerated_note + "\n") 399 400 # Write the output. 401 if source_segments: 402 assert len(output_segments) == len(source_segments) 403 for check_segment, source_segment in zip(output_segments, source_segments): 404 for line in check_segment: 405 output.write(line) 406 for line in source_segment: 407 output.write(line) 408 else: 409 for segment in output_segments: 410 output.write("\n") 411 for output_line in segment: 412 output.write(output_line) 413 output.write("\n") 414 output.close() 415 416 417if __name__ == "__main__": 418 main() 419