xref: /llvm-project/mlir/test/python/dialects/python_test.py (revision b0e00ca6a605b88e83129c8c6be4e177f93cbfea)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.func as func
5import mlir.dialects.python_test as test
6import mlir.dialects.tensor as tensor
7import mlir.dialects.arith as arith
8
9
10def run(f):
11    print("\nTEST:", f.__name__)
12    f()
13    return f
14
15
16# CHECK-LABEL: TEST: testAttributes
17@run
18def testAttributes():
19    with Context() as ctx, Location.unknown():
20        ctx.allow_unregistered_dialects = True
21
22        #
23        # Check op construction with attributes.
24        #
25
26        i32 = IntegerType.get_signless(32)
27        one = IntegerAttr.get(i32, 1)
28        two = IntegerAttr.get(i32, 2)
29        unit = UnitAttr.get()
30
31        # CHECK: "python_test.attributed_op"() {
32        # CHECK-DAG: mandatory_i32 = 1 : i32
33        # CHECK-DAG: optional_i32 = 2 : i32
34        # CHECK-DAG: unit
35        # CHECK: }
36        op = test.AttributedOp(one, optional_i32=two, unit=unit)
37        print(f"{op}")
38
39        # CHECK: "python_test.attributed_op"() {
40        # CHECK: mandatory_i32 = 2 : i32
41        # CHECK: }
42        op2 = test.AttributedOp(two)
43        print(f"{op2}")
44
45        #
46        # Check generic "attributes" access and mutation.
47        #
48
49        assert "additional" not in op.attributes
50
51        # CHECK: "python_test.attributed_op"() {
52        # CHECK-DAG: additional = 1 : i32
53        # CHECK-DAG: mandatory_i32 = 2 : i32
54        # CHECK: }
55        op2.attributes["additional"] = one
56        print(f"{op2}")
57
58        # CHECK: "python_test.attributed_op"() {
59        # CHECK-DAG: additional = 2 : i32
60        # CHECK-DAG: mandatory_i32 = 2 : i32
61        # CHECK: }
62        op2.attributes["additional"] = two
63        print(f"{op2}")
64
65        # CHECK: "python_test.attributed_op"() {
66        # CHECK-NOT: additional = 2 : i32
67        # CHECK:     mandatory_i32 = 2 : i32
68        # CHECK: }
69        del op2.attributes["additional"]
70        print(f"{op2}")
71
72        try:
73            print(op.attributes["additional"])
74        except KeyError:
75            pass
76        else:
77            assert False, "expected KeyError on unknown attribute key"
78
79        #
80        # Check accessors to defined attributes.
81        #
82
83        # CHECK: Mandatory: 1
84        # CHECK: Optional: 2
85        # CHECK: Unit: True
86        print(f"Mandatory: {op.mandatory_i32.value}")
87        print(f"Optional: {op.optional_i32.value}")
88        print(f"Unit: {op.unit}")
89
90        # CHECK: Mandatory: 2
91        # CHECK: Optional: None
92        # CHECK: Unit: False
93        print(f"Mandatory: {op2.mandatory_i32.value}")
94        print(f"Optional: {op2.optional_i32}")
95        print(f"Unit: {op2.unit}")
96
97        # CHECK: Mandatory: 2
98        # CHECK: Optional: None
99        # CHECK: Unit: False
100        op.mandatory_i32 = two
101        op.optional_i32 = None
102        op.unit = False
103        print(f"Mandatory: {op.mandatory_i32.value}")
104        print(f"Optional: {op.optional_i32}")
105        print(f"Unit: {op.unit}")
106        assert "optional_i32" not in op.attributes
107        assert "unit" not in op.attributes
108
109        try:
110            op.mandatory_i32 = None
111        except ValueError:
112            pass
113        else:
114            assert False, "expected ValueError on setting a mandatory attribute to None"
115
116        # CHECK: Optional: 2
117        op.optional_i32 = two
118        print(f"Optional: {op.optional_i32.value}")
119
120        # CHECK: Optional: None
121        del op.optional_i32
122        print(f"Optional: {op.optional_i32}")
123
124        # CHECK: Unit: False
125        op.unit = None
126        print(f"Unit: {op.unit}")
127        assert "unit" not in op.attributes
128
129        # CHECK: Unit: True
130        op.unit = True
131        print(f"Unit: {op.unit}")
132
133        # CHECK: Unit: False
134        del op.unit
135        print(f"Unit: {op.unit}")
136
137
138# CHECK-LABEL: TEST: attrBuilder
139@run
140def attrBuilder():
141    with Context() as ctx, Location.unknown():
142        ctx.allow_unregistered_dialects = True
143        # CHECK: python_test.attributes_op
144        op = test.AttributesOp(
145            # CHECK-DAG: x_affinemap = affine_map<() -> (2)>
146            x_affinemap=AffineMap.get_constant(2),
147            # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
148            x_affinemaparr=[AffineMap.get_identity(3)],
149            # CHECK-DAG: x_arr = [true, "x"]
150            x_arr=[BoolAttr.get(True), StringAttr.get("x")],
151            x_boolarr=[False, True],  # CHECK-DAG: x_boolarr = [false, true]
152            x_bool=True,  # CHECK-DAG: x_bool = true
153            x_dboolarr=[True, False],  # CHECK-DAG: x_dboolarr = array<i1: true, false>
154            x_df16arr=[21, 22],  # CHECK-DAG: x_df16arr = array<i16: 21, 22>
155            # CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
156            x_df32arr=[23, 24],
157            # CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
158            x_df64arr=[25, 26],
159            x_di32arr=[0, 1],  # CHECK-DAG: x_di32arr = array<i32: 0, 1>
160            # CHECK-DAG: x_di64arr = array<i64: 1, 2>
161            x_di64arr=[1, 2],
162            x_di8arr=[2, 3],  # CHECK-DAG: x_di8arr = array<i8: 2, 3>
163            # CHECK-DAG: x_dictarr = [{a = false}]
164            x_dictarr=[{"a": BoolAttr.get(False)}],
165            x_dict={"b": BoolAttr.get(True)},  # CHECK-DAG: x_dict = {b = true}
166            x_f32=-2.25,  # CHECK-DAG: x_f32 = -2.250000e+00 : f32
167            # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
168            x_f32arr=[2.0, 3.0],
169            x_f64=4.25,  # CHECK-DAG: x_f64 = 4.250000e+00 : f64
170            x_f64arr=[4.0, 8.0],  # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
171            # CHECK-DAG: x_f64elems = dense<[3.952530e-323, 7.905050e-323]> : tensor<2xf64>
172            x_f64elems=[8.0, 16.0],
173            # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
174            x_flatsymrefarr=["symbol1", "symbol2"],
175            x_flatsymref="symbol3",  # CHECK-DAG: x_flatsymref = @symbol3
176            x_i1=0,  # CHECK-DAG: x_i1 = false
177            x_i16=42,  # CHECK-DAG: x_i16 = 42 : i16
178            x_i32=6,  # CHECK-DAG: x_i32 = 6 : i32
179            x_i32arr=[4, 5],  # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
180            x_i32elems=[5, 6],  # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xsi32>
181            x_i64=9,  # CHECK-DAG: x_i64 = 9 : i64
182            x_i64arr=[7, 8],  # CHECK-DAG: x_i64arr = [7, 8]
183            x_i64elems=[8, 9],  # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xsi64>
184            x_i64svecarr=[10, 11],  # CHECK-DAG: x_i64svecarr = [10, 11]
185            x_i8=11,  # CHECK-DAG: x_i8 = 11 : i8
186            x_idx=10,  # CHECK-DAG: x_idx = 10 : index
187            # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
188            x_idxelems=[11, 12],
189            # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
190            x_idxlistarr=[[13], [14, 15]],
191            x_si1=-1,  # CHECK-DAG: x_si1 = -1 : si1
192            x_si16=-2,  # CHECK-DAG: x_si16 = -2 : si16
193            x_si32=-3,  # CHECK-DAG: x_si32 = -3 : si32
194            x_si64=-123,  # CHECK-DAG: x_si64 = -123 : si64
195            x_si8=-4,  # CHECK-DAG: x_si8 = -4 : si8
196            x_strarr=["hello", "world"],  # CHECK-DAG: x_strarr = ["hello", "world"]
197            x_str="hello world!",  # CHECK-DAG: x_str = "hello world!"
198            # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
199            x_symrefarr=["flatsym", ["deep", "sym"]],
200            x_symref=["deep", "sym2"],  # CHECK-DAG: x_symref = @deep::@sym2
201            x_sym="symbol",  # CHECK-DAG: x_sym = "symbol"
202            x_typearr=[F32Type.get()],  # CHECK-DAG: x_typearr = [f32]
203            x_type=F64Type.get(),  # CHECK-DAG: x_type = f64
204            x_ui1=1,  # CHECK-DAG: x_ui1 = 1 : ui1
205            x_ui16=2,  # CHECK-DAG: x_ui16 = 2 : ui16
206            x_ui32=3,  # CHECK-DAG: x_ui32 = 3 : ui32
207            x_ui64=4,  # CHECK-DAG: x_ui64 = 4 : ui64
208            x_ui8=5,  # CHECK-DAG: x_ui8 = 5 : ui8
209            x_unit=True,  # CHECK-DAG: x_unit
210        )
211        op.verify()
212        op.print(use_local_scope=True)
213
214
215# CHECK-LABEL: TEST: inferReturnTypes
216@run
217def inferReturnTypes():
218    with Context() as ctx, Location.unknown(ctx):
219        test.register_python_test_dialect(ctx)
220        module = Module.create()
221        with InsertionPoint(module.body):
222            op = test.InferResultsOp()
223            dummy = test.DummyOp()
224
225        # CHECK: [Type(i32), Type(i64)]
226        iface = InferTypeOpInterface(op)
227        print(iface.inferReturnTypes())
228
229        # CHECK: [Type(i32), Type(i64)]
230        iface_static = InferTypeOpInterface(test.InferResultsOp)
231        print(iface.inferReturnTypes())
232
233        assert isinstance(iface.opview, test.InferResultsOp)
234        assert iface.opview == iface.operation.opview
235
236        try:
237            iface_static.opview
238        except TypeError:
239            pass
240        else:
241            assert False, (
242                "not expected to be able to obtain an opview from a static" " interface"
243            )
244
245        try:
246            InferTypeOpInterface(dummy)
247        except ValueError:
248            pass
249        else:
250            assert False, "not expected dummy op to implement the interface"
251
252        try:
253            InferTypeOpInterface(test.DummyOp)
254        except ValueError:
255            pass
256        else:
257            assert False, "not expected dummy op class to implement the interface"
258
259
260# CHECK-LABEL: TEST: resultTypesDefinedByTraits
261@run
262def resultTypesDefinedByTraits():
263    with Context() as ctx, Location.unknown(ctx):
264        test.register_python_test_dialect(ctx)
265        module = Module.create()
266        with InsertionPoint(module.body):
267            inferred = test.InferResultsOp()
268            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
269            # CHECK-COUNT-2: i32
270            print(same.one.type)
271            print(same.two.type)
272
273            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
274                inferred.results[1], TypeAttr.get(IndexType.get())
275            )
276            # CHECK-COUNT-2: index
277            print(first_type_attr.one.type)
278            print(first_type_attr.two.type)
279
280            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
281            # CHECK-COUNT-3: f32
282            print(first_attr.one.type)
283            print(first_attr.two.type)
284            print(first_attr.three.type)
285
286            implied = test.InferResultsImpliedOp()
287            # CHECK: i32
288            print(implied.integer.type)
289            # CHECK: f64
290            print(implied.flt.type)
291            # CHECK: index
292            print(implied.index.type)
293
294
295# CHECK-LABEL: TEST: testOptionalOperandOp
296@run
297def testOptionalOperandOp():
298    with Context() as ctx, Location.unknown():
299        test.register_python_test_dialect(ctx)
300
301        module = Module.create()
302        with InsertionPoint(module.body):
303            op1 = test.OptionalOperandOp()
304            # CHECK: op1.input is None: True
305            print(f"op1.input is None: {op1.input is None}")
306
307            op2 = test.OptionalOperandOp(input=op1)
308            # CHECK: op2.input is None: False
309            print(f"op2.input is None: {op2.input is None}")
310
311
312# CHECK-LABEL: TEST: testCustomAttribute
313@run
314def testCustomAttribute():
315    with Context() as ctx:
316        test.register_python_test_dialect(ctx)
317        a = test.TestAttr.get()
318        # CHECK: #python_test.test_attr
319        print(a)
320
321        # The following cast must not assert.
322        b = test.TestAttr(a)
323
324        unit = UnitAttr.get()
325        try:
326            test.TestAttr(unit)
327        except ValueError as e:
328            assert "Cannot cast attribute to TestAttr" in str(e)
329        else:
330            raise
331
332        # The following must trigger a TypeError from our adaptors and must not
333        # crash.
334        try:
335            test.TestAttr(42)
336        except TypeError as e:
337            assert "Expected an MLIR object" in str(e)
338        else:
339            raise
340
341        # The following must trigger a TypeError from pybind (therefore, not
342        # checking its message) and must not crash.
343        try:
344            test.TestAttr(42, 56)
345        except TypeError:
346            pass
347        else:
348            raise
349
350
351@run
352def testCustomType():
353    with Context() as ctx:
354        test.register_python_test_dialect(ctx)
355        a = test.TestType.get()
356        # CHECK: !python_test.test_type
357        print(a)
358
359        # The following cast must not assert.
360        b = test.TestType(a)
361        # Instance custom types should have typeids
362        assert isinstance(b.typeid, TypeID)
363        # Subclasses of ir.Type should not have a static_typeid
364        # CHECK: 'TestType' object has no attribute 'static_typeid'
365        try:
366            b.static_typeid
367        except AttributeError as e:
368            print(e)
369
370        i8 = IntegerType.get_signless(8)
371        try:
372            test.TestType(i8)
373        except ValueError as e:
374            assert "Cannot cast type to TestType" in str(e)
375        else:
376            raise
377
378        # The following must trigger a TypeError from our adaptors and must not
379        # crash.
380        try:
381            test.TestType(42)
382        except TypeError as e:
383            assert "Expected an MLIR object" in str(e)
384        else:
385            raise
386
387        # The following must trigger a TypeError from pybind (therefore, not
388        # checking its message) and must not crash.
389        try:
390            test.TestType(42, 56)
391        except TypeError:
392            pass
393        else:
394            raise
395
396
397@run
398# CHECK-LABEL: TEST: testTensorValue
399def testTensorValue():
400    with Context() as ctx, Location.unknown():
401        test.register_python_test_dialect(ctx)
402
403        i8 = IntegerType.get_signless(8)
404
405        class Tensor(test.TestTensorValue):
406            def __str__(self):
407                return super().__str__().replace("Value", "Tensor")
408
409        module = Module.create()
410        with InsertionPoint(module.body):
411            t = tensor.EmptyOp([10, 10], i8).result
412
413            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
414            print(Value(t))
415
416            tt = Tensor(t)
417            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
418            print(tt)
419
420            # CHECK: False
421            print(tt.is_null())
422
423            # Classes of custom types that inherit from concrete types should have
424            # static_typeid
425            assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
426            # And it should be equal to the in-tree concrete type
427            assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
428
429
430# CHECK-LABEL: TEST: inferReturnTypeComponents
431@run
432def inferReturnTypeComponents():
433    with Context() as ctx, Location.unknown(ctx):
434        test.register_python_test_dialect(ctx)
435        module = Module.create()
436        i32 = IntegerType.get_signless(32)
437        with InsertionPoint(module.body):
438            resultType = UnrankedTensorType.get(i32)
439            operandTypes = [
440                RankedTensorType.get([1, 3, 10, 10], i32),
441                UnrankedTensorType.get(i32),
442            ]
443            f = func.FuncOp(
444                "test_inferReturnTypeComponents", (operandTypes, [resultType])
445            )
446            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
447            with InsertionPoint(entry_block):
448                ranked_op = test.InferShapedTypeComponentsOp(
449                    resultType, entry_block.arguments[0]
450                )
451                unranked_op = test.InferShapedTypeComponentsOp(
452                    resultType, entry_block.arguments[1]
453                )
454
455        # CHECK: has rank: True
456        # CHECK: rank: 4
457        # CHECK: element type: i32
458        # CHECK: shape: [1, 3, 10, 10]
459        iface = InferShapedTypeOpInterface(ranked_op)
460        shaped_type_components = iface.inferReturnTypeComponents(
461            operands=[ranked_op.operand]
462        )[0]
463        print("has rank:", shaped_type_components.has_rank)
464        print("rank:", shaped_type_components.rank)
465        print("element type:", shaped_type_components.element_type)
466        print("shape:", shaped_type_components.shape)
467
468        # CHECK: has rank: False
469        # CHECK: rank: None
470        # CHECK: element type: i32
471        # CHECK: shape: None
472        iface = InferShapedTypeOpInterface(unranked_op)
473        shaped_type_components = iface.inferReturnTypeComponents(
474            operands=[unranked_op.operand]
475        )[0]
476        print("has rank:", shaped_type_components.has_rank)
477        print("rank:", shaped_type_components.rank)
478        print("element type:", shaped_type_components.element_type)
479        print("shape:", shaped_type_components.shape)
480
481
482# CHECK-LABEL: TEST: testCustomTypeTypeCaster
483@run
484def testCustomTypeTypeCaster():
485    with Context() as ctx, Location.unknown():
486        test.register_python_test_dialect(ctx)
487
488        a = test.TestType.get()
489        assert a.typeid is not None
490
491        b = Type.parse("!python_test.test_type")
492        # CHECK: !python_test.test_type
493        print(b)
494        # CHECK: TestType(!python_test.test_type)
495        print(repr(b))
496
497        c = test.TestIntegerRankedTensorType.get([10, 10], 5)
498        # CHECK: tensor<10x10xi5>
499        print(c)
500        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
501        print(repr(c))
502
503        # CHECK: Type caster is already registered
504        try:
505
506            def type_caster(pytype):
507                return test.TestIntegerRankedTensorType(pytype)
508
509            register_type_caster(c.typeid, type_caster)
510        except RuntimeError as e:
511            print(e)
512
513        def type_caster(pytype):
514            return RankedTensorType(pytype)
515
516        # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
517        # So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
518        register_type_caster(c.typeid, type_caster, replace=True)
519
520        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
521        # CHECK: tensor<10x10xi5>
522        print(d.type)
523        # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
524        print("ranked tensor type", repr(d.type))
525
526        def type_caster(pytype):
527            return test.TestIntegerRankedTensorType(pytype)
528
529        register_type_caster(c.typeid, type_caster, replace=True)
530
531        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
532        # CHECK: tensor<10x10xi5>
533        print(d.type)
534        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
535        print(repr(d.type))
536
537
538# CHECK-LABEL: TEST: testInferTypeOpInterface
539@run
540def testInferTypeOpInterface():
541    with Context() as ctx, Location.unknown(ctx):
542        test.register_python_test_dialect(ctx)
543        module = Module.create()
544        with InsertionPoint(module.body):
545            i64 = IntegerType.get_signless(64)
546            zero = arith.ConstantOp(i64, 0)
547
548            one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
549            # CHECK: i32
550            print(one_operand.result.type)
551
552            two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
553            # CHECK: f32
554            print(two_operands.result.type)
555