1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7from mlir.dialects.builtin import ModuleOp 8from mlir.dialects import arith 9from mlir.dialects._ods_common import _cext 10 11 12def run(f): 13 print("\nTEST:", f.__name__) 14 f() 15 gc.collect() 16 assert Context._get_live_count() == 0 17 return f 18 19 20def expect_index_error(callback): 21 try: 22 _ = callback() 23 raise RuntimeError("Expected IndexError") 24 except IndexError: 25 pass 26 27 28# Verify iterator based traversal of the op/region/block hierarchy. 29# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 30@run 31def testTraverseOpRegionBlockIterators(): 32 ctx = Context() 33 ctx.allow_unregistered_dialects = True 34 module = Module.parse( 35 r""" 36 func.func @f1(%arg0: i32) -> i32 { 37 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 38 return %1 : i32 39 } 40 """, 41 ctx, 42 ) 43 op = module.operation 44 assert op.context is ctx 45 # Get the block using iterators off of the named collections. 46 regions = list(op.regions) 47 blocks = list(regions[0].blocks) 48 # CHECK: MODULE REGIONS=1 BLOCKS=1 49 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 50 51 # Should verify. 52 # CHECK: .verify = True 53 print(f".verify = {module.operation.verify()}") 54 55 # Get the blocks from the default collection. 56 default_blocks = list(regions[0]) 57 # They should compare equal regardless of how obtained. 58 assert default_blocks == blocks 59 60 # Should be able to get the operations from either the named collection 61 # or the block. 62 operations = list(blocks[0].operations) 63 default_operations = list(blocks[0]) 64 assert default_operations == operations 65 66 def walk_operations(indent, op): 67 for i, region in enumerate(op.regions): 68 print(f"{indent}REGION {i}:") 69 for j, block in enumerate(region): 70 print(f"{indent} BLOCK {j}:") 71 for k, child_op in enumerate(block): 72 print(f"{indent} OP {k}: {child_op}") 73 walk_operations(indent + " ", child_op) 74 75 # CHECK: REGION 0: 76 # CHECK: BLOCK 0: 77 # CHECK: OP 0: func 78 # CHECK: REGION 0: 79 # CHECK: BLOCK 0: 80 # CHECK: OP 0: %0 = "custom.addi" 81 # CHECK: OP 1: func.return 82 walk_operations("", op) 83 84 # CHECK: Region iter: <mlir.{{.+}}.RegionIterator 85 # CHECK: Block iter: <mlir.{{.+}}.BlockIterator 86 # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator 87 print(" Region iter:", iter(op.regions)) 88 print(" Block iter:", iter(op.regions[0])) 89 print("Operation iter:", iter(op.regions[0].blocks[0])) 90 91 92# Verify index based traversal of the op/region/block hierarchy. 93# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 94@run 95def testTraverseOpRegionBlockIndices(): 96 ctx = Context() 97 ctx.allow_unregistered_dialects = True 98 module = Module.parse( 99 r""" 100 func.func @f1(%arg0: i32) -> i32 { 101 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 102 return %1 : i32 103 } 104 """, 105 ctx, 106 ) 107 108 def walk_operations(indent, op): 109 for i in range(len(op.regions)): 110 region = op.regions[i] 111 print(f"{indent}REGION {i}:") 112 for j in range(len(region.blocks)): 113 block = region.blocks[j] 114 print(f"{indent} BLOCK {j}:") 115 for k in range(len(block.operations)): 116 child_op = block.operations[k] 117 print(f"{indent} OP {k}: {child_op}") 118 print( 119 f"{indent} OP {k}: parent {child_op.operation.parent.name}" 120 ) 121 walk_operations(indent + " ", child_op) 122 123 # CHECK: REGION 0: 124 # CHECK: BLOCK 0: 125 # CHECK: OP 0: func 126 # CHECK: OP 0: parent builtin.module 127 # CHECK: REGION 0: 128 # CHECK: BLOCK 0: 129 # CHECK: OP 0: %0 = "custom.addi" 130 # CHECK: OP 0: parent func.func 131 # CHECK: OP 1: func.return 132 # CHECK: OP 1: parent func.func 133 walk_operations("", module.operation) 134 135 136# CHECK-LABEL: TEST: testBlockAndRegionOwners 137@run 138def testBlockAndRegionOwners(): 139 ctx = Context() 140 ctx.allow_unregistered_dialects = True 141 module = Module.parse( 142 r""" 143 builtin.module { 144 func.func @f() { 145 func.return 146 } 147 } 148 """, 149 ctx, 150 ) 151 152 assert module.operation.regions[0].owner == module.operation 153 assert module.operation.regions[0].blocks[0].owner == module.operation 154 155 func = module.body.operations[0] 156 assert func.operation.regions[0].owner == func 157 assert func.operation.regions[0].blocks[0].owner == func 158 159 160# CHECK-LABEL: TEST: testBlockArgumentList 161@run 162def testBlockArgumentList(): 163 with Context() as ctx: 164 module = Module.parse( 165 r""" 166 func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 167 return 168 } 169 """, 170 ctx, 171 ) 172 func = module.body.operations[0] 173 entry_block = func.regions[0].blocks[0] 174 assert len(entry_block.arguments) == 3 175 # CHECK: Argument 0, type i32 176 # CHECK: Argument 1, type f64 177 # CHECK: Argument 2, type index 178 for arg in entry_block.arguments: 179 print(f"Argument {arg.arg_number}, type {arg.type}") 180 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 181 arg.set_type(new_type) 182 183 # CHECK: Argument 0, type i8 184 # CHECK: Argument 1, type i16 185 # CHECK: Argument 2, type i24 186 for arg in entry_block.arguments: 187 print(f"Argument {arg.arg_number}, type {arg.type}") 188 189 # Check that slicing works for block argument lists. 190 # CHECK: Argument 1, type i16 191 # CHECK: Argument 2, type i24 192 for arg in entry_block.arguments[1:]: 193 print(f"Argument {arg.arg_number}, type {arg.type}") 194 195 # Check that we can concatenate slices of argument lists. 196 # CHECK: Length: 4 197 print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:])) 198 199 # CHECK: Type: i8 200 # CHECK: Type: i16 201 # CHECK: Type: i24 202 for t in entry_block.arguments.types: 203 print("Type: ", t) 204 205 # Check that slicing and type access compose. 206 # CHECK: Sliced type: i16 207 # CHECK: Sliced type: i24 208 for t in entry_block.arguments[1:].types: 209 print("Sliced type: ", t) 210 211 # Check that slice addition works as expected. 212 # CHECK: Argument 2, type i24 213 # CHECK: Argument 0, type i8 214 restructured = entry_block.arguments[-1:] + entry_block.arguments[:1] 215 for arg in restructured: 216 print(f"Argument {arg.arg_number}, type {arg.type}") 217 218 219# CHECK-LABEL: TEST: testOperationOperands 220@run 221def testOperationOperands(): 222 with Context() as ctx: 223 ctx.allow_unregistered_dialects = True 224 module = Module.parse( 225 r""" 226 func.func @f1(%arg0: i32) { 227 %0 = "test.producer"() : () -> i64 228 "test.consumer"(%arg0, %0) : (i32, i64) -> () 229 return 230 }""" 231 ) 232 func = module.body.operations[0] 233 entry_block = func.regions[0].blocks[0] 234 consumer = entry_block.operations[1] 235 assert len(consumer.operands) == 2 236 # CHECK: Operand 0, type i32 237 # CHECK: Operand 1, type i64 238 for i, operand in enumerate(consumer.operands): 239 print(f"Operand {i}, type {operand.type}") 240 241 242# CHECK-LABEL: TEST: testOperationOperandsSlice 243@run 244def testOperationOperandsSlice(): 245 with Context() as ctx: 246 ctx.allow_unregistered_dialects = True 247 module = Module.parse( 248 r""" 249 func.func @f1() { 250 %0 = "test.producer0"() : () -> i64 251 %1 = "test.producer1"() : () -> i64 252 %2 = "test.producer2"() : () -> i64 253 %3 = "test.producer3"() : () -> i64 254 %4 = "test.producer4"() : () -> i64 255 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 256 return 257 }""" 258 ) 259 func = module.body.operations[0] 260 entry_block = func.regions[0].blocks[0] 261 consumer = entry_block.operations[5] 262 assert len(consumer.operands) == 5 263 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 264 assert left == right 265 266 # CHECK: test.producer0 267 # CHECK: test.producer1 268 # CHECK: test.producer2 269 # CHECK: test.producer3 270 # CHECK: test.producer4 271 full_slice = consumer.operands[:] 272 for operand in full_slice: 273 print(operand) 274 275 # CHECK: test.producer0 276 # CHECK: test.producer1 277 first_two = consumer.operands[0:2] 278 for operand in first_two: 279 print(operand) 280 281 # CHECK: test.producer3 282 # CHECK: test.producer4 283 last_two = consumer.operands[3:] 284 for operand in last_two: 285 print(operand) 286 287 # CHECK: test.producer0 288 # CHECK: test.producer2 289 # CHECK: test.producer4 290 even = consumer.operands[::2] 291 for operand in even: 292 print(operand) 293 294 # CHECK: test.producer2 295 fourth = consumer.operands[::2][1::2] 296 for operand in fourth: 297 print(operand) 298 299 300# CHECK-LABEL: TEST: testOperationOperandsSet 301@run 302def testOperationOperandsSet(): 303 with Context() as ctx, Location.unknown(ctx): 304 ctx.allow_unregistered_dialects = True 305 module = Module.parse( 306 r""" 307 func.func @f1() { 308 %0 = "test.producer0"() : () -> i64 309 %1 = "test.producer1"() : () -> i64 310 %2 = "test.producer2"() : () -> i64 311 "test.consumer"(%0) : (i64) -> () 312 return 313 }""" 314 ) 315 func = module.body.operations[0] 316 entry_block = func.regions[0].blocks[0] 317 producer1 = entry_block.operations[1] 318 producer2 = entry_block.operations[2] 319 consumer = entry_block.operations[3] 320 assert len(consumer.operands) == 1 321 type = consumer.operands[0].type 322 323 # CHECK: test.producer1 324 consumer.operands[0] = producer1.result 325 print(consumer.operands[0]) 326 327 # CHECK: test.producer2 328 consumer.operands[-1] = producer2.result 329 print(consumer.operands[0]) 330 331 332# CHECK-LABEL: TEST: testDetachedOperation 333@run 334def testDetachedOperation(): 335 ctx = Context() 336 ctx.allow_unregistered_dialects = True 337 with Location.unknown(ctx): 338 i32 = IntegerType.get_signed(32) 339 op1 = Operation.create( 340 "custom.op1", 341 results=[i32, i32], 342 regions=1, 343 attributes={ 344 "foo": StringAttr.get("foo_value"), 345 "bar": StringAttr.get("bar_value"), 346 }, 347 ) 348 # CHECK: %0:2 = "custom.op1"() ({ 349 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 350 print(op1) 351 352 # TODO: Check successors once enough infra exists to do it properly. 353 354 355# CHECK-LABEL: TEST: testOperationInsertionPoint 356@run 357def testOperationInsertionPoint(): 358 ctx = Context() 359 ctx.allow_unregistered_dialects = True 360 module = Module.parse( 361 r""" 362 func.func @f1(%arg0: i32) -> i32 { 363 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 364 return %1 : i32 365 } 366 """, 367 ctx, 368 ) 369 370 # Create test op. 371 with Location.unknown(ctx): 372 op1 = Operation.create("custom.op1") 373 op2 = Operation.create("custom.op2") 374 375 func = module.body.operations[0] 376 entry_block = func.regions[0].blocks[0] 377 ip = InsertionPoint.at_block_begin(entry_block) 378 ip.insert(op1) 379 ip.insert(op2) 380 # CHECK: func @f1 381 # CHECK: "custom.op1"() 382 # CHECK: "custom.op2"() 383 # CHECK: %0 = "custom.addi" 384 print(module) 385 386 # Trying to add a previously added op should raise. 387 try: 388 ip.insert(op1) 389 except ValueError: 390 pass 391 else: 392 assert False, "expected insert of attached op to raise" 393 394 395# CHECK-LABEL: TEST: testOperationWithRegion 396@run 397def testOperationWithRegion(): 398 ctx = Context() 399 ctx.allow_unregistered_dialects = True 400 with Location.unknown(ctx): 401 i32 = IntegerType.get_signed(32) 402 op1 = Operation.create("custom.op1", regions=1) 403 block = op1.regions[0].blocks.append(i32, i32) 404 # CHECK: "custom.op1"() ({ 405 # CHECK: ^bb0(%arg0: si32, %arg1: si32): 406 # CHECK: "custom.terminator"() : () -> () 407 # CHECK: }) : () -> () 408 terminator = Operation.create("custom.terminator") 409 ip = InsertionPoint(block) 410 ip.insert(terminator) 411 print(op1) 412 413 # Now add the whole operation to another op. 414 # TODO: Verify lifetime hazard by nulling out the new owning module and 415 # accessing op1. 416 # TODO: Also verify accessing the terminator once both parents are nulled 417 # out. 418 module = Module.parse( 419 r""" 420 func.func @f1(%arg0: i32) -> i32 { 421 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 422 return %1 : i32 423 } 424 """ 425 ) 426 func = module.body.operations[0] 427 entry_block = func.regions[0].blocks[0] 428 ip = InsertionPoint.at_block_begin(entry_block) 429 ip.insert(op1) 430 # CHECK: func @f1 431 # CHECK: "custom.op1"() 432 # CHECK: "custom.terminator" 433 # CHECK: %0 = "custom.addi" 434 print(module) 435 436 437# CHECK-LABEL: TEST: testOperationResultList 438@run 439def testOperationResultList(): 440 ctx = Context() 441 module = Module.parse( 442 r""" 443 func.func @f1() { 444 %0:3 = call @f2() : () -> (i32, f64, index) 445 call @f3() : () -> () 446 return 447 } 448 func.func private @f2() -> (i32, f64, index) 449 func.func private @f3() -> () 450 """, 451 ctx, 452 ) 453 caller = module.body.operations[0] 454 call = caller.regions[0].blocks[0].operations[0] 455 assert len(call.results) == 3 456 # CHECK: Result 0, type i32 457 # CHECK: Result 1, type f64 458 # CHECK: Result 2, type index 459 for res in call.results: 460 print(f"Result {res.result_number}, type {res.type}") 461 462 # CHECK: Result type i32 463 # CHECK: Result type f64 464 # CHECK: Result type index 465 for t in call.results.types: 466 print(f"Result type {t}") 467 468 # Out of range 469 expect_index_error(lambda: call.results[3]) 470 expect_index_error(lambda: call.results[-4]) 471 472 no_results_call = caller.regions[0].blocks[0].operations[1] 473 assert len(no_results_call.results) == 0 474 assert no_results_call.results.owner == no_results_call 475 476 477# CHECK-LABEL: TEST: testOperationResultListSlice 478@run 479def testOperationResultListSlice(): 480 with Context() as ctx: 481 ctx.allow_unregistered_dialects = True 482 module = Module.parse( 483 r""" 484 func.func @f1() { 485 "some.op"() : () -> (i1, i2, i3, i4, i5) 486 return 487 } 488 """ 489 ) 490 func = module.body.operations[0] 491 entry_block = func.regions[0].blocks[0] 492 producer = entry_block.operations[0] 493 494 assert len(producer.results) == 5 495 for left, right in zip(producer.results, producer.results[::-1][::-1]): 496 assert left == right 497 assert left.result_number == right.result_number 498 499 # CHECK: Result 0, type i1 500 # CHECK: Result 1, type i2 501 # CHECK: Result 2, type i3 502 # CHECK: Result 3, type i4 503 # CHECK: Result 4, type i5 504 full_slice = producer.results[:] 505 for res in full_slice: 506 print(f"Result {res.result_number}, type {res.type}") 507 508 # CHECK: Result 1, type i2 509 # CHECK: Result 2, type i3 510 # CHECK: Result 3, type i4 511 middle = producer.results[1:4] 512 for res in middle: 513 print(f"Result {res.result_number}, type {res.type}") 514 515 # CHECK: Result 1, type i2 516 # CHECK: Result 3, type i4 517 odd = producer.results[1::2] 518 for res in odd: 519 print(f"Result {res.result_number}, type {res.type}") 520 521 # CHECK: Result 3, type i4 522 # CHECK: Result 1, type i2 523 inverted_middle = producer.results[-2:0:-2] 524 for res in inverted_middle: 525 print(f"Result {res.result_number}, type {res.type}") 526 527 528# CHECK-LABEL: TEST: testOperationAttributes 529@run 530def testOperationAttributes(): 531 ctx = Context() 532 ctx.allow_unregistered_dialects = True 533 module = Module.parse( 534 r""" 535 "some.op"() { some.attribute = 1 : i8, 536 other.attribute = 3.0, 537 dependent = "text" } : () -> () 538 """, 539 ctx, 540 ) 541 op = module.body.operations[0] 542 assert len(op.attributes) == 3 543 iattr = op.attributes["some.attribute"] 544 fattr = op.attributes["other.attribute"] 545 sattr = op.attributes["dependent"] 546 # CHECK: Attribute type i8, value 1 547 print(f"Attribute type {iattr.type}, value {iattr.value}") 548 # CHECK: Attribute type f64, value 3.0 549 print(f"Attribute type {fattr.type}, value {fattr.value}") 550 # CHECK: Attribute value text 551 print(f"Attribute value {sattr.value}") 552 # CHECK: Attribute value b'text' 553 print(f"Attribute value {sattr.value_bytes}") 554 555 # We don't know in which order the attributes are stored. 556 # CHECK-DAG: NamedAttribute(dependent="text") 557 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 558 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 559 for attr in op.attributes: 560 print(str(attr)) 561 562 # Check that exceptions are raised as expected. 563 try: 564 op.attributes["does_not_exist"] 565 except KeyError: 566 pass 567 else: 568 assert False, "expected KeyError on accessing a non-existent attribute" 569 570 try: 571 op.attributes[42] 572 except IndexError: 573 pass 574 else: 575 assert False, "expected IndexError on accessing an out-of-bounds attribute" 576 577 578# CHECK-LABEL: TEST: testOperationPrint 579@run 580def testOperationPrint(): 581 ctx = Context() 582 module = Module.parse( 583 r""" 584 func.func @f1(%arg0: i32) -> i32 { 585 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 586 return %arg0 : i32 587 } 588 """, 589 ctx, 590 ) 591 592 # Test print to stdout. 593 # CHECK: return %arg0 : i32 594 module.operation.print() 595 596 # Test print to text file. 597 f = io.StringIO() 598 # CHECK: <class 'str'> 599 # CHECK: return %arg0 : i32 600 module.operation.print(file=f) 601 str_value = f.getvalue() 602 print(str_value.__class__) 603 print(f.getvalue()) 604 605 # Test roundtrip to bytecode. 606 bytecode_stream = io.BytesIO() 607 module.operation.write_bytecode(bytecode_stream, desired_version=1) 608 bytecode = bytecode_stream.getvalue() 609 assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR" 610 module_roundtrip = Module.parse(bytecode, ctx) 611 f = io.StringIO() 612 module_roundtrip.operation.print(file=f) 613 roundtrip_value = f.getvalue() 614 assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode" 615 616 # Test print to binary file. 617 f = io.BytesIO() 618 # CHECK: <class 'bytes'> 619 # CHECK: return %arg0 : i32 620 module.operation.print(file=f, binary=True) 621 bytes_value = f.getvalue() 622 print(bytes_value.__class__) 623 print(bytes_value) 624 625 # Test print local_scope. 626 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 627 module.operation.print(enable_debug_info=True, use_local_scope=True) 628 629 # Test printing using state. 630 state = AsmState(module.operation) 631 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> 632 module.operation.print(state) 633 634 # Test print with options. 635 # CHECK: value = dense_resource<__elided__> : tensor<4xi32> 636 # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 637 module.operation.print( 638 large_elements_limit=2, 639 enable_debug_info=True, 640 pretty_debug_info=True, 641 print_generic_op_form=True, 642 use_local_scope=True, 643 ) 644 645 # Test print with skip_regions option 646 # CHECK: func.func @f1(%arg0: i32) -> i32 647 # CHECK-NOT: func.return 648 module.body.operations[0].print( 649 skip_regions=True, 650 ) 651 652 653# CHECK-LABEL: TEST: testKnownOpView 654@run 655def testKnownOpView(): 656 with Context(), Location.unknown(): 657 Context.current.allow_unregistered_dialects = True 658 module = Module.parse( 659 r""" 660 %1 = "custom.f32"() : () -> f32 661 %2 = "custom.f32"() : () -> f32 662 %3 = arith.addf %1, %2 : f32 663 %4 = arith.constant 0 : i32 664 """ 665 ) 666 print(module) 667 668 # addf should map to a known OpView class in the arithmetic dialect. 669 # We know the OpView for it defines an 'lhs' attribute. 670 addf = module.body.operations[2] 671 # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object 672 print(repr(addf)) 673 # CHECK: "custom.f32"() 674 print(addf.lhs) 675 676 # One of the custom ops should resolve to the default OpView. 677 custom = module.body.operations[0] 678 # CHECK: OpView object 679 print(repr(custom)) 680 681 # Check again to make sure negative caching works. 682 custom = module.body.operations[0] 683 # CHECK: OpView object 684 print(repr(custom)) 685 686 # constant should map to an extension OpView class in the arithmetic dialect. 687 constant = module.body.operations[3] 688 # CHECK: <mlir.dialects.arith.ConstantOp object 689 print(repr(constant)) 690 # Checks that the arith extension is being registered successfully 691 # (literal_value is a property on the extension class but not on the default OpView). 692 # CHECK: literal value 0 693 print("literal value", constant.literal_value) 694 695 # Checks that "late" registration/replacement (i.e., post all module loading/initialization) 696 # is working correctly. 697 @_cext.register_operation(arith._Dialect, replace=True) 698 class ConstantOp(arith.ConstantOp): 699 def __init__(self, result, value, *, loc=None, ip=None): 700 if isinstance(value, int): 701 super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) 702 elif isinstance(value, float): 703 super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) 704 else: 705 super().__init__(value, loc=loc, ip=ip) 706 707 constant = module.body.operations[3] 708 # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object 709 print(repr(constant)) 710 711 712# CHECK-LABEL: TEST: testSingleResultProperty 713@run 714def testSingleResultProperty(): 715 with Context(), Location.unknown(): 716 Context.current.allow_unregistered_dialects = True 717 module = Module.parse( 718 r""" 719 "custom.no_result"() : () -> () 720 %0:2 = "custom.two_result"() : () -> (f32, f32) 721 %1 = "custom.one_result"() : () -> f32 722 """ 723 ) 724 print(module) 725 726 try: 727 module.body.operations[0].result 728 except ValueError as e: 729 # CHECK: Cannot call .result on operation custom.no_result which has 0 results 730 print(e) 731 else: 732 assert False, "Expected exception" 733 734 try: 735 module.body.operations[1].result 736 except ValueError as e: 737 # CHECK: Cannot call .result on operation custom.two_result which has 2 results 738 print(e) 739 else: 740 assert False, "Expected exception" 741 742 # CHECK: %1 = "custom.one_result"() : () -> f32 743 print(module.body.operations[2]) 744 745 746def create_invalid_operation(): 747 # This module has two region and is invalid verify that we fallback 748 # to the generic printer for safety. 749 op = Operation.create("builtin.module", regions=2) 750 op.regions[0].blocks.append() 751 return op 752 753 754# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails 755@run 756def testInvalidOperationStrSoftFails(): 757 ctx = Context() 758 with Location.unknown(ctx): 759 invalid_op = create_invalid_operation() 760 # Verify that we fallback to the generic printer for safety. 761 # CHECK: "builtin.module"() ({ 762 # CHECK: }) : () -> () 763 print(invalid_op) 764 try: 765 invalid_op.verify() 766 except MLIRError as e: 767 # CHECK: Exception: < 768 # CHECK: Verification failed: 769 # CHECK: error: unknown: 'builtin.module' op requires one region 770 # CHECK: note: unknown: see current operation: 771 # CHECK: "builtin.module"() ({ 772 # CHECK: ^bb0: 773 # CHECK: }, { 774 # CHECK: }) : () -> () 775 # CHECK: > 776 print(f"Exception: <{e}>") 777 778 779# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails 780@run 781def testInvalidModuleStrSoftFails(): 782 ctx = Context() 783 with Location.unknown(ctx): 784 module = Module.create() 785 with InsertionPoint(module.body): 786 invalid_op = create_invalid_operation() 787 # Verify that we fallback to the generic printer for safety. 788 # CHECK: "builtin.module"() ({ 789 # CHECK: }) : () -> () 790 print(module) 791 792 793# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails 794@run 795def testInvalidOperationGetAsmBinarySoftFails(): 796 ctx = Context() 797 with Location.unknown(ctx): 798 invalid_op = create_invalid_operation() 799 # Verify that we fallback to the generic printer for safety. 800 # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n' 801 print(invalid_op.get_asm(binary=True)) 802 803 804# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 805@run 806def testCreateWithInvalidAttributes(): 807 ctx = Context() 808 with Location.unknown(ctx): 809 try: 810 Operation.create( 811 "builtin.module", attributes={None: StringAttr.get("name")} 812 ) 813 except Exception as e: 814 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 815 print(e) 816 try: 817 Operation.create("builtin.module", attributes={42: StringAttr.get("name")}) 818 except Exception as e: 819 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 820 print(e) 821 try: 822 Operation.create("builtin.module", attributes={"some_key": ctx}) 823 except Exception as e: 824 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 825 print(e) 826 try: 827 Operation.create("builtin.module", attributes={"some_key": None}) 828 except Exception as e: 829 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 830 print(e) 831 832 833# CHECK-LABEL: TEST: testOperationName 834@run 835def testOperationName(): 836 ctx = Context() 837 ctx.allow_unregistered_dialects = True 838 module = Module.parse( 839 r""" 840 %0 = "custom.op1"() : () -> f32 841 %1 = "custom.op2"() : () -> i32 842 %2 = "custom.op1"() : () -> f32 843 """, 844 ctx, 845 ) 846 847 # CHECK: custom.op1 848 # CHECK: custom.op2 849 # CHECK: custom.op1 850 for op in module.body.operations: 851 print(op.operation.name) 852 853 854# CHECK-LABEL: TEST: testCapsuleConversions 855@run 856def testCapsuleConversions(): 857 ctx = Context() 858 ctx.allow_unregistered_dialects = True 859 with Location.unknown(ctx): 860 m = Operation.create("custom.op1").operation 861 m_capsule = m._CAPIPtr 862 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 863 m2 = Operation._CAPICreate(m_capsule) 864 assert m2 is m 865 866 867# CHECK-LABEL: TEST: testOperationErase 868@run 869def testOperationErase(): 870 ctx = Context() 871 ctx.allow_unregistered_dialects = True 872 with Location.unknown(ctx): 873 m = Module.create() 874 with InsertionPoint(m.body): 875 op = Operation.create("custom.op1") 876 877 # CHECK: "custom.op1" 878 print(m) 879 880 op.operation.erase() 881 882 # CHECK-NOT: "custom.op1" 883 print(m) 884 885 # Ensure we can create another operation 886 Operation.create("custom.op2") 887 888 889# CHECK-LABEL: TEST: testOperationClone 890@run 891def testOperationClone(): 892 ctx = Context() 893 ctx.allow_unregistered_dialects = True 894 with Location.unknown(ctx): 895 m = Module.create() 896 with InsertionPoint(m.body): 897 op = Operation.create("custom.op1") 898 899 # CHECK: "custom.op1" 900 print(m) 901 902 clone = op.operation.clone() 903 op.operation.erase() 904 905 # CHECK: "custom.op1" 906 print(m) 907 908 909# CHECK-LABEL: TEST: testOperationLoc 910@run 911def testOperationLoc(): 912 ctx = Context() 913 ctx.allow_unregistered_dialects = True 914 with ctx: 915 loc = Location.name("loc") 916 op = Operation.create("custom.op", loc=loc) 917 assert op.location == loc 918 assert op.operation.location == loc 919 920 921# CHECK-LABEL: TEST: testModuleMerge 922@run 923def testModuleMerge(): 924 with Context(): 925 m1 = Module.parse("func.func private @foo()") 926 m2 = Module.parse( 927 """ 928 func.func private @bar() 929 func.func private @qux() 930 """ 931 ) 932 foo = m1.body.operations[0] 933 bar = m2.body.operations[0] 934 qux = m2.body.operations[1] 935 bar.move_before(foo) 936 qux.move_after(foo) 937 938 # CHECK: module 939 # CHECK: func private @bar 940 # CHECK: func private @foo 941 # CHECK: func private @qux 942 print(m1) 943 944 # CHECK: module { 945 # CHECK-NEXT: } 946 print(m2) 947 948 949# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock 950@run 951def testAppendMoveFromAnotherBlock(): 952 with Context(): 953 m1 = Module.parse("func.func private @foo()") 954 m2 = Module.parse("func.func private @bar()") 955 func = m1.body.operations[0] 956 m2.body.append(func) 957 958 # CHECK: module 959 # CHECK: func private @bar 960 # CHECK: func private @foo 961 962 print(m2) 963 # CHECK: module { 964 # CHECK-NEXT: } 965 print(m1) 966 967 968# CHECK-LABEL: TEST: testDetachFromParent 969@run 970def testDetachFromParent(): 971 with Context(): 972 m1 = Module.parse("func.func private @foo()") 973 func = m1.body.operations[0].detach_from_parent() 974 975 try: 976 func.detach_from_parent() 977 except ValueError as e: 978 if "has no parent" not in str(e): 979 raise 980 else: 981 assert False, "expected ValueError when detaching a detached operation" 982 983 print(m1) 984 # CHECK-NOT: func private @foo 985 986 987# CHECK-LABEL: TEST: testOperationHash 988@run 989def testOperationHash(): 990 ctx = Context() 991 ctx.allow_unregistered_dialects = True 992 with ctx, Location.unknown(): 993 op = Operation.create("custom.op1") 994 assert hash(op) == hash(op.operation) 995 996 997# CHECK-LABEL: TEST: testOperationParse 998@run 999def testOperationParse(): 1000 with Context() as ctx: 1001 ctx.allow_unregistered_dialects = True 1002 1003 # Generic operation parsing. 1004 m = Operation.parse("module {}") 1005 o = Operation.parse('"test.foo"() : () -> ()') 1006 assert isinstance(m, ModuleOp) 1007 assert type(o) is OpView 1008 1009 # Parsing specific operation. 1010 m = ModuleOp.parse("module {}") 1011 assert isinstance(m, ModuleOp) 1012 try: 1013 ModuleOp.parse('"test.foo"() : () -> ()') 1014 except MLIRError as e: 1015 # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' 1016 print(f"error: {e}") 1017 else: 1018 assert False, "expected error" 1019 1020 o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string") 1021 # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1) 1022 print( 1023 f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}" 1024 ) 1025 1026 1027# CHECK-LABEL: TEST: testOpWalk 1028@run 1029def testOpWalk(): 1030 ctx = Context() 1031 ctx.allow_unregistered_dialects = True 1032 module = Module.parse( 1033 r""" 1034 builtin.module { 1035 func.func @f() { 1036 func.return 1037 } 1038 } 1039 """, 1040 ctx, 1041 ) 1042 1043 def callback(op): 1044 print(op.name) 1045 return WalkResult.ADVANCE 1046 1047 # Test post-order walk (default). 1048 # CHECK-NEXT: Post-order 1049 # CHECK-NEXT: func.return 1050 # CHECK-NEXT: func.func 1051 # CHECK-NEXT: builtin.module 1052 print("Post-order") 1053 module.operation.walk(callback) 1054 1055 # Test pre-order walk. 1056 # CHECK-NEXT: Pre-order 1057 # CHECK-NEXT: builtin.module 1058 # CHECK-NEXT: func.fun 1059 # CHECK-NEXT: func.return 1060 print("Pre-order") 1061 module.operation.walk(callback, WalkOrder.PRE_ORDER) 1062 1063 # Test interrput. 1064 # CHECK-NEXT: Interrupt post-order 1065 # CHECK-NEXT: func.return 1066 print("Interrupt post-order") 1067 1068 def callback(op): 1069 print(op.name) 1070 return WalkResult.INTERRUPT 1071 1072 module.operation.walk(callback) 1073 1074 # Test skip. 1075 # CHECK-NEXT: Skip pre-order 1076 # CHECK-NEXT: builtin.module 1077 print("Skip pre-order") 1078 1079 def callback(op): 1080 print(op.name) 1081 return WalkResult.SKIP 1082 1083 module.operation.walk(callback, WalkOrder.PRE_ORDER) 1084 1085 # Test exception. 1086 # CHECK: Exception 1087 # CHECK-NEXT: func.return 1088 # CHECK-NEXT: Exception raised 1089 print("Exception") 1090 1091 def callback(op): 1092 print(op.name) 1093 raise ValueError 1094 return WalkResult.ADVANCE 1095 1096 try: 1097 module.operation.walk(callback) 1098 except RuntimeError: 1099 print("Exception raised") 1100