xref: /llvm-project/mlir/test/python/dialects/ods_helpers.py (revision 3766ba44a8945681f4c52acb0331efcff66ef7b1)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
39f3f6d7bSStella Laurenzoimport gc
458a47508SJeff Niu
59f3f6d7bSStella Laurenzofrom mlir.ir import *
6*3766ba44SKasper Nielsenfrom mlir.dialects._ods_common import equally_sized_accessor
79f3f6d7bSStella Laurenzo
858a47508SJeff Niu
99f3f6d7bSStella Laurenzodef run(f):
109f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
119f3f6d7bSStella Laurenzo    f()
129f3f6d7bSStella Laurenzo    gc.collect()
139f3f6d7bSStella Laurenzo    assert Context._get_live_count() == 0
149f3f6d7bSStella Laurenzo
159f3f6d7bSStella Laurenzo
169f3f6d7bSStella Laurenzodef add_dummy_value():
179f3f6d7bSStella Laurenzo    return Operation.create(
18f9008e63STobias Hieta        "custom.value", results=[IntegerType.get_signless(32)]
19f9008e63STobias Hieta    ).result
209f3f6d7bSStella Laurenzo
219f3f6d7bSStella Laurenzo
229f3f6d7bSStella Laurenzodef testOdsBuildDefaultImplicitRegions():
239f3f6d7bSStella Laurenzo    class TestFixedRegionsOp(OpView):
249f3f6d7bSStella Laurenzo        OPERATION_NAME = "custom.test_op"
259f3f6d7bSStella Laurenzo        _ODS_REGIONS = (2, True)
269f3f6d7bSStella Laurenzo
279f3f6d7bSStella Laurenzo    class TestVariadicRegionsOp(OpView):
289f3f6d7bSStella Laurenzo        OPERATION_NAME = "custom.test_any_regions_op"
299f3f6d7bSStella Laurenzo        _ODS_REGIONS = (2, False)
309f3f6d7bSStella Laurenzo
319f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
329f3f6d7bSStella Laurenzo        ctx.allow_unregistered_dialects = True
339f3f6d7bSStella Laurenzo        m = Module.create()
349f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
359f3f6d7bSStella Laurenzo            op = TestFixedRegionsOp.build_generic(results=[], operands=[])
369f3f6d7bSStella Laurenzo            # CHECK: NUM_REGIONS: 2
379f3f6d7bSStella Laurenzo            print(f"NUM_REGIONS: {len(op.regions)}")
389f3f6d7bSStella Laurenzo            # Including a regions= that matches should be fine.
399f3f6d7bSStella Laurenzo            op = TestFixedRegionsOp.build_generic(results=[], operands=[], regions=2)
409f3f6d7bSStella Laurenzo            print(f"NUM_REGIONS: {len(op.regions)}")
419f3f6d7bSStella Laurenzo            # Reject greater than.
429f3f6d7bSStella Laurenzo            try:
43f9008e63STobias Hieta                op = TestFixedRegionsOp.build_generic(
44f9008e63STobias Hieta                    results=[], operands=[], regions=3
45f9008e63STobias Hieta                )
469f3f6d7bSStella Laurenzo            except ValueError as e:
479f3f6d7bSStella Laurenzo                # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3
489f3f6d7bSStella Laurenzo                print(f"ERROR:{e}")
499f3f6d7bSStella Laurenzo            # Reject less than.
509f3f6d7bSStella Laurenzo            try:
51f9008e63STobias Hieta                op = TestFixedRegionsOp.build_generic(
52f9008e63STobias Hieta                    results=[], operands=[], regions=1
53f9008e63STobias Hieta                )
549f3f6d7bSStella Laurenzo            except ValueError as e:
559f3f6d7bSStella Laurenzo                # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1
569f3f6d7bSStella Laurenzo                print(f"ERROR:{e}")
579f3f6d7bSStella Laurenzo
589f3f6d7bSStella Laurenzo            # If no regions specified for a variadic region op, build the minimum.
599f3f6d7bSStella Laurenzo            op = TestVariadicRegionsOp.build_generic(results=[], operands=[])
609f3f6d7bSStella Laurenzo            # CHECK: DEFAULT_NUM_REGIONS: 2
619f3f6d7bSStella Laurenzo            print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}")
629f3f6d7bSStella Laurenzo            # Should also accept an explicit regions= that matches the minimum.
63f9008e63STobias Hieta            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=2)
649f3f6d7bSStella Laurenzo            # CHECK: EQ_NUM_REGIONS: 2
659f3f6d7bSStella Laurenzo            print(f"EQ_NUM_REGIONS: {len(op.regions)}")
669f3f6d7bSStella Laurenzo            # And accept greater than minimum.
679f3f6d7bSStella Laurenzo            # Should also accept an explicit regions= that matches the minimum.
68f9008e63STobias Hieta            op = TestVariadicRegionsOp.build_generic(results=[], operands=[], regions=3)
699f3f6d7bSStella Laurenzo            # CHECK: GT_NUM_REGIONS: 3
709f3f6d7bSStella Laurenzo            print(f"GT_NUM_REGIONS: {len(op.regions)}")
719f3f6d7bSStella Laurenzo            # Should reject less than minimum.
729f3f6d7bSStella Laurenzo            try:
73f9008e63STobias Hieta                op = TestVariadicRegionsOp.build_generic(
74f9008e63STobias Hieta                    results=[], operands=[], regions=1
75f9008e63STobias Hieta                )
769f3f6d7bSStella Laurenzo            except ValueError as e:
779f3f6d7bSStella Laurenzo                # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1
789f3f6d7bSStella Laurenzo                print(f"ERROR:{e}")
799f3f6d7bSStella Laurenzo
809f3f6d7bSStella Laurenzo
819f3f6d7bSStella Laurenzorun(testOdsBuildDefaultImplicitRegions)
829f3f6d7bSStella Laurenzo
839f3f6d7bSStella Laurenzo
849f3f6d7bSStella Laurenzodef testOdsBuildDefaultNonVariadic():
859f3f6d7bSStella Laurenzo    class TestOp(OpView):
869f3f6d7bSStella Laurenzo        OPERATION_NAME = "custom.test_op"
879f3f6d7bSStella Laurenzo
889f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
899f3f6d7bSStella Laurenzo        ctx.allow_unregistered_dialects = True
909f3f6d7bSStella Laurenzo        m = Module.create()
919f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
929f3f6d7bSStella Laurenzo            v0 = add_dummy_value()
939f3f6d7bSStella Laurenzo            v1 = add_dummy_value()
949f3f6d7bSStella Laurenzo            t0 = IntegerType.get_signless(8)
959f3f6d7bSStella Laurenzo            t1 = IntegerType.get_signless(16)
969f3f6d7bSStella Laurenzo            op = TestOp.build_generic(results=[t0, t1], operands=[v0, v1])
979f3f6d7bSStella Laurenzo            # CHECK: %[[V0:.+]] = "custom.value"
989f3f6d7bSStella Laurenzo            # CHECK: %[[V1:.+]] = "custom.value"
999f3f6d7bSStella Laurenzo            # CHECK: "custom.test_op"(%[[V0]], %[[V1]])
100363b6559SMehdi Amini            # CHECK-NOT: operandSegmentSizes
101363b6559SMehdi Amini            # CHECK-NOT: resultSegmentSizes
1029f3f6d7bSStella Laurenzo            # CHECK-SAME: : (i32, i32) -> (i8, i16)
1039f3f6d7bSStella Laurenzo            print(m)
1049f3f6d7bSStella Laurenzo
105f9008e63STobias Hieta
1069f3f6d7bSStella Laurenzorun(testOdsBuildDefaultNonVariadic)
1079f3f6d7bSStella Laurenzo
1089f3f6d7bSStella Laurenzo
1099f3f6d7bSStella Laurenzodef testOdsBuildDefaultSizedVariadic():
1109f3f6d7bSStella Laurenzo    class TestOp(OpView):
1119f3f6d7bSStella Laurenzo        OPERATION_NAME = "custom.test_op"
1129f3f6d7bSStella Laurenzo        _ODS_OPERAND_SEGMENTS = [1, -1, 0]
1139f3f6d7bSStella Laurenzo        _ODS_RESULT_SEGMENTS = [-1, 0, 1]
1149f3f6d7bSStella Laurenzo
1159f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
1169f3f6d7bSStella Laurenzo        ctx.allow_unregistered_dialects = True
1179f3f6d7bSStella Laurenzo        m = Module.create()
1189f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
1199f3f6d7bSStella Laurenzo            v0 = add_dummy_value()
1209f3f6d7bSStella Laurenzo            v1 = add_dummy_value()
1219f3f6d7bSStella Laurenzo            v2 = add_dummy_value()
1229f3f6d7bSStella Laurenzo            v3 = add_dummy_value()
1239f3f6d7bSStella Laurenzo            t0 = IntegerType.get_signless(8)
1249f3f6d7bSStella Laurenzo            t1 = IntegerType.get_signless(16)
1259f3f6d7bSStella Laurenzo            t2 = IntegerType.get_signless(32)
1269f3f6d7bSStella Laurenzo            t3 = IntegerType.get_signless(64)
1279f3f6d7bSStella Laurenzo            # CHECK: %[[V0:.+]] = "custom.value"
1289f3f6d7bSStella Laurenzo            # CHECK: %[[V1:.+]] = "custom.value"
1299f3f6d7bSStella Laurenzo            # CHECK: %[[V2:.+]] = "custom.value"
1309f3f6d7bSStella Laurenzo            # CHECK: %[[V3:.+]] = "custom.value"
1319f3f6d7bSStella Laurenzo            # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
132363b6559SMehdi Amini            # CHECK-SAME: operandSegmentSizes = array<i32: 1, 2, 1>
133363b6559SMehdi Amini            # CHECK-SAME: resultSegmentSizes = array<i32: 2, 1, 1>
1349f3f6d7bSStella Laurenzo            # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
1359f3f6d7bSStella Laurenzo            op = TestOp.build_generic(
136f9008e63STobias Hieta                results=[[t0, t1], t2, t3], operands=[v0, [v1, v2], v3]
137f9008e63STobias Hieta            )
1389f3f6d7bSStella Laurenzo
1399f3f6d7bSStella Laurenzo            # Now test with optional omitted.
1409f3f6d7bSStella Laurenzo            # CHECK: "custom.test_op"(%[[V0]])
141363b6559SMehdi Amini            # CHECK-SAME: operandSegmentSizes = array<i32: 1, 0, 0>
142363b6559SMehdi Amini            # CHECK-SAME: resultSegmentSizes = array<i32: 0, 0, 1>
1439f3f6d7bSStella Laurenzo            # CHECK-SAME: (i32) -> i64
1449f3f6d7bSStella Laurenzo            op = TestOp.build_generic(
145f9008e63STobias Hieta                results=[None, None, t3], operands=[v0, None, None]
146f9008e63STobias Hieta            )
1479f3f6d7bSStella Laurenzo            print(m)
1489f3f6d7bSStella Laurenzo
1499f3f6d7bSStella Laurenzo            # And verify that errors are raised for None in a required operand.
1509f3f6d7bSStella Laurenzo            try:
1519f3f6d7bSStella Laurenzo                op = TestOp.build_generic(
152f9008e63STobias Hieta                    results=[None, None, t3], operands=[None, None, None]
153f9008e63STobias Hieta                )
1549f3f6d7bSStella Laurenzo            except ValueError as e:
1559f3f6d7bSStella Laurenzo                # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional)
1569f3f6d7bSStella Laurenzo                print(f"OPERAND_CAST_ERROR:{e}")
1579f3f6d7bSStella Laurenzo
1589f3f6d7bSStella Laurenzo            # And verify that errors are raised for None in a required result.
1599f3f6d7bSStella Laurenzo            try:
1609f3f6d7bSStella Laurenzo                op = TestOp.build_generic(
161f9008e63STobias Hieta                    results=[None, None, None], operands=[v0, None, None]
162f9008e63STobias Hieta                )
1639f3f6d7bSStella Laurenzo            except ValueError as e:
1649f3f6d7bSStella Laurenzo                # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional)
1659f3f6d7bSStella Laurenzo                print(f"RESULT_CAST_ERROR:{e}")
1669f3f6d7bSStella Laurenzo
1679f3f6d7bSStella Laurenzo            # Variadic lists with None elements should reject.
1689f3f6d7bSStella Laurenzo            try:
1699f3f6d7bSStella Laurenzo                op = TestOp.build_generic(
170f9008e63STobias Hieta                    results=[None, None, t3], operands=[v0, [None], None]
171f9008e63STobias Hieta                )
1729f3f6d7bSStella Laurenzo            except ValueError as e:
1739f3f6d7bSStella Laurenzo                # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item)
1749f3f6d7bSStella Laurenzo                print(f"OPERAND_LIST_CAST_ERROR:{e}")
1759f3f6d7bSStella Laurenzo            try:
1769f3f6d7bSStella Laurenzo                op = TestOp.build_generic(
177f9008e63STobias Hieta                    results=[[None], None, t3], operands=[v0, None, None]
178f9008e63STobias Hieta                )
1799f3f6d7bSStella Laurenzo            except ValueError as e:
1809f3f6d7bSStella Laurenzo                # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item)
1819f3f6d7bSStella Laurenzo                print(f"RESULT_LIST_CAST_ERROR:{e}")
1829f3f6d7bSStella Laurenzo
183f9008e63STobias Hieta
1849f3f6d7bSStella Laurenzorun(testOdsBuildDefaultSizedVariadic)
1859f3f6d7bSStella Laurenzo
1869f3f6d7bSStella Laurenzo
1879f3f6d7bSStella Laurenzodef testOdsBuildDefaultCastError():
1889f3f6d7bSStella Laurenzo    class TestOp(OpView):
1899f3f6d7bSStella Laurenzo        OPERATION_NAME = "custom.test_op"
1909f3f6d7bSStella Laurenzo
1919f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
1929f3f6d7bSStella Laurenzo        ctx.allow_unregistered_dialects = True
1939f3f6d7bSStella Laurenzo        m = Module.create()
1949f3f6d7bSStella Laurenzo        with InsertionPoint(m.body):
1959f3f6d7bSStella Laurenzo            v0 = add_dummy_value()
1969f3f6d7bSStella Laurenzo            v1 = add_dummy_value()
1979f3f6d7bSStella Laurenzo            t0 = IntegerType.get_signless(8)
1989f3f6d7bSStella Laurenzo            t1 = IntegerType.get_signless(16)
1999f3f6d7bSStella Laurenzo            try:
200f9008e63STobias Hieta                op = TestOp.build_generic(results=[t0, t1], operands=[None, v1])
2019f3f6d7bSStella Laurenzo            except ValueError as e:
2029f3f6d7bSStella Laurenzo                # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value
2039f3f6d7bSStella Laurenzo                print(f"ERROR: {e}")
2049f3f6d7bSStella Laurenzo            try:
205f9008e63STobias Hieta                op = TestOp.build_generic(results=[t0, None], operands=[v0, v1])
2069f3f6d7bSStella Laurenzo            except ValueError as e:
2079f3f6d7bSStella Laurenzo                # CHECK: Result 1 of operation "custom.test_op" must be a Type
2089f3f6d7bSStella Laurenzo                print(f"ERROR: {e}")
2099f3f6d7bSStella Laurenzo
210f9008e63STobias Hieta
2119f3f6d7bSStella Laurenzorun(testOdsBuildDefaultCastError)
212*3766ba44SKasper Nielsen
213*3766ba44SKasper Nielsen
214*3766ba44SKasper Nielsendef testOdsEquallySizedAccessor():
215*3766ba44SKasper Nielsen    class TestOpMultiResultSegments(OpView):
216*3766ba44SKasper Nielsen        OPERATION_NAME = "custom.test_op"
217*3766ba44SKasper Nielsen        _ODS_REGIONS = (1, True)
218*3766ba44SKasper Nielsen
219*3766ba44SKasper Nielsen    with Context() as ctx, Location.unknown():
220*3766ba44SKasper Nielsen        ctx.allow_unregistered_dialects = True
221*3766ba44SKasper Nielsen        m = Module.create()
222*3766ba44SKasper Nielsen        with InsertionPoint(m.body):
223*3766ba44SKasper Nielsen            v = add_dummy_value()
224*3766ba44SKasper Nielsen            ts = [IntegerType.get_signless(i * 8) for i in range(4)]
225*3766ba44SKasper Nielsen
226*3766ba44SKasper Nielsen            op = TestOpMultiResultSegments.build_generic(
227*3766ba44SKasper Nielsen                results=[ts[0], ts[1], ts[2], ts[3]], operands=[v]
228*3766ba44SKasper Nielsen            )
229*3766ba44SKasper Nielsen            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 0)
230*3766ba44SKasper Nielsen            # CHECK: start: 1, elements_per_group: 1
231*3766ba44SKasper Nielsen            print(f"start: {start}, elements_per_group: {elements_per_group}")
232*3766ba44SKasper Nielsen            # CHECK: i8
233*3766ba44SKasper Nielsen            print(op.results[start].type)
234*3766ba44SKasper Nielsen
235*3766ba44SKasper Nielsen            start, elements_per_group = equally_sized_accessor(op.results, 1, 3, 1, 1)
236*3766ba44SKasper Nielsen            # CHECK: start: 2, elements_per_group: 1
237*3766ba44SKasper Nielsen            print(f"start: {start}, elements_per_group: {elements_per_group}")
238*3766ba44SKasper Nielsen            # CHECK: i16
239*3766ba44SKasper Nielsen            print(op.results[start].type)
240*3766ba44SKasper Nielsen
241*3766ba44SKasper Nielsen
242*3766ba44SKasper Nielsenrun(testOdsEquallySizedAccessor)
243*3766ba44SKasper Nielsen
244*3766ba44SKasper Nielsen
245*3766ba44SKasper Nielsendef testOdsEquallySizedAccessorMultipleSegments():
246*3766ba44SKasper Nielsen    class TestOpMultiResultSegments(OpView):
247*3766ba44SKasper Nielsen        OPERATION_NAME = "custom.test_op"
248*3766ba44SKasper Nielsen        _ODS_REGIONS = (1, True)
249*3766ba44SKasper Nielsen        _ODS_RESULT_SEGMENTS = [0, -1, -1]
250*3766ba44SKasper Nielsen
251*3766ba44SKasper Nielsen    def types(lst):
252*3766ba44SKasper Nielsen        return [e.type for e in lst]
253*3766ba44SKasper Nielsen
254*3766ba44SKasper Nielsen    with Context() as ctx, Location.unknown():
255*3766ba44SKasper Nielsen        ctx.allow_unregistered_dialects = True
256*3766ba44SKasper Nielsen        m = Module.create()
257*3766ba44SKasper Nielsen        with InsertionPoint(m.body):
258*3766ba44SKasper Nielsen            v = add_dummy_value()
259*3766ba44SKasper Nielsen            ts = [IntegerType.get_signless(i * 8) for i in range(7)]
260*3766ba44SKasper Nielsen
261*3766ba44SKasper Nielsen            op = TestOpMultiResultSegments.build_generic(
262*3766ba44SKasper Nielsen                results=[ts[0], [ts[1], ts[2], ts[3]], [ts[4], ts[5], ts[6]]],
263*3766ba44SKasper Nielsen                operands=[v],
264*3766ba44SKasper Nielsen            )
265*3766ba44SKasper Nielsen            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 0)
266*3766ba44SKasper Nielsen            # CHECK: start: 1, elements_per_group: 3
267*3766ba44SKasper Nielsen            print(f"start: {start}, elements_per_group: {elements_per_group}")
268*3766ba44SKasper Nielsen            # CHECK: [IntegerType(i8), IntegerType(i16), IntegerType(i24)]
269*3766ba44SKasper Nielsen            print(types(op.results[start : start + elements_per_group]))
270*3766ba44SKasper Nielsen
271*3766ba44SKasper Nielsen            start, elements_per_group = equally_sized_accessor(op.results, 1, 2, 1, 1)
272*3766ba44SKasper Nielsen            # CHECK: start: 4, elements_per_group: 3
273*3766ba44SKasper Nielsen            print(f"start: {start}, elements_per_group: {elements_per_group}")
274*3766ba44SKasper Nielsen            # CHECK: [IntegerType(i32), IntegerType(i40), IntegerType(i48)]
275*3766ba44SKasper Nielsen            print(types(op.results[start : start + elements_per_group]))
276*3766ba44SKasper Nielsen
277*3766ba44SKasper Nielsen
278*3766ba44SKasper Nielsenrun(testOdsEquallySizedAccessorMultipleSegments)
279