xref: /llvm-project/mlir/test/python/dialects/python_test.py (revision c606fefa8579f1d00b53d0a7a6b63319526003e4)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.func as func
5import mlir.dialects.python_test as test
6import mlir.dialects.tensor as tensor
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  return f
12
13# CHECK-LABEL: TEST: testAttributes
14@run
15def testAttributes():
16  with Context() as ctx, Location.unknown():
17    ctx.allow_unregistered_dialects = True
18
19    #
20    # Check op construction with attributes.
21    #
22
23    i32 = IntegerType.get_signless(32)
24    one = IntegerAttr.get(i32, 1)
25    two = IntegerAttr.get(i32, 2)
26    unit = UnitAttr.get()
27
28    # CHECK: "python_test.attributed_op"() {
29    # CHECK-DAG: mandatory_i32 = 1 : i32
30    # CHECK-DAG: optional_i32 = 2 : i32
31    # CHECK-DAG: unit
32    # CHECK: }
33    op = test.AttributedOp(one, optional_i32=two, unit=unit)
34    print(f"{op}")
35
36    # CHECK: "python_test.attributed_op"() {
37    # CHECK: mandatory_i32 = 2 : i32
38    # CHECK: }
39    op2 = test.AttributedOp(two)
40    print(f"{op2}")
41
42    #
43    # Check generic "attributes" access and mutation.
44    #
45
46    assert "additional" not in op.attributes
47
48    # CHECK: "python_test.attributed_op"() {
49    # CHECK-DAG: additional = 1 : i32
50    # CHECK-DAG: mandatory_i32 = 2 : i32
51    # CHECK: }
52    op2.attributes["additional"] = one
53    print(f"{op2}")
54
55    # CHECK: "python_test.attributed_op"() {
56    # CHECK-DAG: additional = 2 : i32
57    # CHECK-DAG: mandatory_i32 = 2 : i32
58    # CHECK: }
59    op2.attributes["additional"] = two
60    print(f"{op2}")
61
62    # CHECK: "python_test.attributed_op"() {
63    # CHECK-NOT: additional = 2 : i32
64    # CHECK:     mandatory_i32 = 2 : i32
65    # CHECK: }
66    del op2.attributes["additional"]
67    print(f"{op2}")
68
69    try:
70      print(op.attributes["additional"])
71    except KeyError:
72      pass
73    else:
74      assert False, "expected KeyError on unknown attribute key"
75
76    #
77    # Check accessors to defined attributes.
78    #
79
80    # CHECK: Mandatory: 1
81    # CHECK: Optional: 2
82    # CHECK: Unit: True
83    print(f"Mandatory: {op.mandatory_i32.value}")
84    print(f"Optional: {op.optional_i32.value}")
85    print(f"Unit: {op.unit}")
86
87    # CHECK: Mandatory: 2
88    # CHECK: Optional: None
89    # CHECK: Unit: False
90    print(f"Mandatory: {op2.mandatory_i32.value}")
91    print(f"Optional: {op2.optional_i32}")
92    print(f"Unit: {op2.unit}")
93
94    # CHECK: Mandatory: 2
95    # CHECK: Optional: None
96    # CHECK: Unit: False
97    op.mandatory_i32 = two
98    op.optional_i32 = None
99    op.unit = False
100    print(f"Mandatory: {op.mandatory_i32.value}")
101    print(f"Optional: {op.optional_i32}")
102    print(f"Unit: {op.unit}")
103    assert "optional_i32" not in op.attributes
104    assert "unit" not in op.attributes
105
106    try:
107      op.mandatory_i32 = None
108    except ValueError:
109      pass
110    else:
111      assert False, "expected ValueError on setting a mandatory attribute to None"
112
113    # CHECK: Optional: 2
114    op.optional_i32 = two
115    print(f"Optional: {op.optional_i32.value}")
116
117    # CHECK: Optional: None
118    del op.optional_i32
119    print(f"Optional: {op.optional_i32}")
120
121    # CHECK: Unit: False
122    op.unit = None
123    print(f"Unit: {op.unit}")
124    assert "unit" not in op.attributes
125
126    # CHECK: Unit: True
127    op.unit = True
128    print(f"Unit: {op.unit}")
129
130    # CHECK: Unit: False
131    del op.unit
132    print(f"Unit: {op.unit}")
133
134# CHECK-LABEL: TEST: attrBuilder
135@run
136def attrBuilder():
137  with Context() as ctx, Location.unknown():
138    ctx.allow_unregistered_dialects = True
139    op = test.AttributesOp(x_bool=True,
140                           x_i16=1,
141                           x_i32=2,
142                           x_i64=3,
143                           x_si16=-1,
144                           x_si32=-2,
145                           x_f32=1.5,
146                           x_f64=2.5,
147                           x_str='x_str',
148                           x_i32_array=[1, 2, 3],
149                           x_i64_array=[4, 5, 6],
150                           x_f32_array=[1.5, -2.5, 3.5],
151                           x_f64_array=[4.5, 5.5, -6.5],
152                           x_i64_dense=[1, 2, 3, 4, 5, 6])
153    print(op)
154
155
156# CHECK-LABEL: TEST: inferReturnTypes
157@run
158def inferReturnTypes():
159  with Context() as ctx, Location.unknown(ctx):
160    test.register_python_test_dialect(ctx)
161    module = Module.create()
162    with InsertionPoint(module.body):
163      op = test.InferResultsOp()
164      dummy = test.DummyOp()
165
166    # CHECK: [Type(i32), Type(i64)]
167    iface = InferTypeOpInterface(op)
168    print(iface.inferReturnTypes())
169
170    # CHECK: [Type(i32), Type(i64)]
171    iface_static = InferTypeOpInterface(test.InferResultsOp)
172    print(iface.inferReturnTypes())
173
174    assert isinstance(iface.opview, test.InferResultsOp)
175    assert iface.opview == iface.operation.opview
176
177    try:
178      iface_static.opview
179    except TypeError:
180      pass
181    else:
182      assert False, ("not expected to be able to obtain an opview from a static"
183                     " interface")
184
185    try:
186      InferTypeOpInterface(dummy)
187    except ValueError:
188      pass
189    else:
190      assert False, "not expected dummy op to implement the interface"
191
192    try:
193      InferTypeOpInterface(test.DummyOp)
194    except ValueError:
195      pass
196    else:
197      assert False, "not expected dummy op class to implement the interface"
198
199
200# CHECK-LABEL: TEST: resultTypesDefinedByTraits
201@run
202def resultTypesDefinedByTraits():
203  with Context() as ctx, Location.unknown(ctx):
204    test.register_python_test_dialect(ctx)
205    module = Module.create()
206    with InsertionPoint(module.body):
207      inferred = test.InferResultsOp()
208      same = test.SameOperandAndResultTypeOp([inferred.results[0]])
209      # CHECK-COUNT-2: i32
210      print(same.one.type)
211      print(same.two.type)
212
213      first_type_attr = test.FirstAttrDeriveTypeAttrOp(
214          inferred.results[1], TypeAttr.get(IndexType.get()))
215      # CHECK-COUNT-2: index
216      print(first_type_attr.one.type)
217      print(first_type_attr.two.type)
218
219      first_attr = test.FirstAttrDeriveAttrOp(
220          FloatAttr.get(F32Type.get(), 3.14))
221      # CHECK-COUNT-3: f32
222      print(first_attr.one.type)
223      print(first_attr.two.type)
224      print(first_attr.three.type)
225
226      implied = test.InferResultsImpliedOp()
227      # CHECK: i32
228      print(implied.integer.type)
229      # CHECK: f64
230      print(implied.flt.type)
231      # CHECK: index
232      print(implied.index.type)
233
234
235# CHECK-LABEL: TEST: testOptionalOperandOp
236@run
237def testOptionalOperandOp():
238  with Context() as ctx, Location.unknown():
239    test.register_python_test_dialect(ctx)
240
241    module = Module.create()
242    with InsertionPoint(module.body):
243
244      op1 = test.OptionalOperandOp()
245      # CHECK: op1.input is None: True
246      print(f"op1.input is None: {op1.input is None}")
247
248      op2 = test.OptionalOperandOp(input=op1)
249      # CHECK: op2.input is None: False
250      print(f"op2.input is None: {op2.input is None}")
251
252
253# CHECK-LABEL: TEST: testCustomAttribute
254@run
255def testCustomAttribute():
256  with Context() as ctx:
257    test.register_python_test_dialect(ctx)
258    a = test.TestAttr.get()
259    # CHECK: #python_test.test_attr
260    print(a)
261
262    # The following cast must not assert.
263    b = test.TestAttr(a)
264
265    unit = UnitAttr.get()
266    try:
267      test.TestAttr(unit)
268    except ValueError as e:
269      assert "Cannot cast attribute to TestAttr" in str(e)
270    else:
271      raise
272
273    # The following must trigger a TypeError from our adaptors and must not
274    # crash.
275    try:
276      test.TestAttr(42)
277    except TypeError as e:
278      assert "Expected an MLIR object" in str(e)
279    else:
280      raise
281
282    # The following must trigger a TypeError from pybind (therefore, not
283    # checking its message) and must not crash.
284    try:
285      test.TestAttr(42, 56)
286    except TypeError:
287      pass
288    else:
289      raise
290
291
292@run
293def testCustomType():
294  with Context() as ctx:
295    test.register_python_test_dialect(ctx)
296    a = test.TestType.get()
297    # CHECK: !python_test.test_type
298    print(a)
299
300    # The following cast must not assert.
301    b = test.TestType(a)
302
303    i8 = IntegerType.get_signless(8)
304    try:
305      test.TestType(i8)
306    except ValueError as e:
307      assert "Cannot cast type to TestType" in str(e)
308    else:
309      raise
310
311    # The following must trigger a TypeError from our adaptors and must not
312    # crash.
313    try:
314      test.TestType(42)
315    except TypeError as e:
316      assert "Expected an MLIR object" in str(e)
317    else:
318      raise
319
320    # The following must trigger a TypeError from pybind (therefore, not
321    # checking its message) and must not crash.
322    try:
323      test.TestType(42, 56)
324    except TypeError:
325      pass
326    else:
327      raise
328
329
330@run
331# CHECK-LABEL: TEST: testTensorValue
332def testTensorValue():
333  with Context() as ctx, Location.unknown():
334    test.register_python_test_dialect(ctx)
335
336    i8 = IntegerType.get_signless(8)
337
338    class Tensor(test.TestTensorValue):
339      def __str__(self):
340        return super().__str__().replace("Value", "Tensor")
341
342    module = Module.create()
343    with InsertionPoint(module.body):
344      t = tensor.EmptyOp([10, 10], i8).result
345
346      # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
347      print(Value(t))
348
349      tt = Tensor(t)
350      # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
351      print(tt)
352
353      # CHECK: False
354      print(tt.is_null())
355
356
357# CHECK-LABEL: TEST: inferReturnTypeComponents
358@run
359def inferReturnTypeComponents():
360    with Context() as ctx, Location.unknown(ctx):
361        test.register_python_test_dialect(ctx)
362        module = Module.create()
363        i32 = IntegerType.get_signless(32)
364        with InsertionPoint(module.body):
365            resultType = UnrankedTensorType.get(i32)
366            operandTypes = [
367                RankedTensorType.get([1, 3, 10, 10], i32),
368                UnrankedTensorType.get(i32),
369            ]
370            f = func.FuncOp(
371                "test_inferReturnTypeComponents", (operandTypes, [resultType])
372            )
373            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
374            with InsertionPoint(entry_block):
375                ranked_op = test.InferShapedTypeComponentsOp(
376                    resultType, entry_block.arguments[0]
377                )
378                unranked_op = test.InferShapedTypeComponentsOp(
379                    resultType, entry_block.arguments[1]
380                )
381
382        # CHECK: has rank: True
383        # CHECK: rank: 4
384        # CHECK: element type: i32
385        # CHECK: shape: [1, 3, 10, 10]
386        iface = InferShapedTypeOpInterface(ranked_op)
387        shaped_type_components = iface.inferReturnTypeComponents(
388            operands=[ranked_op.operand]
389        )[0]
390        print("has rank:", shaped_type_components.has_rank)
391        print("rank:", shaped_type_components.rank)
392        print("element type:", shaped_type_components.element_type)
393        print("shape:", shaped_type_components.shape)
394
395        # CHECK: has rank: False
396        # CHECK: rank: None
397        # CHECK: element type: i32
398        # CHECK: shape: None
399        iface = InferShapedTypeOpInterface(unranked_op)
400        shaped_type_components = iface.inferReturnTypeComponents(
401          operands=[unranked_op.operand]
402        )[0]
403        print("has rank:", shaped_type_components.has_rank)
404        print("rank:", shaped_type_components.rank)
405        print("element type:", shaped_type_components.element_type)
406        print("shape:", shaped_type_components.shape)
407