xref: /llvm-project/mlir/test/python/dialects/linalg/ops.py (revision 1bc5fe669f5477eadd84270e971591a718693bba)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.dialects import arith, func, linalg, tensor, memref
4from mlir.dialects.linalg.opdsl.lang import *
5from mlir.ir import *
6
7
8def run(f):
9    print("\nTEST:", f.__name__)
10    f()
11    return f
12
13
14# CHECK-LABEL: TEST: testFill
15@run
16def testFill():
17    with Context() as ctx, Location.unknown():
18        module = Module.create()
19        f32 = F32Type.get()
20        with InsertionPoint(module.body):
21            # CHECK-LABEL: func @fill_tensor
22            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
23            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
24            #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
25            #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
26            @func.FuncOp.from_py_func(
27                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
28            )
29            def fill_tensor(out):
30                zero = arith.ConstantOp(
31                    value=FloatAttr.get(f32, 0.0), result=f32
32                ).result
33                return linalg.fill(zero, outs=[out])
34
35            # CHECK-LABEL: func @fill_buffer
36            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
37            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
38            #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
39            #  CHECK-NEXT: return
40            @func.FuncOp.from_py_func(
41                MemRefType.get((12, ShapedType.get_dynamic_size()), f32)
42            )
43            def fill_buffer(out):
44                zero = arith.ConstantOp(
45                    value=FloatAttr.get(f32, 0.0), result=f32
46                ).result
47                linalg.fill(zero, outs=[out])
48
49    print(module)
50
51
52# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
53@run
54def testNamedStructuredOpCustomForm():
55    with Context() as ctx, Location.unknown():
56        module = Module.create()
57        f32 = F32Type.get()
58        with InsertionPoint(module.body):
59
60            @func.FuncOp.from_py_func(
61                RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)
62            )
63            def named_form(lhs, rhs):
64                init_result = tensor.EmptyOp([4, 8], f32)
65                # Check for the named form with custom format
66                #      CHECK: linalg.elemwise_unary
67                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
68                # CHECK-SAME:    fun = #linalg.unary_fn<exp>
69                # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
70                unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
71                #      CHECK: linalg.elemwise_binary
72                # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
73                # CHECK-SAME:    fun = #linalg.binary_fn<mul>
74                # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
75                #      CHECK: return
76                binary_result = linalg.elemwise_binary(
77                    lhs,
78                    rhs,
79                    outs=[init_result.result],
80                    fun=BinaryFn.mul,
81                    cast=TypeFn.cast_unsigned,
82                )
83                return unary_result, binary_result
84
85    print(module)
86
87
88# CHECK-LABEL: TEST: testIdentityRegionOps
89@run
90def testIdentityRegionOps():
91    with Context(), Location.unknown():
92        module = Module.create()
93        f32 = F32Type.get()
94        with InsertionPoint(module.body):
95            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
96            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
97            op1 = tensor.EmptyOp([1, 13], f32)
98            op2 = tensor.EmptyOp([13, 1], f32)
99            # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
100            op3 = linalg.TransposeOp(
101                result=[RankedTensorType.get((13, 1), f32)],
102                input=op1,
103                init=op2,
104                permutation=[1, 0],
105            )
106            linalg.fill_builtin_region(op3.operation)
107
108            # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
109            op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
110
111            # CHECK:         func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
112            @func.FuncOp.from_py_func(
113                MemRefType.get((1, 13), f32),
114                MemRefType.get((13, 1), f32),
115            )
116            def transpose_op(op1, op2):
117                # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
118                op3 = linalg.TransposeOp(
119                    result=[],
120                    input=op1,
121                    init=op2,
122                    permutation=[1, 0],
123                )
124                linalg.fill_builtin_region(op3.operation)
125                # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
126                op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
127
128            # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
129            # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
130            op1 = tensor.EmptyOp([16], f32)
131            op2 = tensor.EmptyOp([16, 64], f32)
132            # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
133            op3 = linalg.BroadcastOp(
134                result=[RankedTensorType.get((16, 64), f32)],
135                input=op1,
136                init=op2,
137                dimensions=[1],
138            )
139            linalg.fill_builtin_region(op3.operation)
140
141            # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
142            op4 = tensor.EmptyOp([64], f32)
143            # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
144            op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
145
146            # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
147            @func.FuncOp.from_py_func(
148                MemRefType.get((16,), f32),
149                MemRefType.get((16, 64), f32),
150                MemRefType.get((64,), f32),
151            )
152            def broadcast_op(op1, op2, op3):
153                # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
154                op4 = linalg.BroadcastOp(
155                    result=[],
156                    input=op1,
157                    init=op2,
158                    dimensions=[1],
159                )
160                linalg.fill_builtin_region(op4.operation)
161                # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
162                op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
163
164    print(module)
165
166
167# CHECK-LABEL: TEST: testGenericOp
168@run
169def testGenericOp():
170    with Context(), Location.unknown():
171        module = Module.create()
172        f32 = F32Type.get()
173        memref_t = MemRefType.get([10, 10], f32)
174        with InsertionPoint(module.body):
175            id_map_1 = AffineMap.get_identity(2)
176            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
177            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
178            x = tensor.empty((16, 16), f32)
179            y = tensor.empty((16, 16), f32)
180
181            # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
182            # CHECK: ^bb0(%in: f32, %out: f32):
183            # CHECK:   linalg.yield %in : f32
184            # CHECK: } -> tensor<16x16xf32>
185            @linalg.generic(
186                [x],
187                [y],
188                [id_map_1, id_map_1],
189                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
190            )
191            def f(a, b):
192                assert isinstance(a, Value)
193                assert isinstance(a.type, F32Type)
194                assert isinstance(b, Value)
195                assert isinstance(b.type, F32Type)
196                return a
197
198            assert isinstance(f, Value)
199            assert isinstance(f.type, RankedTensorType)
200
201            # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
202            z = tensor.empty((16, 16, 16), f32)
203
204            minor_id = AffineMap.get_minor_identity(3, 2)
205            id_map_2 = AffineMap.get_identity(3)
206
207            # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
208            # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
209            # CHECK:   linalg.yield %in, %out : f32, f32
210            # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
211            @linalg.generic(
212                [x],
213                [z, z],
214                [minor_id, id_map_2, id_map_2],
215                [
216                    linalg.IteratorType.parallel,
217                    linalg.IteratorType.parallel,
218                    linalg.IteratorType.parallel,
219                ],
220            )
221            def g(a, b, c):
222                assert isinstance(a, Value)
223                assert isinstance(a.type, F32Type)
224                assert isinstance(b, Value)
225                assert isinstance(b.type, F32Type)
226                assert isinstance(c, Value)
227                assert isinstance(c.type, F32Type)
228                return a, b
229
230            assert isinstance(g, OpResultList)
231            assert len(g) == 2
232            assert isinstance(g[0].type, RankedTensorType)
233            assert isinstance(g[1].type, RankedTensorType)
234
235            # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
236            # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
237            xx = memref.alloc(memref_t, [], [])
238            yy = memref.alloc(memref_t, [], [])
239
240            # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
241            # CHECK: ^bb0(%in: f32, %out: f32):
242            # CHECK:   linalg.yield %in : f32
243            # CHECK: }
244            @linalg.generic(
245                [xx],
246                [yy],
247                [id_map_1, id_map_1],
248                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
249            )
250            def f(a, b):
251                assert isinstance(a, Value)
252                assert isinstance(a.type, F32Type)
253                assert isinstance(b, Value)
254                assert isinstance(b.type, F32Type)
255                return a
256
257    module.operation.verify()
258    print(module)
259