1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4import io 5import itertools 6from mlir.ir import * 7from mlir.dialects import builtin 8from mlir.dialects import cf 9from mlir.dialects import func 10 11 12def run(f): 13 print("\nTEST:", f.__name__) 14 f() 15 gc.collect() 16 assert Context._get_live_count() == 0 17 return f 18 19 20# CHECK-LABEL: TEST: testBlockCreation 21# CHECK: func @test(%[[ARG0:.*]]: i32 loc("arg0"), %[[ARG1:.*]]: i16 loc("arg1")) 22# CHECK: cf.br ^bb1(%[[ARG1]] : i16) 23# CHECK: ^bb1(%[[PHI0:.*]]: i16 loc("middle")): 24# CHECK: cf.br ^bb2(%[[ARG0]] : i32) 25# CHECK: ^bb2(%[[PHI1:.*]]: i32 loc("successor")): 26# CHECK: return 27@run 28def testBlockCreation(): 29 with Context() as ctx, Location.unknown(): 30 module = builtin.ModuleOp() 31 with InsertionPoint(module.body): 32 f_type = FunctionType.get( 33 [IntegerType.get_signless(32), IntegerType.get_signless(16)], [] 34 ) 35 f_op = func.FuncOp("test", f_type) 36 entry_block = f_op.add_entry_block( 37 [Location.name("arg0"), Location.name("arg1")] 38 ) 39 i32_arg, i16_arg = entry_block.arguments 40 successor_block = entry_block.create_after( 41 i32_arg.type, arg_locs=[Location.name("successor")] 42 ) 43 with InsertionPoint(successor_block) as successor_ip: 44 assert successor_ip.block == successor_block 45 func.ReturnOp([]) 46 middle_block = successor_block.create_before( 47 i16_arg.type, arg_locs=[Location.name("middle")] 48 ) 49 50 with InsertionPoint(entry_block) as entry_ip: 51 assert entry_ip.block == entry_block 52 cf.BranchOp([i16_arg], dest=middle_block) 53 54 with InsertionPoint(middle_block) as middle_ip: 55 assert middle_ip.block == middle_block 56 cf.BranchOp([i32_arg], dest=successor_block) 57 module.print(enable_debug_info=True) 58 # Ensure region back references are coherent. 59 assert entry_block.region == middle_block.region == successor_block.region 60 61 62# CHECK-LABEL: TEST: testBlockCreationArgLocs 63@run 64def testBlockCreationArgLocs(): 65 with Context() as ctx: 66 ctx.allow_unregistered_dialects = True 67 f32 = F32Type.get() 68 op = Operation.create("test", regions=1, loc=Location.unknown()) 69 blocks = op.regions[0].blocks 70 71 with Location.name("default_loc"): 72 blocks.append(f32) 73 blocks.append() 74 # CHECK: ^bb0(%{{.+}}: f32 loc("default_loc")): 75 # CHECK-NEXT: ^bb1: 76 op.print(enable_debug_info=True) 77 78 try: 79 blocks.append(f32) 80 except RuntimeError as err: 81 # CHECK: Missing loc: An MLIR function requires a Location but none was provided 82 print("Missing loc:", err) 83 84 try: 85 blocks.append(f32, f32, arg_locs=[Location.unknown()]) 86 except ValueError as err: 87 # CHECK: Wrong loc count: Expected 2 locations, got: 1 88 print("Wrong loc count:", err) 89 90 91# CHECK-LABEL: TEST: testFirstBlockCreation 92# CHECK: func @test(%{{.*}}: f32 loc("arg_loc")) 93# CHECK: return 94@run 95def testFirstBlockCreation(): 96 with Context() as ctx, Location.unknown(): 97 module = builtin.ModuleOp() 98 f32 = F32Type.get() 99 with InsertionPoint(module.body): 100 f = func.FuncOp("test", ([f32], [])) 101 entry_block = Block.create_at_start( 102 f.operation.regions[0], [f32], [Location.name("arg_loc")] 103 ) 104 with InsertionPoint(entry_block): 105 func.ReturnOp([]) 106 107 module.print(enable_debug_info=True) 108 assert module.verify() 109 assert f.body.blocks[0] == entry_block 110 111 112# CHECK-LABEL: TEST: testBlockMove 113# CHECK: %0 = "realop"() ({ 114# CHECK: ^bb0([[ARG0:%.+]]: f32): 115# CHECK: "ret"([[ARG0]]) : (f32) -> () 116# CHECK: }) : () -> f32 117@run 118def testBlockMove(): 119 with Context() as ctx, Location.unknown(): 120 ctx.allow_unregistered_dialects = True 121 module = Module.create() 122 f32 = F32Type.get() 123 with InsertionPoint(module.body): 124 dummy = Operation.create("dummy", regions=1) 125 block = Block.create_at_start(dummy.operation.regions[0], [f32]) 126 with InsertionPoint(block): 127 ret_op = Operation.create("ret", operands=[block.arguments[0]]) 128 realop = Operation.create( 129 "realop", results=[r.type for r in ret_op.operands], regions=1 130 ) 131 block.append_to(realop.operation.regions[0]) 132 dummy.operation.erase() 133 print(module) 134 135 136# CHECK-LABEL: TEST: testBlockHash 137@run 138def testBlockHash(): 139 with Context() as ctx, Location.unknown(): 140 ctx.allow_unregistered_dialects = True 141 module = Module.create() 142 f32 = F32Type.get() 143 with InsertionPoint(module.body): 144 dummy = Operation.create("dummy", regions=1) 145 block1 = Block.create_at_start(dummy.operation.regions[0], [f32]) 146 block2 = Block.create_at_start(dummy.operation.regions[0], [f32]) 147 assert hash(block1) != hash(block2) 148 149 150# CHECK-LABEL: TEST: testBlockAddArgs 151@run 152def testBlockAddArgs(): 153 with Context() as ctx, Location.unknown(ctx) as loc: 154 ctx.allow_unregistered_dialects = True 155 f32 = F32Type.get() 156 op = Operation.create("test", regions=1, loc=Location.unknown()) 157 blocks = op.regions[0].blocks 158 blocks.append() 159 # CHECK: ^bb0: 160 op.print(enable_debug_info=True) 161 blocks[0].add_argument(f32, loc) 162 # CHECK: ^bb0(%{{.+}}: f32 loc(unknown)): 163 op.print(enable_debug_info=True) 164 165 166# CHECK-LABEL: TEST: testBlockEraseArgs 167@run 168def testBlockEraseArgs(): 169 with Context() as ctx, Location.unknown(ctx) as loc: 170 ctx.allow_unregistered_dialects = True 171 f32 = F32Type.get() 172 op = Operation.create("test", regions=1, loc=Location.unknown()) 173 blocks = op.regions[0].blocks 174 blocks.append(f32) 175 # CHECK: ^bb0(%{{.+}}: f32 loc(unknown)): 176 op.print(enable_debug_info=True) 177 blocks[0].erase_argument(0) 178 # CHECK: ^bb0: 179 op.print(enable_debug_info=True) 180