xref: /llvm-project/mlir/test/python/dialects/python_test.py (revision 54c99842079997b0fe208acdab01e540c0d81b51)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.python_test as test
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  return f
10
11# CHECK-LABEL: TEST: testAttributes
12@run
13def testAttributes():
14  with Context() as ctx, Location.unknown():
15    ctx.allow_unregistered_dialects = True
16
17    #
18    # Check op construction with attributes.
19    #
20
21    i32 = IntegerType.get_signless(32)
22    one = IntegerAttr.get(i32, 1)
23    two = IntegerAttr.get(i32, 2)
24    unit = UnitAttr.get()
25
26    # CHECK: "python_test.attributed_op"() {
27    # CHECK-DAG: mandatory_i32 = 1 : i32
28    # CHECK-DAG: optional_i32 = 2 : i32
29    # CHECK-DAG: unit
30    # CHECK: }
31    op = test.AttributedOp(one, two, unit)
32    print(f"{op}")
33
34    # CHECK: "python_test.attributed_op"() {
35    # CHECK: mandatory_i32 = 2 : i32
36    # CHECK: }
37    op2 = test.AttributedOp(two, None, None)
38    print(f"{op2}")
39
40    #
41    # Check generic "attributes" access and mutation.
42    #
43
44    assert "additional" not in op.attributes
45
46    # CHECK: "python_test.attributed_op"() {
47    # CHECK-DAG: additional = 1 : i32
48    # CHECK-DAG: mandatory_i32 = 2 : i32
49    # CHECK: }
50    op2.attributes["additional"] = one
51    print(f"{op2}")
52
53    # CHECK: "python_test.attributed_op"() {
54    # CHECK-DAG: additional = 2 : i32
55    # CHECK-DAG: mandatory_i32 = 2 : i32
56    # CHECK: }
57    op2.attributes["additional"] = two
58    print(f"{op2}")
59
60    # CHECK: "python_test.attributed_op"() {
61    # CHECK-NOT: additional = 2 : i32
62    # CHECK:     mandatory_i32 = 2 : i32
63    # CHECK: }
64    del op2.attributes["additional"]
65    print(f"{op2}")
66
67    try:
68      print(op.attributes["additional"])
69    except KeyError:
70      pass
71    else:
72      assert False, "expected KeyError on unknown attribute key"
73
74    #
75    # Check accessors to defined attributes.
76    #
77
78    # CHECK: Mandatory: 1
79    # CHECK: Optional: 2
80    # CHECK: Unit: True
81    print(f"Mandatory: {op.mandatory_i32.value}")
82    print(f"Optional: {op.optional_i32.value}")
83    print(f"Unit: {op.unit}")
84
85    # CHECK: Mandatory: 2
86    # CHECK: Optional: None
87    # CHECK: Unit: False
88    print(f"Mandatory: {op2.mandatory_i32.value}")
89    print(f"Optional: {op2.optional_i32}")
90    print(f"Unit: {op2.unit}")
91
92    # CHECK: Mandatory: 2
93    # CHECK: Optional: None
94    # CHECK: Unit: False
95    op.mandatory_i32 = two
96    op.optional_i32 = None
97    op.unit = False
98    print(f"Mandatory: {op.mandatory_i32.value}")
99    print(f"Optional: {op.optional_i32}")
100    print(f"Unit: {op.unit}")
101    assert "optional_i32" not in op.attributes
102    assert "unit" not in op.attributes
103
104    try:
105      op.mandatory_i32 = None
106    except ValueError:
107      pass
108    else:
109      assert False, "expected ValueError on setting a mandatory attribute to None"
110
111    # CHECK: Optional: 2
112    op.optional_i32 = two
113    print(f"Optional: {op.optional_i32.value}")
114
115    # CHECK: Optional: None
116    del op.optional_i32
117    print(f"Optional: {op.optional_i32}")
118
119    # CHECK: Unit: False
120    op.unit = None
121    print(f"Unit: {op.unit}")
122    assert "unit" not in op.attributes
123
124    # CHECK: Unit: True
125    op.unit = True
126    print(f"Unit: {op.unit}")
127
128    # CHECK: Unit: False
129    del op.unit
130    print(f"Unit: {op.unit}")
131
132
133# CHECK-LABEL: TEST: inferReturnTypes
134@run
135def inferReturnTypes():
136  with Context() as ctx, Location.unknown(ctx):
137    test.register_python_test_dialect(ctx)
138    module = Module.create()
139    with InsertionPoint(module.body):
140      op = test.InferResultsOp()
141      dummy = test.DummyOp()
142
143    # CHECK: [Type(i32), Type(i64)]
144    iface = InferTypeOpInterface(op)
145    print(iface.inferReturnTypes())
146
147    # CHECK: [Type(i32), Type(i64)]
148    iface_static = InferTypeOpInterface(test.InferResultsOp)
149    print(iface.inferReturnTypes())
150
151    assert isinstance(iface.opview, test.InferResultsOp)
152    assert iface.opview == iface.operation.opview
153
154    try:
155      iface_static.opview
156    except TypeError:
157      pass
158    else:
159      assert False, ("not expected to be able to obtain an opview from a static"
160                     " interface")
161
162    try:
163      InferTypeOpInterface(dummy)
164    except ValueError:
165      pass
166    else:
167      assert False, "not expected dummy op to implement the interface"
168
169    try:
170      InferTypeOpInterface(test.DummyOp)
171    except ValueError:
172      pass
173    else:
174      assert False, "not expected dummy op class to implement the interface"
175
176
177# CHECK-LABEL: TEST: resultTypesDefinedByTraits
178@run
179def resultTypesDefinedByTraits():
180  with Context() as ctx, Location.unknown(ctx):
181    test.register_python_test_dialect(ctx)
182    module = Module.create()
183    with InsertionPoint(module.body):
184      inferred = test.InferResultsOp()
185      same = test.SameOperandAndResultTypeOp([inferred.results[0]])
186      # CHECK-COUNT-2: i32
187      print(same.one.type)
188      print(same.two.type)
189
190      first_type_attr = test.FirstAttrDeriveTypeAttrOp(
191          inferred.results[1], TypeAttr.get(IndexType.get()))
192      # CHECK-COUNT-2: index
193      print(first_type_attr.one.type)
194      print(first_type_attr.two.type)
195
196      first_attr = test.FirstAttrDeriveAttrOp(
197          FloatAttr.get(F32Type.get(), 3.14))
198      # CHECK-COUNT-3: f32
199      print(first_attr.one.type)
200      print(first_attr.two.type)
201      print(first_attr.three.type)
202
203      implied = test.InferResultsImpliedOp()
204      # CHECK: i32
205      print(implied.integer.type)
206      # CHECK: f64
207      print(implied.flt.type)
208      # CHECK: index
209      print(implied.index.type)
210
211
212# CHECK-LABEL: TEST: testOptionalOperandOp
213@run
214def testOptionalOperandOp():
215  with Context() as ctx, Location.unknown():
216    test.register_python_test_dialect(ctx)
217
218    module = Module.create()
219    with InsertionPoint(module.body):
220
221      op1 = test.OptionalOperandOp(None)
222      # CHECK: op1.input is None: True
223      print(f"op1.input is None: {op1.input is None}")
224
225      op2 = test.OptionalOperandOp(op1)
226      # CHECK: op2.input is None: False
227      print(f"op2.input is None: {op2.input is None}")
228