1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6 7def run(f): 8 print("\nTEST:", f.__name__) 9 f() 10 gc.collect() 11 assert Context._get_live_count() == 0 12 13 14# CHECK-LABEL: TEST: test_insert_at_block_end 15def test_insert_at_block_end(): 16 ctx = Context() 17 ctx.allow_unregistered_dialects = True 18 with Location.unknown(ctx): 19 module = Module.parse( 20 r""" 21 func.func @foo() -> () { 22 "custom.op1"() : () -> () 23 } 24 """ 25 ) 26 entry_block = module.body.operations[0].regions[0].blocks[0] 27 ip = InsertionPoint(entry_block) 28 assert ip.block == entry_block 29 assert ip.ref_operation is None 30 ip.insert(Operation.create("custom.op2")) 31 # CHECK: "custom.op1" 32 # CHECK: "custom.op2" 33 module.operation.print() 34 35 36run(test_insert_at_block_end) 37 38 39# CHECK-LABEL: TEST: test_insert_before_operation 40def test_insert_before_operation(): 41 ctx = Context() 42 ctx.allow_unregistered_dialects = True 43 with Location.unknown(ctx): 44 module = Module.parse( 45 r""" 46 func.func @foo() -> () { 47 "custom.op1"() : () -> () 48 "custom.op2"() : () -> () 49 } 50 """ 51 ) 52 entry_block = module.body.operations[0].regions[0].blocks[0] 53 ip = InsertionPoint(entry_block.operations[1]) 54 assert ip.block == entry_block 55 assert ip.ref_operation == entry_block.operations[1] 56 ip.insert(Operation.create("custom.op3")) 57 # CHECK: "custom.op1" 58 # CHECK: "custom.op3" 59 # CHECK: "custom.op2" 60 module.operation.print() 61 62 63run(test_insert_before_operation) 64 65 66# CHECK-LABEL: TEST: test_insert_at_block_begin 67def test_insert_at_block_begin(): 68 ctx = Context() 69 ctx.allow_unregistered_dialects = True 70 with Location.unknown(ctx): 71 module = Module.parse( 72 r""" 73 func.func @foo() -> () { 74 "custom.op2"() : () -> () 75 } 76 """ 77 ) 78 entry_block = module.body.operations[0].regions[0].blocks[0] 79 ip = InsertionPoint.at_block_begin(entry_block) 80 assert ip.block == entry_block 81 assert ip.ref_operation == entry_block.operations[0] 82 ip.insert(Operation.create("custom.op1")) 83 # CHECK: "custom.op1" 84 # CHECK: "custom.op2" 85 module.operation.print() 86 87 88run(test_insert_at_block_begin) 89 90 91# CHECK-LABEL: TEST: test_insert_at_block_begin_empty 92def test_insert_at_block_begin_empty(): 93 # TODO: Write this test case when we can create such a situation. 94 pass 95 96 97run(test_insert_at_block_begin_empty) 98 99 100# CHECK-LABEL: TEST: test_insert_at_terminator 101def test_insert_at_terminator(): 102 ctx = Context() 103 ctx.allow_unregistered_dialects = True 104 with Location.unknown(ctx): 105 module = Module.parse( 106 r""" 107 func.func @foo() -> () { 108 "custom.op1"() : () -> () 109 return 110 } 111 """ 112 ) 113 entry_block = module.body.operations[0].regions[0].blocks[0] 114 ip = InsertionPoint.at_block_terminator(entry_block) 115 assert ip.block == entry_block 116 assert ip.ref_operation == entry_block.operations[1] 117 ip.insert(Operation.create("custom.op2")) 118 # CHECK: "custom.op1" 119 # CHECK: "custom.op2" 120 module.operation.print() 121 122 123run(test_insert_at_terminator) 124 125 126# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing 127def test_insert_at_block_terminator_missing(): 128 ctx = Context() 129 ctx.allow_unregistered_dialects = True 130 with ctx: 131 module = Module.parse( 132 r""" 133 func.func @foo() -> () { 134 "custom.op1"() : () -> () 135 } 136 """ 137 ) 138 entry_block = module.body.operations[0].regions[0].blocks[0] 139 try: 140 ip = InsertionPoint.at_block_terminator(entry_block) 141 except ValueError as e: 142 # CHECK: Block has no terminator 143 print(e) 144 else: 145 assert False, "Expected exception" 146 147 148run(test_insert_at_block_terminator_missing) 149 150 151# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors 152def test_insert_at_end_with_terminator_errors(): 153 with Context() as ctx, Location.unknown(): 154 ctx.allow_unregistered_dialects = True 155 module = Module.parse( 156 r""" 157 func.func @foo() -> () { 158 return 159 } 160 """ 161 ) 162 entry_block = module.body.operations[0].regions[0].blocks[0] 163 with InsertionPoint(entry_block): 164 try: 165 Operation.create("custom.op1", results=[], operands=[]) 166 except IndexError as e: 167 # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. 168 print(f"ERROR: {e}") 169 170 171run(test_insert_at_end_with_terminator_errors) 172 173 174# CHECK-LABEL: TEST: test_insertion_point_context 175def test_insertion_point_context(): 176 ctx = Context() 177 ctx.allow_unregistered_dialects = True 178 with Location.unknown(ctx): 179 module = Module.parse( 180 r""" 181 func.func @foo() -> () { 182 "custom.op1"() : () -> () 183 } 184 """ 185 ) 186 entry_block = module.body.operations[0].regions[0].blocks[0] 187 with InsertionPoint(entry_block): 188 Operation.create("custom.op2") 189 with InsertionPoint.at_block_begin(entry_block): 190 Operation.create("custom.opa") 191 Operation.create("custom.opb") 192 Operation.create("custom.op3") 193 # CHECK: "custom.opa" 194 # CHECK: "custom.opb" 195 # CHECK: "custom.op1" 196 # CHECK: "custom.op2" 197 # CHECK: "custom.op3" 198 module.operation.print() 199 200 201run(test_insertion_point_context) 202