xref: /llvm-project/mlir/test/python/dialects/ods_helpers.py (revision 3766ba44a8945681f4c52acb0331efcff66ef7b1)
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4
5from mlir.ir import *
6from mlir.dialects._ods_common import equally_sized_accessor
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    f()
12    gc.collect()
13    assert Context._get_live_count() == 0
14
15
16def add_dummy_value():
17    return Operation.create(
18        "custom.value", results=[IntegerType.get_signless(32)]
19    ).result
20
21
22def testOdsBuildDefaultImplicitRegions():
23    class TestFixedRegionsOp(OpView):
24        OPERATION_NAME = "custom.test_op"
25        _ODS_REGIONS = (2, True)
26
27    class TestVariadicRegionsOp(OpView):
28        OPERATION_NAME = "custom.test_any_regions_op"
29        _ODS_REGIONS = (2, False)
30
31    with Context() as ctx, Location.unknown():
32        ctx.allow_unregistered_dialects = True
33        m = Module.create()
34        with InsertionPoint(m.body):
35            op = TestFixedRegionsOp.build_generic(results=[], operands=[])
36            # CHECK: NUM_REGIONS: 2
37            print(f"NUM_REGIONS: {len(op.regions)}")
38            # Including a regions= that matches should be fine.
39            op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
40            print(f"NUM_REGIONS: {len(op.regions)}")
41            # Reject greater than.
42            try:
43                op = TestFixedRegionsOp.build_generic(
44                    results=[], operands=[], regions=3
45                )
46            except ValueError as e:
47                # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
48                print(f"ERROR:{e}")
49            # Reject less than.
50            try:
51                op = TestFixedRegionsOp.build_generic(
52                    results=[], operands=[], regions=1
53                )
54            except ValueError as e:
55                # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
56                print(f"ERROR:{e}")
57
58            # If no regions specified for a variadic region op, build the minimum.
59            op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
60            # CHECK: DEFAULT_NUM_REGIONS: 2
61            print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
62            # Should also accept an explicit regions= that matches the minimum.
63            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2)
64            # CHECK: EQ_NUM_REGIONS: 2
65            print(f"EQ_NUM_REGIONS: {len(op.regions)}")
66            # And accept greater than minimum.
67            # Should also accept an explicit regions= that matches the minimum.
68            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3)
69            # CHECK: GT_NUM_REGIONS: 3
70            print(f"GT_NUM_REGIONS: {len(op.regions)}")
71            # Should reject less than minimum.
72            try:
73                op = TestVariadicRegionsOp.build_generic(
74                    results=[], operands=[], regions=1
75                )
76            except ValueError as e:
77                # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
78                print(f"ERROR:{e}")
79
80
81run(testOdsBuildDefaultImplicitRegions)
82
83
84def testOdsBuildDefaultNonVariadic():
85    class TestOp(OpView):
86        OPERATION_NAME = "custom.test_op"
87
88    with Context() as ctx, Location.unknown():
89        ctx.allow_unregistered_dialects = True
90        m = Module.create()
91        with InsertionPoint(m.body):
92            v0 = add_dummy_value()
93            v1 = add_dummy_value()
94            t0 = IntegerType.get_signless(8)
95            t1 = IntegerType.get_signless(16)
96            op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
97            # CHECK: %[[V0:.+]] = "custom.value"
98            # CHECK: %[[V1:.+]] = "custom.value"
99            # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
100            # CHECK-NOT: operandSegmentSizes
101            # CHECK-NOT: resultSegmentSizes
102            # CHECK-SAME: : (i32, i32) -> (i8, i16)
103            print(m)
104
105
106run(testOdsBuildDefaultNonVariadic)
107
108
109def testOdsBuildDefaultSizedVariadic():
110    class TestOp(OpView):
111        OPERATION_NAME = "custom.test_op"
112        _ODS_OPERAND_SEGMENTS = [1, -1, 0]
113        _ODS_RESULT_SEGMENTS = [-1, 0, 1]
114
115    with Context() as ctx, Location.unknown():
116        ctx.allow_unregistered_dialects = True
117        m = Module.create()
118        with InsertionPoint(m.body):
119            v0 = add_dummy_value()
120            v1 = add_dummy_value()
121            v2 = add_dummy_value()
122            v3 = add_dummy_value()
123            t0 = IntegerType.get_signless(8)
124            t1 = IntegerType.get_signless(16)
125            t2 = IntegerType.get_signless(32)
126            t3 = IntegerType.get_signless(64)
127            # CHECK: %[[V0:.+]] = "custom.value"
128            # CHECK: %[[V1:.+]] = "custom.value"
129            # CHECK: %[[V2:.+]] = "custom.value"
130            # CHECK: %[[V3:.+]] = "custom.value"
131            # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
132            # CHECK-SAME: operandSegmentSizes = array<i32: 1, 2, 1>
133            # CHECK-SAME: resultSegmentSizes = array<i32: 2, 1, 1>
134            # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
135            op = TestOp.build_generic(
136                results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3]
137            )
138
139            # Now test with optional omitted.
140            # CHECK: "custom.test_op"(%[[V0]])
141            # CHECK-SAME: operandSegmentSizes = array<i32: 1, 0, 0>
142            # CHECK-SAME: resultSegmentSizes = array<i32: 0, 0, 1>
143            # CHECK-SAME: (i32) -> i64
144            op = TestOp.build_generic(
145                results=[None, None, t3], operands=[v0, None, None]
146            )
147            print(m)
148
149            # And verify that errors are raised for None in a required operand.
150            try:
151                op = TestOp.build_generic(
152                    results=[None, None, t3], operands=[None, None, None]
153                )
154            except ValueError as e:
155                # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
156                print(f"OPERAND_CAST_ERROR:{e}")
157
158            # And verify that errors are raised for None in a required result.
159            try:
160                op = TestOp.build_generic(
161                    results=[None, None, None], operands=[v0, None, None]
162                )
163            except ValueError as e:
164                # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
165                print(f"RESULT_CAST_ERROR:{e}")
166
167            # Variadic lists with None elements should reject.
168            try:
169                op = TestOp.build_generic(
170                    results=[None, None, t3], operands=[v0, [None], None]
171                )
172            except ValueError as e:
173                # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
174                print(f"OPERAND_LIST_CAST_ERROR:{e}")
175            try:
176                op = TestOp.build_generic(
177                    results=[[None], None, t3], operands=[v0, None, None]
178                )
179            except ValueError as e:
180                # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
181                print(f"RESULT_LIST_CAST_ERROR:{e}")
182
183
184run(testOdsBuildDefaultSizedVariadic)
185
186
187def testOdsBuildDefaultCastError():
188    class TestOp(OpView):
189        OPERATION_NAME = "custom.test_op"
190
191    with Context() as ctx, Location.unknown():
192        ctx.allow_unregistered_dialects = True
193        m = Module.create()
194        with InsertionPoint(m.body):
195            v0 = add_dummy_value()
196            v1 = add_dummy_value()
197            t0 = IntegerType.get_signless(8)
198            t1 = IntegerType.get_signless(16)
199            try:
200                op = TestOp.build_generic(results=[t0, t1], operands=[None, v1])
201            except ValueError as e:
202                # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
203                print(f"ERROR: {e}")
204            try:
205                op = TestOp.build_generic(results=[t0, None], operands=[v0, v1])
206            except ValueError as e:
207                # CHECK: Result 1 of operation "custom.test_op" must be a Type
208                print(f"ERROR: {e}")
209
210
211run(testOdsBuildDefaultCastError)
212
213
214def testOdsEquallySizedAccessor():
215    class TestOpMultiResultSegments(OpView):
216        OPERATION_NAME = "custom.test_op"
217        _ODS_REGIONS = (1, True)
218
219    with Context() as ctx, Location.unknown():
220        ctx.allow_unregistered_dialects = True
221        m = Module.create()
222        with InsertionPoint(m.body):
223            v = add_dummy_value()
224            ts = [IntegerType.get_signless(i * 8) for i in range(4)]
225
226            op = TestOpMultiResultSegments.build_generic(
227                results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
228            )
229            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
230            # CHECK: start: 1, elements_per_group: 1
231            print(f"start: {start}, elements_per_group: {elements_per_group}")
232            # CHECK: i8
233            print(op.results[start].type)
234
235            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
236            # CHECK: start: 2, elements_per_group: 1
237            print(f"start: {start}, elements_per_group: {elements_per_group}")
238            # CHECK: i16
239            print(op.results[start].type)
240
241
242run(testOdsEquallySizedAccessor)
243
244
245def testOdsEquallySizedAccessorMultipleSegments():
246    class TestOpMultiResultSegments(OpView):
247        OPERATION_NAME = "custom.test_op"
248        _ODS_REGIONS = (1, True)
249        _ODS_RESULT_SEGMENTS = [0, -1, -1]
250
251    def types(lst):
252        return [e.type for e in lst]
253
254    with Context() as ctx, Location.unknown():
255        ctx.allow_unregistered_dialects = True
256        m = Module.create()
257        with InsertionPoint(m.body):
258            v = add_dummy_value()
259            ts = [IntegerType.get_signless(i * 8) for i in range(7)]
260
261            op = TestOpMultiResultSegments.build_generic(
262                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
263                operands=[v],
264            )
265            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
266            # CHECK: start: 1, elements_per_group: 3
267            print(f"start: {start}, elements_per_group: {elements_per_group}")
268            # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
269            print(types(op.results[start : start + elements_per_group]))
270
271            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
272            # CHECK: start: 4, elements_per_group: 3
273            print(f"start: {start}, elements_per_group: {elements_per_group}")
274            # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
275            print(types(op.results[start : start + elements_per_group]))
276
277
278run(testOdsEquallySizedAccessorMultipleSegments)
279