xref: /llvm-project/mlir/test/python/dialects/python_test.py (revision e0ca7e99914609bbed0f30f4834a93d33dcef085)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.func as func
5import mlir.dialects.python_test as test
6import mlir.dialects.tensor as tensor
7import mlir.dialects.arith as arith
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    f()
13    return f
14
15
16# CHECK-LABEL: TEST: testAttributes
17@run
18def testAttributes():
19    with Context() as ctx, Location.unknown():
20        ctx.allow_unregistered_dialects = True
21
22        #
23        # Check op construction with attributes.
24        #
25
26        i32 = IntegerType.get_signless(32)
27        one = IntegerAttr.get(i32, 1)
28        two = IntegerAttr.get(i32, 2)
29        unit = UnitAttr.get()
30
31        # CHECK: "python_test.attributed_op"() {
32        # CHECK-DAG: mandatory_i32 = 1 : i32
33        # CHECK-DAG: optional_i32 = 2 : i32
34        # CHECK-DAG: unit
35        # CHECK: }
36        op = test.AttributedOp(one, optional_i32=two, unit=unit)
37        print(f"{op}")
38
39        # CHECK: "python_test.attributed_op"() {
40        # CHECK: mandatory_i32 = 2 : i32
41        # CHECK: }
42        op2 = test.AttributedOp(two)
43        print(f"{op2}")
44
45        #
46        # Check generic "attributes" access and mutation.
47        #
48
49        assert "additional" not in op.attributes
50
51        # CHECK: "python_test.attributed_op"() {
52        # CHECK-DAG: additional = 1 : i32
53        # CHECK-DAG: mandatory_i32 = 2 : i32
54        # CHECK: }
55        op2.attributes["additional"] = one
56        print(f"{op2}")
57
58        # CHECK: "python_test.attributed_op"() {
59        # CHECK-DAG: additional = 2 : i32
60        # CHECK-DAG: mandatory_i32 = 2 : i32
61        # CHECK: }
62        op2.attributes["additional"] = two
63        print(f"{op2}")
64
65        # CHECK: "python_test.attributed_op"() {
66        # CHECK-NOT: additional = 2 : i32
67        # CHECK:     mandatory_i32 = 2 : i32
68        # CHECK: }
69        del op2.attributes["additional"]
70        print(f"{op2}")
71
72        try:
73            print(op.attributes["additional"])
74        except KeyError:
75            pass
76        else:
77            assert False, "expected KeyError on unknown attribute key"
78
79        #
80        # Check accessors to defined attributes.
81        #
82
83        # CHECK: Mandatory: 1
84        # CHECK: Optional: 2
85        # CHECK: Unit: True
86        print(f"Mandatory: {op.mandatory_i32.value}")
87        print(f"Optional: {op.optional_i32.value}")
88        print(f"Unit: {op.unit}")
89
90        # CHECK: Mandatory: 2
91        # CHECK: Optional: None
92        # CHECK: Unit: False
93        print(f"Mandatory: {op2.mandatory_i32.value}")
94        print(f"Optional: {op2.optional_i32}")
95        print(f"Unit: {op2.unit}")
96
97        # CHECK: Mandatory: 2
98        # CHECK: Optional: None
99        # CHECK: Unit: False
100        op.mandatory_i32 = two
101        op.optional_i32 = None
102        op.unit = False
103        print(f"Mandatory: {op.mandatory_i32.value}")
104        print(f"Optional: {op.optional_i32}")
105        print(f"Unit: {op.unit}")
106        assert "optional_i32" not in op.attributes
107        assert "unit" not in op.attributes
108
109        try:
110            op.mandatory_i32 = None
111        except ValueError:
112            pass
113        else:
114            assert False, "expected ValueError on setting a mandatory attribute to None"
115
116        # CHECK: Optional: 2
117        op.optional_i32 = two
118        print(f"Optional: {op.optional_i32.value}")
119
120        # CHECK: Optional: None
121        del op.optional_i32
122        print(f"Optional: {op.optional_i32}")
123
124        # CHECK: Unit: False
125        op.unit = None
126        print(f"Unit: {op.unit}")
127        assert "unit" not in op.attributes
128
129        # CHECK: Unit: True
130        op.unit = True
131        print(f"Unit: {op.unit}")
132
133        # CHECK: Unit: False
134        del op.unit
135        print(f"Unit: {op.unit}")
136
137
138# CHECK-LABEL: TEST: attrBuilder
139@run
140def attrBuilder():
141    with Context() as ctx, Location.unknown():
142        ctx.allow_unregistered_dialects = True
143        op = test.AttributesOp(
144            x_bool=True,
145            x_i16=1,
146            x_i32=2,
147            x_i64=3,
148            x_si16=-1,
149            x_si32=-2,
150            x_f32=1.5,
151            x_f64=2.5,
152            x_str="x_str",
153            x_i32_array=[1, 2, 3],
154            x_i64_array=[4, 5, 6],
155            x_f32_array=[1.5, -2.5, 3.5],
156            x_f64_array=[4.5, 5.5, -6.5],
157            x_i64_dense=[1, 2, 3, 4, 5, 6],
158        )
159        print(op)
160
161
162# CHECK-LABEL: TEST: inferReturnTypes
163@run
164def inferReturnTypes():
165    with Context() as ctx, Location.unknown(ctx):
166        test.register_python_test_dialect(ctx)
167        module = Module.create()
168        with InsertionPoint(module.body):
169            op = test.InferResultsOp()
170            dummy = test.DummyOp()
171
172        # CHECK: [Type(i32), Type(i64)]
173        iface = InferTypeOpInterface(op)
174        print(iface.inferReturnTypes())
175
176        # CHECK: [Type(i32), Type(i64)]
177        iface_static = InferTypeOpInterface(test.InferResultsOp)
178        print(iface.inferReturnTypes())
179
180        assert isinstance(iface.opview, test.InferResultsOp)
181        assert iface.opview == iface.operation.opview
182
183        try:
184            iface_static.opview
185        except TypeError:
186            pass
187        else:
188            assert False, (
189                "not expected to be able to obtain an opview from a static" " interface"
190            )
191
192        try:
193            InferTypeOpInterface(dummy)
194        except ValueError:
195            pass
196        else:
197            assert False, "not expected dummy op to implement the interface"
198
199        try:
200            InferTypeOpInterface(test.DummyOp)
201        except ValueError:
202            pass
203        else:
204            assert False, "not expected dummy op class to implement the interface"
205
206
207# CHECK-LABEL: TEST: resultTypesDefinedByTraits
208@run
209def resultTypesDefinedByTraits():
210    with Context() as ctx, Location.unknown(ctx):
211        test.register_python_test_dialect(ctx)
212        module = Module.create()
213        with InsertionPoint(module.body):
214            inferred = test.InferResultsOp()
215            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
216            # CHECK-COUNT-2: i32
217            print(same.one.type)
218            print(same.two.type)
219
220            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
221                inferred.results[1], TypeAttr.get(IndexType.get())
222            )
223            # CHECK-COUNT-2: index
224            print(first_type_attr.one.type)
225            print(first_type_attr.two.type)
226
227            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
228            # CHECK-COUNT-3: f32
229            print(first_attr.one.type)
230            print(first_attr.two.type)
231            print(first_attr.three.type)
232
233            implied = test.InferResultsImpliedOp()
234            # CHECK: i32
235            print(implied.integer.type)
236            # CHECK: f64
237            print(implied.flt.type)
238            # CHECK: index
239            print(implied.index.type)
240
241
242# CHECK-LABEL: TEST: testOptionalOperandOp
243@run
244def testOptionalOperandOp():
245    with Context() as ctx, Location.unknown():
246        test.register_python_test_dialect(ctx)
247
248        module = Module.create()
249        with InsertionPoint(module.body):
250
251            op1 = test.OptionalOperandOp()
252            # CHECK: op1.input is None: True
253            print(f"op1.input is None: {op1.input is None}")
254
255            op2 = test.OptionalOperandOp(input=op1)
256            # CHECK: op2.input is None: False
257            print(f"op2.input is None: {op2.input is None}")
258
259
260# CHECK-LABEL: TEST: testCustomAttribute
261@run
262def testCustomAttribute():
263    with Context() as ctx:
264        test.register_python_test_dialect(ctx)
265        a = test.TestAttr.get()
266        # CHECK: #python_test.test_attr
267        print(a)
268
269        # The following cast must not assert.
270        b = test.TestAttr(a)
271
272        unit = UnitAttr.get()
273        try:
274            test.TestAttr(unit)
275        except ValueError as e:
276            assert "Cannot cast attribute to TestAttr" in str(e)
277        else:
278            raise
279
280        # The following must trigger a TypeError from our adaptors and must not
281        # crash.
282        try:
283            test.TestAttr(42)
284        except TypeError as e:
285            assert "Expected an MLIR object" in str(e)
286        else:
287            raise
288
289        # The following must trigger a TypeError from pybind (therefore, not
290        # checking its message) and must not crash.
291        try:
292            test.TestAttr(42, 56)
293        except TypeError:
294            pass
295        else:
296            raise
297
298
299@run
300def testCustomType():
301    with Context() as ctx:
302        test.register_python_test_dialect(ctx)
303        a = test.TestType.get()
304        # CHECK: !python_test.test_type
305        print(a)
306
307        # The following cast must not assert.
308        b = test.TestType(a)
309        # Instance custom types should have typeids
310        assert isinstance(b.typeid, TypeID)
311        # Subclasses of ir.Type should not have a static_typeid
312        # CHECK: 'TestType' object has no attribute 'static_typeid'
313        try:
314            b.static_typeid
315        except AttributeError as e:
316            print(e)
317
318        i8 = IntegerType.get_signless(8)
319        try:
320            test.TestType(i8)
321        except ValueError as e:
322            assert "Cannot cast type to TestType" in str(e)
323        else:
324            raise
325
326        # The following must trigger a TypeError from our adaptors and must not
327        # crash.
328        try:
329            test.TestType(42)
330        except TypeError as e:
331            assert "Expected an MLIR object" in str(e)
332        else:
333            raise
334
335        # The following must trigger a TypeError from pybind (therefore, not
336        # checking its message) and must not crash.
337        try:
338            test.TestType(42, 56)
339        except TypeError:
340            pass
341        else:
342            raise
343
344
345@run
346# CHECK-LABEL: TEST: testTensorValue
347def testTensorValue():
348    with Context() as ctx, Location.unknown():
349        test.register_python_test_dialect(ctx)
350
351        i8 = IntegerType.get_signless(8)
352
353        class Tensor(test.TestTensorValue):
354            def __str__(self):
355                return super().__str__().replace("Value", "Tensor")
356
357        module = Module.create()
358        with InsertionPoint(module.body):
359            t = tensor.EmptyOp([10, 10], i8).result
360
361            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
362            print(Value(t))
363
364            tt = Tensor(t)
365            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
366            print(tt)
367
368            # CHECK: False
369            print(tt.is_null())
370
371            # Classes of custom types that inherit from concrete types should have
372            # static_typeid
373            assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
374            # And it should be equal to the in-tree concrete type
375            assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
376
377
378# CHECK-LABEL: TEST: inferReturnTypeComponents
379@run
380def inferReturnTypeComponents():
381    with Context() as ctx, Location.unknown(ctx):
382        test.register_python_test_dialect(ctx)
383        module = Module.create()
384        i32 = IntegerType.get_signless(32)
385        with InsertionPoint(module.body):
386            resultType = UnrankedTensorType.get(i32)
387            operandTypes = [
388                RankedTensorType.get([1, 3, 10, 10], i32),
389                UnrankedTensorType.get(i32),
390            ]
391            f = func.FuncOp(
392                "test_inferReturnTypeComponents", (operandTypes, [resultType])
393            )
394            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
395            with InsertionPoint(entry_block):
396                ranked_op = test.InferShapedTypeComponentsOp(
397                    resultType, entry_block.arguments[0]
398                )
399                unranked_op = test.InferShapedTypeComponentsOp(
400                    resultType, entry_block.arguments[1]
401                )
402
403        # CHECK: has rank: True
404        # CHECK: rank: 4
405        # CHECK: element type: i32
406        # CHECK: shape: [1, 3, 10, 10]
407        iface = InferShapedTypeOpInterface(ranked_op)
408        shaped_type_components = iface.inferReturnTypeComponents(
409            operands=[ranked_op.operand]
410        )[0]
411        print("has rank:", shaped_type_components.has_rank)
412        print("rank:", shaped_type_components.rank)
413        print("element type:", shaped_type_components.element_type)
414        print("shape:", shaped_type_components.shape)
415
416        # CHECK: has rank: False
417        # CHECK: rank: None
418        # CHECK: element type: i32
419        # CHECK: shape: None
420        iface = InferShapedTypeOpInterface(unranked_op)
421        shaped_type_components = iface.inferReturnTypeComponents(
422            operands=[unranked_op.operand]
423        )[0]
424        print("has rank:", shaped_type_components.has_rank)
425        print("rank:", shaped_type_components.rank)
426        print("element type:", shaped_type_components.element_type)
427        print("shape:", shaped_type_components.shape)
428
429
430# CHECK-LABEL: TEST: testCustomTypeTypeCaster
431@run
432def testCustomTypeTypeCaster():
433    with Context() as ctx, Location.unknown():
434        test.register_python_test_dialect(ctx)
435
436        a = test.TestType.get()
437        assert a.typeid is not None
438
439        b = Type.parse("!python_test.test_type")
440        # CHECK: !python_test.test_type
441        print(b)
442        # CHECK: TestType(!python_test.test_type)
443        print(repr(b))
444
445        c = test.TestIntegerRankedTensorType.get([10, 10], 5)
446        # CHECK: tensor<10x10xi5>
447        print(c)
448        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
449        print(repr(c))
450
451        # CHECK: Type caster is already registered
452        try:
453
454            def type_caster(pytype):
455                return test.TestIntegerRankedTensorType(pytype)
456
457            register_type_caster(c.typeid, type_caster)
458        except RuntimeError as e:
459            print(e)
460
461        def type_caster(pytype):
462            return test.TestIntegerRankedTensorType(pytype)
463
464        register_type_caster(c.typeid, type_caster, replace=True)
465
466        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
467        # CHECK: tensor<10x10xi5>
468        print(d.type)
469        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
470        print(repr(d.type))
471
472
473# CHECK-LABEL: TEST: testInferTypeOpInterface
474@run
475def testInferTypeOpInterface():
476    with Context() as ctx, Location.unknown(ctx):
477        test.register_python_test_dialect(ctx)
478        module = Module.create()
479        with InsertionPoint(module.body):
480            i64 = IntegerType.get_signless(64)
481            zero = arith.ConstantOp(i64, 0)
482
483            one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
484            # CHECK: i32
485            print(one_operand.result.type)
486
487            two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
488            # CHECK: f32
489            print(two_operands.result.type)
490