xref: /llvm-project/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (revision 40556d08491f530e03746fb188b38e7f9cb272c7)
1// RUN: mlir-opt %s -linalg-fold-unit-extent-dims -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -linalg-fold-unit-extent-dims="use-rank-reducing-slices" -cse -split-input-file | FileCheck %s --check-prefix=CHECK-SLICES
3
4#accesses = [
5  affine_map<(i, j, k, l, m) -> (i, k, m)>,
6  affine_map<(i, j, k, l, m) -> ()>,
7  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
8]
9
10#trait = {
11  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
12  indexing_maps = #accesses,
13  library_call = "some_external_func"
14}
15
16func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: tensor<?x1x?x1x?xf32>) -> tensor<?x1x?x1x?xf32> {
17  %0 = linalg.generic #trait
18     ins(%arg0, %arg1 : tensor<?x1x?xf32>, f32)
19    outs(%shape : tensor<?x1x?x1x?xf32>) {
20       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
21         linalg.yield %arg3 : f32
22       } -> tensor<?x1x?x1x?xf32>
23  return %0 : tensor<?x1x?x1x?xf32>
24}
25//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
26//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
27//   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
28//   CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
29// CHECK-LABEL: func @drop_one_trip_loops
30//       CHECK: %[[C2:.*]] = arith.constant 2 : index
31//       CHECK: %[[C1:.*]] = arith.constant 1 : index
32//       CHECK: %[[C0:.*]] = arith.constant 0 : index
33//       CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]]
34//       CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]]
35//       CHECK: linalg.generic
36//  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
37//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
38//       CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]]
39//       CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]]
40//       CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]]
41//       CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]]
42//       CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]]
43//       CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
44
45//   CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
46//   CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
47//   CHECK-SLICES-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
48// CHECK-SLICES-LABEL: func @drop_one_trip_loops
49//       CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [%{{.*}}, 1, %{{.*}}] [1, 1, 1] : tensor<?x1x?xf32> to tensor<?x?xf32>
50//       CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor<?x1x?x1x?xf32> to tensor<?x?x?xf32>
51//       CHECK-SLICES: linalg.generic
52//  CHECK-SLICES-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
53//  CHECK-SLICES-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
54//       CHECK-SLICES: tensor.insert_slice %{{.*}} into %{{.*}}[0, 0, 0, 0, 0] [%{{.*}}, 1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
55
56
57// -----
58
59#accesses = [
60  affine_map<(i, j, k, l, m) -> (i, k, m)>,
61  affine_map<(i, j, k, l, m) -> ()>,
62  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
63]
64
65#trait = {
66  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
67  indexing_maps = #accesses,
68  library_call = "some_external_func"
69}
70
71func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32, %shape: tensor<1x1x?x1x1xf32>) -> tensor<1x1x?x1x1xf32> {
72  %0 = linalg.generic #trait
73     ins(%arg0, %arg1 : tensor<1x1x1xf32>, f32)
74    outs(%shape : tensor<1x1x?x1x1xf32>) {
75       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
76         linalg.yield %arg3 : f32
77       } -> tensor<1x1x?x1x1xf32>
78  return %0 : tensor<1x1x?x1x1xf32>
79}
80//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
81//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
82//   CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)>
83// CHECK-LABEL: func @drop_one_trip_loops_all_ones
84//       CHECK: %[[C2:.*]] = arith.constant 2 : index
85//       CHECK: %[[C1:.*]] = arith.constant 1 : index
86//       CHECK: tensor.collapse_shape %{{.*}} []
87//       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
88//       CHECK: linalg.generic
89//  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
90//  CHECK-SAME:   iterator_types = ["parallel"]
91//       CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32>
92//       CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]]
93//       CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
94
95// -----
96
97#accesses = [
98  affine_map<(i, j, k, l, m) -> (i, k, m)>,
99  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
100]
101
102#trait = {
103  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
104  indexing_maps = #accesses,
105  library_call = "some_external_func"
106}
107
108func.func @drop_one_trip_loops_indexed
109  (%arg0 : tensor<?x1x?xi32>, %shape: tensor<?x1x?x1x?xi32>) -> tensor<?x1x?x1x?xi32>
110{
111  %0 = linalg.generic #trait
112     ins(%arg0 : tensor<?x1x?xi32>)
113    outs(%shape: tensor<?x1x?x1x?xi32>) {
114       ^bb0(%arg6 : i32, %arg7 : i32) :
115         %idx0 = linalg.index 0 : index
116         %idx1 = linalg.index 1 : index
117         %idx2 = linalg.index 2 : index
118         %idx3 = linalg.index 3 : index
119         %idx4 = linalg.index 4 : index
120         %1 = arith.addi %idx0, %idx1 : index
121         %2 = arith.subi %1, %idx2 : index
122         %3 = arith.subi %2, %idx3 : index
123         %4 = arith.addi %3, %idx4 : index
124         %5 = arith.index_cast %4 : index to i32
125         %6 = arith.addi %5, %arg6 : i32
126         linalg.yield %6 : i32
127       } -> tensor<?x1x?x1x?xi32>
128  return %0 : tensor<?x1x?x1x?xi32>
129}
130// The subtractions disappear the access map of the output tensor maps its unit
131// dimensions 1 and 3 to the index dimensions 2 and 3.
132// CHECK-LABEL: func @drop_one_trip_loops_indexed
133//       CHECK:   linalg.generic
134//       CHECK:   ^{{.+}}(
135//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
136//       CHECK:     %[[IDX0:.+]] = linalg.index 0 : index
137//       CHECK:     %[[IDX1:.+]] = linalg.index 1 : index
138//       CHECK:     %[[IDX2:.+]] = linalg.index 2 : index
139//       CHECK:     %[[T3:.+]] = arith.addi %[[IDX0]], %[[IDX1]]
140//       CHECK:     %[[T4:.+]] = arith.addi %[[T3]], %[[IDX2]]
141//       CHECK:     %[[T5:.+]] = arith.index_cast %[[T4]] : index to i32
142//       CHECK:     %[[T6:.+]] = arith.addi %[[T5]], %[[ARG4]] : i32
143//       CHECK:     linalg.yield %[[T6]] : i32
144
145// -----
146
147#map0 = affine_map<(i, j) -> (i, j)>
148#access = [#map0, #map0]
149#trait = {
150  iterator_types = ["parallel", "parallel"],
151  indexing_maps = #access,
152  library_call = "some_external_func"
153}
154
155func.func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
156{
157  %0 = linalg.generic #trait
158     ins(%arg0 : tensor<1x1xf32>)
159    outs(%arg0 : tensor<1x1xf32>) {
160       ^bb0(%arg1: f32, %arg2: f32) :
161         linalg.yield %arg1 : f32
162       } -> tensor<1x1xf32>
163  return %0 : tensor<1x1xf32>
164}
165//       CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
166// CHECK-LABEL: func @drop_all_loops
167//       CHECK:   tensor.collapse_shape %{{.*}} []
168//       CHECK:   linalg.generic
169//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
170//  CHECK-SAME:     iterator_types = []
171
172// -----
173
174#map0 = affine_map<(i, j) -> (i, j)>
175#access = [#map0, #map0]
176#trait = {
177  iterator_types = ["parallel", "parallel"],
178  indexing_maps = #access,
179  library_call = "some_external_func"
180}
181
182func.func @drop_all_loops_indexed
183  (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32>{
184  %0 = linalg.generic #trait
185     ins(%arg0 : tensor<1x1xi32>)
186    outs(%arg0 : tensor<1x1xi32>) {
187       ^bb0(%arg3: i32, %arg4: i32) :
188         %idx0 = linalg.index 0 : index
189         %idx1 = linalg.index 1 : index
190         %1 = arith.addi %idx0, %idx1 : index
191         %2 = arith.index_cast %1 : index to i32
192         %3 = arith.addi %2, %arg3 : i32
193         linalg.yield %3 : i32
194       } -> tensor<1x1xi32>
195  return %0 : tensor<1x1xi32>
196}
197
198// CHECK-LABEL: func @drop_all_loops_indexed
199//       CHECK:   linalg.generic
200//       CHECK:   ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
201//       CHECK:     linalg.yield %[[ARG1]] : i32
202
203// -----
204
205#accesses = [
206  affine_map<(d0) -> (0, d0)>,
207  affine_map<(d0) -> (d0)>
208]
209
210#trait = {
211  indexing_maps = #accesses,
212  iterator_types = ["parallel"],
213  library_call = "some_external_fn"
214}
215
216func.func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>, %shape: tensor<5xf32>) -> tensor<5xf32> {
217  %0 = linalg.generic #trait
218     ins(%arg0 : tensor<1x5xf32>)
219    outs(%shape : tensor<5xf32>) {
220  ^bb0(%arg2: f32, %arg3: f32):
221    linalg.yield %arg2 : f32
222  } -> tensor<5xf32>
223  return %0 : tensor<5xf32>
224}
225//   CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
226
227// CHECK-LABEL: func @leading_dim_1_canonicalization
228//       CHECK:   tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
229//       CHECK:   linalg.generic
230//  CHECK-SAME:     indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
231//  CHECK-SAME:     iterator_types = ["parallel"]
232
233// -----
234
235#accesses = [
236  affine_map<(d0, d1) -> (0, d1)>,
237  affine_map<(d0, d1) -> (d0, 0)>,
238  affine_map<(d0, d1) -> (d0, d1)>
239]
240
241#trait = {
242  indexing_maps = #accesses,
243  iterator_types = ["parallel", "parallel"],
244  library_call = "some_external_fn"
245}
246
247func.func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>, %shape : tensor<5x5xf32>) -> tensor<5x5xf32>
248{
249  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [1, 5] : tensor<5xf32> into tensor<1x5xf32>
250  %1 = tensor.expand_shape %arg1 [[0, 1]] output_shape [5, 1] : tensor<5xf32> into tensor<5x1xf32>
251  %2 = linalg.generic #trait
252     ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>)
253    outs(%shape : tensor<5x5xf32>) {
254       ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
255         %3 = arith.addf %arg3, %arg4 : f32
256         linalg.yield %3 : f32
257       } -> tensor<5x5xf32>
258  return %2 : tensor<5x5xf32>
259}
260//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
261//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
262//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
263// CHECK-LABEL: func @broadcast_test
264//   CHECK-NOT:   linalg.tensor_{{.*}}shape
265//       CHECK:   linalg.generic
266//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
267//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
268//   CHECK-NOT:   linalg.tensor_{{.*}}shape
269
270// -----
271
272#accesses = [
273  affine_map<(d0, d1) -> (0, 0)>,
274  affine_map<(d0, d1) -> (d0, d1)>
275]
276
277#trait = {
278  indexing_maps = #accesses,
279  iterator_types = ["parallel", "parallel"],
280  library_call = "some_external_fn"
281}
282
283func.func @broadcast_scalar(%arg0 : tensor<1x1xf32>, %shape : tensor<?x?xf32>) -> tensor<?x?xf32>
284{
285   %0 = linalg.generic #trait
286     ins(%arg0 : tensor<1x1xf32>)
287    outs(%shape : tensor<?x?xf32>) {
288      ^bb0(%arg2 : f32, %arg3 : f32):
289        linalg.yield %arg2 : f32
290   } -> tensor<?x?xf32>
291   return %0 : tensor<?x?xf32>
292}
293//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
294//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
295// CHECK-LABEL: func @broadcast_scalar
296//  CHECK-SAME:   %[[ARG0:.*]]: tensor<1x1xf32>
297//       CHECK:   %[[A:.*]] = tensor.collapse_shape %[[ARG0]] []
298//  CHECK-SAME:     tensor<1x1xf32> into tensor<f32>
299//       CHECK:   linalg.generic
300//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
301//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
302//  CHECK-SAME:     %[[A]]
303
304// -----
305
306#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
307#map1 = affine_map<(d0, d1, d2) -> (d2)>
308func.func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
309{
310  %1 = tensor.empty() : tensor<1x2x5xf32>
311  %2 = linalg.generic {i64, indexing_maps = [#map1, #map0],
312    iterator_types = ["parallel", "parallel", "parallel"]}
313    ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) {
314    ^bb0(%arg1: f32, %arg2: f32):
315      linalg.yield %arg1 : f32
316    } -> tensor<1x2x5xf32>
317  %3 = tensor.collapse_shape %2 [[0, 1], [2]]
318    : tensor<1x2x5xf32> into tensor<2x5xf32>
319  return %3 : tensor<2x5xf32>
320}
321// CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op
322//       CHECK:   %[[RESULT:.+]] = linalg.generic
323//       CHECK:   return %[[RESULT]]
324
325// -----
326
327func.func @fold_unit_dim_for_empty_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> {
328  %cst = arith.constant 0.0 : f32
329  %init = tensor.empty() : tensor<1xf32>
330  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1xf32>) -> tensor<1xf32>
331  %add = linalg.generic {
332      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
333      iterator_types = ["parallel", "reduction"]}
334    ins(%input : tensor<1x1000xf32>)outs(%fill : tensor<1xf32>) {
335  ^bb0(%arg1: f32, %arg2: f32):
336    %1823 = arith.addf %arg1, %arg2 : f32
337    linalg.yield %1823 : f32
338  } -> tensor<1xf32>
339  return %add : tensor<1xf32>
340}
341
342
343//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
344//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()>
345
346//       CHECK: func @fold_unit_dim_for_empty_tensor
347
348//       CHECK: %[[INPUT_RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x1000xf32> into tensor<1000xf32>
349//       CHECK: %[[INIT:.+]] = tensor.empty() : tensor<f32>
350//       CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<f32>) -> tensor<f32>
351//       CHECK: %[[GENERIC:.+]] = linalg.generic
352//  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
353//  CHECK-SAME:     iterator_types = ["reduction"]
354//  CHECK-SAME:   ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
355//  CHECK-SAME:   outs(%[[FILL]] : tensor<f32>)
356//       CHECK: %[[GENERIC_RESHAPE:.+]] = tensor.expand_shape %[[GENERIC]] [] output_shape [1] : tensor<f32> into tensor<1xf32>
357//       CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>
358
359
360// -----
361
362func.func @fold_slice(
363    %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : tensor<1x?x?x?x?x1x1xf32>,
364    %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index,
365    %arg6 : index, %arg7 : index) -> (tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>) {
366  %0 = tensor.extract_slice %arg0[0, %arg2, %arg3, 0, %arg4, 0, 0]
367                             [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
368      tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
369  %1 = tensor.extract_slice %arg1[%arg2, 0, %arg3, 0, 0, %arg4, 0]
370                             [1, %arg5, %arg6, 1, %arg7, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
371      tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
372  return %0, %1 : tensor<1x?x?x1x?x1x1xf32>, tensor<1x?x?x1x?x1x1xf32>
373}
374//      CHECK: func @fold_slice
375// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32>
376// CHECK-SAME:   %[[ARG1:.+]]: tensor<1x?x?x?x?x1x1xf32>
377//      CHECK:   %[[SLICE1:.+]] = tensor.extract_slice %[[ARG0]]
378// CHECK-SAME:       to tensor<?x?x?xf32>
379//      CHECK:   %[[RESULT1:.+]] = tensor.expand_shape %[[SLICE1]]
380// CHECK-SAME:       {{\[\[}}0, 1], [2], [3, 4, 5, 6]] output_shape [1, %arg5, %arg6, 1, %arg7, 1, 1] : tensor<?x?x?xf32> into tensor<1x?x?x1x?x1x1xf32>
381//      CHECK:   %[[SLICE2:.+]] = tensor.extract_slice %[[ARG1]]
382// CHECK-SAME:       to tensor<?x?x?xf32>
383//      CHECK:   %[[RESULT2:.+]] = tensor.expand_shape %[[SLICE2]]
384// CHECK-SAME:       {{\[\[}}0, 1], [2], [3, 4, 5, 6]] output_shape [1, %arg5, %arg6, 1, %arg7, 1, 1] : tensor<?x?x?xf32> into tensor<1x?x?x1x?x1x1xf32>
385//      CHECK:   return %[[RESULT1]], %[[RESULT2]]
386
387// -----
388
389func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32> {
390  %cst = arith.constant 1.000000e+00 : f32
391  %c3 = arith.constant 3 : index
392  %0 = tensor.dim %arg0, %c3 : tensor<1x?x1x?xf32>
393  %1 = tensor.empty(%0) : tensor<1x?xf32>
394  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x?xf32>) -> tensor<1x?xf32>
395  %3 = linalg.generic {
396    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
397                     affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
398    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
399    ins(%arg0 : tensor<1x?x1x?xf32>)
400    outs(%2 : tensor<1x?xf32>) {
401  ^bb0(%arg1: f32, %arg2: f32):
402    %4 = arith.addf %arg1, %arg2 : f32
403    linalg.yield %4 : f32
404  } -> tensor<1x?xf32>
405  return %3 : tensor<1x?xf32>
406}
407//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
408//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
409//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
410//      CHECK: func @unit_dim_for_reduction
411// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x?xf32>
412//      CHECK: %[[C1:.+]] = arith.constant 1 : index
413//      CHECK: %[[CST:.+]] = arith.constant 1.000000e+00 : f32
414//      CHECK: %[[C3:.+]] = arith.constant 3 : index
415//      CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C3]] : tensor<1x?x1x?xf32>
416//      CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
417//      CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
418//      CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
419//      CHECK: %[[RESULT:.+]] = linalg.generic
420// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP2]]]
421// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
422// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
423// CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
424//      CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32>
425//      CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]]
426//      CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor<?xf32> into tensor<1x?xf32>
427//      CHECK: return %[[EXPANDED]] : tensor<1x?xf32>
428
429// -----
430
431func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> {
432  %cst = arith.constant 1.000000e+00 : f32
433  %c3 = arith.constant 3 : index
434  %1 = tensor.empty() : tensor<1x1xf32>
435  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
436  %3 = linalg.generic {
437    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
438                     affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
439    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
440    ins(%arg0 : tensor<1x?x1x1xf32>)
441    outs(%2 : tensor<1x1xf32>) {
442  ^bb0(%arg1: f32, %arg2: f32):
443    %4 = arith.addf %arg1, %arg2 : f32
444    linalg.yield %4 : f32
445  } -> tensor<1x1xf32>
446  return %3 : tensor<1x1xf32>
447}
448//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)>
449//      CHECK: func @unit_dim_for_both_reduction
450// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x1x1xf32>
451//  CHECK-DAG:   %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
452//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1xf32>
453//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
454//      CHECK:   %[[INIT2:.+]] = tensor.empty() : tensor<1xf32>
455//      CHECK:   %[[RESULT:.+]] = linalg.generic
456// CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
457// CHECK-SAME:     iterator_types = ["parallel"]
458// CHECK-SAME:     ins(%[[RESHAPE]], %[[FILL]] : tensor<?xf32>, tensor<1xf32>)
459// CHECK-SAME:     outs(%[[INIT2]] : tensor<1xf32>)
460//      CHECK:   %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, 1]
461//      CHECK:   return %[[RESULT_RESHAPE]]
462
463// -----
464
465func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32> {
466  %cst = arith.constant 1.000000e+00 : f32
467  %c2 = arith.constant 2 : index
468  %0 = tensor.dim %arg0, %c2 : tensor<?x1x?x1xf32>
469  %1 = tensor.empty(%0) : tensor<?x1xf32>
470  %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<?x1xf32>) -> tensor<?x1xf32>
471  %3 = linalg.generic {
472    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
473                     affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
474    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
475    ins(%arg0 : tensor<?x1x?x1xf32>)
476    outs(%2 : tensor<?x1xf32>) {
477  ^bb0(%arg1: f32, %arg2: f32):
478    %4 = arith.addf %arg1, %arg2 : f32
479    linalg.yield %4 : f32
480  } -> tensor<?x1xf32>
481  return %3 : tensor<?x1xf32>
482}
483//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
484//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
485//  CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
486//      CHECK: func @unit_dim_for_reduction_inner
487// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x1x?x1xf32>
488//      CHECK: %[[C1:.*]] = arith.constant 1 : index
489//      CHECK: %[[C0:.*]] = arith.constant 0 : index
490//      CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
491//      CHECK: %[[C2:.*]] = arith.constant 2 : index
492//      CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x1x?x1xf32>
493//      CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]]
494//      CHECK: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor<?xf32>
495//      CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
496//      CHECK: %[[RESULT:.+]] = linalg.generic
497// CHECK-SAME:     indexing_maps = [#[[MAP]], #[[MAP2]]]
498// CHECK-SAME:     iterator_types = ["parallel", "reduction"]
499// CHECK-SAME:     ins(%[[RESHAPE]] : tensor<?x?xf32>)
500// CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
501//      CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x1x?x1xf32>
502//      CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]]
503//      CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor<?xf32> into tensor<?x1xf32>
504//      CHECK: return %[[RESULT_RESHAPE]]
505
506// -----
507
508func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
509  %0 = tensor.extract_slice %arg0[0, 2] [1, 1] [1, 1] : tensor<1x3xf32> to tensor<1x1xf32>
510  return %0 : tensor<1x1xf32>
511}
512// CHECK-LABEL: func @slice_unit_dims
513//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice
514//  CHECK-SAME:     tensor<1x3xf32> to tensor<f32>
515//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] [] output_shape [1, 1]
516//       CHECK:   return %[[RESULT]]
517
518// -----
519
520func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> {
521  %0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32>
522  return %0 : tensor<1x3x3xf32>
523}
524// CHECK-LABEL: func @rank_reduced_extract_slice
525//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice
526//  CHECK-SAME:     tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
527//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]] output_shape [1, 3, 3]
528//       CHECK:   return %[[RESULT]]
529
530// -----
531
532func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
533  %0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
534  return %0 : tensor<1x3xf32>
535}
536// CHECK-LABEL: func @insert_slice_unit_dims
537//       CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} []
538//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
539//  CHECK-SAME:     tensor<f32> into tensor<1x3xf32>
540//       CHECK:   return %[[RESULT]]
541
542// -----
543
544func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> {
545  %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32>
546  return %0 : tensor<1x1x3x1x3xf32>
547}
548// CHECK-LABEL: func @rank_reduced_insert_slice
549//       CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]]
550//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
551//  CHECK-SAME:     tensor<3x3xf32> into tensor<1x1x3x1x3xf32>
552//       CHECK:   return %[[RESULT]]
553
554// -----
555
556#accesses = [
557  affine_map<(i, j, k, l, m) -> (i, k, m)>,
558  affine_map<(i, j, k, l, m) -> ()>,
559  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
560]
561
562#trait = {
563  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
564  indexing_maps = #accesses,
565  library_call = "some_external_func"
566}
567
568func.func @drop_one_trip_loops(%arg0 : memref<?x1x?xf32>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
569  linalg.generic #trait
570     ins(%arg0, %arg1 : memref<?x1x?xf32>, f32)
571    outs(%shape : memref<?x1x?x1x?xf32>) {
572       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
573         linalg.yield %arg3 : f32
574       }
575  return %shape : memref<?x1x?x1x?xf32>
576}
577//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
578//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
579//   CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
580// CHECK-LABEL: func @drop_one_trip_loops
581//       CHECK: memref.collapse_shape %{{.*}} {{\[}}[0, 1], [2]]
582//       CHECK: linalg.generic
583//  CHECK-SAME:   indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
584//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]
585
586// -----
587
588#accesses = [
589  affine_map<(i, j, k, l, m) -> (i, k, m)>,
590  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
591]
592
593#trait = {
594  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
595  indexing_maps = #accesses,
596  library_call = "some_external_func"
597}
598
599func.func @drop_one_trip_loops_indexed
600  (%arg0 : memref<?x1x?xi32>, %shape: memref<?x1x?x1x?xi32>) -> memref<?x1x?x1x?xi32>
601{
602  linalg.generic #trait
603     ins(%arg0 : memref<?x1x?xi32>)
604    outs(%shape: memref<?x1x?x1x?xi32>) {
605       ^bb0(%arg6 : i32, %arg7 : i32) :
606         %idx0 = linalg.index 0 : index
607         %idx1 = linalg.index 1 : index
608         %idx2 = linalg.index 2 : index
609         %idx3 = linalg.index 3 : index
610         %idx4 = linalg.index 4 : index
611         %1 = arith.addi %idx0, %idx1 : index
612         %2 = arith.subi %1, %idx2 : index
613         %3 = arith.subi %2, %idx3 : index
614         %4 = arith.addi %3, %idx4 : index
615         %5 = arith.index_cast %4 : index to i32
616         %6 = arith.addi %5, %arg6 : i32
617         linalg.yield %6 : i32
618       }
619  return %shape : memref<?x1x?x1x?xi32>
620}
621// The subtractions disappear the access map of the output memref maps its unit
622// dimensions 1 and 3 to the index dimensions 2 and 3.
623// CHECK-LABEL: func @drop_one_trip_loops_indexed
624//       CHECK:   linalg.generic
625//       CHECK:   ^{{.+}}(
626//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32)
627//       CHECK:     %[[IDX0:.+]] = linalg.index 0 : index
628//       CHECK:     %[[IDX1:.+]] = linalg.index 1 : index
629//       CHECK:     %[[IDX2:.+]] = linalg.index 2 : index
630//       CHECK:     %[[T3:.+]] = arith.addi %[[IDX0]], %[[IDX1]]
631//       CHECK:     %[[T4:.+]] = arith.addi %[[T3]], %[[IDX2]]
632//       CHECK:     %[[T5:.+]] = arith.index_cast %[[T4]] : index to i32
633//       CHECK:     %[[T6:.+]] = arith.addi %[[T5]], %[[ARG4]] : i32
634//       CHECK:     linalg.yield %[[T6]] : i32
635
636// -----
637
638#map0 = affine_map<(i, j) -> (i, j)>
639#access = [#map0, #map0]
640#trait = {
641  iterator_types = ["parallel", "parallel"],
642  indexing_maps = #access,
643  library_call = "some_external_func"
644}
645
646func.func @drop_all_loops(%arg0 : memref<1x1xf32>) -> memref<1x1xf32>
647{
648  linalg.generic #trait
649     ins(%arg0 : memref<1x1xf32>)
650    outs(%arg0 : memref<1x1xf32>) {
651       ^bb0(%arg1: f32, %arg2: f32) :
652         linalg.yield %arg1 : f32
653       }
654  return %arg0 : memref<1x1xf32>
655}
656//       CHECK: #[[$MAP0:.*]] = affine_map<() -> ()>
657// CHECK-LABEL: func @drop_all_loops
658//       CHECK:   memref.collapse_shape %{{.*}} []
659//       CHECK:   linalg.generic
660//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP0]]]
661//  CHECK-SAME:     iterator_types = []
662
663// -----
664
665#map0 = affine_map<(i, j) -> (i, j)>
666#access = [#map0, #map0]
667#trait = {
668  iterator_types = ["parallel", "parallel"],
669  indexing_maps = #access,
670  library_call = "some_external_func"
671}
672
673func.func @drop_all_loops_indexed
674  (%arg0 : memref<1x1xi32>) -> memref<1x1xi32>{
675  linalg.generic #trait
676     ins(%arg0 : memref<1x1xi32>)
677    outs(%arg0 : memref<1x1xi32>) {
678       ^bb0(%arg3: i32, %arg4: i32) :
679         %idx0 = linalg.index 0 : index
680         %idx1 = linalg.index 1 : index
681         %1 = arith.addi %idx0, %idx1 : index
682         %2 = arith.index_cast %1 : index to i32
683         %3 = arith.addi %2, %arg3 : i32
684         linalg.yield %3 : i32
685       }
686  return %arg0 : memref<1x1xi32>
687}
688
689// CHECK-LABEL: func @drop_all_loops_indexed
690//       CHECK:   linalg.generic
691//       CHECK:   ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32)
692//       CHECK:     linalg.yield %[[ARG1]] : i32
693
694// -----
695
696#accesses = [
697  affine_map<(d0) -> (0, d0)>,
698  affine_map<(d0) -> (d0)>
699]
700
701#trait = {
702  indexing_maps = #accesses,
703  iterator_types = ["parallel"],
704  library_call = "some_external_fn"
705}
706
707func.func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf32>) -> memref<5xf32> {
708  linalg.generic #trait
709     ins(%arg0 : memref<1x5xf32>)
710    outs(%shape : memref<5xf32>) {
711  ^bb0(%arg2: f32, %arg3: f32):
712    linalg.yield %arg2 : f32
713  }
714  return %shape : memref<5xf32>
715}
716//   CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)>
717
718// CHECK-LABEL: func @leading_dim_1_canonicalization
719//       CHECK:   memref.collapse_shape %{{.*}} {{\[}}[0, 1]]
720//       CHECK:   linalg.generic
721//  CHECK-SAME:     indexing_maps = [#[[$MAP1]], #[[$MAP1]]]
722//  CHECK-SAME:     iterator_types = ["parallel"]
723
724// -----
725
726#accesses = [
727  affine_map<(d0, d1) -> (0, d1)>,
728  affine_map<(d0, d1) -> (d0, 0)>,
729  affine_map<(d0, d1) -> (d0, d1)>
730]
731
732#trait = {
733  indexing_maps = #accesses,
734  iterator_types = ["parallel", "parallel"],
735  library_call = "some_external_fn"
736}
737
738func.func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32>
739{
740  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [1, 5] : memref<5xf32> into memref<1x5xf32>
741  %1 = memref.expand_shape %arg1 [[0, 1]] output_shape [5, 1] : memref<5xf32> into memref<5x1xf32>
742  linalg.generic #trait
743     ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>)
744    outs(%shape : memref<5x5xf32>) {
745       ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
746         %3 = arith.addf %arg3, %arg4 : f32
747         linalg.yield %3 : f32
748       }
749  return %shape : memref<5x5xf32>
750}
751//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
752//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
753//   CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
754// CHECK-LABEL: func @broadcast_test
755//   CHECK-NOT:   linalg.memref_{{.*}}shape
756//       CHECK:   linalg.generic
757//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
758//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
759//   CHECK-NOT:   linalg.memref_{{.*}}shape
760
761// -----
762
763#accesses = [
764  affine_map<(d0, d1) -> (0, 0)>,
765  affine_map<(d0, d1) -> (d0, d1)>
766]
767
768#trait = {
769  indexing_maps = #accesses,
770  iterator_types = ["parallel", "parallel"],
771  library_call = "some_external_fn"
772}
773
774func.func @broadcast_scalar(%arg0 : memref<1x1xf32>, %shape : memref<?x?xf32>) -> memref<?x?xf32>
775{
776   linalg.generic #trait
777     ins(%arg0 : memref<1x1xf32>)
778    outs(%shape : memref<?x?xf32>) {
779      ^bb0(%arg2 : f32, %arg3 : f32):
780        linalg.yield %arg2 : f32
781   }
782   return %shape : memref<?x?xf32>
783}
784//   CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
785//   CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
786// CHECK-LABEL: func @broadcast_scalar
787//  CHECK-SAME:   %[[ARG0:.*]]: memref<1x1xf32>
788//       CHECK:   %[[A:.*]] = memref.collapse_shape %[[ARG0]] []
789//  CHECK-SAME:     memref<1x1xf32> into memref<f32>
790//       CHECK:   linalg.generic
791//  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
792//  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
793//  CHECK-SAME:     %[[A]]
794
795// -----
796
797#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
798#map1 = affine_map<(d0, d1, d2) -> (d2)>
799func.func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32>
800{
801  %1 = memref.alloc() : memref<1x2x5xf32>
802  linalg.generic {i64, indexing_maps = [#map1, #map0],
803    iterator_types = ["parallel", "parallel", "parallel"]}
804    ins(%arg0 : memref<5xf32>) outs(%1 : memref<1x2x5xf32>) {
805    ^bb0(%arg1: f32, %arg2: f32):
806      linalg.yield %arg1 : f32
807    }
808  %3 = memref.collapse_shape %1 [[0, 1], [2]]
809    : memref<1x2x5xf32> into memref<2x5xf32>
810  return %3 : memref<2x5xf32>
811}
812// CHECK-LABEL: func @fold_unit_dim_memref_reshape_op
813//       CHECK:   %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32>
814//       CHECK:   %[[OUT:.*]] = memref.collapse_shape %[[ALLOC]]
815//       CHECK:   linalg.generic
816//       CHECK-SAME:   outs(%[[OUT:.*]] :
817//       CHECK:   %[[RESULT:.*]] = memref.collapse_shape %[[ALLOC]]
818//       CHECK:   return %[[RESULT]]
819
820// -----
821
822func.func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32> {
823  %cst = arith.constant 0.0 : f32
824  %init = memref.alloc() : memref<1xf32>
825  linalg.generic {
826      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
827      iterator_types = ["parallel", "reduction"]}
828    ins(%input : memref<1x1000xf32>)outs(%init : memref<1xf32>) {
829  ^bb0(%arg1: f32, %arg2: f32):
830    %1823 = arith.addf %arg1, %arg2 : f32
831    linalg.yield %1823 : f32
832  }
833  return %init : memref<1xf32>
834}
835
836
837//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
838//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()>
839
840//       CHECK: func @fold_unit_dim_for_init_memref
841//       CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32>
842//       CHECK: %[[INPUT_RESHAPE:.+]] = memref.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32>
843//       CHECK: %[[INIT_RESHAPE:.+]] = memref.collapse_shape %[[INIT]] [] : memref<1xf32> into memref<f32>
844//       CHECK: linalg.generic
845//  CHECK-SAME:     indexing_maps = [#[[MAP1]], #[[MAP2]]]
846//  CHECK-SAME:     iterator_types = ["reduction"]
847//  CHECK-SAME:   ins(%[[INPUT_RESHAPE]] : memref<1000xf32>)
848//  CHECK-SAME:   outs(%[[INIT_RESHAPE]] : memref<f32>)
849//       CHECK: return %[[INIT:.+]] : memref<1xf32>
850
851
852// -----
853// Test that nothing changes and no assertions are fired for memrefs with affine
854// maps while still changing the other operations.
855
856#accesses = [
857  affine_map<(i, j, k, l, m) -> (i, k, m)>,
858  affine_map<(i, j, k, l, m) -> ()>,
859  affine_map<(i, j, k, l, m) -> (i, k, j, l, m)>
860]
861
862#trait = {
863  iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
864  indexing_maps = #accesses,
865  library_call = "some_external_func"
866}
867
868func.func @input_stays_same(%arg0 : memref<?x1x?xf32, strided<[?, 1, 1]>>, %arg1 : f32, %shape: memref<?x1x?x1x?xf32>) -> memref<?x1x?x1x?xf32> {
869  linalg.generic #trait
870     ins(%arg0, %arg1 : memref<?x1x?xf32, strided<[?, 1, 1]>>, f32)
871    outs(%shape : memref<?x1x?x1x?xf32>) {
872       ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) :
873         linalg.yield %arg3 : f32
874       }
875  return %shape : memref<?x1x?x1x?xf32>
876}
877
878// CHECK-DAG:     #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
879// CHECK-DAG:     #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
880// CHECK-DAG:     #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
881// CHECK:     func @input_stays_same(
882// CHECK-SAME:  %[[ARG0:.*]]: memref<?x1x?xf32, strided<[?, 1, 1]>>,
883// CHECK-SAME:  %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<?x1x?x1x?xf32>)
884// CHECK-SAME:  -> memref<?x1x?x1x?xf32> {
885// CHECK:      %[[OUT:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1], [2, 3], [4]]
886// CHECK-SAME:   : memref<?x1x?x1x?xf32> into memref<?x?x?xf32>
887// CHECK:      linalg.generic
888// CHECK-SAME:   {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]],
889// CHECK-SAME:   iterator_types = ["parallel", "parallel", "parallel"]}
890// CHECK-SAME:   ins(%[[ARG0]], %[[ARG1]] : memref<?x1x?xf32, strided<[?, 1, 1]>>, f32)
891// CHECK-SAME:   outs(%[[OUT]] : memref<?x?x?xf32>) {
892// CHECK:      ^bb0(%{{.*}}: f32, %[[ARG:.*]]: f32, %{{.*}}: f32):
893// CHECK:       linalg.yield %[[ARG]] : f32
894// CHECK:      }
895// CHECK:      return %[[ARG2]] : memref<?x1x?x1x?xf32>
896
897// -----
898
899// Negative test for case with tensor encoding.
900#matvec = {
901  indexing_maps = [
902    affine_map<(i,j) -> (i,j)>, // A
903    affine_map<(i,j) -> (j)>,   // b
904    affine_map<(i,j) -> (i)>    // x (out)
905  ],
906  iterator_types = ["parallel", "reduction"]
907}
908
909#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
910
911func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
912    %0 = tensor.empty() : tensor<8xf32>
913    %1 = linalg.generic #matvec
914      ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>)
915      outs(%0: tensor<8xf32>) {
916      ^bb(%a: f32, %b: f32, %x: f32):
917        %m = arith.mulf %a, %b : f32
918        %add = arith.addf %x, %m : f32
919        linalg.yield %add : f32
920    } -> tensor<8xf32>
921    return %1: tensor<8xf32>
922}
923
924// CHECK-LABEL: func @sparse_case
925//  CHECK-NEXT:   tensor.empty
926//  CHECK-NEXT:   linalg.generic
927
928// -----
929
930func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
931  %c2 = arith.constant 2 : index
932  %c4 = arith.constant 4 : index
933  %cst = arith.constant 0.000000e+00 : f32
934  %0 = tensor.empty() : tensor<4x2xf32>
935  %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
936    %1 = tensor.empty() : tensor<1x1xf32>
937    %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
938    scf.forall.in_parallel {
939      //      CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
940      // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
941      tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
942        tensor<1x1xf32> into tensor<4x2xf32>
943    }
944  }
945  return %res: tensor<4x2xf32>
946}
947
948// -----
949
950#map0 = affine_map<(i, j) -> (i, j)>
951#access = [#map0, #map0]
952#trait = {
953  iterator_types = ["parallel", "parallel"],
954  indexing_maps = #access,
955  library_call = "some_external_func"
956}
957
958func.func @drop_all_loops(%arg0 : memref<1x1xf32, 3>) -> memref<1x1xf32, 3>
959{
960  linalg.generic #trait
961     ins(%arg0 : memref<1x1xf32, 3>)
962    outs(%arg0 : memref<1x1xf32, 3>) {
963       ^bb0(%arg1: f32, %arg2: f32) :
964         linalg.yield %arg1 : f32
965       }
966  return %arg0 : memref<1x1xf32, 3>
967}
968
969// CHECK-LABEL: func @drop_all_loops
970//       CHECK:   memref.collapse_shape
971//  CHECK-SAME:     [] : memref<1x1xf32, 3> into memref<f32, 3>
972//       CHECK:   linalg.generic{{.*}}memref<f32, 3>
973
974// CHECK-SLICES-LABEL: func @drop_all_loops
975//       CHECK-SLICES:   memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref<f32, strided<[]>, 3>
976//       CHECK-SLICES:   linalg.generic{{.*}}memref<f32, strided<[]>, 3>
977
978// -----
979
980func.func @drop_unit_pad_dims(%arg0: tensor<1x1x3x1x1xf32>) -> tensor<1x2x3x1x3xf32>
981{
982  %c0 = arith.constant 0 : index
983  %cst0 = arith.constant 0.0 : f32
984  %0 = tensor.pad %arg0 low[0, 1, 0, %c0, 0] high[0, 0, 0, %c0, 2] {
985    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
986      tensor.yield %cst0 : f32
987  } : tensor<1x1x3x1x1xf32> to tensor<1x2x3x1x3xf32>
988  return %0 : tensor<1x2x3x1x3xf32>
989}
990
991// CHECK-LABEL: func @drop_unit_pad_dims
992//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
993//  CHECK-SAME:     {{\[}}[0, 1], [2, 3], [4]{{\]}} : tensor<1x1x3x1x1xf32> into tensor<1x3x1xf32>
994//       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 0] high[0, 0, 2]
995//       CHECK:   } : tensor<1x3x1xf32> to tensor<2x3x3xf32>
996//       CHECK:   tensor.expand_shape %[[PADDED]]
997//  CHECK-SAME:     {{\[}}[0, 1], [2, 3], [4]{{\]}} output_shape [1, 2, 3, 1, 3] : tensor<2x3x3xf32> into tensor<1x2x3x1x3xf32>
998
999// CHECK-SLICES-LABEL: func @drop_unit_pad_dims
1000//       CHECK-SLICES:   %[[EXTRACT:.+]] = tensor.extract_slice
1001//  CHECK-SLICES-SAME:     [0, 0, 0, 0, 0] [1, 1, 3, 1, 1] [1, 1, 1, 1, 1] : tensor<1x1x3x1x1xf32> to tensor<1x3x1xf32>
1002//       CHECK-SLICES:   %[[PADDED:.+]] = tensor.pad %[[EXTRACT]] low[1, 0, 0] high[0, 0, 2]
1003//       CHECK-SLICES:   } : tensor<1x3x1xf32> to tensor<2x3x3xf32>
1004//       CHECK-SLICES:   tensor.insert_slice %[[PADDED]]
1005//  CHECK-SLICES-SAME:     [0, 0, 0, 0, 0] [1, 2, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<2x3x3xf32> into tensor<1x2x3x1x3xf32>
1006
1007// -----
1008
1009func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
1010{
1011  %c0 = arith.constant 0 : index
1012  %cst0 = arith.constant 0.0 : f32
1013  %0 = tensor.pad %arg0 low[0, 5] high[0, 6] {
1014    ^bb0(%arg1: index, %arg2: index):
1015      tensor.yield %cst0 : f32
1016  } : tensor<1x?xf32> to tensor<1x?xf32>
1017  return %0 : tensor<1x?xf32>
1018}
1019
1020// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
1021// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)>
1022// CHECK-LABEL: func @drop_unit_pad_dynamic_dims
1023//       CHECK:   %[[C1:.*]] = arith.constant 1 : index
1024//       CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1025//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
1026//  CHECK-SAME:     {{\[}}[0, 1]{{\]}} : tensor<1x?xf32> into tensor<?xf32>
1027//       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6]
1028//       CHECK:   } : tensor<?xf32> to tensor<?xf32>
1029//       CHECK:   %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32>
1030//       CHECK:   %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]]
1031//       CHECK:   %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]]
1032//       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor<?xf32> into tensor<1x?xf32>
1033
1034// CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)>
1035
1036// CHECK-SLICES-LABEL: func @drop_unit_pad_dynamic_dims
1037//  CHECK-SLICES-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<1x?xf32>
1038//       CHECK-SLICES:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %c1
1039//       CHECK-SLICES:   %[[EXTRACT:.+]] = tensor.extract_slice
1040//  CHECK-SLICES-SAME:     [0, 0] [1, %[[DIM]]] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
1041//       CHECK-SLICES:   %[[PADDED:.+]] = tensor.pad %[[EXTRACT]] low[5] high[6]
1042//       CHECK-SLICES:   } : tensor<?xf32> to tensor<?xf32>
1043//       CHECK-SLICES:   %[[PADDED_DIM:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
1044//       CHECK-SLICES:   %[[EMPTY:.+]] = tensor.empty(%[[PADDED_DIM]]) : tensor<1x?xf32>
1045//       CHECK-SLICES:   tensor.insert_slice %[[PADDED]] into %[[EMPTY]]
1046//  CHECK-SLICES-SAME:     [0, 0] [1, %[[PADDED_DIM]]] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
1047
1048// -----
1049
1050func.func @do_not_drop_non_constant_padding(%arg0: tensor<1x1x3x1x1xf32>, %pad: f32) -> tensor<1x2x3x1x3xf32>
1051{
1052  %c0 = arith.constant 0 : index
1053  %0 = tensor.pad %arg0 low[0, 1, 0, %c0, 0] high[0, 0, 0, %c0, 2] {
1054    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
1055      %0 = arith.index_cast %arg3 : index to i64
1056      %1 = arith.sitofp %0 : i64 to f32
1057      %add = arith.addf %pad, %1 : f32
1058      tensor.yield %add : f32
1059  } : tensor<1x1x3x1x1xf32> to tensor<1x2x3x1x3xf32>
1060  return %0 : tensor<1x2x3x1x3xf32>
1061}
1062
1063// CHECK-LABEL: func @do_not_drop_non_constant_padding
1064//       CHECK:   tensor.pad %{{.*}} low[0, 1, 0, %c0, 0] high[0, 0, 0, %c0, 2]
1065//       CHECK:   } : tensor<1x1x3x1x1xf32> to tensor<1x2x3x1x3xf32>
1066
1067// CHECK-SLICES-LABEL: func @do_not_drop_non_constant_padding
1068//       CHECK-SLICES:   tensor.pad %{{.*}} low[0, 1, 0, %c0, 0] high[0, 0, 0, %c0, 2]
1069//       CHECK-SLICES:   } : tensor<1x1x3x1x1xf32> to tensor<1x2x3x1x3xf32>
1070
1071// -----
1072
1073func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> tensor<1x384x128xf32> {
1074  %c0 = arith.constant 0 : index
1075  %c1 = arith.constant 1 : index
1076  %cst = arith.constant 0.000000e+00 : f32
1077  %padded = tensor.pad %arg0 low[%c0, %c1, %c0] high[%c0, %c0, %c0] {
1078  ^bb0(%arg1: index, %arg2: index, %arg3: index):
1079    tensor.yield %cst : f32
1080  } : tensor<1x383x128xf32> to tensor<1x384x128xf32>
1081  return %padded : tensor<1x384x128xf32>
1082}
1083// CHECK-LABEL: func @drop_known_unit_constant_low_high
1084//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
1085//  CHECK-SAME:     {{\[}}[0, 1], [2]] : tensor<1x383x128xf32> into tensor<383x128xf32>
1086//       CHECK:   %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0] high[0, 0]
1087//       CHECK:   } : tensor<383x128xf32> to tensor<384x128xf32>
1088//       CHECK:   tensor.expand_shape %[[PADDED]]
1089//  CHECK-SAME:     {{\[}}[0, 1], [2]] output_shape [1, 384, 128] : tensor<384x128xf32> into tensor<1x384x128xf32>
1090
1091// -----
1092
1093// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
1094// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
1095// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>
1096
1097// CHECK-LABEL: func @drop_unit_dim_corresponding_to_dynamic_dim
1098// CHECK-SAME:                    %[[ARG0:.*]]: tensor<1x?x?x1xf32>,
1099// CHECK-SAME:                    %[[ARG1:.*]]: index) -> tensor<?x1x61x1xf32> {
1100// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
1101// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
1102// CHECK:           %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor<f32>
1103// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor<?x?xf32>
1104// CHECK:           %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
1105// CHECK:           %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]]
1106// CHECK:           %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor<?x61xf32>
1107// CHECK:           %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>, tensor<f32>, tensor<?x61xf32>) outs(%[[VAL_6]] : tensor<?x61xf32>) {
1108// CHECK:           ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
1109// CHECK:             %[[VAL_12:.*]] = arith.mulf %[[VAL_8]], %[[VAL_9]] : f32
1110// CHECK:             %[[VAL_13:.*]] = arith.addf %[[VAL_10]], %[[VAL_12]] : f32
1111// CHECK:             linalg.yield %[[VAL_13]] : f32
1112// CHECK:           } -> tensor<?x61xf32>
1113// CHECK:           %[[VAL_14:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0, 1], [2, 3]] output_shape {{\[}}%[[VAL_0]], 1, 61, 1] : tensor<?x61xf32> into tensor<?x1x61x1xf32>
1114// CHECK:           return %[[VAL_14]] : tensor<?x1x61x1xf32>
1115// CHECK:         }
1116
1117#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
1118#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
1119#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
1120module {
1121  func.func @drop_unit_dim_corresponding_to_dynamic_dim(%arg0: tensor<1x?x?x1xf32>, %arg1: index) -> tensor<?x1x61x1xf32> {
1122    %cst = arith.constant dense<1.000000e+00> : tensor<1x1x1x1xf32>
1123    %0 = tensor.empty(%arg1) : tensor<?x1x61x1xf32>
1124    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %cst : tensor<1x?x?x1xf32>, tensor<1x1x1x1xf32>) outs(%0 : tensor<?x1x61x1xf32>) {
1125    ^bb0(%in: f32, %in_0: f32, %out: f32):
1126      %2 = arith.mulf %in, %in_0 : f32
1127      %3 = arith.addf %out, %2 : f32
1128      linalg.yield %3 : f32
1129    } -> tensor<?x1x61x1xf32>
1130    return %1 : tensor<?x1x61x1xf32>
1131  }
1132}
1133
1134// -----
1135
1136func.func @no_fold_empty_tensor_dim_out_of_bounds(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1137  %cst = arith.constant 1.000000e+00 : f32
1138  %cst7 = arith.constant 7 : index
1139  %dim = tensor.dim %arg0, %cst7 : tensor<1x?x10xf32>
1140  %0 = tensor.empty(%dim) : tensor<1x?xf32>
1141  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
1142  return %1 : tensor<1x?xf32>
1143}
1144// CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds
1145//  CHECK-SAME:                 %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1146//       CHECK:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1147//       CHECK:   %[[C7:.*]] = arith.constant 7
1148//       CHECK:   %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32>
1149//       CHECK:   %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32>
1150//       CHECK:   %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1151//       CHECK:   return %[[VAL_1]] : tensor<1x?xf32>
1152//       CHECK: }
1153
1154// -----
1155
1156func.func @fold_empty_tensor_dim_op(%arg0: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1157  %cst = arith.constant 1.000000e+00 : f32
1158  %cst2 = index.constant 2
1159  %dim10 = tensor.dim %arg0, %cst2 : tensor<1x?x10xf32>
1160  %0 = tensor.empty(%dim10) : tensor<1x?xf32>
1161  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x?xf32>) -> tensor<1x?xf32>
1162  return %1 : tensor<1x?xf32>
1163}
1164// CHECK-LABEL: func.func @fold_empty_tensor_dim_op
1165//  CHECK-SAME:                 %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1166//       CHECK:   %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1167//       CHECK:   %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32>
1168//       CHECK:   %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32>
1169//       CHECK:   %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1170//       CHECK:   return %[[VAL_2]] : tensor<1x?xf32>
1171//       CHECK: }
1172