xref: /llvm-project/mlir/test/python/dialects/builtin.py (revision 5d3ae5161210c068d01ffba36c8e0761e9971179)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.builtin as builtin
5import mlir.dialects.func as func
6import numpy as np
7
8
9def run(f):
10    print("\nTEST:", f.__name__)
11    f()
12    return f
13
14
15# CHECK-LABEL: TEST: testFromPyFunc
16@run
17def testFromPyFunc():
18    with Context() as ctx, Location.unknown() as loc:
19        ctx.allow_unregistered_dialects = True
20        m = builtin.ModuleOp()
21        f32 = F32Type.get()
22        f64 = F64Type.get()
23        with InsertionPoint(m.body):
24            # CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
25            # CHECK: return %arg0 : f64
26            @func.FuncOp.from_py_func(f64)
27            def unary_return(a):
28                return a
29
30            # CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
31            # CHECK: return %arg0, %arg1 : f32, f64
32            @func.FuncOp.from_py_func(f32, f64)
33            def binary_return(a, b):
34                return a, b
35
36            # CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
37            # CHECK: return
38            @func.FuncOp.from_py_func(f32, f64)
39            def none_return(a, b):
40                pass
41
42            # CHECK-LABEL: func @call_unary
43            # CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
44            # CHECK: return %0 : f64
45            @func.FuncOp.from_py_func(f64)
46            def call_unary(a):
47                return unary_return(a)
48
49            # CHECK-LABEL: func @call_binary
50            # CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
51            # CHECK: return %0#0, %0#1 : f32, f64
52            @func.FuncOp.from_py_func(f32, f64)
53            def call_binary(a, b):
54                return binary_return(a, b)
55
56            # We expect coercion of a single result operation to a returned value.
57            # CHECK-LABEL: func @single_result_op
58            # CHECK: %0 = "custom.op1"() : () -> f32
59            # CHECK: return %0 : f32
60            @func.FuncOp.from_py_func()
61            def single_result_op():
62                return Operation.create("custom.op1", results=[f32])
63
64            # CHECK-LABEL: func @call_none
65            # CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
66            # CHECK: return
67            @func.FuncOp.from_py_func(f32, f64)
68            def call_none(a, b):
69                return none_return(a, b)
70
71            ## Variants and optional feature tests.
72            # CHECK-LABEL: func @from_name_arg
73            @func.FuncOp.from_py_func(f32, f64, name="from_name_arg")
74            def explicit_name(a, b):
75                return b
76
77            @func.FuncOp.from_py_func(f32, f64)
78            def positional_func_op(a, b, func_op):
79                assert isinstance(func_op, func.FuncOp)
80                return b
81
82            @func.FuncOp.from_py_func(f32, f64)
83            def kw_func_op(a, b=None, func_op=None):
84                assert isinstance(func_op, func.FuncOp)
85                return b
86
87            @func.FuncOp.from_py_func(f32, f64)
88            def kwargs_func_op(a, b=None, **kwargs):
89                assert isinstance(kwargs["func_op"], func.FuncOp)
90                return b
91
92            # CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
93            # CHECK: return %arg1 : f64
94            @func.FuncOp.from_py_func(f32, f64, results=[f64])
95            def explicit_results(a, b):
96                func.ReturnOp([b])
97
98    print(m)
99
100
101# CHECK-LABEL: TEST: testFromPyFuncErrors
102@run
103def testFromPyFuncErrors():
104    with Context() as ctx, Location.unknown() as loc:
105        m = builtin.ModuleOp()
106        f32 = F32Type.get()
107        f64 = F64Type.get()
108        with InsertionPoint(m.body):
109            try:
110
111                @func.FuncOp.from_py_func(f64, results=[f64])
112                def unary_return(a):
113                    return a
114
115            except AssertionError as e:
116                # CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
117                print(e)
118
119
120# CHECK-LABEL: TEST: testBuildFuncOp
121@run
122def testBuildFuncOp():
123    ctx = Context()
124    with Location.unknown(ctx) as loc:
125        m = builtin.ModuleOp()
126
127        f32 = F32Type.get()
128        tensor_type = RankedTensorType.get((2, 3, 4), f32)
129        with InsertionPoint.at_block_begin(m.body):
130            f = func.FuncOp(
131                name="some_func",
132                type=FunctionType.get(
133                    inputs=[tensor_type, tensor_type], results=[tensor_type]
134                ),
135                visibility="nested",
136            )
137            # CHECK: Name is: "some_func"
138            print("Name is: ", f.name)
139
140            # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
141            print("Type is: ", f.type)
142
143            # CHECK: Visibility is: "nested"
144            print("Visibility is: ", f.visibility)
145
146            try:
147                entry_block = f.entry_block
148            except IndexError as e:
149                # CHECK: External function does not have a body
150                print(e)
151
152            with InsertionPoint(f.add_entry_block()):
153                func.ReturnOp([f.entry_block.arguments[0]])
154                pass
155
156            try:
157                f.add_entry_block()
158            except IndexError as e:
159                # CHECK: The function already has an entry block!
160                print(e)
161
162            # Try the callback builder and passing type as tuple.
163            f = func.FuncOp(
164                name="some_other_func",
165                type=([tensor_type, tensor_type], [tensor_type]),
166                visibility="nested",
167                body_builder=lambda f: func.ReturnOp([f.entry_block.arguments[0]]),
168            )
169
170    # CHECK: module  {
171    # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
172    # CHECK:   return %arg0 : tensor<2x3x4xf32>
173    # CHECK:  }
174    # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
175    # CHECK:   return %arg0 : tensor<2x3x4xf32>
176    # CHECK:  }
177    print(m)
178
179
180# CHECK-LABEL: TEST: testFuncArgumentAccess
181@run
182def testFuncArgumentAccess():
183    with Context() as ctx, Location.unknown():
184        ctx.allow_unregistered_dialects = True
185        module = Module.create()
186        f32 = F32Type.get()
187        f64 = F64Type.get()
188        with InsertionPoint(module.body):
189            f = func.FuncOp("some_func", ([f32, f32], [f32, f32]))
190            with InsertionPoint(f.add_entry_block()):
191                func.ReturnOp(f.arguments)
192            f.arg_attrs = ArrayAttr.get(
193                [
194                    DictAttr.get(
195                        {
196                            "custom_dialect.foo": StringAttr.get("bar"),
197                            "custom_dialect.baz": UnitAttr.get(),
198                        }
199                    ),
200                    DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])}),
201                ]
202            )
203            f.result_attrs = ArrayAttr.get(
204                [
205                    DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
206                    DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)}),
207                ]
208            )
209
210            other = func.FuncOp("other_func", ([f32, f32], []))
211            with InsertionPoint(other.add_entry_block()):
212                func.ReturnOp([])
213            other.arg_attrs = [
214                DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
215                DictAttr.get(),
216            ]
217
218    # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
219    print(f.arg_attrs)
220
221    # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
222    print(f.result_attrs)
223
224    # CHECK: func @some_func(
225    # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
226    # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
227    # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
228    # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
229    # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
230    #
231    # CHECK: func @other_func(
232    # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
233    # CHECK: %{{.*}}: f32)
234    print(module)
235
236
237# CHECK-LABEL: testDenseElementsAttr
238@run
239def testDenseElementsAttr():
240    with Context(), Location.unknown():
241        values = np.arange(4, dtype=np.int32)
242        i32 = IntegerType.get_signless(32)
243        print(DenseElementsAttr.get(values, type=i32))
244        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : tensor<4xi32>
245        print(DenseElementsAttr.get(values, type=i32, shape=(2, 2)))
246        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>
247        print(DenseElementsAttr.get(values, type=VectorType.get((2, 2), i32)))
248        # CHECK{LITERAL}: dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
249        idx_values = np.arange(4, dtype=np.int64)
250        idx_type = IndexType.get()
251        print(DenseElementsAttr.get(idx_values, type=VectorType.get([4], idx_type)))
252        # CHECK{LITERAL}: dense<[0, 1, 2, 3]> : vector<4xindex>
253