xref: /llvm-project/mlir/test/Dialect/Linalg/reshape_fusion.mlir (revision 1f5335c1db5d54b4465677c224b48e0ffc78e6d9)
1// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s
2
3#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
4#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
5#map2 = affine_map<(d0, d1, d2) -> ()>
6func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
7                                         %arg1 : tensor<?x?x?xf32>,
8                                         %arg2 : f32) ->
9                                         tensor<?x?x?xf32>
10{
11  %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] :
12    tensor<?x?x4x?xf32> into tensor<?x?x?xf32>
13  %1 = linalg.generic {
14     indexing_maps = [#map0, #map1, #map2, #map1],
15     iterator_types = ["parallel", "parallel", "parallel"]}
16       ins(%0, %arg1, %arg2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>, f32)
17       outs(%arg1 : tensor<?x?x?xf32>) {
18    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):
19      %1 = arith.mulf %arg3, %arg4 : f32
20      %2 = arith.addf %1, %arg5 : f32
21      linalg.yield %2 : f32
22  } -> tensor<?x?x?xf32>
23  return %1 : tensor<?x?x?xf32>
24}
25
26//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
27//  CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0, d1)>
28//  CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
29//      CHECK: func @generic_op_reshape_producer_fusion
30// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
31// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
32// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
33//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
34//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
35//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
36//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
37//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
38//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
39//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
40//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
41//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
42//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
43//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
44//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
45//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
46//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
47//      CHECK:   %[[T3:.+]] = linalg.generic
48// CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
49// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
50// CHECK-SAME:     ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor<?x?x4x?xf32>, tensor<?x?x?x4xf32>, f32)
51// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x?x4xf32>)
52//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
53// CHECK-SAME:     [0], [1], [2, 3]
54// CHECK-SAME:     tensor<?x?x?x4xf32> into tensor<?x?x?xf32>
55//      CHECK:   return %[[T4]]
56
57// -----
58
59#map0 = affine_map<(d0, d1) -> (d0, d1)>
60#map1 = affine_map<(d0, d1) -> ()>
61func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
62                                         %arg1 : tensor<?x?xf32>,
63                                         %arg2 : f32,
64                                         %sz0: index,
65                                         %sz1: index) ->
66                                         tensor<?x4x?x5xf32>
67{
68  %0 = linalg.generic {
69     indexing_maps = [#map0, #map0, #map1, #map0],
70     iterator_types = ["parallel", "parallel"]}
71       ins(%arg0, %arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>, f32)
72       outs(%arg0 : tensor<?x?xf32>) {
73    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32):
74      %1 = arith.mulf %arg3, %arg4 : f32
75      %2 = arith.addf %1, %arg5 : f32
76      linalg.yield %2 : f32
77  } -> tensor<?x?xf32>
78  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, 4, %sz1, 5] :
79    tensor<?x?xf32> into tensor<?x4x?x5xf32>
80  return %1 : tensor<?x4x?x5xf32>
81}
82
83//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
84//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> ()>
85
86//      CHECK: func @generic_op_reshape_consumer_fusion
87// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
88// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
89// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
90// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
91//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
92//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
93//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
94//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
95//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
96//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
97//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
98//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
99//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
100//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
101//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
102//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
103//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
104//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
105//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
106//      CHECK:   %[[T3:.+]] = linalg.generic
107// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
108// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
109// CHECK-SAME:     ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor<?x4x?x5xf32>, tensor<?x4x?x5xf32>, f32)
110// CHECK-SAME:     outs(%[[T2]] : tensor<?x4x?x5xf32>)
111//      CHECK:   return %[[T3]] : tensor<?x4x?x5xf32>
112
113
114// -----
115
116func.func @reshape_as_consumer_permutation
117  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
118    -> tensor<?x2x?x3x4x?xf32> {
119  %c = linalg.generic {
120         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
121                          affine_map<(d0, d1, d2) -> (d1, d2)>,
122                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
123         iterator_types = ["parallel", "parallel", "parallel"]}
124          ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>)
125         outs(%a : tensor<?x?x?xf32>) {
126       ^bb0(%arg0 : f32, %arg1: f32, %s: f32):
127         %1 = arith.addf %arg0, %arg1 : f32
128         linalg.yield %1 : f32
129       } -> tensor<?x?x?xf32>
130  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
131  return %d : tensor<?x2x?x3x4x?xf32>
132}
133//  CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
134//  CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
135//  CHECK-DAG: #[[MAP10:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
136//      CHECK: func @reshape_as_consumer_permutation
137// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
138// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
139// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
140//      CHECK:   %[[C12:.+]] = arith.constant 12 : index
141//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
142//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
143//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
144//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
145//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
146//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
147//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
148//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
149//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
150//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
151//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
152//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
153//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
154//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
155//      CHECK:   %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
156//      CHECK:   %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
157//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
158//      CHECK:   %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
159//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
160//      CHECK:   %[[T3:.+]] = linalg.generic
161// CHECK-SAME:     indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]]
162// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]
163// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>)
164// CHECK-SAME:     outs(%[[T2]] : tensor<?x2x?x3x4x?xf32>)
165//      CHECK:   return %[[T3]] : tensor<?x2x?x3x4x?xf32>
166
167// -----
168
169#map0 = affine_map<(d0, d1) -> (d0, d1)>
170#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
171#map2 = affine_map<(d0, d1, d2) -> (d2)>
172
173func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
174                                            -> tensor<8x33x4xf32> {
175  %cst = arith.constant dense<2.000000e+00> : tensor<264x4xf32>
176  %0 = tensor.empty() : tensor<264x4xf32>
177  %1 = linalg.generic {
178     indexing_maps = [#map0, #map0, #map0],
179     iterator_types = ["parallel", "parallel"]}
180       ins(%arg0, %cst : tensor<264x4xf32>, tensor<264x4xf32>)
181       outs(%0 : tensor<264x4xf32>) {
182    ^bb0(%arg1: f32, %arg2: f32, %s: f32):
183      %2 = arith.mulf %arg1, %arg2 : f32
184      linalg.yield %2 : f32
185    } -> tensor<264x4xf32>
186  %2 = tensor.expand_shape %1 [[0, 1], [2]] output_shape [8, 33, 4] :
187    tensor<264x4xf32> into tensor<8x33x4xf32>
188  return %2 : tensor<8x33x4xf32>
189}
190
191//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
192//      CHECK: func @generic_op_reshape_consumer_static
193// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32>
194//  CHECK-DAG:   %[[CST:.+]] = arith.constant
195// CHECK-SAME:     : tensor<8x33x4xf32>
196//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
197//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
198//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
199//      CHECK:   %[[T2:.+]] = linalg.generic
200// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
201// CHECK-SAME:     ["parallel", "parallel", "parallel"]
202// CHECK-SAME:     ins(%[[T0]], %[[CST]] :
203// CHECK-SAME:     outs(%[[T1]] : tensor<8x33x4xf32>)
204//      CHECK:   return %[[T2]] : tensor<8x33x4xf32>
205
206// -----
207
208#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
209#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
210func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
211                                         %arg1 : tensor<?x?x?xi32>) ->
212                                         tensor<?x?x?xi32>
213{
214  %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3]]:
215    tensor<?x?x4x?xi32> into tensor<?x?x?xi32>
216  %1 = linalg.generic {
217     indexing_maps = [#map0, #map1, #map1],
218     iterator_types = ["parallel", "parallel", "parallel"]}
219       ins(%0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xi32>)
220      outs(%0 : tensor<?x?x?xi32>) {
221    ^bb0(%arg3: i32, %arg4: i32, %s: i32):
222      %idx0 = linalg.index 0 : index
223      %idx1 = linalg.index 1 : index
224      %idx2 = linalg.index 2 : index
225      %1 = arith.muli %arg3, %arg4 : i32
226      %2 = arith.index_cast %idx0 : index to i32
227      %3 = arith.addi %1, %2 : i32
228      %4 = arith.index_cast %idx1 : index to i32
229      %5 = arith.addi %3, %4 : i32
230      %6 = arith.index_cast %idx2 : index to i32
231      %7 = arith.addi %5, %6 : i32
232      linalg.yield %7 : i32
233  } -> tensor<?x?x?xi32>
234  return %1 : tensor<?x?x?xi32>
235}
236
237// Only check the body in the indexed version of the test.
238//       CHECK: #[[MAP:.+]] =  affine_map<(d0, d1) -> (d0 + d1 * 4)>
239//       CHECK: func @indexed_consumer_reshape_producer_fusion
240//       CHECK:   linalg.generic
241//       CHECK:   ^{{.*}}(
242//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: i32, %[[ARG4:[a-zA-Z0-9_]+]]: i32,
243//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9_]+]]: i32)
244//   CHECK-DAG:     %[[IDX0:.+]] = linalg.index 0 : index
245//   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
246//   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
247//   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
248//   CHECK-DAG:     %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]])
249//       CHECK:     %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
250//       CHECK:     %[[T5:.+]] = arith.index_cast %[[T3]]
251//       CHECK:     %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
252//       CHECK:     %[[T7:.+]] = arith.index_cast %[[IDX2]]
253//       CHECK:     %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
254//       CHECK:     %[[T9:.+]] = arith.index_cast %[[IDX3]]
255//       CHECK:     %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
256//       CHECK:     linalg.yield %[[T10]]
257
258// -----
259
260#map0 = affine_map<(d0, d1) -> (d0, d1)>
261func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
262                                         %arg1 : tensor<?x?xi32>,
263                                         %sz0: index, %sz1: index) ->
264                                         tensor<?x?x4x5xi32>
265{
266  %0 = linalg.generic {
267     indexing_maps = [#map0, #map0, #map0],
268     iterator_types = ["parallel", "parallel"]}
269       ins(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xi32>)
270      outs(%arg0 : tensor<?x?xi32>) {
271    ^bb0(%arg3: i32, %arg4: i32, %s: i32):
272      %idx0 = linalg.index 0 : index
273      %idx1 = linalg.index 1 : index
274      %1 = arith.muli %arg3, %arg4 : i32
275      %2 = arith.index_cast %idx0 : index to i32
276      %3 = arith.addi %1, %2 : i32
277      %4 = arith.index_cast %idx1 : index to i32
278      %5 = arith.addi %3, %4 : i32
279      linalg.yield %5 : i32
280  } -> tensor<?x?xi32>
281  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
282    tensor<?x?xi32> into tensor<?x?x4x5xi32>
283  return %1 : tensor<?x?x4x5xi32>
284}
285
286// Only check the body in the indexed version of the test.
287//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)>
288//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)>
289//       CHECK: func @indexed_producer_reshape_consumer_fusion
290//       CHECK:   linalg.generic
291//       CHECK:   ^{{.*}}(
292//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: i32, %[[ARG4:[a-zA-Z0-9_]+]]: i32,
293//  CHECK-SAME:     %[[ARG5:[a-zA-Z0-9_]+]]: i32)
294//   CHECK-DAG:     %[[IDX0:.+]] = linalg.index 0 : index
295//   CHECK-DAG:     %[[IDX1:.+]] = linalg.index 1 : index
296//   CHECK-DAG:     %[[IDX2:.+]] = linalg.index 2 : index
297//   CHECK-DAG:     %[[IDX3:.+]] = linalg.index 3 : index
298//       CHECK:     %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]])
299//       CHECK:     %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]])
300//       CHECK:     %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]]
301//       CHECK:     %[[T5:.+]] = arith.index_cast %[[IDX0]]
302//       CHECK:     %[[T6:.+]] = arith.addi %[[T4]], %[[T5]]
303//       CHECK:     %[[T7:.+]] = arith.index_cast %[[T2]]
304//       CHECK:     %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
305//       CHECK:     linalg.yield %[[T8]]
306
307// -----
308
309func.func @reshape_as_consumer_permutation
310  (%a : tensor<210x6x4xi32>, %b : tensor<210x4xi32>)
311    -> tensor<2x3x4x5x6x7xi32> {
312  %shape = tensor.empty() : tensor<6x4x210xi32>
313  %c = linalg.generic {
314         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
315                          affine_map<(d0, d1, d2) -> (d1, d2)>,
316                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
317         iterator_types = ["parallel", "parallel", "parallel"]}
318          ins(%a, %b : tensor<210x6x4xi32>, tensor<210x4xi32>)
319          outs(%shape : tensor<6x4x210xi32>) {
320       ^bb0(%arg3 : i32, %arg4: i32, %s: i32):
321         %idx0 = linalg.index 0 : index
322         %idx1 = linalg.index 1 : index
323         %idx2 = linalg.index 2 : index
324         %1 = arith.addi %arg3, %arg4 : i32
325         %2 = arith.index_cast %idx0 : index to i32
326         %3 = arith.addi %1, %2 : i32
327         %4 = arith.index_cast %idx1 : index to i32
328         %5 = arith.addi %3, %4 : i32
329         %6 = arith.index_cast %idx2 : index to i32
330         %7 = arith.addi %5, %6 : i32
331         linalg.yield %7 : i32
332       } -> tensor<6x4x210xi32>
333  %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
334  return %d : tensor<2x3x4x5x6x7xi32>
335}
336
337// -----
338
339//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
340//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
341//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
342//   CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
343//   CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)>
344//   CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)>
345//       CHECK: func @reshape_as_consumer_permutation
346//  CHECK-SAME:   %[[ARG0:.+]]: tensor<210x6x4xi32>
347//  CHECK-SAME:   %[[ARG1:.+]]: tensor<210x4xi32>
348//   CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
349//       CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32>
350//       CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32>
351//       CHECK:   %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32>
352//       CHECK:   %[[T4:.+]] = linalg.generic
353//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
354//  CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>)
355//  CHECK-SAME:     outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>)
356//       CHECK:   ^{{.+}}(
357//  CHECK-SAME:     %[[ARG8:[a-zA-Z0-9_]+]]: i32, %[[ARG9:[a-zA-Z0-9_]+]]: i32,
358//  CHECK-SAME:     %[[ARG10:[a-zA-Z0-9_]+]]: i32)
359//   CHECK-DAG:       %[[IDX0:.+]] = linalg.index 0 : index
360//   CHECK-DAG:       %[[IDX1:.+]] = linalg.index 1 : index
361//   CHECK-DAG:       %[[IDX2:.+]] = linalg.index 2 : index
362//   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
363//   CHECK-DAG:       %[[IDX4:.+]] = linalg.index 4 : index
364//   CHECK-DAG:       %[[IDX5:.+]] = linalg.index 5 : index
365//   CHECK-DAG:       %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]])
366//   CHECK-DAG:       %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]])
367//   CHECK-DAG:       %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]])
368//   CHECK-DAG:       %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]]
369//       CHECK:       %[[T9:.+]] = arith.index_cast %[[T5]]
370//       CHECK:       %[[T10:.+]] = arith.addi %[[T8]], %[[T9]]
371//       CHECK:       %[[T11:.+]] = arith.index_cast %[[T7]]
372//       CHECK:       %[[T12:.+]] = arith.addi %[[T10]], %[[T11]]
373//       CHECK:       %[[T13:.+]] = arith.index_cast %[[IDX5]]
374//       CHECK:       %[[T14:.+]] = arith.addi %[[T12]], %[[T13]]
375
376// -----
377
378func.func @reshape_as_producer_projected_permutation(
379    %arg0 : tensor<33x8x?xi32>, %shape : tensor<264x?x4xi32>) -> tensor<264x?x4xi32>
380{
381  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
382    : tensor<33x8x?xi32> into tensor<264x?xi32>
383  %1 = linalg.generic
384    {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>,
385                      affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
386     iterator_types = ["parallel", "parallel", "parallel"]}
387     ins(%0 : tensor<264x?xi32>)
388    outs(%shape : tensor<264x?x4xi32>) {
389  ^bb0(%arg1: i32, %s: i32):
390    %idx0 = linalg.index 0 : index
391    %idx1 = linalg.index 1 : index
392    %idx2 = linalg.index 2 : index
393    %2 = arith.index_cast %idx0 : index to i32
394    %3 = arith.addi %arg1, %2 : i32
395    %4 = arith.index_cast %idx1 : index to i32
396    %5 = arith.addi %3, %4 : i32
397    %6 = arith.index_cast %idx2 : index to i32
398    %7 = arith.addi %5, %6 : i32
399    linalg.yield %7 : i32
400  } -> tensor<264x?x4xi32>
401  return %1 : tensor<264x?x4xi32>
402}
403
404//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
405//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
406//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)>
407//       CHECK: @reshape_as_producer_projected_permutation
408//  CHECK-SAME:   %[[ARG0:.+]]: tensor<33x8x?xi32>
409//       CHECK:   %[[RES:.+]] = linalg.generic
410//  CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]]]
411//  CHECK-SAME:     ins(%[[ARG0]] : tensor<33x8x?xi32>)
412//       CHECK:   ^{{.+}}(
413//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: i32,
414//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: i32)
415//   CHECK-DAG:       %[[IDX0:.+]] = linalg.index 0 : index
416//   CHECK-DAG:       %[[IDX1:.+]] = linalg.index 1 : index
417//   CHECK-DAG:       %[[IDX2:.+]] = linalg.index 2 : index
418//   CHECK-DAG:       %[[IDX3:.+]] = linalg.index 3 : index
419//   CHECK-DAG:       %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]])
420//       CHECK:       %[[T1:.+]] = arith.index_cast %[[T0]] : index to i32
421//       CHECK:       %[[T2:.+]] = arith.addi %[[ARG1]], %[[T1]] : i32
422//       CHECK:       %[[T3:.+]] = arith.index_cast %[[IDX2]] : index to i32
423//       CHECK:       %[[T4:.+]] = arith.addi %[[T2]], %[[T3]] : i32
424//       CHECK:       %[[T5:.+]] = arith.index_cast %[[IDX3]] : index to i32
425//       CHECK:       %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] : i32
426//       CHECK:       linalg.yield %[[T6]] : i32
427//       CHECK:    %[[RES2:.+]] = tensor.collapse_shape %[[RES]]
428//  CHECK-SAME:      [0, 1], [2], [3]
429//  CHECK-SAME:    : tensor<33x8x?x4xi32> into tensor<264x?x4xi32>
430//       CHECK:  return %[[RES2]] : tensor<264x?x4xi32>
431
432// -----
433
434#map0 = affine_map<(d0, d1) -> (d0, d1)>
435#map1 = affine_map<(d0, d1) -> (d1, d0)>
436func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
437                                                   %arg1 : tensor<?x?xf32>,
438                                                   %sz0: index, %sz1: index) ->
439                                                   tensor<?x?x4x5xf32>
440{
441  %0 = linalg.generic {
442     indexing_maps = [#map0, #map0, #map1],
443     iterator_types = ["parallel", "parallel"]}
444       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
445       outs(%arg0 : tensor<?x?xf32>) {
446    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
447      %1 = arith.mulf %arg3, %arg4 : f32
448      linalg.yield %1 : f32
449  } -> tensor<?x?xf32>
450  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
451    tensor<?x?xf32> into tensor<?x?x4x5xf32>
452  return %1 : tensor<?x?x4x5xf32>
453}
454
455//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
456//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
457//      CHECK: func @generic_op_reshape_consumer_fusion_projected
458// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
459// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
460// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
461//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
462//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
463//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
464//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
465//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
466//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
467//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
468//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
469//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
470//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_1]], %[[C20]] : index
471//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
472//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
473//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
474//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
475//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
476//      CHECK:   %[[T3:.+]] = linalg.generic
477// CHECK-SAME:     indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]]
478// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
479// CHECK-SAME:     ins(%[[T0]], %[[T1]] : tensor<?x4x5x?xf32>, tensor<?x4x5x?xf32>)
480// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
481//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
482
483// -----
484
485func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
486  %c0 = arith.constant 0 : index
487  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
488  %1 = tensor.dim %0, %c0 : tensor<?xf32>
489  %2 = tensor.empty(%1) : tensor<?xf32>
490  %3 = linalg.generic {
491    indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
492    iterator_types = ["parallel"]}
493    ins(%0 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {
494      ^bb0(%arg1 : f32, %arg2: f32):
495        %4 = arith.addf %arg1, %arg1 : f32
496        linalg.yield %4 : f32
497    } -> tensor<?xf32>
498  return %3 : tensor<?xf32>
499}
500
501//      CHECK: func @no_fuse_dynamic_dims
502// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
503//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
504//      CHECK:   %[[GENERIC:.+]] = linalg.generic
505// CHECK-SAME:       ins(%[[RESHAPE]] : tensor<?xf32>)
506//      CHECK:   return %[[GENERIC]]
507
508// -----
509
510func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor<?xi64>) -> tensor<2xi64> {
511  %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64>
512  %1 = tensor.empty() : tensor<2xi64>
513  %2 = linalg.generic
514    {indexing_maps = [affine_map<(d0) -> (d0)>,
515                      affine_map<(d0) -> (d0)>,
516                      affine_map<(d0) -> (d0)>],
517     iterator_types = ["parallel"]}
518    ins(%0, %arg1 : tensor<2xi64>, tensor<?xi64>)
519    outs(%1 : tensor<2xi64>) {
520  ^bb0(%arg4: i64, %arg5: i64, %arg6: i64):
521    %3 = arith.addi %arg4, %arg5 : i64
522    linalg.yield %3 : i64
523  } -> tensor<2xi64>
524  return %2 : tensor<2xi64>
525}
526
527//      CHECK: func @no_fuse_mismatched_dynamism
528// CHECK-SAME:     %[[ARG0:.+]]: tensor<2x1xi64>
529// CHECK-SAME:     %[[ARG1:.+]]: tensor<?xi64>
530//      CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
531//      CHECK:   %[[GENERIC:.+]] = linalg.generic
532// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor<?xi64>)
533//      CHECK:   return %[[GENERIC]]
534
535// -----
536
537func.func @reshape_as_consumer_permutation_with_multiple_results
538  (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
539   %sz1: index, %sz2: index, %sz3: index, %sz4: index)
540    -> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
541  %c:2 = linalg.generic {
542         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
543                          affine_map<(d0, d1, d2) -> (d1, d2)>,
544                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>,
545                          affine_map<(d0, d1, d2) -> (d2, d0, d1)>],
546         iterator_types = ["parallel", "parallel", "parallel"]}
547          ins(%a, %b : tensor<?x?x?xf32>, tensor<?x?xf32>)
548         outs(%a, %a : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
549       ^bb0(%arg0 : f32, %arg1: f32, %s: f32, %t : f32):
550         %1 = arith.addf %arg0, %arg1 : f32
551         linalg.yield %1, %1 : f32, f32
552       } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
553  %d = tensor.expand_shape %c#0 [[0, 1], [2], [3, 4, 5]] output_shape [%sz0, 2, %sz1, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
554  %e = tensor.expand_shape %c#1 [[0], [1, 2], [3, 4, 5]] output_shape [%sz3, %sz4, 2, 3, 4, %sz2] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
555  return %d, %e : tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>
556}
557//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)>
558//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
559//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)>
560//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d5, d0, d1, d2, d3, d4)>
561//      CHECK: func @reshape_as_consumer_permutation_with_multiple_results
562//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
563//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
564//  CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index
565//       CHECK:   %[[C12:.+]] = arith.constant 12 : index
566//       CHECK:   %[[C2:.+]] = arith.constant 2 : index
567//       CHECK:   %[[C1:.+]] = arith.constant 1 : index
568//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
569//       CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
570//       CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
571//       CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
572//       CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index
573//       CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index
574//       CHECK:   %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor<?x?x?xf32> into tensor<3x4x?x?x2x?xf32>
575//       CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
576//       CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
577//       CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index
578//       CHECK:   %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor<?x?xf32> into tensor<3x4x?x?xf32>
579//       CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
580//       CHECK:   %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
581//       CHECK:   %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
582//       CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index
583//       CHECK:   %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index
584//       CHECK:   %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor<?x?x?xf32> into tensor<?x2x?x3x4x?xf32>
585//       CHECK:   %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
586//       CHECK:   %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
587//       CHECK:   %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?xf32>
588//       CHECK:   %[[VAL_5:.+]] = arith.divsi %[[DIM_10]], %[[C2]] : index
589//       CHECK:   %[[VAL_6:.+]] = arith.divsi %[[DIM_11]], %[[C12]] : index
590//       CHECK:   %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor<?x?x?xf32> into tensor<?x?x2x3x4x?xf32>
591//       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
592//  CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
593//  CHECK-SAME:      ins(%[[RESHAPE0]], %[[RESHAPE1]] :
594//  CHECK-SAME:      outs(%[[RESHAPE2]], %[[RESHAPE3]] :
595//       CHECK:  return %[[GENERIC]]#0, %[[GENERIC]]#1
596
597// -----
598
599#map0 = affine_map<(d0, d1) -> (d1)>
600#map1 = affine_map<(d0, d1) -> (d0, d1)>
601module {
602  func.func @multi_result_op_expansion(%arg0: tensor<512xf32>, %arg1: tensor<512xf32>,
603      %arg2: tensor<512xf32>, %arg3: tensor<200x512xf32>) -> tensor<25x8x1x512xf32> {
604    %0:2 = linalg.generic {
605        indexing_maps = [#map0, #map0, #map0, #map1],
606        iterator_types = ["parallel", "parallel"]}
607        ins(%arg0, %arg1 : tensor<512xf32>, tensor<512xf32>)
608        outs(%arg2, %arg3 : tensor<512xf32>, tensor<200x512xf32>) {
609      ^bb0(%arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32):
610        %2 = arith.addf %arg4, %arg5 : f32
611        linalg.yield %2, %2 : f32, f32
612      } -> (tensor<512xf32>, tensor<200x512xf32>)
613    %1 = tensor.expand_shape %0#1 [[0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32>
614    return %1 : tensor<25x8x1x512xf32>
615  }
616}
617//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
618//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
619//      CHECK: func.func @multi_result_op_expansion(
620// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<512xf32>
621// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512xf32>
622// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<512xf32>
623// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<200x512xf32>
624//      CHECK:     %[[OUTS:.+]] = tensor.expand_shape %[[ARG3]] {{\[\[}}0, 1, 2], [3]] output_shape [25, 8, 1, 512] : tensor<200x512xf32> into tensor<25x8x1x512xf32>
625//      CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
626// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]], #[[MAP1]]]
627// CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
628// CHECK-SAME:       outs(%[[ARG2]], %[[OUTS]] :
629//      CHECK:   return %[[GENERIC]]#1
630
631// -----
632
633#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
634#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
635#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
636func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
637                                                        %arg1 : tensor<?x?xf32>,
638                                                        %arg2 : tensor<?x?xf32>,
639                                                        %sz0: index,
640                                                        %sz1: index) ->
641                                                        tensor<?x?x4x5xf32>
642{
643  %0 = linalg.generic {
644     indexing_maps = [#map0, #map1, #map2],
645     iterator_types = ["parallel", "parallel", "reduction"]}
646       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
647       outs(%arg2 : tensor<?x?xf32>) {
648    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
649      %1 = arith.mulf %arg3, %arg4 : f32
650      linalg.yield %1 : f32
651  } -> tensor<?x?xf32>
652  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
653    tensor<?x?xf32> into tensor<?x?x4x5xf32>
654  return %1 : tensor<?x?x4x5xf32>
655}
656
657//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
658//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
659//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
660//      CHECK: func @generic_op_reshape_consumer_fusion_reduction
661// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
662// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
663// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
664// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
665//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
666//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
667//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
668//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
669//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
670//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index
671//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor<?x?xf32> into tensor<?x4x5x?xf32>
672//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
673//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
674//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
675//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
676//      CHECK:   %[[T3:.+]] = linalg.generic
677// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
678// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
679// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
680// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
681//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
682
683// -----
684
685#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
686#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
687#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
688func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
689                                         %arg1 : tensor<?x4x?xf32>,
690                                         %arg2 : tensor<?x?xf32>) ->
691                                         tensor<?x?xf32>
692{
693  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
694    tensor<?x7x?x8xf32> into tensor<?x?xf32>
695  %1 = linalg.generic {
696     indexing_maps = [#map0, #map1, #map2],
697     iterator_types = ["parallel", "reduction", "parallel"]}
698       ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
699       outs(%arg2 : tensor<?x?xf32>) {
700    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
701      %1 = arith.mulf %arg3, %arg4 : f32
702      %2 = arith.addf %1, %arg5 : f32
703      linalg.yield %2 : f32
704  } -> tensor<?x?xf32>
705  return %1 : tensor<?x?xf32>
706}
707
708//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
709//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
710//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
711//      CHECK: func @generic_op_reshape_producer_fusion_with_reduction
712// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
713// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
714// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
715//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
716//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
717//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
718//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
719//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
720//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x4x?xf32>
721//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x4x?xf32>
722//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
723//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
724//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor<?x4x?xf32> into tensor<?x8x4x?x7xf32>
725//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
726//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
727//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index
728//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C7]] : index
729//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
730//      CHECK:   %[[T3:.+]] = linalg.generic
731// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
732// CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
733// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
734// CHECK-SAME:     outs(%[[T2]] : tensor<?x8x?x7xf32>)
735//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
736// CHECK-SAME:     [0, 1], [2, 3]
737// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
738//      CHECK:   return %[[T4]]
739
740// -----
741
742func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
743                                              %arg1 : tensor<?x?xf32>,
744                                              %arg2 : tensor<?x?xf32>,
745                                              %sz0: index,
746                                              %sz1: index) ->
747                                              tensor<?x?x4x5xf32>
748{
749  %0 = linalg.add ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
750       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
751  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] output_shape [%sz0, %sz1, 4, 5] :
752    tensor<?x?xf32> into tensor<?x?x4x5xf32>
753  return %1 : tensor<?x?x4x5xf32>
754}
755
756//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
757//      CHECK: func @linalg_add_reshape_consumer_fusion
758// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
759// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
760// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
761// CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
762//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
763//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
764//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
765//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
766//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
767//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
768//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
769//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
770//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
771//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
772//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
773//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
774//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
775//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
776//      CHECK:   %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
777//      CHECK:   %[[T4:.+]] = linalg.generic
778// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
779// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
780// CHECK-SAME:     ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
781// CHECK-SAME:     outs(%[[T3]] : tensor<?x?x4x5xf32>)
782//      CHECK:   return %[[T4]] : tensor<?x?x4x5xf32>
783
784// -----
785
786#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
787#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
788#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
789func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
790                                              %arg1 : tensor<?x?xf32>,
791                                              %arg2 : tensor<?x?xf32>) ->
792                                              tensor<?x?xf32>
793{
794  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
795    tensor<?x7x?x8xf32> into tensor<?x?xf32>
796  %1 = linalg.add ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
797       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
798  return %1 : tensor<?x?xf32>
799}
800
801//  CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
802//      CHECK: func @linalg_add_reshape_producer_fusion
803// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
804// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
805// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
806//      CHECK:   %[[C8:.+]] = arith.constant 8 : index
807//      CHECK:   %[[C7:.+]] = arith.constant 7 : index
808//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
809//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
810//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
811//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
812//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
813//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
814//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
815//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
816//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
817//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
818//      CHECK:   %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
819//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
820//      CHECK:   %[[T3:.+]] = linalg.generic
821// CHECK-SAME:     indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
822// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
823// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
824// CHECK-SAME:     outs(%[[T2]] : tensor<?x7x?x8xf32>)
825//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
826// CHECK-SAME:     [0, 1], [2, 3]
827// CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
828//      CHECK:   return %[[T4]]
829
830// -----
831
832func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
833  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
834  %cst = arith.constant 0 : i32
835  %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
836  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
837    tensor.yield %cst : i32
838  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
839  return %padded_0 : tensor<8x12x17x336x14xi32>
840}
841//      CHECK: func @fuse_by_expanding_pad(
842// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
843//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
844// CHECK-SAME:       low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
845//      CHECK:       tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
846//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
847// CHECK-SAME:       : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
848//      CHECK:   return %[[COLLAPSE]]
849
850// -----
851
852func.func @no_fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x339x14xi32> {
853  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
854  %cst = arith.constant 0 : i32
855  %padded_0 = tensor.pad %collapse low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2] {
856  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
857    tensor.yield %cst : i32
858  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
859  return %padded_0 : tensor<8x12x17x339x14xi32>
860}
861//      CHECK: func @no_fuse_by_expanding_pad(
862// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>)
863//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
864// CHECK-SAME:       : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
865//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]]
866// CHECK-SAME:       low[1, 0, 8, 0, 3] high[5, 0, 4, 3, 2]
867//      CHECK:       tensor<2x12x5x336x9xi32> to tensor<8x12x17x339x14xi32>
868//      CHECK:   return %[[PAD]]
869
870// -----
871
872func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?xi32> {
873  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5]] : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
874  %cst = arith.constant 0 : i32
875  %padded_0 = tensor.pad %collapse low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
876  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
877    tensor.yield %cst : i32
878  } : tensor<?x?x?x?xi32> to tensor<?x?x?x?xi32>
879  return %padded_0 : tensor<?x?x?x?xi32>
880}
881//      CHECK: func @fuse_by_expanding_dynamic_pad(
882// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?x?xi32>
883// CHECK-SAME:   %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
884//      CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]]
885// CHECK-SAME:       low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
886//      CHECK:       tensor<?x?x?x?x?x?xi32> to tensor<?x?x?x?x?x?xi32>
887//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
888// CHECK-SAME:       : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
889//      CHECK:   return %[[COLLAPSE]]
890