1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3import gc, os, sys, tempfile 4from mlir.ir import * 5from mlir.passmanager import * 6from mlir.dialects.func import FuncOp 7from mlir.dialects.builtin import ModuleOp 8 9 10# Log everything to stderr and flush so that we have a unified stream to match 11# errors/info emitted by MLIR to stderr. 12def log(*args): 13 print(*args, file=sys.stderr) 14 sys.stderr.flush() 15 16 17def run(f): 18 log("\nTEST:", f.__name__) 19 f() 20 gc.collect() 21 assert Context._get_live_count() == 0 22 23 24# Verify capsule interop. 25# CHECK-LABEL: TEST: testCapsule 26def testCapsule(): 27 with Context(): 28 pm = PassManager() 29 pm_capsule = pm._CAPIPtr 30 assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule) 31 pm._testing_release() 32 pm1 = PassManager._CAPICreate(pm_capsule) 33 assert pm1 is not None # And does not crash. 34 35 36run(testCapsule) 37 38 39# CHECK-LABEL: TEST: testConstruct 40@run 41def testConstruct(): 42 with Context(): 43 # CHECK: pm1: 'any()' 44 # CHECK: pm2: 'builtin.module()' 45 pm1 = PassManager() 46 pm2 = PassManager("builtin.module") 47 log(f"pm1: '{pm1}'") 48 log(f"pm2: '{pm2}'") 49 50 51# Verify successful round-trip. 52# CHECK-LABEL: TEST: testParseSuccess 53def testParseSuccess(): 54 with Context(): 55 # An unregistered pass should not parse. 56 try: 57 pm = PassManager.parse( 58 "builtin.module(func.func(not-existing-pass{json=false}))" 59 ) 60 except ValueError as e: 61 # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass 62 log("ValueError exception:", e) 63 else: 64 log("Exception not produced") 65 66 # A registered pass should parse successfully. 67 pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))") 68 # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) 69 log("Roundtrip: ", pm) 70 71 72run(testParseSuccess) 73 74 75# Verify successful round-trip. 76# CHECK-LABEL: TEST: testParseSpacedPipeline 77def testParseSpacedPipeline(): 78 with Context(): 79 # A registered pass should parse successfully even if has extras spaces for readability 80 pm = PassManager.parse( 81 """builtin.module( 82 func.func( print-op-stats{ json=false } ) 83 )""" 84 ) 85 # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false})) 86 log("Roundtrip: ", pm) 87 88 89run(testParseSpacedPipeline) 90 91 92# Verify failure on unregistered pass. 93# CHECK-LABEL: TEST: testParseFail 94def testParseFail(): 95 with Context(): 96 try: 97 pm = PassManager.parse("any(unknown-pass)") 98 except ValueError as e: 99 # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error: 100 # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline 101 # CHECK: unknown-pass 102 # CHECK: ^ 103 log("ValueError exception:", e) 104 else: 105 log("Exception not produced") 106 107 108run(testParseFail) 109 110 111# Check that adding to a pass manager works 112# CHECK-LABEL: TEST: testAdd 113@run 114def testAdd(): 115 pm = PassManager("any", Context()) 116 # CHECK: pm: 'any()' 117 log(f"pm: '{pm}'") 118 # CHECK: pm: 'any(cse)' 119 pm.add("cse") 120 log(f"pm: '{pm}'") 121 # CHECK: pm: 'any(cse,cse)' 122 pm.add("cse") 123 log(f"pm: '{pm}'") 124 125 126# Verify failure on incorrect level of nesting. 127# CHECK-LABEL: TEST: testInvalidNesting 128def testInvalidNesting(): 129 with Context(): 130 try: 131 pm = PassManager.parse("func.func(normalize-memrefs)") 132 except ValueError as e: 133 # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? 134 log("ValueError exception:", e) 135 else: 136 log("Exception not produced") 137 138 139run(testInvalidNesting) 140 141 142# Verify that a pass manager can execute on IR 143# CHECK-LABEL: TEST: testRunPipeline 144def testRunPipeline(): 145 with Context(): 146 pm = PassManager.parse("any(print-op-stats{json=false})") 147 func = FuncOp.parse(r"""func.func @successfulParse() { return }""") 148 pm.run(func) 149 150 151# CHECK: Operations encountered: 152# CHECK: func.func , 1 153# CHECK: func.return , 1 154run(testRunPipeline) 155 156 157# CHECK-LABEL: TEST: testRunPipelineError 158@run 159def testRunPipelineError(): 160 with Context() as ctx: 161 ctx.allow_unregistered_dialects = True 162 op = Operation.parse('"test.op"() : () -> ()') 163 pm = PassManager.parse("any(cse)") 164 try: 165 pm.run(op) 166 except MLIRError as e: 167 # CHECK: Exception: < 168 # CHECK: Failure while executing pass pipeline: 169 # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation 170 # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () 171 # CHECK: > 172 log(f"Exception: <{e}>") 173 174 175# CHECK-LABEL: TEST: testPostPassOpInvalidation 176@run 177def testPostPassOpInvalidation(): 178 with Context() as ctx: 179 log_op_count = lambda: log("live ops:", ctx._get_live_operation_count()) 180 181 # CHECK: invalidate_ops=False 182 log("invalidate_ops=False") 183 184 # CHECK: live ops: 0 185 log_op_count() 186 187 module = ModuleOp.parse( 188 """ 189 module { 190 arith.constant 10 191 func.func @foo() { 192 arith.constant 10 193 return 194 } 195 } 196 """ 197 ) 198 199 # CHECK: live ops: 1 200 log_op_count() 201 202 outer_const_op = module.body.operations[0] 203 # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 204 log(outer_const_op) 205 206 func_op = module.body.operations[1] 207 # CHECK: func.func @[[FOO:.*]]() { 208 # CHECK: %[[VAL1:.*]] = arith.constant 10 : i64 209 # CHECK: return 210 # CHECK: } 211 log(func_op) 212 213 inner_const_op = func_op.body.blocks[0].operations[0] 214 # CHECK: %[[VAL1]] = arith.constant 10 : i64 215 log(inner_const_op) 216 217 # CHECK: live ops: 4 218 log_op_count() 219 220 PassManager.parse("builtin.module(canonicalize)").run( 221 module, invalidate_ops=False 222 ) 223 # CHECK: func.func @foo() { 224 # CHECK: return 225 # CHECK: } 226 log(func_op) 227 228 # CHECK: func.func @foo() { 229 # CHECK: return 230 # CHECK: } 231 log(module) 232 233 # CHECK: invalidate_ops=True 234 log("invalidate_ops=True") 235 236 # CHECK: live ops: 4 237 log_op_count() 238 239 module = ModuleOp.parse( 240 """ 241 module { 242 arith.constant 10 243 func.func @foo() { 244 arith.constant 10 245 return 246 } 247 } 248 """ 249 ) 250 outer_const_op = module.body.operations[0] 251 func_op = module.body.operations[1] 252 inner_const_op = func_op.body.blocks[0].operations[0] 253 254 # CHECK: live ops: 4 255 log_op_count() 256 257 PassManager.parse("builtin.module(canonicalize)").run(module) 258 259 # CHECK: live ops: 1 260 log_op_count() 261 262 try: 263 log(func_op) 264 except RuntimeError as e: 265 # CHECK: the operation has been invalidated 266 log(e) 267 268 try: 269 log(outer_const_op) 270 except RuntimeError as e: 271 # CHECK: the operation has been invalidated 272 log(e) 273 274 try: 275 log(inner_const_op) 276 except RuntimeError as e: 277 # CHECK: the operation has been invalidated 278 log(e) 279 280 # CHECK: func.func @foo() { 281 # CHECK: return 282 # CHECK: } 283 log(module) 284 285 286# CHECK-LABEL: TEST: testPrintIrAfterAll 287@run 288def testPrintIrAfterAll(): 289 with Context() as ctx: 290 module = ModuleOp.parse( 291 """ 292 module { 293 func.func @main() { 294 %0 = arith.constant 10 295 return 296 } 297 } 298 """ 299 ) 300 pm = PassManager.parse("builtin.module(canonicalize)") 301 ctx.enable_multithreading(False) 302 pm.enable_ir_printing() 303 # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // 304 # CHECK: module { 305 # CHECK: func.func @main() { 306 # CHECK: return 307 # CHECK: } 308 # CHECK: } 309 pm.run(module) 310 311 312# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll 313@run 314def testPrintIrBeforeAndAfterAll(): 315 with Context() as ctx: 316 module = ModuleOp.parse( 317 """ 318 module { 319 func.func @main() { 320 %0 = arith.constant 10 321 return 322 } 323 } 324 """ 325 ) 326 pm = PassManager.parse("builtin.module(canonicalize)") 327 ctx.enable_multithreading(False) 328 pm.enable_ir_printing(print_before_all=True, print_after_all=True) 329 # CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- // 330 # CHECK: module { 331 # CHECK: func.func @main() { 332 # CHECK: %[[C10:.*]] = arith.constant 10 : i64 333 # CHECK: return 334 # CHECK: } 335 # CHECK: } 336 # CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- // 337 # CHECK: module { 338 # CHECK: func.func @main() { 339 # CHECK: return 340 # CHECK: } 341 # CHECK: } 342 pm.run(module) 343 344 345# CHECK-LABEL: TEST: testPrintIrLargeLimitElements 346@run 347def testPrintIrLargeLimitElements(): 348 with Context() as ctx: 349 module = ModuleOp.parse( 350 """ 351 module { 352 func.func @main() -> tensor<3xi64> { 353 %0 = arith.constant dense<[1, 2, 3]> : tensor<3xi64> 354 return %0 : tensor<3xi64> 355 } 356 } 357 """ 358 ) 359 pm = PassManager.parse("builtin.module(canonicalize)") 360 ctx.enable_multithreading(False) 361 pm.enable_ir_printing(large_elements_limit=2) 362 # CHECK: %[[CST:.*]] = arith.constant dense_resource<__elided__> : tensor<3xi64> 363 pm.run(module) 364 365 366# CHECK-LABEL: TEST: testPrintIrTree 367@run 368def testPrintIrTree(): 369 with Context() as ctx: 370 module = ModuleOp.parse( 371 """ 372 module { 373 func.func @main() { 374 %0 = arith.constant 10 375 return 376 } 377 } 378 """ 379 ) 380 pm = PassManager.parse("builtin.module(canonicalize)") 381 ctx.enable_multithreading(False) 382 pm.enable_ir_printing() 383 # CHECK-LABEL: // Tree printing begin 384 # CHECK: \-- builtin_module_no-symbol-name 385 # CHECK: \-- 0_canonicalize.mlir 386 # CHECK-LABEL: // Tree printing end 387 pm.run(module) 388 log("// Tree printing begin") 389 with tempfile.TemporaryDirectory() as temp_dir: 390 pm.enable_ir_printing(tree_printing_dir_path=temp_dir) 391 pm.run(module) 392 393 def print_file_tree(directory, prefix=""): 394 entries = sorted(os.listdir(directory)) 395 for i, entry in enumerate(entries): 396 path = os.path.join(directory, entry) 397 connector = "\-- " if i == len(entries) - 1 else "|-- " 398 log(f"{prefix}{connector}{entry}") 399 if os.path.isdir(path): 400 print_file_tree( 401 path, prefix + (" " if i == len(entries) - 1 else "│ ") 402 ) 403 404 print_file_tree(temp_dir) 405 log("// Tree printing end") 406