xref: /llvm-project/mlir/test/Dialect/Linalg/transform-op-fuse.mlir (revision 9144fed31b59089f4e3e5fedf7eb87d2695ef843)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
2
3// CHECK-LABEL: func.func @fuse_unary
4func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
5
6  //     CHECK: %[[RES:.*]] = scf.for
7  //     CHECK:    scf.for
8  //     CHECK:       linalg.elemwise_unary
9  //     CHECK:       linalg.elemwise_binary
10  //     CHECK: return %[[RES]]
11  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
12                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
13  %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
14                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
15  return %1 : tensor<?x?xf32>
16}
17
18module attributes {transform.with_named_sequence} {
19  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
20    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
21    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
22      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
23      transform.yield
24  }
25}
26
27// -----
28
29// CHECK-LABEL: func.func @fuse_unary
30func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
31
32  //     CHECK: %[[PARTIAL_RES:.*]] = scf.for
33  //     CHECK:     scf.for
34  //     CHECK:       linalg.elemwise_unary
35  //     CHECK:       linalg.elemwise_binary
36  //     CHECK: %[[RES:.*]] = scf.for {{.*}}%[[PARTIAL_RES]]
37  //     CHECK:     scf.for
38  //     CHECK:       linalg.elemwise_unary
39  //     CHECK:       linalg.elemwise_binary
40  //     CHECK: return %[[RES]]
41  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
42                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
43  %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
44                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
45  return %1 : tensor<?x?xf32>
46}
47
48module attributes {transform.with_named_sequence} {
49  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
50    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
51    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
52      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
53    transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
54    transform.yield
55  }
56}
57
58// -----
59
60// CHECK-LABEL: func.func @interchange_reduction
61//  CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
62func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {
63  %five = arith.constant 5.0 : f32
64  %init = tensor.empty() : tensor<12x25xf32>
65
66//   CHECK-DAG: %[[INIT:.+]] = tensor.empty()
67//   CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
68//   CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
69//   CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
70//       CHECK: %[[RES:.*]] = scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]])
71//       CHECK:   scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]])
72//       CHECK:     %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], 0, %[[IV1]]]
73//       CHECK:     %[[OUT_SLICE1:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]]
74//       CHECK:     %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE1]] : tensor<?x?xf32>)
75//       CHECK:     scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]])
76//       CHECK:       %[[IN_SLICE:.+]] = tensor.extract_slice %[[OUT_SLICE0]]
77//       CHECK:       %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0]
78//       CHECK:       linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor<?x?x?xf32>) outs(%[[OUT_SLICE2]] : tensor<?x?xf32>)
79//       CHECK: return %[[RES]]
80
81  %fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32>
82  %0 = linalg.generic {
83    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>],
84    iterator_types = ["parallel", "reduction", "parallel"]
85  } ins(%input : tensor<12x7x25xf32>) outs(%fill : tensor<12x25xf32>) {
86  ^bb0(%arg0: f32, %arg1: f32):
87    %2 = arith.addf %arg0, %arg1 : f32
88    linalg.yield %2 : f32
89  } -> tensor<12x25xf32>
90  func.return %0 : tensor<12x25xf32>
91}
92
93module attributes {transform.with_named_sequence} {
94  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
95    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
96    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]}
97      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
98    %2, %loops_2 = transform.structured.tile_using_for %1 tile_sizes [0, 4]
99      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
100      transform.yield
101  }
102}
103
104// -----
105
106// CHECK-LABEL: func.func @unpack_elemwise
107// CHECK:         %[[RES:.*]] = scf.for
108// CHECK:           scf.for
109// CHECK:             tensor.unpack
110// CHECK:             linalg.elemwise_unary
111// CHECK:         return %[[RES]]
112func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf32>) -> tensor<128x384xf32> {
113  %0 = tensor.empty() : tensor<128x384xf32>
114  %1 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
115      : tensor<16x48x8x8xf32> -> tensor<128x384xf32>
116  %2 = linalg.elemwise_unary ins(%1: tensor<128x384xf32>)
117                             outs(%arg1: tensor<128x384xf32>) -> tensor<128x384xf32>
118  return %2 : tensor<128x384xf32>
119}
120
121module attributes {transform.with_named_sequence} {
122  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
123    %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
124    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]}
125      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
126      transform.yield
127  }
128}
129
130// -----
131
132// CHECK-LABEL: func.func @pack_elemwise
133// CHECK:         %[[RES:.*]] = scf.for
134// CHECK:           scf.for
135// CHECK:             tensor.pack
136// CHECK:             linalg.elemwise_unary
137// CHECK:         return %[[RES]]
138func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> {
139  %0 = tensor.empty() : tensor<16x48x8x8xf32>
140  %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
141      : tensor<128x384xf32> -> tensor<16x48x8x8xf32>
142  %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>)
143                             outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32>
144  return %2 : tensor<16x48x8x8xf32>
145}
146
147module attributes {transform.with_named_sequence} {
148  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
149    %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
150    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]}
151      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
152      transform.yield
153  }
154}
155
156// -----
157
158// CHECK-LABEL: func.func @nofuse_pack_elemwise
159// CHECK:         tensor.pack
160// CHECK:         %[[RES:.*]] = scf.for
161// CHECK:           scf.for
162// CHECK:             linalg.elemwise_unary
163// CHECK:         return %[[RES]]
164func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> {
165  %0 = tensor.empty() : tensor<16x48x8x8xf32>
166  %1 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
167      : tensor<128x384xf32> -> tensor<16x48x8x8xf32>
168  %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>)
169                             outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32>
170  return %2 : tensor<16x48x8x8xf32>
171}
172
173module attributes {transform.with_named_sequence} {
174  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
175    %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
176    %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]}
177      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
178      transform.yield
179  }
180}
181
182// -----
183
184// CHECK-LABEL: func.func @fuse_through_slice
185func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
186
187  //     CHECK: %[[RES:.*]] = scf.for
188  //     CHECK:     scf.for
189  //     CHECK:       linalg.elemwise_unary
190  //     CHECK:       linalg.elemwise_binary
191  //     CHECK: return %[[RES]]
192  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
193                             outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
194  %c0 = arith.constant 0 : index
195  %c1 = arith.constant 1 : index
196  %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
197  %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
198  %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
199  %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
200                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
201  return %2 : tensor<?x?xf32>
202}
203
204module attributes {transform.with_named_sequence} {
205  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
206    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
207    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
208      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
209    transform.yield
210  }
211}
212
213// -----
214
215// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain
216func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
217
218  //     CHECK: %[[RES:.*]] = scf.for
219  //     CHECK:     scf.for
220  //     CHECK:       linalg.elemwise_unary
221  //     CHECK:       linalg.elemwise_binary
222  //     CHECK: return %[[RES]]
223  %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>)
224                             outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32>
225  %1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32>
226  %2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32>
227  %3 = tensor.cast %2 : tensor<98x98xf32> to tensor<?x?xf32>
228  %c0 = arith.constant 0 : index
229  %c1 = arith.constant 1 : index
230  %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
231  %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
232  %4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
233  %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
234                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
235  return %5 : tensor<?x?xf32>
236}
237
238module attributes {transform.with_named_sequence} {
239  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
240    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
241    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
242      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
243    transform.yield
244  }
245}
246
247// -----
248
249// CHECK-LABEL: func.func @fuse_unrelated_slice
250func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) {
251
252  //     CHECK: %[[SLICE1:.+]] = tensor.extract_slice
253  //     CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
254  //     CHECK: %[[RES:.*]] = scf.for
255  //     CHECK:     scf.for
256  //     CHECK:       linalg.elemwise_unary
257  //     CHECK:       linalg.elemwise_binary
258  //     CHECK: return %[[RES]], %[[SLICE2]]
259  %c0 = arith.constant 0 : index
260  %c1 = arith.constant 1 : index
261  %dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
262  %dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
263  %slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
264  %slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
265  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
266                             outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
267  %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
268  %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
269                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
270  return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
271}
272
273module attributes {transform.with_named_sequence} {
274  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
275    %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
276    %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
277      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
278    transform.yield
279  }
280}
281