xref: /llvm-project/mlir/test/python/dialects/linalg/ops.py (revision 1bc5fe669f5477eadd84270e971591a718693bba)
19f3f6d7bSStella Laurenzo# RUN: %PYTHON %s | FileCheck %s
29f3f6d7bSStella Laurenzo
3*1bc5fe66SMaksim Leventalfrom mlir.dialects import arith, func, linalg, tensor, memref
4e8e718faSAlex Zinenkofrom mlir.dialects.linalg.opdsl.lang import *
558a47508SJeff Niufrom mlir.ir import *
651fdd802Sgysit
79f3f6d7bSStella Laurenzo
89f3f6d7bSStella Laurenzodef run(f):
99f3f6d7bSStella Laurenzo    print("\nTEST:", f.__name__)
109f3f6d7bSStella Laurenzo    f()
119f3f6d7bSStella Laurenzo    return f
129f3f6d7bSStella Laurenzo
139f3f6d7bSStella Laurenzo
149f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testFill
159f3f6d7bSStella Laurenzo@run
169f3f6d7bSStella Laurenzodef testFill():
179f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
189f3f6d7bSStella Laurenzo        module = Module.create()
199f3f6d7bSStella Laurenzo        f32 = F32Type.get()
209f3f6d7bSStella Laurenzo        with InsertionPoint(module.body):
219f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @fill_tensor
229f3f6d7bSStella Laurenzo            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: tensor<12x?xf32>
23a54f4eaeSMogball            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
247294be2bSgysit            #  CHECK-NEXT: %[[RES:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : tensor<12x?xf32>) -> tensor<12x?xf32>
259f3f6d7bSStella Laurenzo            #  CHECK-NEXT: return %[[RES]] : tensor<12x?xf32>
26fb4cedccSAliia Khasanova            @func.FuncOp.from_py_func(
27f9008e63STobias Hieta                RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32)
28f9008e63STobias Hieta            )
299f3f6d7bSStella Laurenzo            def fill_tensor(out):
30f9008e63STobias Hieta                zero = arith.ConstantOp(
31f9008e63STobias Hieta                    value=FloatAttr.get(f32, 0.0), result=f32
32f9008e63STobias Hieta                ).result
337294be2bSgysit                return linalg.fill(zero, outs=[out])
349f3f6d7bSStella Laurenzo
359f3f6d7bSStella Laurenzo            # CHECK-LABEL: func @fill_buffer
369f3f6d7bSStella Laurenzo            #  CHECK-SAME:   %[[OUT:[0-9a-z]+]]: memref<12x?xf32>
37a54f4eaeSMogball            #  CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}} : f32
387294be2bSgysit            #  CHECK-NEXT: linalg.fill ins(%[[CST]] : f32) outs(%[[OUT]] : memref<12x?xf32>)
399f3f6d7bSStella Laurenzo            #  CHECK-NEXT: return
40fb4cedccSAliia Khasanova            @func.FuncOp.from_py_func(
41f9008e63STobias Hieta                MemRefType.get((12, ShapedType.get_dynamic_size()), f32)
42f9008e63STobias Hieta            )
439f3f6d7bSStella Laurenzo            def fill_buffer(out):
44f9008e63STobias Hieta                zero = arith.ConstantOp(
45f9008e63STobias Hieta                    value=FloatAttr.get(f32, 0.0), result=f32
46f9008e63STobias Hieta                ).result
477294be2bSgysit                linalg.fill(zero, outs=[out])
489f3f6d7bSStella Laurenzo
499f3f6d7bSStella Laurenzo    print(module)
509f3f6d7bSStella Laurenzo
519f3f6d7bSStella Laurenzo
529f3f6d7bSStella Laurenzo# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
539f3f6d7bSStella Laurenzo@run
549f3f6d7bSStella Laurenzodef testNamedStructuredOpCustomForm():
559f3f6d7bSStella Laurenzo    with Context() as ctx, Location.unknown():
569f3f6d7bSStella Laurenzo        module = Module.create()
579f3f6d7bSStella Laurenzo        f32 = F32Type.get()
589f3f6d7bSStella Laurenzo        with InsertionPoint(module.body):
59a54f4eaeSMogball
6036550692SRiver Riddle            @func.FuncOp.from_py_func(
61f9008e63STobias Hieta                RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)
62f9008e63STobias Hieta            )
639f3f6d7bSStella Laurenzo            def named_form(lhs, rhs):
6481ca5aa4SMatthias Springer                init_result = tensor.EmptyOp([4, 8], f32)
6524357fecSgysit                # Check for the named form with custom format
6624357fecSgysit                #      CHECK: linalg.elemwise_unary
67e9085d0dSgysit                # CHECK-SAME:    cast = #linalg.type_fn<cast_signed>
6824357fecSgysit                # CHECK-SAME:    fun = #linalg.unary_fn<exp>
6924357fecSgysit                # CHECK-SAME:    ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
7024357fecSgysit                unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
7124357fecSgysit                #      CHECK: linalg.elemwise_binary
7224357fecSgysit                # CHECK-SAME:    cast = #linalg.type_fn<cast_unsigned>
7324357fecSgysit                # CHECK-SAME:    fun = #linalg.binary_fn<mul>
7424357fecSgysit                # CHECK-SAME:    ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
7524357fecSgysit                #      CHECK: return
7624357fecSgysit                binary_result = linalg.elemwise_binary(
7724357fecSgysit                    lhs,
7824357fecSgysit                    rhs,
7924357fecSgysit                    outs=[init_result.result],
8024357fecSgysit                    fun=BinaryFn.mul,
81f9008e63STobias Hieta                    cast=TypeFn.cast_unsigned,
82f9008e63STobias Hieta                )
8324357fecSgysit                return unary_result, binary_result
849f3f6d7bSStella Laurenzo
859f3f6d7bSStella Laurenzo    print(module)
869f3f6d7bSStella Laurenzo
87*1bc5fe66SMaksim Levental
88a9694043SMaksim Levental# CHECK-LABEL: TEST: testIdentityRegionOps
89a9694043SMaksim Levental@run
90a9694043SMaksim Leventaldef testIdentityRegionOps():
91a9694043SMaksim Levental    with Context(), Location.unknown():
92a9694043SMaksim Levental        module = Module.create()
93a9694043SMaksim Levental        f32 = F32Type.get()
94a9694043SMaksim Levental        with InsertionPoint(module.body):
95a9694043SMaksim Levental            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
96a9694043SMaksim Levental            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
97a9694043SMaksim Levental            op1 = tensor.EmptyOp([1, 13], f32)
98a9694043SMaksim Levental            op2 = tensor.EmptyOp([13, 1], f32)
99a9694043SMaksim Levental            # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
100a9694043SMaksim Levental            op3 = linalg.TransposeOp(
101a9694043SMaksim Levental                result=[RankedTensorType.get((13, 1), f32)],
102a9694043SMaksim Levental                input=op1,
103a9694043SMaksim Levental                init=op2,
104a9694043SMaksim Levental                permutation=[1, 0],
105a9694043SMaksim Levental            )
106a9694043SMaksim Levental            linalg.fill_builtin_region(op3.operation)
107a9694043SMaksim Levental
108a9694043SMaksim Levental            # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
109a9694043SMaksim Levental            op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
110a9694043SMaksim Levental
111a9694043SMaksim Levental            # CHECK:         func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
112a9694043SMaksim Levental            @func.FuncOp.from_py_func(
113a9694043SMaksim Levental                MemRefType.get((1, 13), f32),
114a9694043SMaksim Levental                MemRefType.get((13, 1), f32),
115a9694043SMaksim Levental            )
116a9694043SMaksim Levental            def transpose_op(op1, op2):
117a9694043SMaksim Levental                # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
118a9694043SMaksim Levental                op3 = linalg.TransposeOp(
119a9694043SMaksim Levental                    result=[],
120a9694043SMaksim Levental                    input=op1,
121a9694043SMaksim Levental                    init=op2,
122a9694043SMaksim Levental                    permutation=[1, 0],
123a9694043SMaksim Levental                )
124a9694043SMaksim Levental                linalg.fill_builtin_region(op3.operation)
125a9694043SMaksim Levental                # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
126a9694043SMaksim Levental                op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
127a9694043SMaksim Levental
128a9694043SMaksim Levental            # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
129a9694043SMaksim Levental            # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
130a9694043SMaksim Levental            op1 = tensor.EmptyOp([16], f32)
131a9694043SMaksim Levental            op2 = tensor.EmptyOp([16, 64], f32)
132a9694043SMaksim Levental            # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
133a9694043SMaksim Levental            op3 = linalg.BroadcastOp(
134a9694043SMaksim Levental                result=[RankedTensorType.get((16, 64), f32)],
135a9694043SMaksim Levental                input=op1,
136a9694043SMaksim Levental                init=op2,
137a9694043SMaksim Levental                dimensions=[1],
138a9694043SMaksim Levental            )
139a9694043SMaksim Levental            linalg.fill_builtin_region(op3.operation)
140a9694043SMaksim Levental
141a9694043SMaksim Levental            # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
142a9694043SMaksim Levental            op4 = tensor.EmptyOp([64], f32)
143a9694043SMaksim Levental            # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
144a9694043SMaksim Levental            op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
145a9694043SMaksim Levental
146a9694043SMaksim Levental            # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
147a9694043SMaksim Levental            @func.FuncOp.from_py_func(
148a9694043SMaksim Levental                MemRefType.get((16,), f32),
149a9694043SMaksim Levental                MemRefType.get((16, 64), f32),
150a9694043SMaksim Levental                MemRefType.get((64,), f32),
151a9694043SMaksim Levental            )
152a9694043SMaksim Levental            def broadcast_op(op1, op2, op3):
153a9694043SMaksim Levental                # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
154a9694043SMaksim Levental                op4 = linalg.BroadcastOp(
155a9694043SMaksim Levental                    result=[],
156a9694043SMaksim Levental                    input=op1,
157a9694043SMaksim Levental                    init=op2,
158a9694043SMaksim Levental                    dimensions=[1],
159a9694043SMaksim Levental                )
160a9694043SMaksim Levental                linalg.fill_builtin_region(op4.operation)
161a9694043SMaksim Levental                # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
162a9694043SMaksim Levental                op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
163a9694043SMaksim Levental
164a9694043SMaksim Levental    print(module)
165*1bc5fe66SMaksim Levental
166*1bc5fe66SMaksim Levental
167*1bc5fe66SMaksim Levental# CHECK-LABEL: TEST: testGenericOp
168*1bc5fe66SMaksim Levental@run
169*1bc5fe66SMaksim Leventaldef testGenericOp():
170*1bc5fe66SMaksim Levental    with Context(), Location.unknown():
171*1bc5fe66SMaksim Levental        module = Module.create()
172*1bc5fe66SMaksim Levental        f32 = F32Type.get()
173*1bc5fe66SMaksim Levental        memref_t = MemRefType.get([10, 10], f32)
174*1bc5fe66SMaksim Levental        with InsertionPoint(module.body):
175*1bc5fe66SMaksim Levental            id_map_1 = AffineMap.get_identity(2)
176*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
177*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
178*1bc5fe66SMaksim Levental            x = tensor.empty((16, 16), f32)
179*1bc5fe66SMaksim Levental            y = tensor.empty((16, 16), f32)
180*1bc5fe66SMaksim Levental
181*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
182*1bc5fe66SMaksim Levental            # CHECK: ^bb0(%in: f32, %out: f32):
183*1bc5fe66SMaksim Levental            # CHECK:   linalg.yield %in : f32
184*1bc5fe66SMaksim Levental            # CHECK: } -> tensor<16x16xf32>
185*1bc5fe66SMaksim Levental            @linalg.generic(
186*1bc5fe66SMaksim Levental                [x],
187*1bc5fe66SMaksim Levental                [y],
188*1bc5fe66SMaksim Levental                [id_map_1, id_map_1],
189*1bc5fe66SMaksim Levental                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
190*1bc5fe66SMaksim Levental            )
191*1bc5fe66SMaksim Levental            def f(a, b):
192*1bc5fe66SMaksim Levental                assert isinstance(a, Value)
193*1bc5fe66SMaksim Levental                assert isinstance(a.type, F32Type)
194*1bc5fe66SMaksim Levental                assert isinstance(b, Value)
195*1bc5fe66SMaksim Levental                assert isinstance(b.type, F32Type)
196*1bc5fe66SMaksim Levental                return a
197*1bc5fe66SMaksim Levental
198*1bc5fe66SMaksim Levental            assert isinstance(f, Value)
199*1bc5fe66SMaksim Levental            assert isinstance(f.type, RankedTensorType)
200*1bc5fe66SMaksim Levental
201*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
202*1bc5fe66SMaksim Levental            z = tensor.empty((16, 16, 16), f32)
203*1bc5fe66SMaksim Levental
204*1bc5fe66SMaksim Levental            minor_id = AffineMap.get_minor_identity(3, 2)
205*1bc5fe66SMaksim Levental            id_map_2 = AffineMap.get_identity(3)
206*1bc5fe66SMaksim Levental
207*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
208*1bc5fe66SMaksim Levental            # CHECK: ^bb0(%in: f32, %out: f32, %out_1: f32):
209*1bc5fe66SMaksim Levental            # CHECK:   linalg.yield %in, %out : f32, f32
210*1bc5fe66SMaksim Levental            # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
211*1bc5fe66SMaksim Levental            @linalg.generic(
212*1bc5fe66SMaksim Levental                [x],
213*1bc5fe66SMaksim Levental                [z, z],
214*1bc5fe66SMaksim Levental                [minor_id, id_map_2, id_map_2],
215*1bc5fe66SMaksim Levental                [
216*1bc5fe66SMaksim Levental                    linalg.IteratorType.parallel,
217*1bc5fe66SMaksim Levental                    linalg.IteratorType.parallel,
218*1bc5fe66SMaksim Levental                    linalg.IteratorType.parallel,
219*1bc5fe66SMaksim Levental                ],
220*1bc5fe66SMaksim Levental            )
221*1bc5fe66SMaksim Levental            def g(a, b, c):
222*1bc5fe66SMaksim Levental                assert isinstance(a, Value)
223*1bc5fe66SMaksim Levental                assert isinstance(a.type, F32Type)
224*1bc5fe66SMaksim Levental                assert isinstance(b, Value)
225*1bc5fe66SMaksim Levental                assert isinstance(b.type, F32Type)
226*1bc5fe66SMaksim Levental                assert isinstance(c, Value)
227*1bc5fe66SMaksim Levental                assert isinstance(c.type, F32Type)
228*1bc5fe66SMaksim Levental                return a, b
229*1bc5fe66SMaksim Levental
230*1bc5fe66SMaksim Levental            assert isinstance(g, OpResultList)
231*1bc5fe66SMaksim Levental            assert len(g) == 2
232*1bc5fe66SMaksim Levental            assert isinstance(g[0].type, RankedTensorType)
233*1bc5fe66SMaksim Levental            assert isinstance(g[1].type, RankedTensorType)
234*1bc5fe66SMaksim Levental
235*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<10x10xf32>
236*1bc5fe66SMaksim Levental            # CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<10x10xf32>
237*1bc5fe66SMaksim Levental            xx = memref.alloc(memref_t, [], [])
238*1bc5fe66SMaksim Levental            yy = memref.alloc(memref_t, [], [])
239*1bc5fe66SMaksim Levental
240*1bc5fe66SMaksim Levental            # CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_5]] : memref<10x10xf32>) outs(%[[VAL_6]] : memref<10x10xf32>) {
241*1bc5fe66SMaksim Levental            # CHECK: ^bb0(%in: f32, %out: f32):
242*1bc5fe66SMaksim Levental            # CHECK:   linalg.yield %in : f32
243*1bc5fe66SMaksim Levental            # CHECK: }
244*1bc5fe66SMaksim Levental            @linalg.generic(
245*1bc5fe66SMaksim Levental                [xx],
246*1bc5fe66SMaksim Levental                [yy],
247*1bc5fe66SMaksim Levental                [id_map_1, id_map_1],
248*1bc5fe66SMaksim Levental                [linalg.IteratorType.parallel, linalg.IteratorType.parallel],
249*1bc5fe66SMaksim Levental            )
250*1bc5fe66SMaksim Levental            def f(a, b):
251*1bc5fe66SMaksim Levental                assert isinstance(a, Value)
252*1bc5fe66SMaksim Levental                assert isinstance(a.type, F32Type)
253*1bc5fe66SMaksim Levental                assert isinstance(b, Value)
254*1bc5fe66SMaksim Levental                assert isinstance(b.type, F32Type)
255*1bc5fe66SMaksim Levental                return a
256*1bc5fe66SMaksim Levental
257*1bc5fe66SMaksim Levental    module.operation.verify()
258*1bc5fe66SMaksim Levental    print(module)
259