xref: /llvm-project/mlir/test/python/ir/insertion_point.py (revision 5192e299cf444040025ccf3e75bfad36b4624050)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6
7def run(f):
8    print("\nTEST:", f.__name__)
9    f()
10    gc.collect()
11    assert Context._get_live_count() == 0
12
13
14# CHECK-LABEL: TEST: test_insert_at_block_end
15def test_insert_at_block_end():
16    ctx = Context()
17    ctx.allow_unregistered_dialects = True
18    with Location.unknown(ctx):
19        module = Module.parse(
20            r"""
21      func.func @foo() -> () {
22        "custom.op1"() : () -> ()
23      }
24    """
25        )
26        entry_block = module.body.operations[0].regions[0].blocks[0]
27        ip = InsertionPoint(entry_block)
28        assert ip.block == entry_block
29        assert ip.ref_operation is None
30        ip.insert(Operation.create("custom.op2"))
31        # CHECK: "custom.op1"
32        # CHECK: "custom.op2"
33        module.operation.print()
34
35
36run(test_insert_at_block_end)
37
38
39# CHECK-LABEL: TEST: test_insert_before_operation
40def test_insert_before_operation():
41    ctx = Context()
42    ctx.allow_unregistered_dialects = True
43    with Location.unknown(ctx):
44        module = Module.parse(
45            r"""
46      func.func @foo() -> () {
47        "custom.op1"() : () -> ()
48        "custom.op2"() : () -> ()
49      }
50    """
51        )
52        entry_block = module.body.operations[0].regions[0].blocks[0]
53        ip = InsertionPoint(entry_block.operations[1])
54        assert ip.block == entry_block
55        assert ip.ref_operation == entry_block.operations[1]
56        ip.insert(Operation.create("custom.op3"))
57        # CHECK: "custom.op1"
58        # CHECK: "custom.op3"
59        # CHECK: "custom.op2"
60        module.operation.print()
61
62
63run(test_insert_before_operation)
64
65
66# CHECK-LABEL: TEST: test_insert_at_block_begin
67def test_insert_at_block_begin():
68    ctx = Context()
69    ctx.allow_unregistered_dialects = True
70    with Location.unknown(ctx):
71        module = Module.parse(
72            r"""
73      func.func @foo() -> () {
74        "custom.op2"() : () -> ()
75      }
76    """
77        )
78        entry_block = module.body.operations[0].regions[0].blocks[0]
79        ip = InsertionPoint.at_block_begin(entry_block)
80        assert ip.block == entry_block
81        assert ip.ref_operation == entry_block.operations[0]
82        ip.insert(Operation.create("custom.op1"))
83        # CHECK: "custom.op1"
84        # CHECK: "custom.op2"
85        module.operation.print()
86
87
88run(test_insert_at_block_begin)
89
90
91# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
92def test_insert_at_block_begin_empty():
93    # TODO: Write this test case when we can create such a situation.
94    pass
95
96
97run(test_insert_at_block_begin_empty)
98
99
100# CHECK-LABEL: TEST: test_insert_at_terminator
101def test_insert_at_terminator():
102    ctx = Context()
103    ctx.allow_unregistered_dialects = True
104    with Location.unknown(ctx):
105        module = Module.parse(
106            r"""
107      func.func @foo() -> () {
108        "custom.op1"() : () -> ()
109        return
110      }
111    """
112        )
113        entry_block = module.body.operations[0].regions[0].blocks[0]
114        ip = InsertionPoint.at_block_terminator(entry_block)
115        assert ip.block == entry_block
116        assert ip.ref_operation == entry_block.operations[1]
117        ip.insert(Operation.create("custom.op2"))
118        # CHECK: "custom.op1"
119        # CHECK: "custom.op2"
120        module.operation.print()
121
122
123run(test_insert_at_terminator)
124
125
126# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
127def test_insert_at_block_terminator_missing():
128    ctx = Context()
129    ctx.allow_unregistered_dialects = True
130    with ctx:
131        module = Module.parse(
132            r"""
133      func.func @foo() -> () {
134        "custom.op1"() : () -> ()
135      }
136    """
137        )
138        entry_block = module.body.operations[0].regions[0].blocks[0]
139        try:
140            ip = InsertionPoint.at_block_terminator(entry_block)
141        except ValueError as e:
142            # CHECK: Block has no terminator
143            print(e)
144        else:
145            assert False, "Expected exception"
146
147
148run(test_insert_at_block_terminator_missing)
149
150
151# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
152def test_insert_at_end_with_terminator_errors():
153    with Context() as ctx, Location.unknown():
154        ctx.allow_unregistered_dialects = True
155        module = Module.parse(
156            r"""
157      func.func @foo() -> () {
158        return
159      }
160    """
161        )
162        entry_block = module.body.operations[0].regions[0].blocks[0]
163        with InsertionPoint(entry_block):
164            try:
165                Operation.create("custom.op1", results=[], operands=[])
166            except IndexError as e:
167                # CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
168                print(f"ERROR: {e}")
169
170
171run(test_insert_at_end_with_terminator_errors)
172
173
174# CHECK-LABEL: TEST: test_insertion_point_context
175def test_insertion_point_context():
176    ctx = Context()
177    ctx.allow_unregistered_dialects = True
178    with Location.unknown(ctx):
179        module = Module.parse(
180            r"""
181      func.func @foo() -> () {
182        "custom.op1"() : () -> ()
183      }
184    """
185        )
186        entry_block = module.body.operations[0].regions[0].blocks[0]
187        with InsertionPoint(entry_block):
188            Operation.create("custom.op2")
189            with InsertionPoint.at_block_begin(entry_block):
190                Operation.create("custom.opa")
191                Operation.create("custom.opb")
192            Operation.create("custom.op3")
193        # CHECK: "custom.opa"
194        # CHECK: "custom.opb"
195        # CHECK: "custom.op1"
196        # CHECK: "custom.op2"
197        # CHECK: "custom.op3"
198        module.operation.print()
199
200
201run(test_insertion_point_context)
202