xref: /llvm-project/mlir/test/python/dialects/python_test.py (revision 9315645834ea81cf9550364a4950f289e9706a26)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.func as func
5import mlir.dialects.python_test as test
6import mlir.dialects.tensor as tensor
7import mlir.dialects.arith as arith
8
9test.register_python_test_dialect(get_dialect_registry())
10
11
12def run(f):
13    print("\nTEST:", f.__name__)
14    f()
15    return f
16
17
18# CHECK-LABEL: TEST: testAttributes
19@run
20def testAttributes():
21    with Context() as ctx, Location.unknown():
22        #
23        # Check op construction with attributes.
24        #
25
26        i32 = IntegerType.get_signless(32)
27        one = IntegerAttr.get(i32, 1)
28        two = IntegerAttr.get(i32, 2)
29        unit = UnitAttr.get()
30
31        # CHECK: python_test.attributed_op  {
32        # CHECK-DAG: mandatory_i32 = 1 : i32
33        # CHECK-DAG: optional_i32 = 2 : i32
34        # CHECK-DAG: unit
35        # CHECK: }
36        op = test.AttributedOp(one, optional_i32=two, unit=unit)
37        print(f"{op}")
38
39        # CHECK: python_test.attributed_op  {
40        # CHECK: mandatory_i32 = 2 : i32
41        # CHECK: }
42        op2 = test.AttributedOp(two)
43        print(f"{op2}")
44
45        #
46        # Check generic "attributes" access and mutation.
47        #
48
49        assert "additional" not in op.attributes
50
51        # CHECK: python_test.attributed_op  {
52        # CHECK-DAG: additional = 1 : i32
53        # CHECK-DAG: mandatory_i32 = 2 : i32
54        # CHECK: }
55        op2.attributes["additional"] = one
56        print(f"{op2}")
57
58        # CHECK: python_test.attributed_op  {
59        # CHECK-DAG: additional = 2 : i32
60        # CHECK-DAG: mandatory_i32 = 2 : i32
61        # CHECK: }
62        op2.attributes["additional"] = two
63        print(f"{op2}")
64
65        # CHECK: python_test.attributed_op  {
66        # CHECK-NOT: additional = 2 : i32
67        # CHECK:     mandatory_i32 = 2 : i32
68        # CHECK: }
69        del op2.attributes["additional"]
70        print(f"{op2}")
71
72        try:
73            print(op.attributes["additional"])
74        except KeyError:
75            pass
76        else:
77            assert False, "expected KeyError on unknown attribute key"
78
79        #
80        # Check accessors to defined attributes.
81        #
82
83        # CHECK: Mandatory: 1
84        # CHECK: Optional: 2
85        # CHECK: Unit: True
86        print(f"Mandatory: {op.mandatory_i32.value}")
87        print(f"Optional: {op.optional_i32.value}")
88        print(f"Unit: {op.unit}")
89
90        # CHECK: Mandatory: 2
91        # CHECK: Optional: None
92        # CHECK: Unit: False
93        print(f"Mandatory: {op2.mandatory_i32.value}")
94        print(f"Optional: {op2.optional_i32}")
95        print(f"Unit: {op2.unit}")
96
97        # CHECK: Mandatory: 2
98        # CHECK: Optional: None
99        # CHECK: Unit: False
100        op.mandatory_i32 = two
101        op.optional_i32 = None
102        op.unit = False
103        print(f"Mandatory: {op.mandatory_i32.value}")
104        print(f"Optional: {op.optional_i32}")
105        print(f"Unit: {op.unit}")
106        assert "optional_i32" not in op.attributes
107        assert "unit" not in op.attributes
108
109        try:
110            op.mandatory_i32 = None
111        except ValueError:
112            pass
113        else:
114            assert False, "expected ValueError on setting a mandatory attribute to None"
115
116        # CHECK: Optional: 2
117        op.optional_i32 = two
118        print(f"Optional: {op.optional_i32.value}")
119
120        # CHECK: Optional: None
121        del op.optional_i32
122        print(f"Optional: {op.optional_i32}")
123
124        # CHECK: Unit: False
125        op.unit = None
126        print(f"Unit: {op.unit}")
127        assert "unit" not in op.attributes
128
129        # CHECK: Unit: True
130        op.unit = True
131        print(f"Unit: {op.unit}")
132
133        # CHECK: Unit: False
134        del op.unit
135        print(f"Unit: {op.unit}")
136
137
138# CHECK-LABEL: TEST: attrBuilder
139@run
140def attrBuilder():
141    with Context() as ctx, Location.unknown():
142        # CHECK: python_test.attributes_op
143        op = test.AttributesOp(
144            # CHECK-DAG: x_affinemap = affine_map<() -> (2)>
145            x_affinemap=AffineMap.get_constant(2),
146            # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
147            x_affinemaparr=[AffineMap.get_identity(3)],
148            # CHECK-DAG: x_arr = [true, "x"]
149            x_arr=[BoolAttr.get(True), StringAttr.get("x")],
150            x_boolarr=[False, True],  # CHECK-DAG: x_boolarr = [false, true]
151            x_bool=True,  # CHECK-DAG: x_bool = true
152            x_dboolarr=[True, False],  # CHECK-DAG: x_dboolarr = array<i1: true, false>
153            x_df16arr=[21, 22],  # CHECK-DAG: x_df16arr = array<i16: 21, 22>
154            # CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
155            x_df32arr=[23, 24],
156            # CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
157            x_df64arr=[25, 26],
158            x_di32arr=[0, 1],  # CHECK-DAG: x_di32arr = array<i32: 0, 1>
159            # CHECK-DAG: x_di64arr = array<i64: 1, 2>
160            x_di64arr=[1, 2],
161            x_di8arr=[2, 3],  # CHECK-DAG: x_di8arr = array<i8: 2, 3>
162            # CHECK-DAG: x_dictarr = [{a = false}]
163            x_dictarr=[{"a": BoolAttr.get(False)}],
164            x_dict={"b": BoolAttr.get(True)},  # CHECK-DAG: x_dict = {b = true}
165            x_f32=-2.25,  # CHECK-DAG: x_f32 = -2.250000e+00 : f32
166            # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
167            x_f32arr=[2.0, 3.0],
168            x_f64=4.25,  # CHECK-DAG: x_f64 = 4.250000e+00 : f64
169            x_f64arr=[4.0, 8.0],  # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
170            # CHECK-DAG: x_f64elems = dense<[8.000000e+00, 1.600000e+01]> : tensor<2xf64>
171            x_f64elems=[8.0, 16.0],
172            # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
173            x_flatsymrefarr=["symbol1", "symbol2"],
174            x_flatsymref="symbol3",  # CHECK-DAG: x_flatsymref = @symbol3
175            x_i1=0,  # CHECK-DAG: x_i1 = false
176            x_i16=42,  # CHECK-DAG: x_i16 = 42 : i16
177            x_i32=6,  # CHECK-DAG: x_i32 = 6 : i32
178            x_i32arr=[4, 5],  # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
179            x_i32elems=[5, 6],  # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32>
180            x_i64=9,  # CHECK-DAG: x_i64 = 9 : i64
181            x_i64arr=[7, 8],  # CHECK-DAG: x_i64arr = [7, 8]
182            x_i64elems=[8, 9],  # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64>
183            x_i64svecarr=[10, 11],  # CHECK-DAG: x_i64svecarr = [10, 11]
184            x_i8=11,  # CHECK-DAG: x_i8 = 11 : i8
185            x_idx=10,  # CHECK-DAG: x_idx = 10 : index
186            # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
187            x_idxelems=[11, 12],
188            # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
189            x_idxlistarr=[[13], [14, 15]],
190            x_si1=-1,  # CHECK-DAG: x_si1 = -1 : si1
191            x_si16=-2,  # CHECK-DAG: x_si16 = -2 : si16
192            x_si32=-3,  # CHECK-DAG: x_si32 = -3 : si32
193            x_si64=-123,  # CHECK-DAG: x_si64 = -123 : si64
194            x_si8=-4,  # CHECK-DAG: x_si8 = -4 : si8
195            x_strarr=["hello", "world"],  # CHECK-DAG: x_strarr = ["hello", "world"]
196            x_str="hello world!",  # CHECK-DAG: x_str = "hello world!"
197            # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
198            x_symrefarr=["flatsym", ["deep", "sym"]],
199            x_symref=["deep", "sym2"],  # CHECK-DAG: x_symref = @deep::@sym2
200            x_sym="symbol",  # CHECK-DAG: x_sym = "symbol"
201            x_typearr=[F32Type.get()],  # CHECK-DAG: x_typearr = [f32]
202            x_type=F64Type.get(),  # CHECK-DAG: x_type = f64
203            x_ui1=1,  # CHECK-DAG: x_ui1 = 1 : ui1
204            x_ui16=2,  # CHECK-DAG: x_ui16 = 2 : ui16
205            x_ui32=3,  # CHECK-DAG: x_ui32 = 3 : ui32
206            x_ui64=4,  # CHECK-DAG: x_ui64 = 4 : ui64
207            x_ui8=5,  # CHECK-DAG: x_ui8 = 5 : ui8
208            x_unit=True,  # CHECK-DAG: x_unit
209        )
210        op.verify()
211        op.print(use_local_scope=True)
212
213
214# CHECK-LABEL: TEST: inferReturnTypes
215@run
216def inferReturnTypes():
217    with Context() as ctx, Location.unknown(ctx):
218        module = Module.create()
219        with InsertionPoint(module.body):
220            op = test.InferResultsOp()
221            dummy = test.DummyOp()
222
223        # CHECK: [Type(i32), Type(i64)]
224        iface = InferTypeOpInterface(op)
225        print(iface.inferReturnTypes())
226
227        # CHECK: [Type(i32), Type(i64)]
228        iface_static = InferTypeOpInterface(test.InferResultsOp)
229        print(iface.inferReturnTypes())
230
231        assert isinstance(iface.opview, test.InferResultsOp)
232        assert iface.opview == iface.operation.opview
233
234        try:
235            iface_static.opview
236        except TypeError:
237            pass
238        else:
239            assert False, (
240                "not expected to be able to obtain an opview from a static" " interface"
241            )
242
243        try:
244            InferTypeOpInterface(dummy)
245        except ValueError:
246            pass
247        else:
248            assert False, "not expected dummy op to implement the interface"
249
250        try:
251            InferTypeOpInterface(test.DummyOp)
252        except ValueError:
253            pass
254        else:
255            assert False, "not expected dummy op class to implement the interface"
256
257
258# CHECK-LABEL: TEST: resultTypesDefinedByTraits
259@run
260def resultTypesDefinedByTraits():
261    with Context() as ctx, Location.unknown(ctx):
262        module = Module.create()
263        with InsertionPoint(module.body):
264            inferred = test.InferResultsOp()
265            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
266            # CHECK-COUNT-2: i32
267            print(same.one.type)
268            print(same.two.type)
269
270            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
271                inferred.results[1], TypeAttr.get(IndexType.get())
272            )
273            # CHECK-COUNT-2: index
274            print(first_type_attr.one.type)
275            print(first_type_attr.two.type)
276
277            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
278            # CHECK-COUNT-3: f32
279            print(first_attr.one.type)
280            print(first_attr.two.type)
281            print(first_attr.three.type)
282
283            implied = test.InferResultsImpliedOp()
284            # CHECK: i32
285            print(implied.integer.type)
286            # CHECK: f64
287            print(implied.flt.type)
288            # CHECK: index
289            print(implied.index.type)
290
291
292# CHECK-LABEL: TEST: testOptionalOperandOp
293@run
294def testOptionalOperandOp():
295    with Context() as ctx, Location.unknown():
296        module = Module.create()
297        with InsertionPoint(module.body):
298            op1 = test.OptionalOperandOp()
299            # CHECK: op1.input is None: True
300            print(f"op1.input is None: {op1.input is None}")
301
302            op2 = test.OptionalOperandOp(input=op1)
303            # CHECK: op2.input is None: False
304            print(f"op2.input is None: {op2.input is None}")
305
306
307# CHECK-LABEL: TEST: testCustomAttribute
308@run
309def testCustomAttribute():
310    with Context() as ctx, Location.unknown():
311        a = test.TestAttr.get()
312        # CHECK: #python_test.test_attr
313        print(a)
314
315        # CHECK: python_test.custom_attributed_op  {
316        # CHECK: #python_test.test_attr
317        # CHECK: }
318        op2 = test.CustomAttributedOp(a)
319        print(f"{op2}")
320
321        # CHECK: #python_test.test_attr
322        print(f"{op2.test_attr}")
323
324        # CHECK: TestAttr(#python_test.test_attr)
325        print(repr(op2.test_attr))
326
327        # The following cast must not assert.
328        b = test.TestAttr(a)
329
330        unit = UnitAttr.get()
331        try:
332            test.TestAttr(unit)
333        except ValueError as e:
334            assert "Cannot cast attribute to TestAttr" in str(e)
335        else:
336            raise
337
338        # The following must trigger a TypeError from our adaptors and must not
339        # crash.
340        try:
341            test.TestAttr(42)
342        except TypeError as e:
343            assert "Expected an MLIR object" in str(e)
344        else:
345            raise
346
347        # The following must trigger a TypeError from pybind (therefore, not
348        # checking its message) and must not crash.
349        try:
350            test.TestAttr(42, 56)
351        except TypeError:
352            pass
353        else:
354            raise
355
356
357@run
358def testCustomType():
359    with Context() as ctx:
360        a = test.TestType.get()
361        # CHECK: !python_test.test_type
362        print(a)
363
364        # The following cast must not assert.
365        b = test.TestType(a)
366        # Instance custom types should have typeids
367        assert isinstance(b.typeid, TypeID)
368        # Subclasses of ir.Type should not have a static_typeid
369        # CHECK: 'TestType' object has no attribute 'static_typeid'
370        try:
371            b.static_typeid
372        except AttributeError as e:
373            print(e)
374
375        i8 = IntegerType.get_signless(8)
376        try:
377            test.TestType(i8)
378        except ValueError as e:
379            assert "Cannot cast type to TestType" in str(e)
380        else:
381            raise
382
383        # The following must trigger a TypeError from our adaptors and must not
384        # crash.
385        try:
386            test.TestType(42)
387        except TypeError as e:
388            assert "Expected an MLIR object" in str(e)
389        else:
390            raise
391
392        # The following must trigger a TypeError from pybind (therefore, not
393        # checking its message) and must not crash.
394        try:
395            test.TestType(42, 56)
396        except TypeError:
397            pass
398        else:
399            raise
400
401
402@run
403# CHECK-LABEL: TEST: testTensorValue
404def testTensorValue():
405    with Context() as ctx, Location.unknown():
406        i8 = IntegerType.get_signless(8)
407
408        class Tensor(test.TestTensorValue):
409            def __str__(self):
410                return super().__str__().replace("Value", "Tensor")
411
412        module = Module.create()
413        with InsertionPoint(module.body):
414            t = tensor.EmptyOp([10, 10], i8).result
415
416            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
417            print(Value(t))
418
419            tt = Tensor(t)
420            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
421            print(tt)
422
423            # CHECK: False
424            print(tt.is_null())
425
426            # Classes of custom types that inherit from concrete types should have
427            # static_typeid
428            assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
429            # And it should be equal to the in-tree concrete type
430            assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
431
432            d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
433            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
434            print(d)
435            # CHECK: TestTensorValue
436            print(repr(d))
437
438
439# CHECK-LABEL: TEST: inferReturnTypeComponents
440@run
441def inferReturnTypeComponents():
442    with Context() as ctx, Location.unknown(ctx):
443        module = Module.create()
444        i32 = IntegerType.get_signless(32)
445        with InsertionPoint(module.body):
446            resultType = UnrankedTensorType.get(i32)
447            operandTypes = [
448                RankedTensorType.get([1, 3, 10, 10], i32),
449                UnrankedTensorType.get(i32),
450            ]
451            f = func.FuncOp(
452                "test_inferReturnTypeComponents", (operandTypes, [resultType])
453            )
454            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
455            with InsertionPoint(entry_block):
456                ranked_op = test.InferShapedTypeComponentsOp(
457                    resultType, entry_block.arguments[0]
458                )
459                unranked_op = test.InferShapedTypeComponentsOp(
460                    resultType, entry_block.arguments[1]
461                )
462
463        # CHECK: has rank: True
464        # CHECK: rank: 4
465        # CHECK: element type: i32
466        # CHECK: shape: [1, 3, 10, 10]
467        iface = InferShapedTypeOpInterface(ranked_op)
468        shaped_type_components = iface.inferReturnTypeComponents(
469            operands=[ranked_op.operand]
470        )[0]
471        print("has rank:", shaped_type_components.has_rank)
472        print("rank:", shaped_type_components.rank)
473        print("element type:", shaped_type_components.element_type)
474        print("shape:", shaped_type_components.shape)
475
476        # CHECK: has rank: False
477        # CHECK: rank: None
478        # CHECK: element type: i32
479        # CHECK: shape: None
480        iface = InferShapedTypeOpInterface(unranked_op)
481        shaped_type_components = iface.inferReturnTypeComponents(
482            operands=[unranked_op.operand]
483        )[0]
484        print("has rank:", shaped_type_components.has_rank)
485        print("rank:", shaped_type_components.rank)
486        print("element type:", shaped_type_components.element_type)
487        print("shape:", shaped_type_components.shape)
488
489
490# CHECK-LABEL: TEST: testCustomTypeTypeCaster
491@run
492def testCustomTypeTypeCaster():
493    with Context() as ctx, Location.unknown():
494        a = test.TestType.get()
495        assert a.typeid is not None
496
497        b = Type.parse("!python_test.test_type")
498        # CHECK: !python_test.test_type
499        print(b)
500        # CHECK: TestType(!python_test.test_type)
501        print(repr(b))
502
503        c = test.TestIntegerRankedTensorType.get([10, 10], 5)
504        # CHECK: tensor<10x10xi5>
505        print(c)
506        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
507        print(repr(c))
508
509        # CHECK: Type caster is already registered
510        try:
511
512            @register_type_caster(c.typeid)
513            def type_caster(pytype):
514                return test.TestIntegerRankedTensorType(pytype)
515
516        except RuntimeError as e:
517            print(e)
518
519        # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
520        # So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
521        @register_type_caster(c.typeid, replace=True)
522        def type_caster(pytype):
523            return RankedTensorType(pytype)
524
525        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
526        # CHECK: tensor<10x10xi5>
527        print(d.type)
528        # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
529        print("ranked tensor type", repr(d.type))
530
531        @register_type_caster(c.typeid, replace=True)
532        def type_caster(pytype):
533            return test.TestIntegerRankedTensorType(pytype)
534
535        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
536        # CHECK: tensor<10x10xi5>
537        print(d.type)
538        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
539        print(repr(d.type))
540
541
542# CHECK-LABEL: TEST: testInferTypeOpInterface
543@run
544def testInferTypeOpInterface():
545    with Context() as ctx, Location.unknown(ctx):
546        module = Module.create()
547        with InsertionPoint(module.body):
548            i64 = IntegerType.get_signless(64)
549            zero = arith.ConstantOp(i64, 0)
550
551            one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
552            # CHECK: i32
553            print(one_operand.result.type)
554
555            two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
556            # CHECK: f32
557            print(two_operands.result.type)
558