19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s 29f3f6d7bSStella Laurenzo 39f3f6d7bSStella Laurenzoimport gc 49f3f6d7bSStella Laurenzofrom mlir.ir import * 59f3f6d7bSStella Laurenzo 6f9008e63STobias Hieta 79f3f6d7bSStella Laurenzodef run(f): 89f3f6d7bSStella Laurenzo print("\nTEST:", f.__name__) 99f3f6d7bSStella Laurenzo f() 109f3f6d7bSStella Laurenzo gc.collect() 119f3f6d7bSStella Laurenzo assert Context._get_live_count() == 0 129f3f6d7bSStella Laurenzo 139f3f6d7bSStella Laurenzo 149f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_end 159f3f6d7bSStella Laurenzodef test_insert_at_block_end(): 169f3f6d7bSStella Laurenzo ctx = Context() 179f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 189f3f6d7bSStella Laurenzo with Location.unknown(ctx): 19f9008e63STobias Hieta module = Module.parse( 20f9008e63STobias Hieta r""" 212310ced8SRiver Riddle func.func @foo() -> () { 229f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 239f3f6d7bSStella Laurenzo } 24f9008e63STobias Hieta """ 25f9008e63STobias Hieta ) 269f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 279f3f6d7bSStella Laurenzo ip = InsertionPoint(entry_block) 28*5a600c23STomás Longeri assert ip.block == entry_block 29*5a600c23STomás Longeri assert ip.ref_operation is None 309f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op2")) 319f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 329f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 339f3f6d7bSStella Laurenzo module.operation.print() 349f3f6d7bSStella Laurenzo 35f9008e63STobias Hieta 369f3f6d7bSStella Laurenzorun(test_insert_at_block_end) 379f3f6d7bSStella Laurenzo 389f3f6d7bSStella Laurenzo 399f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_before_operation 409f3f6d7bSStella Laurenzodef test_insert_before_operation(): 419f3f6d7bSStella Laurenzo ctx = Context() 429f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 439f3f6d7bSStella Laurenzo with Location.unknown(ctx): 44f9008e63STobias Hieta module = Module.parse( 45f9008e63STobias Hieta r""" 462310ced8SRiver Riddle func.func @foo() -> () { 479f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 489f3f6d7bSStella Laurenzo "custom.op2"() : () -> () 499f3f6d7bSStella Laurenzo } 50f9008e63STobias Hieta """ 51f9008e63STobias Hieta ) 529f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 539f3f6d7bSStella Laurenzo ip = InsertionPoint(entry_block.operations[1]) 54*5a600c23STomás Longeri assert ip.block == entry_block 55*5a600c23STomás Longeri assert ip.ref_operation == entry_block.operations[1] 569f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op3")) 579f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 589f3f6d7bSStella Laurenzo # CHECK: "custom.op3" 599f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 609f3f6d7bSStella Laurenzo module.operation.print() 619f3f6d7bSStella Laurenzo 62f9008e63STobias Hieta 639f3f6d7bSStella Laurenzorun(test_insert_before_operation) 649f3f6d7bSStella Laurenzo 659f3f6d7bSStella Laurenzo 669f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin 679f3f6d7bSStella Laurenzodef test_insert_at_block_begin(): 689f3f6d7bSStella Laurenzo ctx = Context() 699f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 709f3f6d7bSStella Laurenzo with Location.unknown(ctx): 71f9008e63STobias Hieta module = Module.parse( 72f9008e63STobias Hieta r""" 732310ced8SRiver Riddle func.func @foo() -> () { 749f3f6d7bSStella Laurenzo "custom.op2"() : () -> () 759f3f6d7bSStella Laurenzo } 76f9008e63STobias Hieta """ 77f9008e63STobias Hieta ) 789f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 799f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_begin(entry_block) 80*5a600c23STomás Longeri assert ip.block == entry_block 81*5a600c23STomás Longeri assert ip.ref_operation == entry_block.operations[0] 829f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op1")) 839f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 849f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 859f3f6d7bSStella Laurenzo module.operation.print() 869f3f6d7bSStella Laurenzo 87f9008e63STobias Hieta 889f3f6d7bSStella Laurenzorun(test_insert_at_block_begin) 899f3f6d7bSStella Laurenzo 909f3f6d7bSStella Laurenzo 919f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin_empty 929f3f6d7bSStella Laurenzodef test_insert_at_block_begin_empty(): 939f3f6d7bSStella Laurenzo # TODO: Write this test case when we can create such a situation. 949f3f6d7bSStella Laurenzo pass 959f3f6d7bSStella Laurenzo 96f9008e63STobias Hieta 979f3f6d7bSStella Laurenzorun(test_insert_at_block_begin_empty) 989f3f6d7bSStella Laurenzo 999f3f6d7bSStella Laurenzo 1009f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_terminator 1019f3f6d7bSStella Laurenzodef test_insert_at_terminator(): 1029f3f6d7bSStella Laurenzo ctx = Context() 1039f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1049f3f6d7bSStella Laurenzo with Location.unknown(ctx): 105f9008e63STobias Hieta module = Module.parse( 106f9008e63STobias Hieta r""" 1072310ced8SRiver Riddle func.func @foo() -> () { 1089f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 1099f3f6d7bSStella Laurenzo return 1109f3f6d7bSStella Laurenzo } 111f9008e63STobias Hieta """ 112f9008e63STobias Hieta ) 1139f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1149f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_terminator(entry_block) 115*5a600c23STomás Longeri assert ip.block == entry_block 116*5a600c23STomás Longeri assert ip.ref_operation == entry_block.operations[1] 1179f3f6d7bSStella Laurenzo ip.insert(Operation.create("custom.op2")) 1189f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 1199f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 1209f3f6d7bSStella Laurenzo module.operation.print() 1219f3f6d7bSStella Laurenzo 122f9008e63STobias Hieta 1239f3f6d7bSStella Laurenzorun(test_insert_at_terminator) 1249f3f6d7bSStella Laurenzo 1259f3f6d7bSStella Laurenzo 1269f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing 1279f3f6d7bSStella Laurenzodef test_insert_at_block_terminator_missing(): 1289f3f6d7bSStella Laurenzo ctx = Context() 1299f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1309f3f6d7bSStella Laurenzo with ctx: 131f9008e63STobias Hieta module = Module.parse( 132f9008e63STobias Hieta r""" 1332310ced8SRiver Riddle func.func @foo() -> () { 1349f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 1359f3f6d7bSStella Laurenzo } 136f9008e63STobias Hieta """ 137f9008e63STobias Hieta ) 1389f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1399f3f6d7bSStella Laurenzo try: 1409f3f6d7bSStella Laurenzo ip = InsertionPoint.at_block_terminator(entry_block) 1419f3f6d7bSStella Laurenzo except ValueError as e: 1429f3f6d7bSStella Laurenzo # CHECK: Block has no terminator 1439f3f6d7bSStella Laurenzo print(e) 1449f3f6d7bSStella Laurenzo else: 1459f3f6d7bSStella Laurenzo assert False, "Expected exception" 1469f3f6d7bSStella Laurenzo 147f9008e63STobias Hieta 1489f3f6d7bSStella Laurenzorun(test_insert_at_block_terminator_missing) 1499f3f6d7bSStella Laurenzo 1509f3f6d7bSStella Laurenzo 1519f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors 1529f3f6d7bSStella Laurenzodef test_insert_at_end_with_terminator_errors(): 1539f3f6d7bSStella Laurenzo with Context() as ctx, Location.unknown(): 1549f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 155f9008e63STobias Hieta module = Module.parse( 156f9008e63STobias Hieta r""" 1572310ced8SRiver Riddle func.func @foo() -> () { 1589f3f6d7bSStella Laurenzo return 1599f3f6d7bSStella Laurenzo } 160f9008e63STobias Hieta """ 161f9008e63STobias Hieta ) 1629f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1639f3f6d7bSStella Laurenzo with InsertionPoint(entry_block): 1649f3f6d7bSStella Laurenzo try: 1659f3f6d7bSStella Laurenzo Operation.create("custom.op1", results=[], operands=[]) 1669f3f6d7bSStella Laurenzo except IndexError as e: 1679f3f6d7bSStella Laurenzo # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator. 1689f3f6d7bSStella Laurenzo print(f"ERROR: {e}") 1699f3f6d7bSStella Laurenzo 170f9008e63STobias Hieta 1719f3f6d7bSStella Laurenzorun(test_insert_at_end_with_terminator_errors) 1729f3f6d7bSStella Laurenzo 1739f3f6d7bSStella Laurenzo 1749f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insertion_point_context 1759f3f6d7bSStella Laurenzodef test_insertion_point_context(): 1769f3f6d7bSStella Laurenzo ctx = Context() 1779f3f6d7bSStella Laurenzo ctx.allow_unregistered_dialects = True 1789f3f6d7bSStella Laurenzo with Location.unknown(ctx): 179f9008e63STobias Hieta module = Module.parse( 180f9008e63STobias Hieta r""" 1812310ced8SRiver Riddle func.func @foo() -> () { 1829f3f6d7bSStella Laurenzo "custom.op1"() : () -> () 1839f3f6d7bSStella Laurenzo } 184f9008e63STobias Hieta """ 185f9008e63STobias Hieta ) 1869f3f6d7bSStella Laurenzo entry_block = module.body.operations[0].regions[0].blocks[0] 1879f3f6d7bSStella Laurenzo with InsertionPoint(entry_block): 1889f3f6d7bSStella Laurenzo Operation.create("custom.op2") 1899f3f6d7bSStella Laurenzo with InsertionPoint.at_block_begin(entry_block): 1909f3f6d7bSStella Laurenzo Operation.create("custom.opa") 1919f3f6d7bSStella Laurenzo Operation.create("custom.opb") 1929f3f6d7bSStella Laurenzo Operation.create("custom.op3") 1939f3f6d7bSStella Laurenzo # CHECK: "custom.opa" 1949f3f6d7bSStella Laurenzo # CHECK: "custom.opb" 1959f3f6d7bSStella Laurenzo # CHECK: "custom.op1" 1969f3f6d7bSStella Laurenzo # CHECK: "custom.op2" 1979f3f6d7bSStella Laurenzo # CHECK: "custom.op3" 1989f3f6d7bSStella Laurenzo module.operation.print() 1999f3f6d7bSStella Laurenzo 200f9008e63STobias Hieta 2019f3f6d7bSStella Laurenzorun(test_insertion_point_context) 202