xref: /llvm-project/mlir/test/python/dialects/python_test.py (revision f9008e6366c2496b1ca1785b891d5578174ad63e)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.func as func
5import mlir.dialects.python_test as test
6import mlir.dialects.tensor as tensor
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    f()
12    return f
13
14
15# CHECK-LABEL: TEST: testAttributes
16@run
17def testAttributes():
18    with Context() as ctx, Location.unknown():
19        ctx.allow_unregistered_dialects = True
20
21        #
22        # Check op construction with attributes.
23        #
24
25        i32 = IntegerType.get_signless(32)
26        one = IntegerAttr.get(i32, 1)
27        two = IntegerAttr.get(i32, 2)
28        unit = UnitAttr.get()
29
30        # CHECK: "python_test.attributed_op"() {
31        # CHECK-DAG: mandatory_i32 = 1 : i32
32        # CHECK-DAG: optional_i32 = 2 : i32
33        # CHECK-DAG: unit
34        # CHECK: }
35        op = test.AttributedOp(one, optional_i32=two, unit=unit)
36        print(f"{op}")
37
38        # CHECK: "python_test.attributed_op"() {
39        # CHECK: mandatory_i32 = 2 : i32
40        # CHECK: }
41        op2 = test.AttributedOp(two)
42        print(f"{op2}")
43
44        #
45        # Check generic "attributes" access and mutation.
46        #
47
48        assert "additional" not in op.attributes
49
50        # CHECK: "python_test.attributed_op"() {
51        # CHECK-DAG: additional = 1 : i32
52        # CHECK-DAG: mandatory_i32 = 2 : i32
53        # CHECK: }
54        op2.attributes["additional"] = one
55        print(f"{op2}")
56
57        # CHECK: "python_test.attributed_op"() {
58        # CHECK-DAG: additional = 2 : i32
59        # CHECK-DAG: mandatory_i32 = 2 : i32
60        # CHECK: }
61        op2.attributes["additional"] = two
62        print(f"{op2}")
63
64        # CHECK: "python_test.attributed_op"() {
65        # CHECK-NOT: additional = 2 : i32
66        # CHECK:     mandatory_i32 = 2 : i32
67        # CHECK: }
68        del op2.attributes["additional"]
69        print(f"{op2}")
70
71        try:
72            print(op.attributes["additional"])
73        except KeyError:
74            pass
75        else:
76            assert False, "expected KeyError on unknown attribute key"
77
78        #
79        # Check accessors to defined attributes.
80        #
81
82        # CHECK: Mandatory: 1
83        # CHECK: Optional: 2
84        # CHECK: Unit: True
85        print(f"Mandatory: {op.mandatory_i32.value}")
86        print(f"Optional: {op.optional_i32.value}")
87        print(f"Unit: {op.unit}")
88
89        # CHECK: Mandatory: 2
90        # CHECK: Optional: None
91        # CHECK: Unit: False
92        print(f"Mandatory: {op2.mandatory_i32.value}")
93        print(f"Optional: {op2.optional_i32}")
94        print(f"Unit: {op2.unit}")
95
96        # CHECK: Mandatory: 2
97        # CHECK: Optional: None
98        # CHECK: Unit: False
99        op.mandatory_i32 = two
100        op.optional_i32 = None
101        op.unit = False
102        print(f"Mandatory: {op.mandatory_i32.value}")
103        print(f"Optional: {op.optional_i32}")
104        print(f"Unit: {op.unit}")
105        assert "optional_i32" not in op.attributes
106        assert "unit" not in op.attributes
107
108        try:
109            op.mandatory_i32 = None
110        except ValueError:
111            pass
112        else:
113            assert False, "expected ValueError on setting a mandatory attribute to None"
114
115        # CHECK: Optional: 2
116        op.optional_i32 = two
117        print(f"Optional: {op.optional_i32.value}")
118
119        # CHECK: Optional: None
120        del op.optional_i32
121        print(f"Optional: {op.optional_i32}")
122
123        # CHECK: Unit: False
124        op.unit = None
125        print(f"Unit: {op.unit}")
126        assert "unit" not in op.attributes
127
128        # CHECK: Unit: True
129        op.unit = True
130        print(f"Unit: {op.unit}")
131
132        # CHECK: Unit: False
133        del op.unit
134        print(f"Unit: {op.unit}")
135
136
137# CHECK-LABEL: TEST: attrBuilder
138@run
139def attrBuilder():
140    with Context() as ctx, Location.unknown():
141        ctx.allow_unregistered_dialects = True
142        op = test.AttributesOp(
143            x_bool=True,
144            x_i16=1,
145            x_i32=2,
146            x_i64=3,
147            x_si16=-1,
148            x_si32=-2,
149            x_f32=1.5,
150            x_f64=2.5,
151            x_str="x_str",
152            x_i32_array=[1, 2, 3],
153            x_i64_array=[4, 5, 6],
154            x_f32_array=[1.5, -2.5, 3.5],
155            x_f64_array=[4.5, 5.5, -6.5],
156            x_i64_dense=[1, 2, 3, 4, 5, 6],
157        )
158        print(op)
159
160
161# CHECK-LABEL: TEST: inferReturnTypes
162@run
163def inferReturnTypes():
164    with Context() as ctx, Location.unknown(ctx):
165        test.register_python_test_dialect(ctx)
166        module = Module.create()
167        with InsertionPoint(module.body):
168            op = test.InferResultsOp()
169            dummy = test.DummyOp()
170
171        # CHECK: [Type(i32), Type(i64)]
172        iface = InferTypeOpInterface(op)
173        print(iface.inferReturnTypes())
174
175        # CHECK: [Type(i32), Type(i64)]
176        iface_static = InferTypeOpInterface(test.InferResultsOp)
177        print(iface.inferReturnTypes())
178
179        assert isinstance(iface.opview, test.InferResultsOp)
180        assert iface.opview == iface.operation.opview
181
182        try:
183            iface_static.opview
184        except TypeError:
185            pass
186        else:
187            assert False, (
188                "not expected to be able to obtain an opview from a static" " interface"
189            )
190
191        try:
192            InferTypeOpInterface(dummy)
193        except ValueError:
194            pass
195        else:
196            assert False, "not expected dummy op to implement the interface"
197
198        try:
199            InferTypeOpInterface(test.DummyOp)
200        except ValueError:
201            pass
202        else:
203            assert False, "not expected dummy op class to implement the interface"
204
205
206# CHECK-LABEL: TEST: resultTypesDefinedByTraits
207@run
208def resultTypesDefinedByTraits():
209    with Context() as ctx, Location.unknown(ctx):
210        test.register_python_test_dialect(ctx)
211        module = Module.create()
212        with InsertionPoint(module.body):
213            inferred = test.InferResultsOp()
214            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
215            # CHECK-COUNT-2: i32
216            print(same.one.type)
217            print(same.two.type)
218
219            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
220                inferred.results[1], TypeAttr.get(IndexType.get())
221            )
222            # CHECK-COUNT-2: index
223            print(first_type_attr.one.type)
224            print(first_type_attr.two.type)
225
226            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
227            # CHECK-COUNT-3: f32
228            print(first_attr.one.type)
229            print(first_attr.two.type)
230            print(first_attr.three.type)
231
232            implied = test.InferResultsImpliedOp()
233            # CHECK: i32
234            print(implied.integer.type)
235            # CHECK: f64
236            print(implied.flt.type)
237            # CHECK: index
238            print(implied.index.type)
239
240
241# CHECK-LABEL: TEST: testOptionalOperandOp
242@run
243def testOptionalOperandOp():
244    with Context() as ctx, Location.unknown():
245        test.register_python_test_dialect(ctx)
246
247        module = Module.create()
248        with InsertionPoint(module.body):
249
250            op1 = test.OptionalOperandOp()
251            # CHECK: op1.input is None: True
252            print(f"op1.input is None: {op1.input is None}")
253
254            op2 = test.OptionalOperandOp(input=op1)
255            # CHECK: op2.input is None: False
256            print(f"op2.input is None: {op2.input is None}")
257
258
259# CHECK-LABEL: TEST: testCustomAttribute
260@run
261def testCustomAttribute():
262    with Context() as ctx:
263        test.register_python_test_dialect(ctx)
264        a = test.TestAttr.get()
265        # CHECK: #python_test.test_attr
266        print(a)
267
268        # The following cast must not assert.
269        b = test.TestAttr(a)
270
271        unit = UnitAttr.get()
272        try:
273            test.TestAttr(unit)
274        except ValueError as e:
275            assert "Cannot cast attribute to TestAttr" in str(e)
276        else:
277            raise
278
279        # The following must trigger a TypeError from our adaptors and must not
280        # crash.
281        try:
282            test.TestAttr(42)
283        except TypeError as e:
284            assert "Expected an MLIR object" in str(e)
285        else:
286            raise
287
288        # The following must trigger a TypeError from pybind (therefore, not
289        # checking its message) and must not crash.
290        try:
291            test.TestAttr(42, 56)
292        except TypeError:
293            pass
294        else:
295            raise
296
297
298@run
299def testCustomType():
300    with Context() as ctx:
301        test.register_python_test_dialect(ctx)
302        a = test.TestType.get()
303        # CHECK: !python_test.test_type
304        print(a)
305
306        # The following cast must not assert.
307        b = test.TestType(a)
308        # Instance custom types should have typeids
309        assert isinstance(b.typeid, TypeID)
310        # Subclasses of ir.Type should not have a static_typeid
311        # CHECK: 'TestType' object has no attribute 'static_typeid'
312        try:
313            b.static_typeid
314        except AttributeError as e:
315            print(e)
316
317        i8 = IntegerType.get_signless(8)
318        try:
319            test.TestType(i8)
320        except ValueError as e:
321            assert "Cannot cast type to TestType" in str(e)
322        else:
323            raise
324
325        # The following must trigger a TypeError from our adaptors and must not
326        # crash.
327        try:
328            test.TestType(42)
329        except TypeError as e:
330            assert "Expected an MLIR object" in str(e)
331        else:
332            raise
333
334        # The following must trigger a TypeError from pybind (therefore, not
335        # checking its message) and must not crash.
336        try:
337            test.TestType(42, 56)
338        except TypeError:
339            pass
340        else:
341            raise
342
343
344@run
345# CHECK-LABEL: TEST: testTensorValue
346def testTensorValue():
347    with Context() as ctx, Location.unknown():
348        test.register_python_test_dialect(ctx)
349
350        i8 = IntegerType.get_signless(8)
351
352        class Tensor(test.TestTensorValue):
353            def __str__(self):
354                return super().__str__().replace("Value", "Tensor")
355
356        module = Module.create()
357        with InsertionPoint(module.body):
358            t = tensor.EmptyOp([10, 10], i8).result
359
360            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
361            print(Value(t))
362
363            tt = Tensor(t)
364            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
365            print(tt)
366
367            # CHECK: False
368            print(tt.is_null())
369
370            # Classes of custom types that inherit from concrete types should have
371            # static_typeid
372            assert isinstance(test.TestTensorType.static_typeid, TypeID)
373            # And it should be equal to the in-tree concrete type
374            assert test.TestTensorType.static_typeid == t.type.typeid
375
376
377# CHECK-LABEL: TEST: inferReturnTypeComponents
378@run
379def inferReturnTypeComponents():
380    with Context() as ctx, Location.unknown(ctx):
381        test.register_python_test_dialect(ctx)
382        module = Module.create()
383        i32 = IntegerType.get_signless(32)
384        with InsertionPoint(module.body):
385            resultType = UnrankedTensorType.get(i32)
386            operandTypes = [
387                RankedTensorType.get([1, 3, 10, 10], i32),
388                UnrankedTensorType.get(i32),
389            ]
390            f = func.FuncOp(
391                "test_inferReturnTypeComponents", (operandTypes, [resultType])
392            )
393            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
394            with InsertionPoint(entry_block):
395                ranked_op = test.InferShapedTypeComponentsOp(
396                    resultType, entry_block.arguments[0]
397                )
398                unranked_op = test.InferShapedTypeComponentsOp(
399                    resultType, entry_block.arguments[1]
400                )
401
402        # CHECK: has rank: True
403        # CHECK: rank: 4
404        # CHECK: element type: i32
405        # CHECK: shape: [1, 3, 10, 10]
406        iface = InferShapedTypeOpInterface(ranked_op)
407        shaped_type_components = iface.inferReturnTypeComponents(
408            operands=[ranked_op.operand]
409        )[0]
410        print("has rank:", shaped_type_components.has_rank)
411        print("rank:", shaped_type_components.rank)
412        print("element type:", shaped_type_components.element_type)
413        print("shape:", shaped_type_components.shape)
414
415        # CHECK: has rank: False
416        # CHECK: rank: None
417        # CHECK: element type: i32
418        # CHECK: shape: None
419        iface = InferShapedTypeOpInterface(unranked_op)
420        shaped_type_components = iface.inferReturnTypeComponents(
421            operands=[unranked_op.operand]
422        )[0]
423        print("has rank:", shaped_type_components.has_rank)
424        print("rank:", shaped_type_components.rank)
425        print("element type:", shaped_type_components.element_type)
426        print("shape:", shaped_type_components.shape)
427