xref: /llvm-project/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir (revision a5985ca51dd7e0759d65fac9cb2b6a4448ebc404)
1// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s
2
3// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
4#map0 = affine_map<(d0, d1) -> (d0, d1)>
5
6// CHECK-LABEL: @add_mul_fusion
7func.func @add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
8{
9  %c0 = arith.constant 0 : index
10  %c1 = arith.constant 1 : index
11  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
12  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
13  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
14  %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
15      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
16      outs(%2 : tensor<?x?xf32>) {
17    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
18      %4 = arith.addf %arg3, %arg4 : f32
19      linalg.yield %4 : f32
20  } -> tensor<?x?xf32>
21  // CHECK: linalg.generic {
22  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
23  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
24      ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
25      outs(%2 : tensor<?x?xf32>) {
26    // CHECK: ^{{[a-zA-Z0-9_]*}}
27    // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
28    // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
29    // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
30    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
31      // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
32      // CHECK-NOT: linalg.yield
33      // CHECK: arith.mulf [[T1]], [[ARG2]]
34      // CHECK: linalg.yield
35      %5 = arith.mulf %arg5, %arg6 : f32
36      linalg.yield %5 : f32
37    } -> tensor<?x?xf32>
38  return %4 : tensor<?x?xf32>
39}
40
41// -----
42
43// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
44// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> ()>
45#map0 = affine_map<(d0, d1) -> (d0, d1)>
46#map1 = affine_map<(d0, d1) -> ()>
47
48// CHECK-LABEL: @scalar_add_mul_fusion
49func.func @scalar_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : f32, %arg2 : f32) -> tensor<?x?xf32>
50{
51  %c0 = arith.constant 0 : index
52  %c1 = arith.constant 1 : index
53  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
54  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
55  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
56  %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
57      ins(%arg0, %arg1 : tensor<?x?xf32>, f32)
58      outs(%2 : tensor<?x?xf32>) {
59    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
60      %4 = arith.addf %arg3, %arg4 : f32
61      linalg.yield %4 : f32
62  } -> tensor<?x?xf32>
63  // CHECK: linalg.generic {
64  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP1]], [[$MAP0]]{{\]}}
65  %4 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
66      ins(%3, %arg2 : tensor<?x?xf32>, f32)
67      outs(%2 : tensor<?x?xf32>) {
68    // CHECK: ^{{[a-zA-Z0-9_]*}}
69    // CHECK-SAME: [[ARG3:%[a-zA-Z0-9_]*]]
70    // CHECK-SAME: [[ARG4:%[a-zA-Z0-9_]*]]
71    // CHECK-SAME: [[ARG5:%[a-zA-Z0-9_]*]]
72    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
73      // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG3]], [[ARG4]]
74      // CHECK-NOT: linalg.yield
75      // CHECK: arith.mulf [[T1]], [[ARG5]]
76      // CHECK: linalg.yield
77      %5 = arith.mulf %arg5, %arg6 : f32
78      linalg.yield %5 : f32
79    } -> tensor<?x?xf32>
80  return %4 : tensor<?x?xf32>
81}
82
83// -----
84
85// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
86// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)>
87#map0 = affine_map<(d0, d1) -> (d0, d1)>
88#map1 = affine_map<(d0, d1) -> (d1, d0)>
89
90// CHECK-LABEL: @transpose_add_mul_fusion
91func.func @transpose_add_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
92{
93  %c0 = arith.constant 0 : index
94  %c1 = arith.constant 1 : index
95  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
96  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
97  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
98  %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
99      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
100      outs(%2 : tensor<?x?xf32>) {
101    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
102      %4 = arith.addf %arg3, %arg4 : f32
103      linalg.yield %4 : f32
104  } -> tensor<?x?xf32>
105  // CHECK: linalg.generic {
106  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP1]], [[$MAP0]], [[$MAP0]]{{\]}}
107  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
108      ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
109      outs(%2 : tensor<?x?xf32>) {
110    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
111      %5 = arith.mulf %arg5, %arg6 : f32
112      linalg.yield %5 : f32
113    } -> tensor<?x?xf32>
114  return %4 : tensor<?x?xf32>
115}
116
117// -----
118
119// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
120// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)>
121#map0 = affine_map<(d0, d1) -> (d0, d1)>
122#map1 = affine_map<(d0, d1) -> (d1, d0)>
123
124// CHECK-LABEL: @add_transpose_mul_fusion
125func.func @add_transpose_mul_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
126{
127  %c0 = arith.constant 0 : index
128  %c1 = arith.constant 1 : index
129  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
130  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
131  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
132  %3 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]}
133      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
134      outs(%2 : tensor<?x?xf32>) {
135    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
136      %4 = arith.addf %arg3, %arg4 : f32
137      linalg.yield %4 : f32
138  } -> tensor<?x?xf32>
139  // CHECK: linalg.generic {
140  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
141  %4 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
142      ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
143      outs(%2 : tensor<?x?xf32>){
144    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
145      %5 = arith.mulf %arg5, %arg6 : f32
146      linalg.yield %5 : f32
147    } -> tensor<?x?xf32>
148  return %4 : tensor<?x?xf32>
149}
150
151// -----
152
153// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
154// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0)>
155#map0 = affine_map<(d0, d1) -> (d0, d1)>
156#map1 = affine_map<(d0, d1) -> (d0)>
157#map2 = affine_map<(d0) -> (d0)>
158
159// CHECK-LABEL: @add_broadcast_mul_fusion
160func.func @add_broadcast_mul_fusion(%arg0: tensor<?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
161{
162  %c0 = arith.constant 0 : index
163  %c1 = arith.constant 1 : index
164  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
165  %1 = tensor.empty(%0) : tensor<?xf32>
166  %2 = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]}
167      ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
168      outs(%1 : tensor<?xf32>) {
169    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
170      %3 = arith.addf %arg3, %arg4 : f32
171      linalg.yield %3 : f32
172  } -> tensor<?xf32>
173  // CHECK: linalg.generic {
174  // CHECK-SAME: indexing_maps = {{\[}}[[$MAP1]], [[$MAP1]], [[$MAP0]], [[$MAP0]]
175  %3 = tensor.dim %arg2, %c1 : tensor<?x?xf32>
176  %4 = tensor.empty(%0, %3) : tensor<?x?xf32>
177  %5 = linalg.generic {indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]}
178      ins(%2, %arg2 : tensor<?xf32>, tensor<?x?xf32>)
179      outs(%4 : tensor<?x?xf32>){
180    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
181      %6 = arith.mulf %arg5, %arg6 : f32
182      linalg.yield %6 : f32
183    } -> tensor<?x?xf32>
184  return %5 : tensor<?x?xf32>
185}
186
187// -----
188
189// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
190#map0 = affine_map<() -> ()>
191
192// CHECK-LABEL: @add_mul_scalar_fusion
193func.func @add_mul_scalar_fusion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
194{
195  %0 = tensor.empty() : tensor<f32>
196  %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
197      ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
198      outs(%0 : tensor<f32>) {
199    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
200      %2 = arith.addf %arg3, %arg4 : f32
201      linalg.yield %2 : f32
202  } -> tensor<f32>
203  // CHECK: linalg.generic {
204  // CHECK: arith.addf
205  // CHECK: arith.mulf
206  %2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = []}
207      ins(%1, %arg2 : tensor<f32>, tensor<f32>)
208      outs(%0 : tensor<f32>) {
209    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
210      %3 = arith.mulf %arg3, %arg4 : f32
211      linalg.yield %3 : f32
212  } -> tensor<f32>
213
214  return %2 : tensor<f32>
215}
216
217// -----
218
219#map0 = affine_map<(d0, d1, d2) -> (d0)>
220#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
221func.func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
222{
223  %c0 = arith.constant 0 : index
224  %c1 = arith.constant 1 : index
225  %c2 = arith.constant 2 : index
226  %cst = arith.constant dense<42.0> : tensor<5xf32>
227  %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32>
228  %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32>
229  %2 = tensor.empty(%0, %1) : tensor<5x?x?xf32>
230  %3 = linalg.generic {
231    indexing_maps = [#map0, #map1, #map1],
232    iterator_types = ["parallel", "parallel", "parallel"]}
233    ins(%cst, %arg0 : tensor<5xf32>, tensor<5x?x?xf32>)
234    outs(%2 : tensor<5x?x?xf32>) {
235    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
236      %4 = arith.mulf %arg1, %arg2 : f32
237      linalg.yield %4 : f32
238    } -> tensor<5x?x?xf32>
239  return %3 : tensor<5x?x?xf32>
240}
241//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
242// CHECK-LABEL: func @generic_op_constant_fusion
243//       CHECK:   %[[CST:.*]] = arith.constant {{.*}} : f32
244//       CHECK:   linalg.generic
245//       CHECK:   ^{{.+}}(%[[ARG1:[a-zA-Z0-9_]+]]: f32, %{{.+}}: f32):
246//       CHECK:     arith.mulf %[[ARG1]], %[[CST]]
247
248// -----
249
250#map0 = affine_map<(d0, d1, d2) -> ()>
251#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
252func.func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
253  -> tensor<5x?x?xf32>
254{
255  %c0 = arith.constant 0 : index
256  %c1 = arith.constant 1 : index
257  %c2 = arith.constant 2 : index
258  %cst = arith.constant dense<42.0> : tensor<f32>
259  %0 = tensor.dim %arg0, %c1 : tensor<5x?x?xf32>
260  %1 = tensor.dim %arg0, %c2 : tensor<5x?x?xf32>
261  %2 = tensor.empty(%0, %1) : tensor<5x?x?xf32>
262  %3 = linalg.generic {
263    indexing_maps = [#map0, #map1, #map1],
264    iterator_types = ["parallel", "parallel", "parallel"]}
265    ins(%cst, %arg0 : tensor<f32>, tensor<5x?x?xf32>)
266    outs(%2 : tensor<5x?x?xf32>) {
267    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
268      %4 = arith.mulf %arg1, %arg2 : f32
269      linalg.yield %4 : f32
270    } -> tensor<5x?x?xf32>
271  return %3 : tensor<5x?x?xf32>
272}
273//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
274// CHECK-LABEL: func @generic_op_zero_dim_constant_fusion
275//       CHECK:   %[[CST:.*]] = arith.constant {{.*}} : f32
276//       CHECK:   linalg.generic
277//       CHECK:   ^{{.*}}(%[[ARG1:[a-zA-Z0-9_]*]]: f32, %{{.*}}: f32)
278//       CHECK:     arith.mulf %[[ARG1]], %[[CST]]
279
280// -----
281
282#map0 = affine_map<(d0, d1) -> (d0, d1)>
283func.func @producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>,
284                                       %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
285  %c0 = arith.constant 0 : index
286  %c1 = arith.constant 1 : index
287  %0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
288  %1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
289  %2 = tensor.empty(%0, %1) : tensor<?x?xi32>
290  %3 = linalg.generic {
291    indexing_maps = [#map0, #map0, #map0],
292    iterator_types = ["parallel", "parallel"] }
293    ins(%arg0, %arg1  : tensor<?x?xi32>, tensor<?x?xi32>)
294    outs(%2 : tensor<?x?xi32>) {
295    ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):
296      %10 = arith.addi %arg2, %arg3 : i32
297      linalg.yield %10 : i32
298    } -> tensor<?x?xi32>
299  %4 = linalg.generic {
300    indexing_maps = [#map0, #map0],
301    iterator_types = ["parallel", "parallel"] }
302    ins(%3 : tensor<?x?xi32>)
303    outs(%2 : tensor<?x?xi32>) {
304    ^bb0(%arg2: i32, %arg3: i32):
305      %idx0 = linalg.index 0 : index
306      %idx1 = linalg.index 1 : index
307      %5 = arith.index_cast %idx0 : index to i32
308      %6 = arith.index_cast %idx1 : index to i32
309      %7 = arith.addi %arg2, %5 : i32
310      %8 = arith.subi %7, %6 : i32
311      linalg.yield %8 : i32
312    } -> tensor<?x?xi32>
313  return %4 : tensor<?x?xi32>
314}
315//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
316// CHECK-LABEL: func @producer_indexed_consumer_fusion
317//      CHECK: linalg.generic
318// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]]
319//      CHECK: ^{{[a-zA-Z0-9_]*}}
320// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
321// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
322//      CHECK:   %[[VAL1:.+]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
323//      CHECK:   %[[IDX0:.+]] = linalg.index 0 : index
324//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
325//      CHECK:   %[[ADD_OPERAND:.+]] = arith.index_cast %[[IDX0]] : index to i32
326//      CHECK:   %[[SUB_OPERAND:.+]] = arith.index_cast %[[IDX1]] : index to i32
327//      CHECK:   %[[VAL2:.+]] = arith.addi %[[VAL1]], %[[ADD_OPERAND]] : i32
328//      CHECK:   %[[VAL3:.+]] = arith.subi %[[VAL2]], %[[SUB_OPERAND]] : i32
329//      CHECK:   linalg.yield %[[VAL3]] : i32
330//  CHECK-NOT: linalg.generic
331
332// -----
333
334#map0 = affine_map<(d0, d1) -> (d0, d1)>
335func.func @indexed_producer_consumer_fusion(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
336  %c0 = arith.constant 0 : index
337  %c1 = arith.constant 1 : index
338  %0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
339  %1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
340  %2 = tensor.empty(%0, %1) : tensor<?x?xi32>
341  %3 = linalg.generic {
342    indexing_maps = [#map0, #map0],
343    iterator_types = ["parallel", "parallel"] }
344    ins(%arg0 : tensor<?x?xi32>)
345    outs(%2 : tensor<?x?xi32>) {
346    ^bb0(%arg4: i32, %arg5: i32):
347      %idx0 = linalg.index 0 : index
348      %idx1 = linalg.index 1 : index
349      %4 = arith.index_cast %idx0 : index to i32
350      %5 = arith.index_cast %idx1 : index to i32
351      %6 = arith.addi %arg4, %4 : i32
352      %7 = arith.subi %6, %5 : i32
353      linalg.yield %7 : i32
354    } -> tensor<?x?xi32>
355  %4 = linalg.generic {
356    indexing_maps = [#map0, #map0, #map0],
357    iterator_types = ["parallel", "parallel"] }
358    ins(%3, %arg0 : tensor<?x?xi32>, tensor<?x?xi32>)
359    outs(%2 : tensor<?x?xi32>) {
360    ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):
361      %10 = arith.addi %arg2, %arg3 : i32
362      linalg.yield %10 : i32
363    } -> tensor<?x?xi32>
364  return %4 : tensor<?x?xi32>
365}
366//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
367// CHECK-LABEL: func @indexed_producer_consumer_fusion
368//       CHECK: linalg.generic
369// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
370//      CHECK: ^{{[a-zA-Z0-9_]*}}
371// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
372// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
373//      CHECK:   %[[IDX0:.+]] = linalg.index 0 : index
374//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
375//      CHECK:   %[[ADD_OPERAND:.+]] = arith.index_cast %[[IDX0]] : index to i32
376//      CHECK:   %[[SUB_OPERAND:.+]] = arith.index_cast %[[IDX1]] : index to i32
377//      CHECK:   %[[VAL1:.+]] = arith.addi %[[ARG0]], %[[ADD_OPERAND]] : i32
378//      CHECK:   %[[VAL2:.+]] = arith.subi %[[VAL1]], %[[SUB_OPERAND]] : i32
379//      CHECK:   %[[VAL3:.+]] = arith.addi %[[VAL2]], %[[ARG0]] : i32
380//      CHECK:   linalg.yield %[[VAL3]] : i32
381//   CHECK-NOT: linalg.generic
382
383// -----
384
385// The indices of the first generic op are swapped after fusion.
386#map0 = affine_map<(d0, d1) -> (d1, d0)>
387#map1 = affine_map<(d0, d1) -> (d0, d1)>
388func.func @indexed_producer_indexed_consumer_fusion(%arg0: tensor<?x?xi32>)
389                                               -> tensor<?x?xi32> {
390  %c0 = arith.constant 0 : index
391  %c1 = arith.constant 1 : index
392  %0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
393  %1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
394  %2 = tensor.empty(%0, %1) : tensor<?x?xi32>
395  %3 = linalg.generic {
396    indexing_maps = [#map0, #map0],
397    iterator_types = ["parallel", "parallel"] }
398    ins(%arg0 : tensor<?x?xi32>)
399    outs(%2 : tensor<?x?xi32>) {
400    ^bb0(%arg2: i32, %arg3: i32):
401      %idx0 = linalg.index 0 : index
402      %idx1 = linalg.index 1 : index
403      %4 = arith.index_cast %idx0 : index to i32
404      %5 = arith.index_cast %idx1 : index to i32
405      %6 = arith.addi %arg2, %4 : i32
406      %7 = arith.subi %5, %6 : i32
407      linalg.yield %7 : i32
408    } -> tensor<?x?xi32>
409  %4= linalg.generic {
410    indexing_maps = [#map1, #map1],
411    iterator_types = ["parallel", "parallel"] }
412    ins(%3 : tensor<?x?xi32>)
413    outs(%2 : tensor<?x?xi32>) {
414    ^bb0(%arg2: i32, %arg3: i32):
415      %idx0 = linalg.index 0 : index
416      %idx1 = linalg.index 1 : index
417      %5 = arith.index_cast %idx0 : index to i32
418      %6 = arith.index_cast %idx1 : index to i32
419      %7 = arith.addi %arg2, %5 : i32
420      %8 = arith.subi %7, %6 : i32
421      linalg.yield %8 : i32
422    } -> tensor<?x?xi32>
423  return %4 : tensor<?x?xi32>
424}
425//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
426// CHECK-LABEL: func @indexed_producer_indexed_consumer_fusion
427//       CHECK: linalg.generic
428// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
429//      CHECK: ^{{[a-zA-Z0-9_]*}}
430// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
431//      CHECK:   %[[IDX0:.+]] = linalg.index 0 : index
432//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
433//      CHECK:   %[[ADD_OPERAND1:.+]] = arith.index_cast %[[IDX1]] : index to i32
434//      CHECK:   %[[SUB_OPERAND1:.+]] = arith.index_cast %[[IDX0]] : index to i32
435//      CHECK:   %[[VAL1:.+]] = arith.addi %[[ARG0]], %[[ADD_OPERAND1]] : i32
436//      CHECK:   %[[VAL2:.+]] = arith.subi %[[SUB_OPERAND1]], %[[VAL1]] : i32
437//      CHECK:   %[[IDX2:.+]] = linalg.index 0 : index
438//      CHECK:   %[[IDX3:.+]] = linalg.index 1 : index
439//      CHECK:   %[[ADD_OPERAND2:.+]] = arith.index_cast %[[IDX2]] : index to i32
440//      CHECK:   %[[SUB_OPERAND2:.+]] = arith.index_cast %[[IDX3]] : index to i32
441//      CHECK:   %[[VAL3:.+]] = arith.addi %[[VAL2]], %[[ADD_OPERAND2]] : i32
442//      CHECK:   %[[VAL4:.+]] = arith.subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
443//      CHECK:   linalg.yield %[[VAL4]] : i32
444//   CHECK-NOT: linalg.generic
445
446// -----
447
448#map1 = affine_map<(d0) -> (d0)>
449#map2 = affine_map<(d0, d1) -> (d0, d1)>
450#map3 = affine_map<(d0, d1) -> (d1)>
451func.func @one_dim_indexed_producer_consumer_fusion(%arg0 : tensor<?xi32>,
452                                               %arg1 : tensor<?x?xi32>) -> tensor<?x?xi32> {
453  %c0 = arith.constant 0 : index
454  %c1 = arith.constant 1 : index
455  %d0 = tensor.dim %arg0, %c0 : tensor<?xi32>
456  %0 = tensor.empty(%d0) : tensor<?xi32>
457  %1 = linalg.generic
458      {indexing_maps = [#map1, #map1],
459       iterator_types = ["parallel"]}
460      ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xi32>) {
461      ^bb0(%arg2 : i32, %arg3 : i32):
462        %2 = linalg.index 0 : index
463        %3 = arith.index_cast %2 : index to i32
464        %4 = arith.addi %arg2, %3 : i32
465        linalg.yield %4 : i32
466      } -> tensor<?xi32>
467  %2 = tensor.dim %arg1, %c0 : tensor<?x?xi32>
468  %3 = tensor.dim %arg1, %c1 : tensor<?x?xi32>
469  %4 = tensor.empty(%2, %3) : tensor<?x?xi32>
470  %5 = linalg.generic
471      {indexing_maps = [#map2, #map3, #map2],
472       iterator_types = ["parallel", "parallel"]}
473      ins(%arg1, %1 : tensor<?x?xi32>, tensor<?xi32>)
474      outs(%4 : tensor<?x?xi32>) {
475      ^bb0(%arg2 : i32, %arg3 : i32, %arg4: i32):
476        %6 = arith.addi %arg2, %arg3 : i32
477        linalg.yield %6 : i32
478     } -> tensor<?x?xi32>
479  return %5 : tensor<?x?xi32>
480}
481//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
482//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
483// CHECK-LABEL: func @one_dim_indexed_producer_consumer_fusion
484//       CHECK: linalg.generic
485// CHECK-SAME:    indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]]
486//      CHECK: ^{{[a-zA-Z0-9_]*}}
487// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9_]*]]: i32, %[[ARG1:[a-zA-Z0-9_]*]]: i32
488//      CHECK:   %[[IDX1:.+]] = linalg.index 1 : index
489//      CHECK:   %[[VAL1:.+]] = arith.index_cast %[[IDX1]] : index to i32
490//      CHECK:   %[[VAL2:.+]] = arith.addi %[[ARG1]], %[[VAL1]] : i32
491//      CHECK:   %[[VAL3:.+]] = arith.addi %[[ARG0]], %[[VAL2]] : i32
492//      CHECK:   linalg.yield %[[VAL3]] : i32
493//   CHECK-NOT: linalg.generic
494
495// -----
496
497func.func @scalar_generic_fusion
498  (%arg0: tensor<5x1x1xf32>, %arg1 : tensor<i32>) -> tensor<10xf32>
499{
500  %c0 = arith.constant 0 : index
501  %cst = arith.constant dense<1.000000e+00> : tensor<10xf32>
502  %0 = tensor.empty() : tensor<f32>
503  %1 = linalg.generic
504    {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
505     iterator_types = []}
506    ins(%arg1 : tensor<i32>) outs(%0 : tensor<f32>) {
507    ^bb0(%arg2: i32, %arg3: f32):
508      %3 = arith.index_cast %arg2 : i32 to index
509      %4 = tensor.extract %arg0[%3, %c0, %c0] : tensor<5x1x1xf32>
510      linalg.yield %4 : f32
511    } -> tensor<f32>
512  %2 = tensor.empty() : tensor<10xf32>
513  %3 = linalg.generic
514   {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>,
515                     affine_map<(d0) -> (d0)>],
516    iterator_types = ["parallel"]}
517    ins(%1, %cst : tensor<f32>, tensor<10xf32>) outs(%2 : tensor<10xf32>) {
518    ^bb0(%arg2: f32, %arg3: f32, %arg4: f32):
519      %4 = arith.mulf %arg2, %arg3 : f32
520      linalg.yield %4 : f32
521    } -> tensor<10xf32>
522  return %3 : tensor<10xf32>
523}
524//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()>
525//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
526//       CHECK: func @scalar_generic_fusion
527//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32>
528//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<i32>
529//       CHECK:   %[[T0:.+]] = linalg.generic
530//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
531//  CHECK-SAME:     iterator_types = ["parallel"]
532//  CHECK-SAME:     ins(%[[ARG1]] : tensor<i32>)
533//       CHECK:     tensor.extract %[[ARG0]]
534//       CHECK:     linalg.yield
535//       CHECK:  return %[[T0]]
536
537// -----
538
539func.func @constant_fusion(%arg0 : tensor<4xf32>) -> (tensor<4xf32>) {
540  %cst = arith.constant dense<1.0> : tensor<4xf32>
541  %1 = tensor.empty() : tensor<4xf32>
542  %2 = linalg.generic
543    {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>,
544                      affine_map<(d0) -> (d0)>],
545     iterator_types = ["parallel"]}
546    ins (%arg0, %cst : tensor<4xf32>, tensor<4xf32>)
547    outs (%1 : tensor<4xf32>) {
548    ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
549      %3 = arith.addf %arg1, %arg2 : f32
550      linalg.yield %3 : f32
551    } -> tensor<4xf32>
552  return %2 : tensor<4xf32>
553}
554
555//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
556//      CHECK: func @constant_fusion(%[[ARG0:.+]]: tensor<4xf32>)
557//  CHECK-DAG:   %[[CST:.+]] = arith.constant 1.000000e+00 : f32
558//  CHECK-DAG:   %[[T0:.+]] = tensor.empty() : tensor<4xf32>
559//      CHECK:   %[[T1:.+]] = linalg.generic
560// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]]]
561// CHECK-SAME:     ins(%[[ARG0]] : tensor<4xf32>)
562// CHECK-SAME:     outs(%[[T0]] : tensor<4xf32>)
563//      CHECK:   ^{{.+}}(
564// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: f32, %[[ARG2:[a-zA-Z0-9_]+]]: f32)
565//      CHECK:     %[[T2:.+]] = arith.addf %[[ARG1]], %[[CST]]
566//      CHECK:     linalg.yield %[[T2]]
567//      CHECK:   return %[[T1]]
568
569// -----
570
571#map0 = affine_map<(d0, d1) -> (d0, d1)>
572#map1 = affine_map<(d0) -> (0, d0)>
573#map2 = affine_map<(d0) -> (0)>
574func.func @consumer_with_reduction(%arg0: tensor<1x10xf32>,
575                              %arg1: tensor<1x10xf32>,
576                              %arg2: tensor<1xf32>) -> tensor<1xf32> {
577  %init = tensor.empty() : tensor<1x10xf32>
578  %0 = linalg.generic
579    {indexing_maps = [#map0, #map0, #map0],
580     iterator_types = ["parallel", "parallel"]}
581    ins(%arg0, %arg1 : tensor<1x10xf32>, tensor<1x10xf32>)
582    outs(%init : tensor<1x10xf32>) {
583  ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
584    %2 = arith.addf %arg3, %arg4 : f32
585    linalg.yield %2 : f32
586  } -> tensor<1x10xf32>
587  %1 = linalg.generic
588    {indexing_maps = [#map1, #map2],
589     iterator_types = ["reduction"]}
590    ins(%0 : tensor<1x10xf32>)
591    outs(%arg2 : tensor<1xf32>)  {
592  ^bb0(%arg3: f32, %arg4: f32):
593    %2 = arith.addf %arg3, %arg4 : f32
594    linalg.yield %2 : f32
595  } -> tensor<1xf32>
596  return %1 : tensor<1xf32>
597}
598//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (0, d0)>
599//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (0)>
600//      CHECK: func @consumer_with_reduction(%[[ARG0:.+]]: tensor<1x10xf32>, %[[ARG1:.+]]: tensor<1x10xf32>, %[[ARG2:.+]]: tensor<1xf32>)
601//      CHECK:   %[[RES:.+]] = linalg.generic
602// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]]]
603// CHECK-SAME:     iterator_types = ["reduction"]
604// CHECK-SAME:     ins(%[[ARG0]], %[[ARG1]] : tensor<1x10xf32>, tensor<1x10xf32>)
605//      CHECK:   ^{{.+}}(%[[T0:.+]]: f32, %[[T1:.+]]: f32, %[[T2:.+]]: f32)
606//      CHECK:     %[[T3:.+]] = arith.addf %[[T0]], %[[T1]] : f32
607//      CHECK:     %[[T4:.+]] = arith.addf %[[T3]], %[[T2]] : f32
608//      CHECK:     linalg.yield %[[T4]]
609//      CHECK:   return %[[RES]]
610
611// -----
612
613// CHECK-LABEL: func @sigmoid_dynamic_dim(
614//       CHECK:   %[[RES:.*]] = linalg.generic
615//   CHECK-NOT:   linalg.generic
616//       CHECK:   return %[[RES]]
617func.func @sigmoid_dynamic_dim(%0: tensor<?x1xf32>) -> tensor<?x1xf32> {
618  %cp5 = arith.constant 5.000000e-01 : f32
619  %c0 = arith.constant 0 : index
620  %shape = shape.shape_of %0 : tensor<?x1xf32> -> tensor<?xindex>
621  %extend = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<2xindex>
622  %extracted = tensor.extract %extend[%c0] : tensor<2xindex>
623  %init0 = tensor.empty(%extracted) : tensor<?x1xf32>
624  %1 = linalg.generic {indexing_maps = [
625    affine_map<(d0, d1) -> (d0, d1)>],
626    iterator_types = ["parallel", "parallel"]
627  }
628     outs(%init0 : tensor<?x1xf32>) {
629    ^bb0(%a: f32):
630      linalg.yield %cp5 : f32
631  } -> tensor<?x1xf32>
632  %d0 = tensor.dim %0, %c0 : tensor<?x1xf32>
633  %init1 = tensor.empty(%d0) : tensor<?x1xf32>
634  %2 = linalg.generic {indexing_maps = [
635    affine_map<(d0, d1) -> (d0, d1)>,
636    affine_map<(d0, d1) -> (d0, d1)>,
637    affine_map<(d0, d1) -> (d0, d1)>],
638    iterator_types = ["parallel", "parallel"]
639  }
640      ins(%0, %1 : tensor<?x1xf32>, tensor<?x1xf32>)
641     outs(%init1 : tensor<?x1xf32>) {
642  ^bb0(%a: f32, %b: f32, %c: f32):
643      %m = arith.mulf %a, %b : f32
644      linalg.yield %m : f32
645  } -> tensor<?x1xf32>
646  return %2 : tensor<?x1xf32>
647}
648
649// -----
650
651func.func private @compute1(%a: f64) -> f64
652func.func private @compute2(%a: f64, %b: i32) -> i32
653
654// CHECK-LABEL: func @generic_index_op2(
655func.func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tensor<1x8xi32> {
656  %0 = linalg.generic {
657    indexing_maps = [affine_map<(i, j) -> (i, j)>],
658    iterator_types = ["parallel", "parallel"]}
659  outs(%arg0 : tensor<1x8xf64>) {
660  ^bb0(%a: f64):
661    %r = func.call @compute1(%a) : (f64) -> f64
662    linalg.yield %r : f64
663  } -> tensor<1x8xf64>
664
665  // CHECK-NEXT:   %[[R:.*]]:2 = linalg.generic
666  //      CHECK:     bb0(%[[BBA:[0-9a-zA-Z_]*]]: f64, %[[BBB:[0-9a-zA-Z_]*]]: i32):
667  // CHECK-NEXT:       %[[A:.*]] = func.call @compute1(%[[BBA]]) : (f64) -> f64
668  // CHECK-NEXT:       %[[B:.*]] = func.call @compute2(%[[A]], %[[BBB]]) : (f64, i32) -> i32
669  // CHECK-NEXT:       linalg.yield %[[A]], %[[B]] : f64, i32
670  // CHECK-NEXT:   } -> (tensor<1x8xf64>, tensor<1x8xi32>)
671  %1 = linalg.generic {
672    indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
673    iterator_types = ["parallel", "parallel"]}
674  ins(%0 : tensor<1x8xf64>)
675  outs(%arg1 : tensor<1x8xi32>) {
676  ^bb0(%a: f64, %b: i32):
677    %r = func.call @compute2(%a, %b) : (f64, i32) -> i32
678    linalg.yield %r : i32
679  } -> tensor<1x8xi32>
680
681  // CHECK-NEXT:   return %[[R]]#1 : tensor<1x8xi32>
682  return %1 : tensor<1x8xi32>
683}
684
685// -----
686
687// CHECK-LABEL: func @no_fuse_constant_with_reduction
688func.func @no_fuse_constant_with_reduction() -> tensor<3xf32>
689{
690  //      CHECK: %[[CONST:.+]] = arith.constant {{.+}} : tensor<3x2xf32>
691  //      CHECK: %[[RESULT:.+]] = linalg.generic
692  // CHECK-SAME:   ins(%[[CONST]] : tensor<3x2xf32>)
693  //      CHECK: return %[[RESULT]]
694  %three = arith.constant dense<3.0> : tensor<3x2xf32>
695  %init = tensor.empty() : tensor<3xf32>
696  %result = linalg.generic {
697      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
698                       affine_map<(d0, d1) -> (d0)>],
699      iterator_types = ["parallel", "reduction"]}
700     ins(%three : tensor<3x2xf32>) outs(%init : tensor<3xf32>) {
701     ^bb0(%arg0 : f32, %arg1 : f32):
702        %0 = arith.addf %arg0, %arg1 : f32
703        linalg.yield %0 : f32
704  } -> tensor<3xf32>
705  return %result : tensor<3xf32>
706}
707
708// -----
709
710#map = affine_map<(d0, d1) -> (d0, d1)>
711#trait = {
712  indexing_maps = [#map, #map],
713  iterator_types = ["parallel", "parallel"]
714}
715func.func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
716{
717  %0 = linalg.generic #trait ins(%arg0 : tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
718       ^bb0(%arg1 : f32, %arg2 : f32) :
719         %1 = arith.addf %arg1, %arg1 : f32
720         linalg.yield %1 : f32
721       } -> tensor<?x?xf32>
722  %2 = linalg.generic #trait ins(%0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
723       ^bb0(%arg1 : f32, %arg2 : f32) :
724         %3 = arith.mulf %arg1, %arg1 : f32
725         linalg.yield %3 : f32
726       } -> tensor<?x?xf32>
727  return %2 : tensor<?x?xf32>
728}
729//      CHECK: func @break_outs_dependency(
730// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>)
731//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
732//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
733//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
734//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
735//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
736//      CHECK:   %[[GENERIC1:.+]] = linalg.generic
737// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)
738//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[GENERIC1]], %[[C0]]
739//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[GENERIC1]], %[[C1]]
740//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
741//      CHECK:   %[[RESULT:.+]] = linalg.generic
742// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)
743
744// -----
745
746func.func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xi32>) {
747  %cst = arith.constant 4.0 : f32
748  %c42 = arith.constant 42 : i32
749  %c0 = arith.constant 0 : index
750  %c1 = arith.constant 1 : index
751  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
752  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
753  %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
754  %1 = tensor.empty(%d0, %d1) : tensor<?x?xi32>
755  %2:2 = linalg.generic {
756      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
757                       affine_map<(d0, d1) -> ()>,
758                       affine_map<(d0, d1) -> ()>,
759                       affine_map<(d0, d1) -> (d0, d1)>,
760                       affine_map<(d0, d1) -> (d0, d1)>],
761      iterator_types = ["parallel", "parallel"]}
762      ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
763      outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
764      ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
765        %3 = arith.addf %arg1, %arg2 : f32
766        linalg.yield %3, %arg3 : f32, i32
767      } -> (tensor<?x?xf32>, tensor<?x?xi32>)
768  return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
769}
770// CHECK-LABEL: func @fuse_scalar_constant
771//   CHECK-DAG:   %[[CST:.+]] = arith.constant 4.000000e+00 : f32
772//   CHECK-DAG:   %[[C42:.+]] = arith.constant 42 : i32
773//       CHECK:   linalg.generic
774//  CHECK-SAME:       ins(%{{.+}} : tensor<?x?xf32>)
775//       CHECK:     %[[YIELD:.+]] = arith.addf %{{.+}}, %[[CST]] : f32
776//       CHECK:     linalg.yield %[[YIELD]], %[[C42]] : f32, i32
777
778// -----
779
780// Fusing the broadcast into a reduction would require to insert extra knowledge
781// about the size of the reduction dimension. As long, as this is not
782// implemented, we check that two linalg operations remain.
783// TODO: Support this case in element-wise fusion.
784
785#map0 = affine_map<(d0, d1) -> ()>
786#map1 = affine_map<(d0, d1) -> (d0, d1)>
787#map2 = affine_map<(d0, d1) -> (d1, d0)>
788#map3 = affine_map<(d0, d1) -> (d0)>
789
790// CHECK-LABEL: @no_fusion_missing_reduction_shape
791// CHECK: linalg.generic
792// CHECK: linalg.generic
793func.func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> tensor<?xf32> {
794  %cst = arith.constant 0xFF800000 : f32
795  %4 = tensor.empty(%arg1, %arg1) : tensor<?x?xf32>
796  %5 = linalg.generic {
797    indexing_maps = [#map0, #map1],
798    iterator_types = ["parallel", "parallel"]
799  } ins(%arg0 : tensor<f32>) outs(%4 : tensor<?x?xf32>) {
800  ^bb0(%arg2: f32, %arg3: f32):
801    linalg.yield %arg2 : f32
802  } -> tensor<?x?xf32>
803  %6 = tensor.empty(%arg1) : tensor<?xf32>
804  %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
805  %8 = linalg.generic {
806    indexing_maps = [#map2, #map3],
807    iterator_types = ["parallel", "reduction"]
808  } ins(%5 : tensor<?x?xf32>) outs(%7 : tensor<?xf32>) {
809  ^bb0(%arg2: f32, %arg3: f32):
810    %9 = arith.maximumf %arg2, %arg3 : f32
811    linalg.yield %9 : f32
812  } -> tensor<?xf32>
813  return %8 : tensor<?xf32>
814}
815
816// -----
817
818func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
819  %c1_i32 = arith.constant 1 : i32
820  %0 = linalg.generic {
821        indexing_maps = [affine_map<(d0) -> (d0)>],
822        iterator_types = ["parallel"]}
823        outs(%arg0 : tensor<5000xi64>) {
824        ^bb0(%arg3: i64):  // no predecessors
825          %22 = linalg.index 0 : index
826          %23 = arith.index_cast %22 : index to i64
827          linalg.yield %23 : i64
828        } -> tensor<5000xi64>
829  %1 = tensor.empty() : tensor<5000xi32>
830  %2 = linalg.generic {
831        indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>],
832        iterator_types = ["parallel", "parallel"]}
833        ins(%0 : tensor<5000xi64>) outs(%1 : tensor<5000xi32>) {
834        ^bb0(%arg3: i64, %arg5: i32):  // no predecessors
835          %22 = arith.index_cast %arg3 : i64 to index
836          %23 = tensor.extract %arg1[%22] : tensor<5000xi32>
837          linalg.yield %23 : i32
838        } -> tensor<5000xi32>
839  return %2 : tensor<5000xi32>
840}
841//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)>
842//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
843//      CHECK: func @fusion_different_axes(
844// CHECK-SAME:     %[[ARG0:.+]]: tensor<5000xi64>
845// CHECK-SAME:     %[[ARG1:.+]]: tensor<5000xi32>
846//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<5000xi64>
847//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<5000xi32>
848//      CHECK:   %[[RESULT:.+]]:2 = linalg.generic
849// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
850// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
851// CHECK-NEXT:   ^bb0(
852// CHECK-SAME:       %[[B0:.+]]: i64
853// CHECK-SAME:       %[[B1:.+]]: i32
854//  CHECK-DAG:     %[[T0:.+]] = linalg.index 0
855//  CHECK-DAG:     %[[CAST1:.+]] = arith.index_cast %[[T0]] : index to i64
856//  CHECK-DAG:     %[[CAST2:.+]] = arith.index_cast %[[CAST1]] : i64 to index
857//      CHECK:     %[[EXTRACT:.+]] = tensor.extract %[[ARG1]][%[[CAST2]]]
858//      CHECK:     linalg.yield %[[CAST1]], %[[EXTRACT]]
859//      CHECK:   return %[[RESULT]]#1
860
861// -----
862
863// CHECK-LABEL: func @fold_fill_generic_basic
864//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
865//   CHECK-NOT: linalg.fill
866//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
867//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
868//  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
869#map0 = affine_map<(d0) -> (d0)>
870func.func @fold_fill_generic_basic(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
871  %c0 = arith.constant 0 : index
872  %cst = arith.constant 7.0 : f32
873  %0 = tensor.dim %arg0, %c0 : tensor<?xf32>
874  %1 = tensor.empty(%0) : tensor<?xf32>
875  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
876  %3 = tensor.empty(%0) : tensor<?xf32>
877  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf32>, tensor<?xf32>) outs (%3:tensor<?xf32>) {
878  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
879    %5 = arith.addf  %arg1, %arg2 : f32
880        linalg.yield %5 : f32
881  } -> tensor<?xf32>
882  return %4 : tensor<?xf32>
883}
884
885// -----
886
887// CHECK-LABEL: func @fold_fill_generic_different_dtype
888//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf16>) -> tensor<?xf16> {
889//   CHECK-NOT: linalg.fill
890//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
891//  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf16>)
892//  CHECK-SAME: outs({{.*}} : tensor<?xf16>) {
893#map0 = affine_map<(d0) -> (d0)>
894func.func @fold_fill_generic_different_dtype(%arg0: tensor<?xf16>) -> (tensor<?xf16>) {
895  %c0 = arith.constant 0 : index
896  %cst = arith.constant 7.0 : f32
897  %0 = tensor.dim %arg0, %c0 : tensor<?xf16>
898  %1 = tensor.empty(%0) : tensor<?xf16>
899  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf16>) -> tensor<?xf16>
900  %3 = tensor.empty(%0) : tensor<?xf16>
901  %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor<?xf16>, tensor<?xf16>) outs (%3:tensor<?xf16>) {
902  ^bb0(%arg1: f16, %arg2: f16, %arg3: f16):
903    %5 = arith.addf  %arg1, %arg2 : f16
904        linalg.yield %5 : f16
905  } -> tensor<?xf16>
906  return %4 : tensor<?xf16>
907}
908
909// -----
910
911// CHECK-LABEL: func @fold_fill_generic_mixedaccess
912//   CHECK-NOT: linalg.fill
913//       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
914//   CHECK-NOT: ins
915//  CHECK-SAME: outs({{.*}} : tensor<?x?xf32>) {
916#map0 = affine_map<(d0, d1) -> (d0, d1)>
917#map1 = affine_map<(d0, d1) -> (d1, d0)>
918func.func @fold_fill_generic_mixedaccess(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
919  %c0 = arith.constant 0 : index
920  %c1 = arith.constant 0 : index
921  %cst1 = arith.constant 7.0 : f32
922  %cst2 = arith.constant 6.0 : f32
923  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
924  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
925  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
926  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
927  %4 = tensor.empty(%1, %0) : tensor<?x?xf32>
928  %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor<?x?xf32>) -> tensor<?x?xf32>
929  %6 = tensor.empty(%0, %1) : tensor<?x?xf32>
930  %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor<?x?xf32>, tensor<?x?xf32>) outs (%6:tensor<?x?xf32>) {
931  ^bb0(%arg1: f32, %arg2: f32, %arg3: f32):
932    %8 = arith.divf  %arg1, %arg2 : f32
933        linalg.yield %8 : f32
934  } -> tensor<?x?xf32>
935  return %7 : tensor<?x?xf32>
936}
937
938// -----
939
940#map = affine_map<() -> ()>
941module {
942  func.func @fuse_multi_result_producer(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> tensor<f32> {
943    %0 = tensor.empty() : tensor<f32>
944    %1 = tensor.empty() : tensor<f32>
945    %2:2 = linalg.generic {
946      indexing_maps = [#map, #map, #map, #map, #map], iterator_types = []}
947      ins(%arg0, %arg1, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>) outs(%0, %1 : tensor<f32>, tensor<f32>) {
948    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, %arg9: f32):
949      %4 = arith.addf %arg5, %arg6 : f32
950      %5 = arith.addf %4, %arg7 : f32
951      linalg.yield %4, %5 : f32, f32
952    } -> (tensor<f32>, tensor<f32>)
953    %3 = linalg.generic {
954      indexing_maps = [#map, #map, #map], iterator_types = []}
955      ins(%2#1, %arg1 : tensor<f32>, tensor<f32>) outs(%arg4 : tensor<f32>) {
956    ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
957      %4 = arith.addf %arg5, %arg6 : f32
958      %5 = arith.addf %4, %arg6 : f32
959      linalg.yield %5 : f32
960    } -> tensor<f32>
961    return %3 : tensor<f32>
962  }
963}
964// CHECK-LABEL: func.func @fuse_multi_result_producer
965//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<f32>
966//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<f32>
967//       CHECK:   %[[INIT:.+]] = tensor.empty
968//       CHECK:   %[[GENERIC:.+]] = linalg.generic
969//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
970//  CHECK-SAME:       outs(%[[INIT]] :
971//  CHECK-NEXT:     ^bb0
972//  CHECK-SAME:         %[[B0:[a-zA-Z0-9_]+]]: f32
973//  CHECK-SAME:         %[[B1:[a-zA-Z0-9_]+]]: f32
974//   CHECK-DAG:     %[[T0:.+]] = arith.addf %[[B0]], %[[B1]]
975//   CHECK-DAG:     %[[T1:.+]] = arith.addf %[[T0]], %[[B1]]
976//   CHECK-DAG:     %[[T2:.+]] = arith.addf %[[T1]], %[[B1]]
977//   CHECK-DAG:     %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
978//       CHECK:     linalg.yield %[[T3]] : f32
979//       CHECK:   return %[[GENERIC]]
980