xref: /llvm-project/mlir/test/python/ir/insertion_point.py (revision 5192e299cf444040025ccf3e75bfad36b4624050)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc
49f3f6d7bSStella Laurenzofrom mlir.ir import *
59f3f6d7bSStella Laurenzo
6f9008e63STobias Hieta
79f3f6d7bSStella Laurenzodef run(f):
89f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
99f3f6d7bSStella Laurenzo    f()
109f3f6d7bSStella Laurenzo    gc.collect()
119f3f6d7bSStella Laurenzo    assert Context._get_live_count() == 0
129f3f6d7bSStella Laurenzo
139f3f6d7bSStella Laurenzo
149f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_end
159f3f6d7bSStella Laurenzodef test_insert_at_block_end():
169f3f6d7bSStella Laurenzo    ctx = Context()
179f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
189f3f6d7bSStella Laurenzo    with Location.unknown(ctx):
19f9008e63STobias Hieta        module = Module.parse(
20f9008e63STobias Hieta            r"""
212310ced8SRiver Riddle      func.func @foo() -> () {
229f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
239f3f6d7bSStella Laurenzo      }
24f9008e63STobias Hieta    """
25f9008e63STobias Hieta        )
269f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
279f3f6d7bSStella Laurenzo        ip = InsertionPoint(entry_block)
28*5a600c23STomás Longeri        assert ip.block == entry_block
29*5a600c23STomás Longeri        assert ip.ref_operation is None
309f3f6d7bSStella Laurenzo        ip.insert(Operation.create("custom.op2"))
319f3f6d7bSStella Laurenzo        # CHECK: "custom.op1"
329f3f6d7bSStella Laurenzo        # CHECK: "custom.op2"
339f3f6d7bSStella Laurenzo        module.operation.print()
349f3f6d7bSStella Laurenzo
35f9008e63STobias Hieta
369f3f6d7bSStella Laurenzorun(test_insert_at_block_end)
379f3f6d7bSStella Laurenzo
389f3f6d7bSStella Laurenzo
399f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_before_operation
409f3f6d7bSStella Laurenzodef test_insert_before_operation():
419f3f6d7bSStella Laurenzo    ctx = Context()
429f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
439f3f6d7bSStella Laurenzo    with Location.unknown(ctx):
44f9008e63STobias Hieta        module = Module.parse(
45f9008e63STobias Hieta            r"""
462310ced8SRiver Riddle      func.func @foo() -> () {
479f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
489f3f6d7bSStella Laurenzo        "custom.op2"() : () -> ()
499f3f6d7bSStella Laurenzo      }
50f9008e63STobias Hieta    """
51f9008e63STobias Hieta        )
529f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
539f3f6d7bSStella Laurenzo        ip = InsertionPoint(entry_block.operations[1])
54*5a600c23STomás Longeri        assert ip.block == entry_block
55*5a600c23STomás Longeri        assert ip.ref_operation == entry_block.operations[1]
569f3f6d7bSStella Laurenzo        ip.insert(Operation.create("custom.op3"))
579f3f6d7bSStella Laurenzo        # CHECK: "custom.op1"
589f3f6d7bSStella Laurenzo        # CHECK: "custom.op3"
599f3f6d7bSStella Laurenzo        # CHECK: "custom.op2"
609f3f6d7bSStella Laurenzo        module.operation.print()
619f3f6d7bSStella Laurenzo
62f9008e63STobias Hieta
639f3f6d7bSStella Laurenzorun(test_insert_before_operation)
649f3f6d7bSStella Laurenzo
659f3f6d7bSStella Laurenzo
669f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin
679f3f6d7bSStella Laurenzodef test_insert_at_block_begin():
689f3f6d7bSStella Laurenzo    ctx = Context()
699f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
709f3f6d7bSStella Laurenzo    with Location.unknown(ctx):
71f9008e63STobias Hieta        module = Module.parse(
72f9008e63STobias Hieta            r"""
732310ced8SRiver Riddle      func.func @foo() -> () {
749f3f6d7bSStella Laurenzo        "custom.op2"() : () -> ()
759f3f6d7bSStella Laurenzo      }
76f9008e63STobias Hieta    """
77f9008e63STobias Hieta        )
789f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
799f3f6d7bSStella Laurenzo        ip = InsertionPoint.at_block_begin(entry_block)
80*5a600c23STomás Longeri        assert ip.block == entry_block
81*5a600c23STomás Longeri        assert ip.ref_operation == entry_block.operations[0]
829f3f6d7bSStella Laurenzo        ip.insert(Operation.create("custom.op1"))
839f3f6d7bSStella Laurenzo        # CHECK: "custom.op1"
849f3f6d7bSStella Laurenzo        # CHECK: "custom.op2"
859f3f6d7bSStella Laurenzo        module.operation.print()
869f3f6d7bSStella Laurenzo
87f9008e63STobias Hieta
889f3f6d7bSStella Laurenzorun(test_insert_at_block_begin)
899f3f6d7bSStella Laurenzo
909f3f6d7bSStella Laurenzo
919f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
929f3f6d7bSStella Laurenzodef test_insert_at_block_begin_empty():
939f3f6d7bSStella Laurenzo    # TODO: Write this test case when we can create such a situation.
949f3f6d7bSStella Laurenzo    pass
959f3f6d7bSStella Laurenzo
96f9008e63STobias Hieta
979f3f6d7bSStella Laurenzorun(test_insert_at_block_begin_empty)
989f3f6d7bSStella Laurenzo
999f3f6d7bSStella Laurenzo
1009f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_terminator
1019f3f6d7bSStella Laurenzodef test_insert_at_terminator():
1029f3f6d7bSStella Laurenzo    ctx = Context()
1039f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
1049f3f6d7bSStella Laurenzo    with Location.unknown(ctx):
105f9008e63STobias Hieta        module = Module.parse(
106f9008e63STobias Hieta            r"""
1072310ced8SRiver Riddle      func.func @foo() -> () {
1089f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
1099f3f6d7bSStella Laurenzo        return
1109f3f6d7bSStella Laurenzo      }
111f9008e63STobias Hieta    """
112f9008e63STobias Hieta        )
1139f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
1149f3f6d7bSStella Laurenzo        ip = InsertionPoint.at_block_terminator(entry_block)
115*5a600c23STomás Longeri        assert ip.block == entry_block
116*5a600c23STomás Longeri        assert ip.ref_operation == entry_block.operations[1]
1179f3f6d7bSStella Laurenzo        ip.insert(Operation.create("custom.op2"))
1189f3f6d7bSStella Laurenzo        # CHECK: "custom.op1"
1199f3f6d7bSStella Laurenzo        # CHECK: "custom.op2"
1209f3f6d7bSStella Laurenzo        module.operation.print()
1219f3f6d7bSStella Laurenzo
122f9008e63STobias Hieta
1239f3f6d7bSStella Laurenzorun(test_insert_at_terminator)
1249f3f6d7bSStella Laurenzo
1259f3f6d7bSStella Laurenzo
1269f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
1279f3f6d7bSStella Laurenzodef test_insert_at_block_terminator_missing():
1289f3f6d7bSStella Laurenzo    ctx = Context()
1299f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
1309f3f6d7bSStella Laurenzo    with ctx:
131f9008e63STobias Hieta        module = Module.parse(
132f9008e63STobias Hieta            r"""
1332310ced8SRiver Riddle      func.func @foo() -> () {
1349f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
1359f3f6d7bSStella Laurenzo      }
136f9008e63STobias Hieta    """
137f9008e63STobias Hieta        )
1389f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
1399f3f6d7bSStella Laurenzo        try:
1409f3f6d7bSStella Laurenzo            ip = InsertionPoint.at_block_terminator(entry_block)
1419f3f6d7bSStella Laurenzo        except ValueError as e:
1429f3f6d7bSStella Laurenzo            # CHECK: Block has no terminator
1439f3f6d7bSStella Laurenzo            print(e)
1449f3f6d7bSStella Laurenzo        else:
1459f3f6d7bSStella Laurenzo            assert False, "Expected exception"
1469f3f6d7bSStella Laurenzo
147f9008e63STobias Hieta
1489f3f6d7bSStella Laurenzorun(test_insert_at_block_terminator_missing)
1499f3f6d7bSStella Laurenzo
1509f3f6d7bSStella Laurenzo
1519f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
1529f3f6d7bSStella Laurenzodef test_insert_at_end_with_terminator_errors():
1539f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
1549f3f6d7bSStella Laurenzo        ctx.allow_unregistered_dialects = True
155f9008e63STobias Hieta        module = Module.parse(
156f9008e63STobias Hieta            r"""
1572310ced8SRiver Riddle      func.func @foo() -> () {
1589f3f6d7bSStella Laurenzo        return
1599f3f6d7bSStella Laurenzo      }
160f9008e63STobias Hieta    """
161f9008e63STobias Hieta        )
1629f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
1639f3f6d7bSStella Laurenzo        with InsertionPoint(entry_block):
1649f3f6d7bSStella Laurenzo            try:
1659f3f6d7bSStella Laurenzo                Operation.create("custom.op1", results=[], operands=[])
1669f3f6d7bSStella Laurenzo            except IndexError as e:
1679f3f6d7bSStella Laurenzo                # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
1689f3f6d7bSStella Laurenzo                print(f"ERROR: {e}")
1699f3f6d7bSStella Laurenzo
170f9008e63STobias Hieta
1719f3f6d7bSStella Laurenzorun(test_insert_at_end_with_terminator_errors)
1729f3f6d7bSStella Laurenzo
1739f3f6d7bSStella Laurenzo
1749f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: test_insertion_point_context
1759f3f6d7bSStella Laurenzodef test_insertion_point_context():
1769f3f6d7bSStella Laurenzo    ctx = Context()
1779f3f6d7bSStella Laurenzo    ctx.allow_unregistered_dialects = True
1789f3f6d7bSStella Laurenzo    with Location.unknown(ctx):
179f9008e63STobias Hieta        module = Module.parse(
180f9008e63STobias Hieta            r"""
1812310ced8SRiver Riddle      func.func @foo() -> () {
1829f3f6d7bSStella Laurenzo        "custom.op1"() : () -> ()
1839f3f6d7bSStella Laurenzo      }
184f9008e63STobias Hieta    """
185f9008e63STobias Hieta        )
1869f3f6d7bSStella Laurenzo        entry_block = module.body.operations[0].regions[0].blocks[0]
1879f3f6d7bSStella Laurenzo        with InsertionPoint(entry_block):
1889f3f6d7bSStella Laurenzo            Operation.create("custom.op2")
1899f3f6d7bSStella Laurenzo            with InsertionPoint.at_block_begin(entry_block):
1909f3f6d7bSStella Laurenzo                Operation.create("custom.opa")
1919f3f6d7bSStella Laurenzo                Operation.create("custom.opb")
1929f3f6d7bSStella Laurenzo            Operation.create("custom.op3")
1939f3f6d7bSStella Laurenzo        # CHECK: "custom.opa"
1949f3f6d7bSStella Laurenzo        # CHECK: "custom.opb"
1959f3f6d7bSStella Laurenzo        # CHECK: "custom.op1"
1969f3f6d7bSStella Laurenzo        # CHECK: "custom.op2"
1979f3f6d7bSStella Laurenzo        # CHECK: "custom.op3"
1989f3f6d7bSStella Laurenzo        module.operation.print()
1999f3f6d7bSStella Laurenzo
200f9008e63STobias Hieta
2019f3f6d7bSStella Laurenzorun(test_insertion_point_context)
202