19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 39f3f6d7bSStella Laurenzoimport gc 49f3f6d7bSStella Laurenzoimport io 59f3f6d7bSStella Laurenzoimport itertools 69f3f6d7bSStella Laurenzofrom mlir.ir import * 737107e17Srkayaithfrom mlir.dialects.builtin import ModuleOp 8b0e00ca6SMaksim Leventalfrom mlir.dialects import arith 9b0e00ca6SMaksim Leventalfrom mlir.dialects._ods_common import _cext 109f3f6d7bSStella Laurenzo 11a54f4eaeSMogball 129f3f6d7bSStella Laurenzodef run(f): 139f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 149f3f6d7bSStella Laurenzo f() 159f3f6d7bSStella Laurenzo gc.collect() 169f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 1778f2dae0SAlex Zinenko return f 189f3f6d7bSStella Laurenzo 199f3f6d7bSStella Laurenzo 20429b0cf1SStella Laurenzodef expect_index_error(callback): 21429b0cf1SStella Laurenzo try: 22429b0cf1SStella Laurenzo _ = callback() 23429b0cf1SStella Laurenzo raise RuntimeError("Expected IndexError") 24429b0cf1SStella Laurenzo except IndexError: 25429b0cf1SStella Laurenzo pass 26429b0cf1SStella Laurenzo 27429b0cf1SStella Laurenzo 289f3f6d7bSStella Laurenzo# Verify iterator based traversal of the op/region/block hierarchy. 299f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators 3078f2dae0SAlex Zinenko@run 319f3f6d7bSStella Laurenzodef testTraverseOpRegionBlockIterators(): 329f3f6d7bSStella Laurenzo ctx = Context() 339f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 34a54f4eaeSMogball module = Module.parse( 35a54f4eaeSMogball r""" 362310ced8SRiver Riddle func.func @f1(%arg0: i32) -> i32 { 379f3f6d7bSStella Laurenzo %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 389f3f6d7bSStella Laurenzo return %1 : i32 399f3f6d7bSStella Laurenzo } 40f9008e63STobias Hieta """, 41f9008e63STobias Hieta ctx, 42f9008e63STobias Hieta ) 439f3f6d7bSStella Laurenzo op = module.operation 449f3f6d7bSStella Laurenzo assert op.context is ctx 459f3f6d7bSStella Laurenzo # Get the block using iterators off of the named collections. 469f3f6d7bSStella Laurenzo regions = list(op.regions) 479f3f6d7bSStella Laurenzo blocks = list(regions[0].blocks) 489f3f6d7bSStella Laurenzo # CHECK: MODULE REGIONS=1 BLOCKS=1 499f3f6d7bSStella Laurenzo print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}") 509f3f6d7bSStella Laurenzo 519f3f6d7bSStella Laurenzo # Should verify. 529f3f6d7bSStella Laurenzo # CHECK: .verify = True 539f3f6d7bSStella Laurenzo print(f".verify = {module.operation.verify()}") 549f3f6d7bSStella Laurenzo 55d0d26ee7SRahul Kayaith # Get the blocks from the default collection. 56d0d26ee7SRahul Kayaith default_blocks = list(regions[0]) 579f3f6d7bSStella Laurenzo # They should compare equal regardless of how obtained. 589f3f6d7bSStella Laurenzo assert default_blocks == blocks 599f3f6d7bSStella Laurenzo 609f3f6d7bSStella Laurenzo # Should be able to get the operations from either the named collection 619f3f6d7bSStella Laurenzo # or the block. 629f3f6d7bSStella Laurenzo operations = list(blocks[0].operations) 639f3f6d7bSStella Laurenzo default_operations = list(blocks[0]) 649f3f6d7bSStella Laurenzo assert default_operations == operations 659f3f6d7bSStella Laurenzo 669f3f6d7bSStella Laurenzo def walk_operations(indent, op): 67f431d387SMehdi Amini for i, region in enumerate(op.regions): 689f3f6d7bSStella Laurenzo print(f"{indent}REGION {i}:") 699f3f6d7bSStella Laurenzo for j, block in enumerate(region): 709f3f6d7bSStella Laurenzo print(f"{indent} BLOCK {j}:") 719f3f6d7bSStella Laurenzo for k, child_op in enumerate(block): 729f3f6d7bSStella Laurenzo print(f"{indent} OP {k}: {child_op}") 739f3f6d7bSStella Laurenzo walk_operations(indent + " ", child_op) 749f3f6d7bSStella Laurenzo 759f3f6d7bSStella Laurenzo # CHECK: REGION 0: 769f3f6d7bSStella Laurenzo # CHECK: BLOCK 0: 77c7515a49SMehdi Amini # CHECK: OP 0: func 789f3f6d7bSStella Laurenzo # CHECK: REGION 0: 799f3f6d7bSStella Laurenzo # CHECK: BLOCK 0: 809f3f6d7bSStella Laurenzo # CHECK: OP 0: %0 = "custom.addi" 81a8308020SRiver Riddle # CHECK: OP 1: func.return 829f3f6d7bSStella Laurenzo walk_operations("", op) 839f3f6d7bSStella Laurenzo 84d0d26ee7SRahul Kayaith # CHECK: Region iter: <mlir.{{.+}}.RegionIterator 85d0d26ee7SRahul Kayaith # CHECK: Block iter: <mlir.{{.+}}.BlockIterator 86d0d26ee7SRahul Kayaith # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator 87d0d26ee7SRahul Kayaith print(" Region iter:", iter(op.regions)) 88d0d26ee7SRahul Kayaith print(" Block iter:", iter(op.regions[0])) 89d0d26ee7SRahul Kayaith print("Operation iter:", iter(op.regions[0].blocks[0])) 90d0d26ee7SRahul Kayaith 91a54f4eaeSMogball 929f3f6d7bSStella Laurenzo# Verify index based traversal of the op/region/block hierarchy. 939f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices 9478f2dae0SAlex Zinenko@run 959f3f6d7bSStella Laurenzodef testTraverseOpRegionBlockIndices(): 969f3f6d7bSStella Laurenzo ctx = Context() 979f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 98a54f4eaeSMogball module = Module.parse( 99a54f4eaeSMogball r""" 1002310ced8SRiver Riddle func.func @f1(%arg0: i32) -> i32 { 1019f3f6d7bSStella Laurenzo %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 1029f3f6d7bSStella Laurenzo return %1 : i32 1039f3f6d7bSStella Laurenzo } 104f9008e63STobias Hieta """, 105f9008e63STobias Hieta ctx, 106f9008e63STobias Hieta ) 1079f3f6d7bSStella Laurenzo 1089f3f6d7bSStella Laurenzo def walk_operations(indent, op): 1099f3f6d7bSStella Laurenzo for i in range(len(op.regions)): 1109f3f6d7bSStella Laurenzo region = op.regions[i] 1119f3f6d7bSStella Laurenzo print(f"{indent}REGION {i}:") 1129f3f6d7bSStella Laurenzo for j in range(len(region.blocks)): 1139f3f6d7bSStella Laurenzo block = region.blocks[j] 1149f3f6d7bSStella Laurenzo print(f"{indent} BLOCK {j}:") 1159f3f6d7bSStella Laurenzo for k in range(len(block.operations)): 1169f3f6d7bSStella Laurenzo child_op = block.operations[k] 1179f3f6d7bSStella Laurenzo print(f"{indent} OP {k}: {child_op}") 118f9008e63STobias Hieta print( 119f9008e63STobias Hieta f"{indent} OP {k}: parent {child_op.operation.parent.name}" 120f9008e63STobias Hieta ) 1219f3f6d7bSStella Laurenzo walk_operations(indent + " ", child_op) 1229f3f6d7bSStella Laurenzo 1239f3f6d7bSStella Laurenzo # CHECK: REGION 0: 1249f3f6d7bSStella Laurenzo # CHECK: BLOCK 0: 125c7515a49SMehdi Amini # CHECK: OP 0: func 126dad10a9aSStella Laurenzo # CHECK: OP 0: parent builtin.module 1279f3f6d7bSStella Laurenzo # CHECK: REGION 0: 1289f3f6d7bSStella Laurenzo # CHECK: BLOCK 0: 1299f3f6d7bSStella Laurenzo # CHECK: OP 0: %0 = "custom.addi" 13036550692SRiver Riddle # CHECK: OP 0: parent func.func 131a8308020SRiver Riddle # CHECK: OP 1: func.return 13236550692SRiver Riddle # CHECK: OP 1: parent func.func 1339f3f6d7bSStella Laurenzo walk_operations("", module.operation) 1349f3f6d7bSStella Laurenzo 135a54f4eaeSMogball 13678f2dae0SAlex Zinenko# CHECK-LABEL: TEST: testBlockAndRegionOwners 13778f2dae0SAlex Zinenko@run 13878f2dae0SAlex Zinenkodef testBlockAndRegionOwners(): 13978f2dae0SAlex Zinenko ctx = Context() 14078f2dae0SAlex Zinenko ctx.allow_unregistered_dialects = True 14178f2dae0SAlex Zinenko module = Module.parse( 14278f2dae0SAlex Zinenko r""" 14378f2dae0SAlex Zinenko builtin.module { 14436550692SRiver Riddle func.func @f() { 14523aa5a74SRiver Riddle func.return 14678f2dae0SAlex Zinenko } 14778f2dae0SAlex Zinenko } 148f9008e63STobias Hieta """, 149f9008e63STobias Hieta ctx, 150f9008e63STobias Hieta ) 15178f2dae0SAlex Zinenko 15278f2dae0SAlex Zinenko assert module.operation.regions[0].owner == module.operation 15378f2dae0SAlex Zinenko assert module.operation.regions[0].blocks[0].owner == module.operation 15478f2dae0SAlex Zinenko 15578f2dae0SAlex Zinenko func = module.body.operations[0] 15678f2dae0SAlex Zinenko assert func.operation.regions[0].owner == func 15778f2dae0SAlex Zinenko assert func.operation.regions[0].blocks[0].owner == func 1589f3f6d7bSStella Laurenzo 1599f3f6d7bSStella Laurenzo 1609f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testBlockArgumentList 16178f2dae0SAlex Zinenko@run 1629f3f6d7bSStella Laurenzodef testBlockArgumentList(): 1639f3f6d7bSStella Laurenzo with Context() as ctx: 164a54f4eaeSMogball module = Module.parse( 165a54f4eaeSMogball r""" 1662310ced8SRiver Riddle func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) { 1679f3f6d7bSStella Laurenzo return 1689f3f6d7bSStella Laurenzo } 169f9008e63STobias Hieta """, 170f9008e63STobias Hieta ctx, 171f9008e63STobias Hieta ) 1729f3f6d7bSStella Laurenzo func = module.body.operations[0] 1739f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 1749f3f6d7bSStella Laurenzo assert len(entry_block.arguments) == 3 1759f3f6d7bSStella Laurenzo # CHECK: Argument 0, type i32 1769f3f6d7bSStella Laurenzo # CHECK: Argument 1, type f64 1779f3f6d7bSStella Laurenzo # CHECK: Argument 2, type index 1789f3f6d7bSStella Laurenzo for arg in entry_block.arguments: 1799f3f6d7bSStella Laurenzo print(f"Argument {arg.arg_number}, type {arg.type}") 1809f3f6d7bSStella Laurenzo new_type = IntegerType.get_signless(8 * (arg.arg_number + 1)) 1819f3f6d7bSStella Laurenzo arg.set_type(new_type) 1829f3f6d7bSStella Laurenzo 1839f3f6d7bSStella Laurenzo # CHECK: Argument 0, type i8 1849f3f6d7bSStella Laurenzo # CHECK: Argument 1, type i16 1859f3f6d7bSStella Laurenzo # CHECK: Argument 2, type i24 1869f3f6d7bSStella Laurenzo for arg in entry_block.arguments: 1879f3f6d7bSStella Laurenzo print(f"Argument {arg.arg_number}, type {arg.type}") 1889f3f6d7bSStella Laurenzo 189afeda4b9SAlex Zinenko # Check that slicing works for block argument lists. 190afeda4b9SAlex Zinenko # CHECK: Argument 1, type i16 191afeda4b9SAlex Zinenko # CHECK: Argument 2, type i24 192afeda4b9SAlex Zinenko for arg in entry_block.arguments[1:]: 193afeda4b9SAlex Zinenko print(f"Argument {arg.arg_number}, type {arg.type}") 194afeda4b9SAlex Zinenko 195afeda4b9SAlex Zinenko # Check that we can concatenate slices of argument lists. 196afeda4b9SAlex Zinenko # CHECK: Length: 4 197f9008e63STobias Hieta print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:])) 198afeda4b9SAlex Zinenko 199ed9e52f3SAlex Zinenko # CHECK: Type: i8 200ed9e52f3SAlex Zinenko # CHECK: Type: i16 201ed9e52f3SAlex Zinenko # CHECK: Type: i24 202ed9e52f3SAlex Zinenko for t in entry_block.arguments.types: 203ed9e52f3SAlex Zinenko print("Type: ", t) 204ed9e52f3SAlex Zinenko 205ee168fb9SAlex Zinenko # Check that slicing and type access compose. 206ee168fb9SAlex Zinenko # CHECK: Sliced type: i16 207ee168fb9SAlex Zinenko # CHECK: Sliced type: i24 208ee168fb9SAlex Zinenko for t in entry_block.arguments[1:].types: 209ee168fb9SAlex Zinenko print("Sliced type: ", t) 210ee168fb9SAlex Zinenko 211ee168fb9SAlex Zinenko # Check that slice addition works as expected. 212ee168fb9SAlex Zinenko # CHECK: Argument 2, type i24 213ee168fb9SAlex Zinenko # CHECK: Argument 0, type i8 214ee168fb9SAlex Zinenko restructured = entry_block.arguments[-1:] + entry_block.arguments[:1] 215ee168fb9SAlex Zinenko for arg in restructured: 216ee168fb9SAlex Zinenko print(f"Argument {arg.arg_number}, type {arg.type}") 217ee168fb9SAlex Zinenko 2189f3f6d7bSStella Laurenzo 2199f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationOperands 22078f2dae0SAlex Zinenko@run 2219f3f6d7bSStella Laurenzodef testOperationOperands(): 2229f3f6d7bSStella Laurenzo with Context() as ctx: 2239f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 224f9008e63STobias Hieta module = Module.parse( 225f9008e63STobias Hieta r""" 2262310ced8SRiver Riddle func.func @f1(%arg0: i32) { 2279f3f6d7bSStella Laurenzo %0 = "test.producer"() : () -> i64 2289f3f6d7bSStella Laurenzo "test.consumer"(%arg0, %0) : (i32, i64) -> () 2299f3f6d7bSStella Laurenzo return 230f9008e63STobias Hieta }""" 231f9008e63STobias Hieta ) 2329f3f6d7bSStella Laurenzo func = module.body.operations[0] 2339f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 2349f3f6d7bSStella Laurenzo consumer = entry_block.operations[1] 2359f3f6d7bSStella Laurenzo assert len(consumer.operands) == 2 2369f3f6d7bSStella Laurenzo # CHECK: Operand 0, type i32 2379f3f6d7bSStella Laurenzo # CHECK: Operand 1, type i64 2389f3f6d7bSStella Laurenzo for i, operand in enumerate(consumer.operands): 2399f3f6d7bSStella Laurenzo print(f"Operand {i}, type {operand.type}") 2409f3f6d7bSStella Laurenzo 2419f3f6d7bSStella Laurenzo 2429f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationOperandsSlice 24378f2dae0SAlex Zinenko@run 2449f3f6d7bSStella Laurenzodef testOperationOperandsSlice(): 2459f3f6d7bSStella Laurenzo with Context() as ctx: 2469f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 247f9008e63STobias Hieta module = Module.parse( 248f9008e63STobias Hieta r""" 2492310ced8SRiver Riddle func.func @f1() { 2509f3f6d7bSStella Laurenzo %0 = "test.producer0"() : () -> i64 2519f3f6d7bSStella Laurenzo %1 = "test.producer1"() : () -> i64 2529f3f6d7bSStella Laurenzo %2 = "test.producer2"() : () -> i64 2539f3f6d7bSStella Laurenzo %3 = "test.producer3"() : () -> i64 2549f3f6d7bSStella Laurenzo %4 = "test.producer4"() : () -> i64 2559f3f6d7bSStella Laurenzo "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> () 2569f3f6d7bSStella Laurenzo return 257f9008e63STobias Hieta }""" 258f9008e63STobias Hieta ) 2599f3f6d7bSStella Laurenzo func = module.body.operations[0] 2609f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 2619f3f6d7bSStella Laurenzo consumer = entry_block.operations[5] 2629f3f6d7bSStella Laurenzo assert len(consumer.operands) == 5 2639f3f6d7bSStella Laurenzo for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]): 2649f3f6d7bSStella Laurenzo assert left == right 2659f3f6d7bSStella Laurenzo 2669f3f6d7bSStella Laurenzo # CHECK: test.producer0 2679f3f6d7bSStella Laurenzo # CHECK: test.producer1 2689f3f6d7bSStella Laurenzo # CHECK: test.producer2 2699f3f6d7bSStella Laurenzo # CHECK: test.producer3 2709f3f6d7bSStella Laurenzo # CHECK: test.producer4 2719f3f6d7bSStella Laurenzo full_slice = consumer.operands[:] 2729f3f6d7bSStella Laurenzo for operand in full_slice: 2739f3f6d7bSStella Laurenzo print(operand) 2749f3f6d7bSStella Laurenzo 2759f3f6d7bSStella Laurenzo # CHECK: test.producer0 2769f3f6d7bSStella Laurenzo # CHECK: test.producer1 2779f3f6d7bSStella Laurenzo first_two = consumer.operands[0:2] 2789f3f6d7bSStella Laurenzo for operand in first_two: 2799f3f6d7bSStella Laurenzo print(operand) 2809f3f6d7bSStella Laurenzo 2819f3f6d7bSStella Laurenzo # CHECK: test.producer3 2829f3f6d7bSStella Laurenzo # CHECK: test.producer4 2839f3f6d7bSStella Laurenzo last_two = consumer.operands[3:] 2849f3f6d7bSStella Laurenzo for operand in last_two: 2859f3f6d7bSStella Laurenzo print(operand) 2869f3f6d7bSStella Laurenzo 2879f3f6d7bSStella Laurenzo # CHECK: test.producer0 2889f3f6d7bSStella Laurenzo # CHECK: test.producer2 2899f3f6d7bSStella Laurenzo # CHECK: test.producer4 2909f3f6d7bSStella Laurenzo even = consumer.operands[::2] 2919f3f6d7bSStella Laurenzo for operand in even: 2929f3f6d7bSStella Laurenzo print(operand) 2939f3f6d7bSStella Laurenzo 2949f3f6d7bSStella Laurenzo # CHECK: test.producer2 2959f3f6d7bSStella Laurenzo fourth = consumer.operands[::2][1::2] 2969f3f6d7bSStella Laurenzo for operand in fourth: 2979f3f6d7bSStella Laurenzo print(operand) 2989f3f6d7bSStella Laurenzo 2999f3f6d7bSStella Laurenzo 3009f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationOperandsSet 30178f2dae0SAlex Zinenko@run 3029f3f6d7bSStella Laurenzodef testOperationOperandsSet(): 3039f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(ctx): 3049f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 305f9008e63STobias Hieta module = Module.parse( 306f9008e63STobias Hieta r""" 3072310ced8SRiver Riddle func.func @f1() { 3089f3f6d7bSStella Laurenzo %0 = "test.producer0"() : () -> i64 3099f3f6d7bSStella Laurenzo %1 = "test.producer1"() : () -> i64 3109f3f6d7bSStella Laurenzo %2 = "test.producer2"() : () -> i64 3119f3f6d7bSStella Laurenzo "test.consumer"(%0) : (i64) -> () 3129f3f6d7bSStella Laurenzo return 313f9008e63STobias Hieta }""" 314f9008e63STobias Hieta ) 3159f3f6d7bSStella Laurenzo func = module.body.operations[0] 3169f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 3179f3f6d7bSStella Laurenzo producer1 = entry_block.operations[1] 3189f3f6d7bSStella Laurenzo producer2 = entry_block.operations[2] 3199f3f6d7bSStella Laurenzo consumer = entry_block.operations[3] 3209f3f6d7bSStella Laurenzo assert len(consumer.operands) == 1 3219f3f6d7bSStella Laurenzo type = consumer.operands[0].type 3229f3f6d7bSStella Laurenzo 3239f3f6d7bSStella Laurenzo # CHECK: test.producer1 3249f3f6d7bSStella Laurenzo consumer.operands[0] = producer1.result 3259f3f6d7bSStella Laurenzo print(consumer.operands[0]) 3269f3f6d7bSStella Laurenzo 3279f3f6d7bSStella Laurenzo # CHECK: test.producer2 3289f3f6d7bSStella Laurenzo consumer.operands[-1] = producer2.result 3299f3f6d7bSStella Laurenzo print(consumer.operands[0]) 3309f3f6d7bSStella Laurenzo 3319f3f6d7bSStella Laurenzo 3329f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testDetachedOperation 33378f2dae0SAlex Zinenko@run 3349f3f6d7bSStella Laurenzodef testDetachedOperation(): 3359f3f6d7bSStella Laurenzo ctx = Context() 3369f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 3379f3f6d7bSStella Laurenzo with Location.unknown(ctx): 3389f3f6d7bSStella Laurenzo i32 = IntegerType.get_signed(32) 3399f3f6d7bSStella Laurenzo op1 = Operation.create( 340a54f4eaeSMogball "custom.op1", 341a54f4eaeSMogball results=[i32, i32], 342a54f4eaeSMogball regions=1, 343a54f4eaeSMogball attributes={ 3449f3f6d7bSStella Laurenzo "foo": StringAttr.get("foo_value"), 3459f3f6d7bSStella Laurenzo "bar": StringAttr.get("bar_value"), 346f9008e63STobias Hieta }, 347f9008e63STobias Hieta ) 3489f3f6d7bSStella Laurenzo # CHECK: %0:2 = "custom.op1"() ({ 3499f3f6d7bSStella Laurenzo # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32) 3509f3f6d7bSStella Laurenzo print(op1) 3519f3f6d7bSStella Laurenzo 3529f3f6d7bSStella Laurenzo # TODO: Check successors once enough infra exists to do it properly. 3539f3f6d7bSStella Laurenzo 354a54f4eaeSMogball 3559f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationInsertionPoint 35678f2dae0SAlex Zinenko@run 3579f3f6d7bSStella Laurenzodef testOperationInsertionPoint(): 3589f3f6d7bSStella Laurenzo ctx = Context() 3599f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 360a54f4eaeSMogball module = Module.parse( 361a54f4eaeSMogball r""" 3622310ced8SRiver Riddle func.func @f1(%arg0: i32) -> i32 { 3639f3f6d7bSStella Laurenzo %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 3649f3f6d7bSStella Laurenzo return %1 : i32 3659f3f6d7bSStella Laurenzo } 366f9008e63STobias Hieta """, 367f9008e63STobias Hieta ctx, 368f9008e63STobias Hieta ) 3699f3f6d7bSStella Laurenzo 3709f3f6d7bSStella Laurenzo # Create test op. 3719f3f6d7bSStella Laurenzo with Location.unknown(ctx): 3729f3f6d7bSStella Laurenzo op1 = Operation.create("custom.op1") 3739f3f6d7bSStella Laurenzo op2 = Operation.create("custom.op2") 3749f3f6d7bSStella Laurenzo 3759f3f6d7bSStella Laurenzo func = module.body.operations[0] 3769f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 3779f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_begin(entry_block) 3789f3f6d7bSStella Laurenzo ip.insert(op1) 3799f3f6d7bSStella Laurenzo ip.insert(op2) 3809f3f6d7bSStella Laurenzo # CHECK: func @f1 3819f3f6d7bSStella Laurenzo # CHECK: "custom.op1"() 3829f3f6d7bSStella Laurenzo # CHECK: "custom.op2"() 3839f3f6d7bSStella Laurenzo # CHECK: %0 = "custom.addi" 3849f3f6d7bSStella Laurenzo print(module) 3859f3f6d7bSStella Laurenzo 3869f3f6d7bSStella Laurenzo # Trying to add a previously added op should raise. 3879f3f6d7bSStella Laurenzo try: 3889f3f6d7bSStella Laurenzo ip.insert(op1) 3899f3f6d7bSStella Laurenzo except ValueError: 3909f3f6d7bSStella Laurenzo pass 3919f3f6d7bSStella Laurenzo else: 3929f3f6d7bSStella Laurenzo assert False, "expected insert of attached op to raise" 3939f3f6d7bSStella Laurenzo 394a54f4eaeSMogball 3959f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationWithRegion 39678f2dae0SAlex Zinenko@run 3979f3f6d7bSStella Laurenzodef testOperationWithRegion(): 3989f3f6d7bSStella Laurenzo ctx = Context() 3999f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 4009f3f6d7bSStella Laurenzo with Location.unknown(ctx): 4019f3f6d7bSStella Laurenzo i32 = IntegerType.get_signed(32) 4029f3f6d7bSStella Laurenzo op1 = Operation.create("custom.op1", regions=1) 4039f3f6d7bSStella Laurenzo block = op1.regions[0].blocks.append(i32, i32) 4049f3f6d7bSStella Laurenzo # CHECK: "custom.op1"() ({ 405d75c3e83SRiver Riddle # CHECK: ^bb0(%arg0: si32, %arg1: si32): 4069f3f6d7bSStella Laurenzo # CHECK: "custom.terminator"() : () -> () 4079f3f6d7bSStella Laurenzo # CHECK: }) : () -> () 4089f3f6d7bSStella Laurenzo terminator = Operation.create("custom.terminator") 4099f3f6d7bSStella Laurenzo ip = InsertionPoint(block) 4109f3f6d7bSStella Laurenzo ip.insert(terminator) 4119f3f6d7bSStella Laurenzo print(op1) 4129f3f6d7bSStella Laurenzo 4139f3f6d7bSStella Laurenzo # Now add the whole operation to another op. 4149f3f6d7bSStella Laurenzo # TODO: Verify lifetime hazard by nulling out the new owning module and 4159f3f6d7bSStella Laurenzo # accessing op1. 4169f3f6d7bSStella Laurenzo # TODO: Also verify accessing the terminator once both parents are nulled 4179f3f6d7bSStella Laurenzo # out. 418f9008e63STobias Hieta module = Module.parse( 419f9008e63STobias Hieta r""" 4202310ced8SRiver Riddle func.func @f1(%arg0: i32) -> i32 { 4219f3f6d7bSStella Laurenzo %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 4229f3f6d7bSStella Laurenzo return %1 : i32 4239f3f6d7bSStella Laurenzo } 424f9008e63STobias Hieta """ 425f9008e63STobias Hieta ) 4269f3f6d7bSStella Laurenzo func = module.body.operations[0] 4279f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 4289f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_begin(entry_block) 4299f3f6d7bSStella Laurenzo ip.insert(op1) 4309f3f6d7bSStella Laurenzo # CHECK: func @f1 4319f3f6d7bSStella Laurenzo # CHECK: "custom.op1"() 4329f3f6d7bSStella Laurenzo # CHECK: "custom.terminator" 4339f3f6d7bSStella Laurenzo # CHECK: %0 = "custom.addi" 4349f3f6d7bSStella Laurenzo print(module) 4359f3f6d7bSStella Laurenzo 436a54f4eaeSMogball 4379f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationResultList 43878f2dae0SAlex Zinenko@run 4399f3f6d7bSStella Laurenzodef testOperationResultList(): 4409f3f6d7bSStella Laurenzo ctx = Context() 441a54f4eaeSMogball module = Module.parse( 442a54f4eaeSMogball r""" 4432310ced8SRiver Riddle func.func @f1() { 4449f3f6d7bSStella Laurenzo %0:3 = call @f2() : () -> (i32, f64, index) 4456e4ea4eeSmax call @f3() : () -> () 4469f3f6d7bSStella Laurenzo return 4479f3f6d7bSStella Laurenzo } 4482310ced8SRiver Riddle func.func private @f2() -> (i32, f64, index) 4496e4ea4eeSmax func.func private @f3() -> () 450f9008e63STobias Hieta """, 451f9008e63STobias Hieta ctx, 452f9008e63STobias Hieta ) 4539f3f6d7bSStella Laurenzo caller = module.body.operations[0] 4549f3f6d7bSStella Laurenzo call = caller.regions[0].blocks[0].operations[0] 4559f3f6d7bSStella Laurenzo assert len(call.results) == 3 4569f3f6d7bSStella Laurenzo # CHECK: Result 0, type i32 4579f3f6d7bSStella Laurenzo # CHECK: Result 1, type f64 4589f3f6d7bSStella Laurenzo # CHECK: Result 2, type index 4599f3f6d7bSStella Laurenzo for res in call.results: 4609f3f6d7bSStella Laurenzo print(f"Result {res.result_number}, type {res.type}") 4619f3f6d7bSStella Laurenzo 462ed9e52f3SAlex Zinenko # CHECK: Result type i32 463ed9e52f3SAlex Zinenko # CHECK: Result type f64 464ed9e52f3SAlex Zinenko # CHECK: Result type index 465ed9e52f3SAlex Zinenko for t in call.results.types: 466ed9e52f3SAlex Zinenko print(f"Result type {t}") 467ed9e52f3SAlex Zinenko 468429b0cf1SStella Laurenzo # Out of range 469429b0cf1SStella Laurenzo expect_index_error(lambda: call.results[3]) 470429b0cf1SStella Laurenzo expect_index_error(lambda: call.results[-4]) 4719f3f6d7bSStella Laurenzo 4726e4ea4eeSmax no_results_call = caller.regions[0].blocks[0].operations[1] 4736e4ea4eeSmax assert len(no_results_call.results) == 0 4746e4ea4eeSmax assert no_results_call.results.owner == no_results_call 4756e4ea4eeSmax 4769f3f6d7bSStella Laurenzo 4779f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationResultListSlice 47878f2dae0SAlex Zinenko@run 4799f3f6d7bSStella Laurenzodef testOperationResultListSlice(): 4809f3f6d7bSStella Laurenzo with Context() as ctx: 4819f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 482f9008e63STobias Hieta module = Module.parse( 483f9008e63STobias Hieta r""" 4842310ced8SRiver Riddle func.func @f1() { 4859f3f6d7bSStella Laurenzo "some.op"() : () -> (i1, i2, i3, i4, i5) 4869f3f6d7bSStella Laurenzo return 4879f3f6d7bSStella Laurenzo } 488f9008e63STobias Hieta """ 489f9008e63STobias Hieta ) 4909f3f6d7bSStella Laurenzo func = module.body.operations[0] 4919f3f6d7bSStella Laurenzo entry_block = func.regions[0].blocks[0] 4929f3f6d7bSStella Laurenzo producer = entry_block.operations[0] 4939f3f6d7bSStella Laurenzo 4949f3f6d7bSStella Laurenzo assert len(producer.results) == 5 4959f3f6d7bSStella Laurenzo for left, right in zip(producer.results, producer.results[::-1][::-1]): 4969f3f6d7bSStella Laurenzo assert left == right 4979f3f6d7bSStella Laurenzo assert left.result_number == right.result_number 4989f3f6d7bSStella Laurenzo 4999f3f6d7bSStella Laurenzo # CHECK: Result 0, type i1 5009f3f6d7bSStella Laurenzo # CHECK: Result 1, type i2 5019f3f6d7bSStella Laurenzo # CHECK: Result 2, type i3 5029f3f6d7bSStella Laurenzo # CHECK: Result 3, type i4 5039f3f6d7bSStella Laurenzo # CHECK: Result 4, type i5 5049f3f6d7bSStella Laurenzo full_slice = producer.results[:] 5059f3f6d7bSStella Laurenzo for res in full_slice: 5069f3f6d7bSStella Laurenzo print(f"Result {res.result_number}, type {res.type}") 5079f3f6d7bSStella Laurenzo 5089f3f6d7bSStella Laurenzo # CHECK: Result 1, type i2 5099f3f6d7bSStella Laurenzo # CHECK: Result 2, type i3 5109f3f6d7bSStella Laurenzo # CHECK: Result 3, type i4 5119f3f6d7bSStella Laurenzo middle = producer.results[1:4] 5129f3f6d7bSStella Laurenzo for res in middle: 5139f3f6d7bSStella Laurenzo print(f"Result {res.result_number}, type {res.type}") 5149f3f6d7bSStella Laurenzo 5159f3f6d7bSStella Laurenzo # CHECK: Result 1, type i2 5169f3f6d7bSStella Laurenzo # CHECK: Result 3, type i4 5179f3f6d7bSStella Laurenzo odd = producer.results[1::2] 5189f3f6d7bSStella Laurenzo for res in odd: 5199f3f6d7bSStella Laurenzo print(f"Result {res.result_number}, type {res.type}") 5209f3f6d7bSStella Laurenzo 5219f3f6d7bSStella Laurenzo # CHECK: Result 3, type i4 5229f3f6d7bSStella Laurenzo # CHECK: Result 1, type i2 5239f3f6d7bSStella Laurenzo inverted_middle = producer.results[-2:0:-2] 5249f3f6d7bSStella Laurenzo for res in inverted_middle: 5259f3f6d7bSStella Laurenzo print(f"Result {res.result_number}, type {res.type}") 5269f3f6d7bSStella Laurenzo 5279f3f6d7bSStella Laurenzo 5289f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationAttributes 52978f2dae0SAlex Zinenko@run 5309f3f6d7bSStella Laurenzodef testOperationAttributes(): 5319f3f6d7bSStella Laurenzo ctx = Context() 5329f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 533a54f4eaeSMogball module = Module.parse( 534a54f4eaeSMogball r""" 5359f3f6d7bSStella Laurenzo "some.op"() { some.attribute = 1 : i8, 5369f3f6d7bSStella Laurenzo other.attribute = 3.0, 5379f3f6d7bSStella Laurenzo dependent = "text" } : () -> () 538f9008e63STobias Hieta """, 539f9008e63STobias Hieta ctx, 540f9008e63STobias Hieta ) 5419f3f6d7bSStella Laurenzo op = module.body.operations[0] 5429f3f6d7bSStella Laurenzo assert len(op.attributes) == 3 543974c1596SRahul Kayaith iattr = op.attributes["some.attribute"] 544974c1596SRahul Kayaith fattr = op.attributes["other.attribute"] 545974c1596SRahul Kayaith sattr = op.attributes["dependent"] 5469f3f6d7bSStella Laurenzo # CHECK: Attribute type i8, value 1 5479f3f6d7bSStella Laurenzo print(f"Attribute type {iattr.type}, value {iattr.value}") 5489f3f6d7bSStella Laurenzo # CHECK: Attribute type f64, value 3.0 5499f3f6d7bSStella Laurenzo print(f"Attribute type {fattr.type}, value {fattr.value}") 5509f3f6d7bSStella Laurenzo # CHECK: Attribute value text 5519f3f6d7bSStella Laurenzo print(f"Attribute value {sattr.value}") 55262bf6c2eSChris Jones # CHECK: Attribute value b'text' 55362bf6c2eSChris Jones print(f"Attribute value {sattr.value_bytes}") 5549f3f6d7bSStella Laurenzo 5559f3f6d7bSStella Laurenzo # We don't know in which order the attributes are stored. 5569f3f6d7bSStella Laurenzo # CHECK-DAG: NamedAttribute(dependent="text") 5579f3f6d7bSStella Laurenzo # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64) 5589f3f6d7bSStella Laurenzo # CHECK-DAG: NamedAttribute(some.attribute=1 : i8) 5599f3f6d7bSStella Laurenzo for attr in op.attributes: 5609f3f6d7bSStella Laurenzo print(str(attr)) 5619f3f6d7bSStella Laurenzo 5629f3f6d7bSStella Laurenzo # Check that exceptions are raised as expected. 5639f3f6d7bSStella Laurenzo try: 5649f3f6d7bSStella Laurenzo op.attributes["does_not_exist"] 5659f3f6d7bSStella Laurenzo except KeyError: 5669f3f6d7bSStella Laurenzo pass 5679f3f6d7bSStella Laurenzo else: 5689f3f6d7bSStella Laurenzo assert False, "expected KeyError on accessing a non-existent attribute" 5699f3f6d7bSStella Laurenzo 5709f3f6d7bSStella Laurenzo try: 5719f3f6d7bSStella Laurenzo op.attributes[42] 5729f3f6d7bSStella Laurenzo except IndexError: 5739f3f6d7bSStella Laurenzo pass 5749f3f6d7bSStella Laurenzo else: 5759f3f6d7bSStella Laurenzo assert False, "expected IndexError on accessing an out-of-bounds attribute" 5769f3f6d7bSStella Laurenzo 5779f3f6d7bSStella Laurenzo 5789f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationPrint 57978f2dae0SAlex Zinenko@run 5809f3f6d7bSStella Laurenzodef testOperationPrint(): 5819f3f6d7bSStella Laurenzo ctx = Context() 582a54f4eaeSMogball module = Module.parse( 583a54f4eaeSMogball r""" 5842310ced8SRiver Riddle func.func @f1(%arg0: i32) -> i32 { 585bccf27d9SMark Browning %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 5869f3f6d7bSStella Laurenzo return %arg0 : i32 5879f3f6d7bSStella Laurenzo } 588f9008e63STobias Hieta """, 589f9008e63STobias Hieta ctx, 590f9008e63STobias Hieta ) 5919f3f6d7bSStella Laurenzo 5929f3f6d7bSStella Laurenzo # Test print to stdout. 5939f3f6d7bSStella Laurenzo # CHECK: return %arg0 : i32 5949f3f6d7bSStella Laurenzo module.operation.print() 5959f3f6d7bSStella Laurenzo 5969f3f6d7bSStella Laurenzo # Test print to text file. 5979f3f6d7bSStella Laurenzo f = io.StringIO() 5989f3f6d7bSStella Laurenzo # CHECK: <class 'str'> 5999f3f6d7bSStella Laurenzo # CHECK: return %arg0 : i32 6009f3f6d7bSStella Laurenzo module.operation.print(file=f) 6019f3f6d7bSStella Laurenzo str_value = f.getvalue() 6029f3f6d7bSStella Laurenzo print(str_value.__class__) 6039f3f6d7bSStella Laurenzo print(f.getvalue()) 6049f3f6d7bSStella Laurenzo 60589418ddcSMehdi Amini # Test roundtrip to bytecode. 60689418ddcSMehdi Amini bytecode_stream = io.BytesIO() 6075c90e1ffSJacques Pienaar module.operation.write_bytecode(bytecode_stream, desired_version=1) 60889418ddcSMehdi Amini bytecode = bytecode_stream.getvalue() 609f9008e63STobias Hieta assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR" 61089418ddcSMehdi Amini module_roundtrip = Module.parse(bytecode, ctx) 61189418ddcSMehdi Amini f = io.StringIO() 61289418ddcSMehdi Amini module_roundtrip.operation.print(file=f) 61389418ddcSMehdi Amini roundtrip_value = f.getvalue() 61489418ddcSMehdi Amini assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode" 61589418ddcSMehdi Amini 6169f3f6d7bSStella Laurenzo # Test print to binary file. 6179f3f6d7bSStella Laurenzo f = io.BytesIO() 6189f3f6d7bSStella Laurenzo # CHECK: <class 'bytes'> 6199f3f6d7bSStella Laurenzo # CHECK: return %arg0 : i32 6209f3f6d7bSStella Laurenzo module.operation.print(file=f, binary=True) 6219f3f6d7bSStella Laurenzo bytes_value = f.getvalue() 6229f3f6d7bSStella Laurenzo print(bytes_value.__class__) 6239f3f6d7bSStella Laurenzo print(bytes_value) 6249f3f6d7bSStella Laurenzo 625204acc5cSJacques Pienaar # Test print local_scope. 626bccf27d9SMark Browning # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom") 627bccf27d9SMark Browning module.operation.print(enable_debug_info=True, use_local_scope=True) 628bccf27d9SMark Browning 629204acc5cSJacques Pienaar # Test printing using state. 630204acc5cSJacques Pienaar state = AsmState(module.operation) 631204acc5cSJacques Pienaar # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> 632204acc5cSJacques Pienaar module.operation.print(state) 633204acc5cSJacques Pienaar 634*abad8455SJonas Rickert # Test print with options. 63540abd7eaSRiver Riddle # CHECK: value = dense_resource<__elided__> : tensor<4xi32> 63623aa5a74SRiver Riddle # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7 637a54f4eaeSMogball module.operation.print( 638a54f4eaeSMogball large_elements_limit=2, 639a54f4eaeSMogball enable_debug_info=True, 640a54f4eaeSMogball pretty_debug_info=True, 641a54f4eaeSMogball print_generic_op_form=True, 642f9008e63STobias Hieta use_local_scope=True, 643f9008e63STobias Hieta ) 6449f3f6d7bSStella Laurenzo 645*abad8455SJonas Rickert # Test print with skip_regions option 646*abad8455SJonas Rickert # CHECK: func.func @f1(%arg0: i32) -> i32 647*abad8455SJonas Rickert # CHECK-NOT: func.return 648*abad8455SJonas Rickert module.body.operations[0].print( 649*abad8455SJonas Rickert skip_regions=True, 650*abad8455SJonas Rickert ) 651*abad8455SJonas Rickert 6529f3f6d7bSStella Laurenzo 6539f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testKnownOpView 65478f2dae0SAlex Zinenko@run 6559f3f6d7bSStella Laurenzodef testKnownOpView(): 6569f3f6d7bSStella Laurenzo with Context(), Location.unknown(): 6579f3f6d7bSStella Laurenzo Context.current.allow_unregistered_dialects = True 658f9008e63STobias Hieta module = Module.parse( 659f9008e63STobias Hieta r""" 6609f3f6d7bSStella Laurenzo %1 = "custom.f32"() : () -> f32 6619f3f6d7bSStella Laurenzo %2 = "custom.f32"() : () -> f32 662a54f4eaeSMogball %3 = arith.addf %1, %2 : f32 663b0e00ca6SMaksim Levental %4 = arith.constant 0 : i32 664f9008e63STobias Hieta """ 665f9008e63STobias Hieta ) 6669f3f6d7bSStella Laurenzo print(module) 6679f3f6d7bSStella Laurenzo 66823aa5a74SRiver Riddle # addf should map to a known OpView class in the arithmetic dialect. 6699f3f6d7bSStella Laurenzo # We know the OpView for it defines an 'lhs' attribute. 6709f3f6d7bSStella Laurenzo addf = module.body.operations[2] 671a7f8b7cdSRahul Kayaith # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object 6729f3f6d7bSStella Laurenzo print(repr(addf)) 6739f3f6d7bSStella Laurenzo # CHECK: "custom.f32"() 6749f3f6d7bSStella Laurenzo print(addf.lhs) 6759f3f6d7bSStella Laurenzo 6769f3f6d7bSStella Laurenzo # One of the custom ops should resolve to the default OpView. 6779f3f6d7bSStella Laurenzo custom = module.body.operations[0] 678310c9496SStella Laurenzo # CHECK: OpView object 6799f3f6d7bSStella Laurenzo print(repr(custom)) 6809f3f6d7bSStella Laurenzo 6819f3f6d7bSStella Laurenzo # Check again to make sure negative caching works. 6829f3f6d7bSStella Laurenzo custom = module.body.operations[0] 683310c9496SStella Laurenzo # CHECK: OpView object 6849f3f6d7bSStella Laurenzo print(repr(custom)) 6859f3f6d7bSStella Laurenzo 686b0e00ca6SMaksim Levental # constant should map to an extension OpView class in the arithmetic dialect. 687b0e00ca6SMaksim Levental constant = module.body.operations[3] 688b0e00ca6SMaksim Levental # CHECK: <mlir.dialects.arith.ConstantOp object 689b0e00ca6SMaksim Levental print(repr(constant)) 690b0e00ca6SMaksim Levental # Checks that the arith extension is being registered successfully 691b0e00ca6SMaksim Levental # (literal_value is a property on the extension class but not on the default OpView). 692b0e00ca6SMaksim Levental # CHECK: literal value 0 693b0e00ca6SMaksim Levental print("literal value", constant.literal_value) 694b0e00ca6SMaksim Levental 695b0e00ca6SMaksim Levental # Checks that "late" registration/replacement (i.e., post all module loading/initialization) 696b0e00ca6SMaksim Levental # is working correctly. 697b0e00ca6SMaksim Levental @_cext.register_operation(arith._Dialect, replace=True) 698b0e00ca6SMaksim Levental class ConstantOp(arith.ConstantOp): 699b0e00ca6SMaksim Levental def __init__(self, result, value, *, loc=None, ip=None): 700b0e00ca6SMaksim Levental if isinstance(value, int): 701b0e00ca6SMaksim Levental super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) 702b0e00ca6SMaksim Levental elif isinstance(value, float): 703b0e00ca6SMaksim Levental super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) 704b0e00ca6SMaksim Levental else: 705b0e00ca6SMaksim Levental super().__init__(value, loc=loc, ip=ip) 706b0e00ca6SMaksim Levental 707b0e00ca6SMaksim Levental constant = module.body.operations[3] 708b0e00ca6SMaksim Levental # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object 709b0e00ca6SMaksim Levental print(repr(constant)) 710b0e00ca6SMaksim Levental 711a54f4eaeSMogball 7129f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testSingleResultProperty 71378f2dae0SAlex Zinenko@run 7149f3f6d7bSStella Laurenzodef testSingleResultProperty(): 7159f3f6d7bSStella Laurenzo with Context(), Location.unknown(): 7169f3f6d7bSStella Laurenzo Context.current.allow_unregistered_dialects = True 717f9008e63STobias Hieta module = Module.parse( 718f9008e63STobias Hieta r""" 7199f3f6d7bSStella Laurenzo "custom.no_result"() : () -> () 7209f3f6d7bSStella Laurenzo %0:2 = "custom.two_result"() : () -> (f32, f32) 7219f3f6d7bSStella Laurenzo %1 = "custom.one_result"() : () -> f32 722f9008e63STobias Hieta """ 723f9008e63STobias Hieta ) 7249f3f6d7bSStella Laurenzo print(module) 7259f3f6d7bSStella Laurenzo 7269f3f6d7bSStella Laurenzo try: 7279f3f6d7bSStella Laurenzo module.body.operations[0].result 7289f3f6d7bSStella Laurenzo except ValueError as e: 7299f3f6d7bSStella Laurenzo # CHECK: Cannot call .result on operation custom.no_result which has 0 results 7309f3f6d7bSStella Laurenzo print(e) 7319f3f6d7bSStella Laurenzo else: 7329f3f6d7bSStella Laurenzo assert False, "Expected exception" 7339f3f6d7bSStella Laurenzo 7349f3f6d7bSStella Laurenzo try: 7359f3f6d7bSStella Laurenzo module.body.operations[1].result 7369f3f6d7bSStella Laurenzo except ValueError as e: 7379f3f6d7bSStella Laurenzo # CHECK: Cannot call .result on operation custom.two_result which has 2 results 7389f3f6d7bSStella Laurenzo print(e) 7399f3f6d7bSStella Laurenzo else: 7409f3f6d7bSStella Laurenzo assert False, "Expected exception" 7419f3f6d7bSStella Laurenzo 7429f3f6d7bSStella Laurenzo # CHECK: %1 = "custom.one_result"() : () -> f32 7439f3f6d7bSStella Laurenzo print(module.body.operations[2]) 7449f3f6d7bSStella Laurenzo 745a54f4eaeSMogball 746ace1d0adSStella Laurenzodef create_invalid_operation(): 7479f3f6d7bSStella Laurenzo # This module has two region and is invalid verify that we fallback 7489f3f6d7bSStella Laurenzo # to the generic printer for safety. 749ace1d0adSStella Laurenzo op = Operation.create("builtin.module", regions=2) 750ace1d0adSStella Laurenzo op.regions[0].blocks.append() 751ace1d0adSStella Laurenzo return op 752ace1d0adSStella Laurenzo 753f9008e63STobias Hieta 754ace1d0adSStella Laurenzo# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails 755ace1d0adSStella Laurenzo@run 756ace1d0adSStella Laurenzodef testInvalidOperationStrSoftFails(): 757ace1d0adSStella Laurenzo ctx = Context() 758ace1d0adSStella Laurenzo with Location.unknown(ctx): 759ace1d0adSStella Laurenzo invalid_op = create_invalid_operation() 760ace1d0adSStella Laurenzo # Verify that we fallback to the generic printer for safety. 761dad10a9aSStella Laurenzo # CHECK: "builtin.module"() ({ 7629f3f6d7bSStella Laurenzo # CHECK: }) : () -> () 763ace1d0adSStella Laurenzo print(invalid_op) 7643ea4c501SRahul Kayaith try: 7653ea4c501SRahul Kayaith invalid_op.verify() 7663ea4c501SRahul Kayaith except MLIRError as e: 7673ea4c501SRahul Kayaith # CHECK: Exception: < 7683ea4c501SRahul Kayaith # CHECK: Verification failed: 7693ea4c501SRahul Kayaith # CHECK: error: unknown: 'builtin.module' op requires one region 7703ea4c501SRahul Kayaith # CHECK: note: unknown: see current operation: 7713ea4c501SRahul Kayaith # CHECK: "builtin.module"() ({ 7723ea4c501SRahul Kayaith # CHECK: ^bb0: 7733ea4c501SRahul Kayaith # CHECK: }, { 7743ea4c501SRahul Kayaith # CHECK: }) : () -> () 7753ea4c501SRahul Kayaith # CHECK: > 7763ea4c501SRahul Kayaith print(f"Exception: <{e}>") 777ace1d0adSStella Laurenzo 778ace1d0adSStella Laurenzo 779ace1d0adSStella Laurenzo# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails 780ace1d0adSStella Laurenzo@run 781ace1d0adSStella Laurenzodef testInvalidModuleStrSoftFails(): 782ace1d0adSStella Laurenzo ctx = Context() 783ace1d0adSStella Laurenzo with Location.unknown(ctx): 784ace1d0adSStella Laurenzo module = Module.create() 785ace1d0adSStella Laurenzo with InsertionPoint(module.body): 786ace1d0adSStella Laurenzo invalid_op = create_invalid_operation() 787ace1d0adSStella Laurenzo # Verify that we fallback to the generic printer for safety. 7882aa12583SRahul Kayaith # CHECK: "builtin.module"() ({ 7892aa12583SRahul Kayaith # CHECK: }) : () -> () 790ace1d0adSStella Laurenzo print(module) 791ace1d0adSStella Laurenzo 792ace1d0adSStella Laurenzo 793ace1d0adSStella Laurenzo# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails 794ace1d0adSStella Laurenzo@run 795ace1d0adSStella Laurenzodef testInvalidOperationGetAsmBinarySoftFails(): 796ace1d0adSStella Laurenzo ctx = Context() 797ace1d0adSStella Laurenzo with Location.unknown(ctx): 798ace1d0adSStella Laurenzo invalid_op = create_invalid_operation() 799ace1d0adSStella Laurenzo # Verify that we fallback to the generic printer for safety. 8002aa12583SRahul Kayaith # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n' 801ace1d0adSStella Laurenzo print(invalid_op.get_asm(binary=True)) 802a54f4eaeSMogball 803a54f4eaeSMogball 8049f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testCreateWithInvalidAttributes 80578f2dae0SAlex Zinenko@run 8069f3f6d7bSStella Laurenzodef testCreateWithInvalidAttributes(): 8079f3f6d7bSStella Laurenzo ctx = Context() 8089f3f6d7bSStella Laurenzo with Location.unknown(ctx): 8099f3f6d7bSStella Laurenzo try: 810afeda4b9SAlex Zinenko Operation.create( 811f9008e63STobias Hieta "builtin.module", attributes={None: StringAttr.get("name")} 812f9008e63STobias Hieta ) 8139f3f6d7bSStella Laurenzo except Exception as e: 814dad10a9aSStella Laurenzo # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 8159f3f6d7bSStella Laurenzo print(e) 8169f3f6d7bSStella Laurenzo try: 817f9008e63STobias Hieta Operation.create("builtin.module", attributes={42: StringAttr.get("name")}) 8189f3f6d7bSStella Laurenzo except Exception as e: 819dad10a9aSStella Laurenzo # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module" 8209f3f6d7bSStella Laurenzo print(e) 8219f3f6d7bSStella Laurenzo try: 822f8479d9dSRiver Riddle Operation.create("builtin.module", attributes={"some_key": ctx}) 8239f3f6d7bSStella Laurenzo except Exception as e: 824dad10a9aSStella Laurenzo # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module" 8259f3f6d7bSStella Laurenzo print(e) 8269f3f6d7bSStella Laurenzo try: 827f8479d9dSRiver Riddle Operation.create("builtin.module", attributes={"some_key": None}) 8289f3f6d7bSStella Laurenzo except Exception as e: 829dad10a9aSStella Laurenzo # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module" 8309f3f6d7bSStella Laurenzo print(e) 831a54f4eaeSMogball 832a54f4eaeSMogball 8339f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationName 83478f2dae0SAlex Zinenko@run 8359f3f6d7bSStella Laurenzodef testOperationName(): 8369f3f6d7bSStella Laurenzo ctx = Context() 8379f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 838a54f4eaeSMogball module = Module.parse( 839a54f4eaeSMogball r""" 8409f3f6d7bSStella Laurenzo %0 = "custom.op1"() : () -> f32 8419f3f6d7bSStella Laurenzo %1 = "custom.op2"() : () -> i32 8429f3f6d7bSStella Laurenzo %2 = "custom.op1"() : () -> f32 843f9008e63STobias Hieta """, 844f9008e63STobias Hieta ctx, 845f9008e63STobias Hieta ) 8469f3f6d7bSStella Laurenzo 8479f3f6d7bSStella Laurenzo # CHECK: custom.op1 8489f3f6d7bSStella Laurenzo # CHECK: custom.op2 8499f3f6d7bSStella Laurenzo # CHECK: custom.op1 8509f3f6d7bSStella Laurenzo for op in module.body.operations: 8519f3f6d7bSStella Laurenzo print(op.operation.name) 8529f3f6d7bSStella Laurenzo 853a54f4eaeSMogball 85478bd1246SAlex Zinenko# CHECK-LABEL: TEST: testCapsuleConversions 85578bd1246SAlex Zinenko@run 85678bd1246SAlex Zinenkodef testCapsuleConversions(): 85778bd1246SAlex Zinenko ctx = Context() 85878bd1246SAlex Zinenko ctx.allow_unregistered_dialects = True 85978bd1246SAlex Zinenko with Location.unknown(ctx): 86078bd1246SAlex Zinenko m = Operation.create("custom.op1").operation 86178bd1246SAlex Zinenko m_capsule = m._CAPIPtr 86278bd1246SAlex Zinenko assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) 86378bd1246SAlex Zinenko m2 = Operation._CAPICreate(m_capsule) 86478bd1246SAlex Zinenko assert m2 is m 86578bd1246SAlex Zinenko 86678bd1246SAlex Zinenko 8679f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testOperationErase 86878f2dae0SAlex Zinenko@run 8699f3f6d7bSStella Laurenzodef testOperationErase(): 8709f3f6d7bSStella Laurenzo ctx = Context() 8719f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 8729f3f6d7bSStella Laurenzo with Location.unknown(ctx): 8739f3f6d7bSStella Laurenzo m = Module.create() 8749f3f6d7bSStella Laurenzo with InsertionPoint(m.body): 8759f3f6d7bSStella Laurenzo op = Operation.create("custom.op1") 8769f3f6d7bSStella Laurenzo 8779f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 8789f3f6d7bSStella Laurenzo print(m) 8799f3f6d7bSStella Laurenzo 8809f3f6d7bSStella Laurenzo op.operation.erase() 8819f3f6d7bSStella Laurenzo 8829f3f6d7bSStella Laurenzo # CHECK-NOT: "custom.op1" 8839f3f6d7bSStella Laurenzo print(m) 8849f3f6d7bSStella Laurenzo 8859f3f6d7bSStella Laurenzo # Ensure we can create another operation 8869f3f6d7bSStella Laurenzo Operation.create("custom.op2") 887d5429a13Srkayaith 888d5429a13Srkayaith 889774818c0SDominik Grewe# CHECK-LABEL: TEST: testOperationClone 890774818c0SDominik Grewe@run 891774818c0SDominik Grewedef testOperationClone(): 892774818c0SDominik Grewe ctx = Context() 893774818c0SDominik Grewe ctx.allow_unregistered_dialects = True 894774818c0SDominik Grewe with Location.unknown(ctx): 895774818c0SDominik Grewe m = Module.create() 896774818c0SDominik Grewe with InsertionPoint(m.body): 897774818c0SDominik Grewe op = Operation.create("custom.op1") 898774818c0SDominik Grewe 899774818c0SDominik Grewe # CHECK: "custom.op1" 900774818c0SDominik Grewe print(m) 901774818c0SDominik Grewe 902774818c0SDominik Grewe clone = op.operation.clone() 903774818c0SDominik Grewe op.operation.erase() 904774818c0SDominik Grewe 905774818c0SDominik Grewe # CHECK: "custom.op1" 906774818c0SDominik Grewe print(m) 907774818c0SDominik Grewe 908774818c0SDominik Grewe 909d5429a13Srkayaith# CHECK-LABEL: TEST: testOperationLoc 910d5429a13Srkayaith@run 911d5429a13Srkayaithdef testOperationLoc(): 912d5429a13Srkayaith ctx = Context() 913d5429a13Srkayaith ctx.allow_unregistered_dialects = True 914d5429a13Srkayaith with ctx: 915d5429a13Srkayaith loc = Location.name("loc") 916d5429a13Srkayaith op = Operation.create("custom.op", loc=loc) 917d5429a13Srkayaith assert op.location == loc 918d5429a13Srkayaith assert op.operation.location == loc 91924685aaeSAlex Zinenko 920f78fe0b7Srkayaith 92124685aaeSAlex Zinenko# CHECK-LABEL: TEST: testModuleMerge 92224685aaeSAlex Zinenko@run 92324685aaeSAlex Zinenkodef testModuleMerge(): 92424685aaeSAlex Zinenko with Context(): 925a8308020SRiver Riddle m1 = Module.parse("func.func private @foo()") 926f9008e63STobias Hieta m2 = Module.parse( 927f9008e63STobias Hieta """ 9282310ced8SRiver Riddle func.func private @bar() 9292310ced8SRiver Riddle func.func private @qux() 930f9008e63STobias Hieta """ 931f9008e63STobias Hieta ) 93224685aaeSAlex Zinenko foo = m1.body.operations[0] 93324685aaeSAlex Zinenko bar = m2.body.operations[0] 93424685aaeSAlex Zinenko qux = m2.body.operations[1] 93524685aaeSAlex Zinenko bar.move_before(foo) 93624685aaeSAlex Zinenko qux.move_after(foo) 93724685aaeSAlex Zinenko 93824685aaeSAlex Zinenko # CHECK: module 93924685aaeSAlex Zinenko # CHECK: func private @bar 94024685aaeSAlex Zinenko # CHECK: func private @foo 94124685aaeSAlex Zinenko # CHECK: func private @qux 94224685aaeSAlex Zinenko print(m1) 94324685aaeSAlex Zinenko 94424685aaeSAlex Zinenko # CHECK: module { 94524685aaeSAlex Zinenko # CHECK-NEXT: } 94624685aaeSAlex Zinenko print(m2) 94724685aaeSAlex Zinenko 94824685aaeSAlex Zinenko 94924685aaeSAlex Zinenko# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock 95024685aaeSAlex Zinenko@run 95124685aaeSAlex Zinenkodef testAppendMoveFromAnotherBlock(): 95224685aaeSAlex Zinenko with Context(): 953a8308020SRiver Riddle m1 = Module.parse("func.func private @foo()") 954a8308020SRiver Riddle m2 = Module.parse("func.func private @bar()") 95524685aaeSAlex Zinenko func = m1.body.operations[0] 95624685aaeSAlex Zinenko m2.body.append(func) 95724685aaeSAlex Zinenko 95824685aaeSAlex Zinenko # CHECK: module 95924685aaeSAlex Zinenko # CHECK: func private @bar 96024685aaeSAlex Zinenko # CHECK: func private @foo 96124685aaeSAlex Zinenko 96224685aaeSAlex Zinenko print(m2) 96324685aaeSAlex Zinenko # CHECK: module { 96424685aaeSAlex Zinenko # CHECK-NEXT: } 96524685aaeSAlex Zinenko print(m1) 96624685aaeSAlex Zinenko 96724685aaeSAlex Zinenko 96824685aaeSAlex Zinenko# CHECK-LABEL: TEST: testDetachFromParent 96924685aaeSAlex Zinenko@run 97024685aaeSAlex Zinenkodef testDetachFromParent(): 97124685aaeSAlex Zinenko with Context(): 972a8308020SRiver Riddle m1 = Module.parse("func.func private @foo()") 97324685aaeSAlex Zinenko func = m1.body.operations[0].detach_from_parent() 97424685aaeSAlex Zinenko 97524685aaeSAlex Zinenko try: 97624685aaeSAlex Zinenko func.detach_from_parent() 97724685aaeSAlex Zinenko except ValueError as e: 97824685aaeSAlex Zinenko if "has no parent" not in str(e): 97924685aaeSAlex Zinenko raise 98024685aaeSAlex Zinenko else: 98124685aaeSAlex Zinenko assert False, "expected ValueError when detaching a detached operation" 98224685aaeSAlex Zinenko 98324685aaeSAlex Zinenko print(m1) 98424685aaeSAlex Zinenko # CHECK-NOT: func private @foo 98530d61893SAlex Zinenko 98630d61893SAlex Zinenko 987f78fe0b7Srkayaith# CHECK-LABEL: TEST: testOperationHash 988f78fe0b7Srkayaith@run 989f78fe0b7Srkayaithdef testOperationHash(): 990f78fe0b7Srkayaith ctx = Context() 991f78fe0b7Srkayaith ctx.allow_unregistered_dialects = True 992f78fe0b7Srkayaith with ctx, Location.unknown(): 993f78fe0b7Srkayaith op = Operation.create("custom.op1") 994f78fe0b7Srkayaith assert hash(op) == hash(op.operation) 99537107e17Srkayaith 99637107e17Srkayaith 99737107e17Srkayaith# CHECK-LABEL: TEST: testOperationParse 99837107e17Srkayaith@run 99937107e17Srkayaithdef testOperationParse(): 100037107e17Srkayaith with Context() as ctx: 100137107e17Srkayaith ctx.allow_unregistered_dialects = True 100237107e17Srkayaith 100337107e17Srkayaith # Generic operation parsing. 1004f9008e63STobias Hieta m = Operation.parse("module {}") 100537107e17Srkayaith o = Operation.parse('"test.foo"() : () -> ()') 100637107e17Srkayaith assert isinstance(m, ModuleOp) 100737107e17Srkayaith assert type(o) is OpView 100837107e17Srkayaith 100937107e17Srkayaith # Parsing specific operation. 1010f9008e63STobias Hieta m = ModuleOp.parse("module {}") 101137107e17Srkayaith assert isinstance(m, ModuleOp) 101237107e17Srkayaith try: 101337107e17Srkayaith ModuleOp.parse('"test.foo"() : () -> ()') 10143ea4c501SRahul Kayaith except MLIRError as e: 101537107e17Srkayaith # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' 101637107e17Srkayaith print(f"error: {e}") 101737107e17Srkayaith else: 101837107e17Srkayaith assert False, "expected error" 101937107e17Srkayaith 102037107e17Srkayaith o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string") 102137107e17Srkayaith # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1) 1022f9008e63STobias Hieta print( 1023f9008e63STobias Hieta f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}" 1024f9008e63STobias Hieta ) 102547148832SHideto Ueno 102647148832SHideto Ueno 102747148832SHideto Ueno# CHECK-LABEL: TEST: testOpWalk 102847148832SHideto Ueno@run 102947148832SHideto Uenodef testOpWalk(): 103047148832SHideto Ueno ctx = Context() 103147148832SHideto Ueno ctx.allow_unregistered_dialects = True 103247148832SHideto Ueno module = Module.parse( 103347148832SHideto Ueno r""" 103447148832SHideto Ueno builtin.module { 103547148832SHideto Ueno func.func @f() { 103647148832SHideto Ueno func.return 103747148832SHideto Ueno } 103847148832SHideto Ueno } 103947148832SHideto Ueno """, 104047148832SHideto Ueno ctx, 104147148832SHideto Ueno ) 104247148832SHideto Ueno 104347148832SHideto Ueno def callback(op): 104447148832SHideto Ueno print(op.name) 104547148832SHideto Ueno return WalkResult.ADVANCE 104647148832SHideto Ueno 104747148832SHideto Ueno # Test post-order walk (default). 104847148832SHideto Ueno # CHECK-NEXT: Post-order 104947148832SHideto Ueno # CHECK-NEXT: func.return 105047148832SHideto Ueno # CHECK-NEXT: func.func 105147148832SHideto Ueno # CHECK-NEXT: builtin.module 105247148832SHideto Ueno print("Post-order") 105347148832SHideto Ueno module.operation.walk(callback) 105447148832SHideto Ueno 105547148832SHideto Ueno # Test pre-order walk. 105647148832SHideto Ueno # CHECK-NEXT: Pre-order 105747148832SHideto Ueno # CHECK-NEXT: builtin.module 105847148832SHideto Ueno # CHECK-NEXT: func.fun 105947148832SHideto Ueno # CHECK-NEXT: func.return 106047148832SHideto Ueno print("Pre-order") 106147148832SHideto Ueno module.operation.walk(callback, WalkOrder.PRE_ORDER) 106247148832SHideto Ueno 106347148832SHideto Ueno # Test interrput. 106447148832SHideto Ueno # CHECK-NEXT: Interrupt post-order 106547148832SHideto Ueno # CHECK-NEXT: func.return 106647148832SHideto Ueno print("Interrupt post-order") 106747148832SHideto Ueno 106847148832SHideto Ueno def callback(op): 106947148832SHideto Ueno print(op.name) 107047148832SHideto Ueno return WalkResult.INTERRUPT 107147148832SHideto Ueno 107247148832SHideto Ueno module.operation.walk(callback) 107347148832SHideto Ueno 107447148832SHideto Ueno # Test skip. 107547148832SHideto Ueno # CHECK-NEXT: Skip pre-order 107647148832SHideto Ueno # CHECK-NEXT: builtin.module 107747148832SHideto Ueno print("Skip pre-order") 107847148832SHideto Ueno 107947148832SHideto Ueno def callback(op): 108047148832SHideto Ueno print(op.name) 108147148832SHideto Ueno return WalkResult.SKIP 108247148832SHideto Ueno 108347148832SHideto Ueno module.operation.walk(callback, WalkOrder.PRE_ORDER) 108447148832SHideto Ueno 108547148832SHideto Ueno # Test exception. 108647148832SHideto Ueno # CHECK: Exception 108747148832SHideto Ueno # CHECK-NEXT: func.return 108847148832SHideto Ueno # CHECK-NEXT: Exception raised 108947148832SHideto Ueno print("Exception") 109047148832SHideto Ueno 109147148832SHideto Ueno def callback(op): 109247148832SHideto Ueno print(op.name) 109347148832SHideto Ueno raise ValueError 109447148832SHideto Ueno return WalkResult.ADVANCE 109547148832SHideto Ueno 109647148832SHideto Ueno try: 109747148832SHideto Ueno module.operation.walk(callback) 1098bc553646Stomnatan30 except RuntimeError: 109947148832SHideto Ueno print("Exception raised") 1100