xref: /llvm-project/mlir/test/Dialect/Linalg/canonicalize.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
2
3// CHECK-LABEL: func @memref_cast(
4func.func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
5  %c0 = arith.constant 0 : index
6  %c1 = arith.constant 1 : index
7  %c8 = arith.constant 8 : index
8  %c16 = arith.constant 16 : index
9  %1 = memref.alloc (%b) : memref<?xi8>
10  %2 = memref.view %1[%c0][] : memref<?xi8> to memref<16x16xf32>
11  %3 = memref.cast %2 : memref<16x16xf32> to memref<?x?xf32>
12
13  // CHECK:  linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
14  linalg.matmul ins(%3, %3: memref<?x?xf32>, memref<?x?xf32>)
15               outs(%3: memref<?x?xf32>)
16  return %3: memref<?x?xf32>
17}
18
19// -----
20
21#accesses = [
22  affine_map<(i) -> (i)>
23]
24
25#trait = {
26  indexing_maps = #accesses,
27  iterator_types = ["parallel"]
28}
29
30func.func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
31  // memref<0x32> is expected to be dce'ed
32  memref.copy %arg0, %arg0 : memref<0xf32> to memref<0xf32>
33
34  // tensor<0xf32> cannot be dce'ed
35  %1 = linalg.generic #trait outs(%arg1 : tensor<0xf32>) {
36  ^bb(%0: f32) :
37    linalg.yield %0 : f32
38  } -> tensor<0xf32>
39
40  return %1: tensor<0xf32>
41}
42// CHECK-LABEL: @dce_zero_memref
43//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<0xf32>
44//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<0xf32>
45//   CHECK-NOT:   memref.copy
46//  CHECK-NEXT:   return %[[ARG1]]
47
48// -----
49
50func.func @dce_self_linalg_copy(%arg0 : memref<?xf32>) {
51  linalg.copy ins(%arg0: memref<?xf32>) outs(%arg0: memref<?xf32>)
52  return
53}
54
55// CHECK-LABEL: @dce_self_linalg_copy
56//   CHECK-NOT:   copy
57
58// -----
59
60// CHECK-LABEL: func @tensor.cast(
61func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
62  -> tensor<3x?xf32>
63{
64  %ta = tensor.cast %a : tensor<3x4xf32> to tensor<?x?xf32>
65  %tb = tensor.cast %b : tensor<4x?xf32> to tensor<?x?xf32>
66  %tc = tensor.cast %c : tensor<3x?xf32> to tensor<?x?xf32>
67
68  //      CHECK:  linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
69  // CHECK-SAME:    outs({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
70  %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
71                    outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
72
73  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
74
75  return %1: tensor<3x?xf32>
76}
77
78// -----
79
80// CHECK-LABEL: func @tensor.cast.unranked(
81func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>)
82  -> tensor<*xf32>
83{
84  //      CHECK:  tensor.cast
85  //      CHECK:  tensor.cast
86  //      CHECK:  tensor.cast
87  %ta = tensor.cast %a : tensor<*xf32> to tensor<?x?xf32>
88  %tb = tensor.cast %b : tensor<*xf32> to tensor<?x?xf32>
89  %tc = tensor.cast %c : tensor<*xf32> to tensor<?x?xf32>
90
91  //      CHECK:  linalg.matmul ins({{.*}}tensor<?x?xf32>, tensor<?x?xf32>)
92  // CHECK-SAME:    outs({{.*}}tensor<?x?xf32>) -> tensor<?x?xf32>
93  %0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
94                    outs(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
95
96  //      CHECK:  tensor.cast
97  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<*xf32>
98
99  return %1: tensor<*xf32>
100}
101
102// -----
103
104// CHECK-LABEL: func @linalg_effects(
105func.func @linalg_effects(
106    %a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>,
107    %d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) {
108  // CHECK-NOT:   %{{.*}} = linalg.matmul
109  %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
110                    outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
111
112  // CHECK:   linalg.matmul
113  linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>)
114               outs(%f : memref<?x?xf32>)
115  return
116}
117
118// -----
119
120#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
121func.func @remove_no_op(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>)
122  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
123  %c0 = arith.constant 0 : index
124  %c1 = arith.constant 1 : index
125  %c2 = arith.constant 2 : index
126  %0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
127  %1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
128  %2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
129  %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32>
130  %4, %5 = linalg.generic {
131    indexing_maps = [#map, #map, #map, #map],
132    iterator_types = ["parallel", "parallel", "parallel"]
133  } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
134    outs(%3, %3 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
135  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32):
136    linalg.yield %arg3, %arg2 : f32, f32
137  } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
138  return %4, %5 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
139}
140// CHECK-LABEL: func @remove_no_op
141//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
142//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
143//       CHECK:     return %[[ARG1]], %[[ARG0]]
144
145// -----
146
147#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
148func.func @remove_no_op_mismatched_types(%arg0 : tensor<?x?x?xf32>)
149  -> tensor<1x2x3xf32> {
150  %out = tensor.empty() : tensor<1x2x3xf32>
151  %g = linalg.generic {
152    indexing_maps = [#map, #map],
153    iterator_types = ["parallel", "parallel", "parallel"]
154  } ins(%arg0 : tensor<?x?x?xf32>)
155    outs(%out : tensor<1x2x3xf32>) {
156  ^bb0(%arg2 : f32, %arg3 : f32):
157    linalg.yield %arg2 : f32
158  } -> (tensor<1x2x3xf32>)
159  return %g : tensor<1x2x3xf32>
160}
161// CHECK-LABEL: func @remove_no_op_mismatched_types
162//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
163//       CHECK:     %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<1x2x3xf32>
164//       CHECK:     return %[[CAST]]
165
166// -----
167
168#map = affine_map<() -> ()>
169func.func @cant_fold_to_tensor_cast(%arg0 : f32) -> tensor<f32> {
170  %out = tensor.empty() : tensor<f32>
171  %g = linalg.generic {
172    indexing_maps = [#map, #map],
173    iterator_types = []
174  } ins(%arg0 : f32)
175    outs(%out : tensor<f32>) {
176  ^bb0(%arg2 : f32, %arg3 : f32):
177    linalg.yield %arg2 : f32
178  } -> (tensor<f32>)
179  return %g : tensor<f32>
180}
181// CHECK-LABEL: func @cant_fold_to_tensor_cast
182//       CHECK:     linalg.generic
183
184// -----
185
186#map = affine_map<(d0, d1) -> (d0, d1)>
187func.func @keep_not_noop(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
188  %c0 = arith.constant 0 : index
189  %c1 = arith.constant 1 : index
190  %cst = arith.constant 1.000000e+00 : f32
191  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
192  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
193  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
194  cf.br ^bb1(%cst : f32)
195
196^bb1(%arg1 : f32):
197  %3 = linalg.generic
198    {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
199    ins(%arg0 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) {
200    ^bb0(%arg2: f32, %arg3 : f32):
201      linalg.yield %arg1 : f32
202    } -> tensor<?x?xf32>
203  return %3 : tensor<?x?xf32>
204}
205// CHECK-LABEL: func @keep_not_noop
206//       CHECK:   %[[RESULT:.+]] = linalg.generic
207//       CHECK:   return %[[RESULT]]
208
209// -----
210
211#map = affine_map<(d0, d1) -> (d0, d1)>
212func.func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
213  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
214  %c0 = arith.constant 0 : index
215  %c1 = arith.constant 1 : index
216  %cst = arith.constant 1.000000e+00 : f32
217  %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
218  %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
219  %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
220  cf.br ^bb1(%cst : f32)
221
222^bb1(%arg2 : f32):
223  %3:2 = linalg.generic
224    {indexing_maps = [#map, #map, #map, #map],
225     iterator_types = ["parallel", "parallel"]}
226    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
227    outs(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>) {
228    ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32):
229      linalg.yield %arg2, %arg4 : f32, f32
230    } -> (tensor<?x?xf32>, tensor<?x?xf32>)
231  return %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
232}
233// CHECK-LABEL: func @keep_not_noop
234//       CHECK:   %[[RESULT:.+]]:2 = linalg.generic
235//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
236
237// -----
238
239#accesses = [
240  affine_map<(i, j) -> (i, j)>
241]
242
243#trait = {
244  indexing_maps = #accesses,
245  iterator_types = ["parallel", "parallel"]
246}
247
248// CHECK-LABEL: func @dead_linalg_tensor
249//   CHECK-NOT:   linalg.fill
250//   CHECK-NOT:   linalg.matmul
251//   CHECK-NOT:   linalg.generic
252//   CHECK-NOT:   tensor.pad
253//       CHECK:   return
254func.func @dead_linalg_tensor(%arg0 : tensor<7x7xi32>, %arg1 : tensor<7x7xf32>,
255                         %arg2: tensor<?x?xf32>, %high : index) {
256  %c0_i32 = arith.constant 0 : i32
257  %c0 = arith.constant 0 : index
258  %cst = arith.constant 0.000000e+00 : f32
259  %0 = linalg.fill ins(%c0_i32 : i32) outs(%arg0 : tensor<7x7xi32>) -> tensor<7x7xi32>
260  %1 = linalg.matmul ins(%arg1, %arg1: tensor<7x7xf32>, tensor<7x7xf32>)
261                     outs(%arg1: tensor<7x7xf32>) -> tensor<7x7xf32>
262  %2 = linalg.generic #trait outs(%arg0 : tensor<7x7xi32>) {
263  ^bb(%3: i32) :
264    linalg.yield %3 : i32
265  } -> tensor<7x7xi32>
266  %3 = tensor.pad %arg2 low[%c0, %c0] high[%high, %high] {
267        ^bb0(%arg9: index, %arg10: index):
268          tensor.yield %cst : f32
269  } : tensor<?x?xf32> to tensor<2x4xf32>
270  return
271}
272
273// -----
274
275func.func @propagate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
276    %arg3 : index) -> tensor<?x?xf32> {
277  %c0 = arith.constant 0 : index
278  %c1 = arith.constant 1 : index
279  %c21 = arith.constant 21 : index
280  %c42 = arith.constant 42 : index
281  %0 = tensor.empty(%c21, %c42) : tensor<?x?xf32>
282  %1 = linalg.fill ins(%arg1 : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
283  %2 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
284  %3 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
285  %4 = tensor.insert_slice %arg0 into %1[%arg2, %arg3] [%2, %3] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
286  return %4 : tensor<?x?xf32>
287}
288// CHECK-LABEL: func @propagate_casts
289//       CHECK:   %[[INIT:.+]] = tensor.empty
290//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
291//       CHECK:   %[[INSERTED:.+]] = tensor.insert_slice %{{.+}} into %[[FILL]]
292//       CHECK:   %[[RESULT:.+]] = tensor.cast %[[INSERTED]]
293//       CHECK:   return %[[RESULT]]
294
295// -----
296
297// CHECK-LABEL: @self_copy
298func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
299
300//   CHECK-NOT: memref.copy
301  memref.copy %arg0, %arg0 : memref<2x3x?x4xf32> to memref<2x3x?x4xf32>
302
303//   CHECK: return
304  return
305}
306
307// -----
308// CHECK-LABEL: func @fold_fill_reshape()
309func.func @fold_fill_reshape() -> tensor<6x4xf32> {
310  %zero = arith.constant 0.0 : f32
311  %empty = tensor.empty() : tensor<1x2x3x4xf32>
312  // CHECK:      %[[COLLAPSE:.+]] = tensor.collapse_shape
313  // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32)
314  // CHECK-SAME:   outs(%[[COLLAPSE]] : tensor<6x4xf32>)
315  %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
316  %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]]
317      : tensor<1x2x3x4xf32> into tensor<6x4xf32>
318  // CHECK: return %[[FILL]] : tensor<6x4xf32>
319  return %reshape : tensor<6x4xf32>
320}
321
322// -----
323
324//       CHECK: func @fold_fill_reshape_dynamic
325//  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?xf32>
326func.func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32> {
327  %zero = arith.constant 0.0 : f32
328  // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
329  %0 = linalg.fill ins(%zero : f32) outs(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
330  // CHECK: %[[RESULT:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[RESHAPE]]
331  %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]]
332      : tensor<?x?x?x?x?xf32> into tensor<?x?xf32>
333  // CHECK: return %[[RESULT]]
334  return %1 : tensor<?x?xf32>
335}
336
337// -----
338//       CHECK: func @fold_fill_extract
339//  CHECK-SAME:   %[[ARG0:.+]]: i1
340func.func @fold_fill_extract(%arg0 : i1) -> i1 {
341  %c0 = arith.constant 0 : index
342  %c1 = arith.constant 1 : index
343
344  %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1>
345  %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1>
346
347  %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1>
348
349  //  CHECK:   return %[[ARG0]]
350  return %extracted : i1
351}
352
353// -----
354
355func.func @fill_pack() -> tensor<24x32x16x16xf32> {
356  %dest = tensor.empty() : tensor<384x512xf32>
357  %cst = arith.constant 0.000000e+00 : f32
358  %0 = tensor.empty() : tensor<24x32x16x16xf32>
359  %1 = linalg.fill ins(%cst : f32) outs(%dest : tensor<384x512xf32>) -> tensor<384x512xf32>
360  %pack = tensor.pack %1 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<384x512xf32> -> tensor<24x32x16x16xf32>
361  return %pack : tensor<24x32x16x16xf32>
362}
363// CHECK-LABEL: func.func @fill_pack
364// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty() : tensor<24x32x16x16xf32>
365// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
366// CHECK:         return %[[FILL]]
367
368// -----
369
370func.func @fill_pack_general() -> tensor<1x1x8x4x4x8xi32>{
371  %c0_i32 = arith.constant 0 : i32
372  %alloc = memref.alloc() : memref<1x1x8x4x4x8xi32>
373  %9 = tensor.empty() : tensor<1x1x16x64xi32>
374  %extracted_slice_15 = tensor.extract_slice %9[0, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : tensor<1x1x16x64xi32> to tensor<1x1x16x64xi32>
375  %16 = linalg.fill ins(%c0_i32 : i32) outs(%extracted_slice_15 : tensor<1x1x16x64xi32>) -> tensor<1x1x16x64xi32>
376  %0 = bufferization.to_tensor %alloc restrict writable : memref<1x1x8x4x4x8xi32> to tensor<1x1x8x4x4x8xi32>
377  %pack_18 = tensor.pack %16 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %0 : tensor<1x1x16x64xi32> -> tensor<1x1x8x4x4x8xi32>
378  return %pack_18 : tensor<1x1x8x4x4x8xi32>
379}
380
381// CHECK-LABEL: func.func @fill_pack_general
382// CHECK:         %[[ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x4x8xi32>
383// CHECK:         %[[TENSOR:.+]] = bufferization.to_tensor %[[ALLOC]]
384// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[TENSOR]]
385// CHECK:         return %[[FILL]]
386
387// -----
388
389#map = affine_map<()[s0] -> (s0 ceildiv 16)>
390func.func @dynamic_fill_pack(%arg0: tensor<?x?xf32>) -> tensor<?x?x16x16xf32> {
391  %cst = arith.constant 0.000000e+00 : f32
392  %c0 = arith.constant 0 : index
393  %c1 = arith.constant 1 : index
394  %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
395  %dim = tensor.dim %0, %c0 : tensor<?x?xf32>
396  %dim_0 = tensor.dim %0, %c1 : tensor<?x?xf32>
397  %1 = affine.apply #map()[%dim]
398  %2 = affine.apply #map()[%dim_0]
399  %3 = tensor.empty(%1, %2) : tensor<?x?x16x16xf32>
400  %pack = tensor.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %3 : tensor<?x?xf32> -> tensor<?x?x16x16xf32>
401  return %pack : tensor<?x?x16x16xf32>
402}
403// CHECK-DAG:   #[[MAP:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
404// CHECK:       func.func @dynamic_fill_pack
405// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]
406// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
407// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
408// CHECK:         %[[D0:.+]] = tensor.dim %[[DEST]], %[[C0]]
409// CHECK:         %[[D1:.+]] = tensor.dim %[[DEST]], %[[C1]]
410// CHECK:         %[[PACKED_D0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
411// CHECK:         %[[PACKED_D1:.+]] = affine.apply #[[MAP]]()[%[[D1]]]
412// CHECK:         %[[PACKED_EMPTY:.+]] = tensor.empty(%[[PACKED_D0]], %[[PACKED_D1]]) : tensor<?x?x16x16xf32>
413// CHECK:         %[[FILL:.+]] = linalg.fill ins(%{{.+}}) outs(%[[PACKED_EMPTY]]
414// CHECK:         return %[[FILL]]
415
416// -----
417
418// CHECK: func @fold_self_copy
419func.func @fold_self_copy(%0 : memref<4x16xf32>) {
420// CHECK-NEXT: return
421  linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
422                                   affine_map<(d0, d1) -> (d0, d1)>],
423                  iterator_types = ["parallel", "parallel"]}
424    ins(%0 : memref<4x16xf32>)
425    outs(%0 : memref<4x16xf32>) {
426      ^bb0(%arg4: f32, %arg5: f32):
427        linalg.yield %arg4 : f32
428    }
429  return
430}
431
432// -----
433
434// CHECK-LABEL: func @fold_static_pad_fill
435//       CHECK:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
436//       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<412x276xf32>
437//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]]
438//       CHECK:   return %[[FILL]]
439func.func @fold_static_pad_fill() -> tensor<412x276xf32> {
440  %f0 = arith.constant 0.0 : f32
441  %empty = tensor.empty() : tensor<400x273xf32>
442  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<400x273xf32>) -> tensor<400x273xf32>
443  %pad = tensor.pad %fill low[4, 1] high[8, 2] {
444  ^bb0(%arg1: index, %arg2: index):
445    tensor.yield %f0 : f32
446  } : tensor<400x273xf32> to tensor<412x276xf32>
447  return %pad : tensor<412x276xf32>
448}
449
450// -----
451
452// CHECK: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 9)>
453// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 10)>
454// CHECK: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 + 23)>
455// CHECK: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 32)>
456
457//      CHECK: func @fold_dynamic_pad_fill
458// CHECK-SAME: %[[SRC:.+]]: tensor<8x?x16x32xf32>, %[[LOW0:.+]]: index, %[[LOW3:.+]]: index, %[[HIGH2:.+]]: index, %[[HIGH3:.+]]: index
459
460//  CHECK-DAG:   %[[I1:.+]] = arith.constant 1 : index
461//  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
462//      CHECK:   %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
463//      CHECK:   %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32>
464//      CHECK:   %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
465//      CHECK:   %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
466//      CHECK:   %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
467//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[S0]], %[[S1]], %[[S2]], %[[S3]]) : tensor<?x?x?x?xf32>
468//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]]
469//      CHECK:   return %[[FILL]]
470func.func @fold_dynamic_pad_fill(%empty: tensor<8x?x16x32xf32>, %low0: index, %low3: index, %high2: index, %high3: index) -> tensor<?x?x?x?xf32> {
471  %f0 = arith.constant 0.0 : f32
472  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x?x16x32xf32>) -> tensor<8x?x16x32xf32>
473  %pad = tensor.pad %fill low[%low0, 8, 7, %low3] high[1, 2, %high2, %high3] {
474  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
475    tensor.yield %f0 : f32
476  } : tensor<8x?x16x32xf32> to tensor<?x?x?x?xf32>
477  return %pad : tensor<?x?x?x?xf32>
478}
479
480// -----
481
482// CHECK-LABEL: func @no_fold_pad_fill_value_mismatch
483func.func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
484  %f0 = arith.constant 0.0 : f32
485  %f1 = arith.constant 1.0 : f32
486  %empty = tensor.empty() : tensor<400x273xf32>
487  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<400x273xf32>) -> tensor<400x273xf32>
488  // CHECK: tensor.pad
489  %pad = tensor.pad %fill low[4, 1] high[8, 2] {
490  ^bb0(%arg1: index, %arg2: index):
491    tensor.yield %f1 : f32
492  } : tensor<400x273xf32> to tensor<412x276xf32>
493  return %pad : tensor<412x276xf32>
494}
495
496// -----
497
498// Tests below verify whether static information is propagated through all the operands of generic op.
499// 1. If one of the inputs of generic op has static info and it has no cast source.
500// 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
501// 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
502#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
503// CHECK-LABEL: func @static_input_without_cast
504// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
505func.func @static_input_without_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
506  %c0 = arith.constant 0 : index
507  %c1 = arith.constant 1 : index
508  %c2 = arith.constant 2 : index
509  %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
510  %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
511  %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
512  %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32>
513  %4 = linalg.generic {
514    indexing_maps = [#map, #map, #map],
515    iterator_types = ["parallel", "parallel", "parallel"]
516  } ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<?x?x?xf32>)
517    outs(%3 : tensor<?x?x?xf32>) {
518  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
519    %9 = arith.addf %arg2, %arg3 : f32
520    linalg.yield %9 : f32
521  } -> (tensor<?x?x?xf32>)
522  %5 = tensor.cast %4 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
523  return %5 : tensor<2x3x4xf32>
524    //  CHECK:      %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
525    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
526    //  CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
527    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
528}
529
530// -----
531
532#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
533// CHECK-LABEL: func @static_input_with_cast
534// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
535func.func @static_input_with_cast(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
536  %c0 = arith.constant 0 : index
537  %c1 = arith.constant 1 : index
538  %c2 = arith.constant 2 : index
539  %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
540  %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
541  %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
542  %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32>
543  %4 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
544  %5 = linalg.generic {
545    indexing_maps = [#map, #map, #map],
546    iterator_types = ["parallel", "parallel", "parallel"]
547  } ins(%arg0, %4 : tensor<2x3x4xf32>, tensor<2x?x?xf32>)
548    outs(%3 : tensor<?x?x?xf32>) {
549  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
550    %9 = arith.addf %arg2, %arg3 : f32
551    linalg.yield %9 : f32
552  } -> (tensor<?x?x?xf32>)
553  %6 = tensor.cast %5 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
554  return %6: tensor<2x3x4xf32>
555    //  CHECK:      %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
556    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
557    //  CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
558    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
559}
560
561// -----
562
563#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
564// CHECK-LABEL: func @static_output_with_cast
565// CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
566func.func @static_output_with_cast(%arg0 : tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
567  %c0 = arith.constant 0 : index
568  %c1 = arith.constant 1 : index
569  %c2 = arith.constant 2 : index
570  %0 = tensor.dim %arg2, %c0 : tensor<2x3x4xf32>
571  %1 = tensor.dim %arg2, %c1 : tensor<2x3x4xf32>
572  %2 = tensor.dim %arg2, %c2 : tensor<2x3x4xf32>
573  %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32>
574  %4 = tensor.cast %3 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
575  %5 = tensor.cast %arg1 : tensor<?x?x?xf32> to tensor<2x?x?xf32>
576  %6 = linalg.generic {
577    indexing_maps = [#map, #map, #map],
578    iterator_types = ["parallel", "parallel", "parallel"]
579  } ins(%arg0, %5 : tensor<?x?x?xf32>, tensor<2x?x?xf32>)
580    outs(%4 : tensor<2x3x4xf32>) {
581  ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
582    %9 = arith.addf %arg3, %arg4 : f32
583    linalg.yield %9 : f32
584  } -> (tensor<2x3x4xf32>)
585  return %6: tensor<2x3x4xf32>
586    //  CHECK:      %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
587    //  CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
588    //  CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
589    //  CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
590    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
591}
592
593// -----
594
595// This test checks the folding of tensor.cast operation when the source value of cast
596// has more static information than the destination value.
597#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
598// CHECK-LABEL: func @cast_source
599// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
600func.func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
601  %c0 = arith.constant 0 : index
602  %c1 = arith.constant 1 : index
603  %c2 = arith.constant 2 : index
604  %0 = tensor.dim %arg0, %c0 : tensor<2x3x4xf32>
605  %1 = tensor.dim %arg0, %c1 : tensor<2x3x4xf32>
606  %2 = tensor.dim %arg0, %c2 : tensor<2x3x4xf32>
607  %3 = tensor.empty(%0, %1, %2) : tensor<?x?x?xf32>
608  %4 = tensor.cast %arg0 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
609  %5 = tensor.cast %arg1 : tensor<2x3x4xf32> to tensor<2x?x?xf32>
610  %6 = linalg.generic {
611    indexing_maps = [#map, #map, #map],
612    iterator_types = ["parallel", "parallel", "parallel"]
613  } ins(%4, %5 : tensor<2x?x?xf32>, tensor<2x?x?xf32>)
614    outs(%3 : tensor<?x?x?xf32>) {
615  ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32):
616    %9 = arith.addf %arg2, %arg3 : f32
617    linalg.yield %9 : f32
618  } -> (tensor<?x?x?xf32>)
619  %7 = tensor.cast %6 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
620  return %7: tensor<2x3x4xf32>
621    //  CHECK:      %[[GENERIC_OP:.*]] = linalg.generic
622    //  CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
623    //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
624}
625
626// -----
627
628#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
629// CHECK-LABEL: func @cast_dest
630// CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<1x?x?xf32>,
631func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor<?x?x?xf32> {
632  %0 = tensor.empty(%arg2, %arg3, %arg4) : tensor<?x?x?xf32>
633  %1 = tensor.cast %arg1 : tensor<1x?x?xf32> to tensor<?x?x?xf32>
634  %2 = linalg.generic {
635    indexing_maps = [#map, #map, #map],
636    iterator_types = ["parallel", "parallel", "parallel"]
637  } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<1x?x?xf32>)
638    outs(%0 : tensor<?x?x?xf32>) {
639  ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
640    %3 = arith.subf %arg5, %arg6 : f32
641    linalg.yield %3 : f32
642  } -> tensor<?x?x?xf32>
643  return %2 : tensor<?x?x?xf32>
644// CHECK:      %[[GENERIC_OP:.*]] = linalg.generic
645// CHECK-SAME: ins(%{{.*}}, %[[ARG1]] : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
646// CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>)
647// CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor<?x?x?xf32>
648}
649
650// -----
651
652//       CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
653// CHECK-LABEL: func @insert_pad_into_fill
654//  CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)
655//   CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
656//   CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
657//   CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
658//   CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
659//       CHECK: %[[INIT:.+]] = tensor.empty()
660//       CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]]{{.*}}outs(%[[INIT]]
661//       CHECK: %[[OFFSET1:.+]] = affine.apply #[[$MAP]]()[%[[LOW1]]]
662//       CHECK: %[[D0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?xf32>
663//       CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x?xf32>
664//       CHECK: %[[D2:.+]] = tensor.dim %[[INPUT]], %[[C2]] : tensor<?x?x?xf32>
665//       CHECK: tensor.insert_slice %[[INPUT]] into %[[FILL]][%[[LOW0]], %[[OFFSET1]], 2] [%[[D0]], %[[D1]], %[[D2]]] [1, 1, 1]
666func.func @insert_pad_into_fill(%input: tensor<?x?x?xf32>, %low0: index, %low1: index, %high1: index, %high2: index) -> tensor<8x384x384xf32> {
667  %f0 = arith.constant 0.0 : f32
668  %c0 = arith.constant 0 : index
669  %pad = tensor.pad %input low[%low0, %low1, %c0] high[%c0, %high1, %high2] {
670  ^bb0(%arg3: index, %arg4: index, %arg5: index):
671    tensor.yield %f0 : f32
672  } : tensor<?x?x?xf32> to tensor<8x128x128xf32>
673  %empty = tensor.empty() : tensor<8x384x384xf32>
674  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
675  %0 = tensor.insert_slice %pad into %fill[0, 1, 2] [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
676  return %0: tensor<8x384x384xf32>
677}
678
679// -----
680
681// CHECK-LABEL: func @multi_insert_pad_into_fill
682//  CHECK-SAME: (%[[INPUT:.+]]: tensor<7x123x124xf32>, %[[A:.+]]: tensor<8x128x128xf32>, %[[OFFSET:.+]]: index)
683//       CHECK:   %[[FILL:.+]] = linalg.fill
684//       CHECK:   %[[INSERT0:.+]] = tensor.insert_slice %[[A]] into %[[FILL]][%[[OFFSET]], 0, 0] [8, 128, 128] [1, 1, 1]
685//       CHECK:   %[[INSERT1:.+]] = tensor.insert_slice %[[A]] into %[[INSERT0]][0, 128, %[[OFFSET]]] [8, 128, 128] [1, 1, 1]
686//       CHECK:                  tensor.insert_slice %[[INPUT]] into %[[INSERT1]][1, 2, 256] [7, 123, 124] [1, 1, 1]
687func.func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
688  %f0 = arith.constant 0.0 : f32
689  %c0 = arith.constant 0 : index
690  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
691  ^bb0(%arg3: index, %arg4: index, %arg5: index):
692    tensor.yield %f0 : f32
693  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
694  %empty = tensor.empty() : tensor<8x384x384xf32>
695  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
696  %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
697  %1 = tensor.insert_slice %a   into %0   [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
698  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
699  return %2: tensor<8x384x384xf32>
700}
701
702// -----
703
704// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap
705func.func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
706  %f0 = arith.constant 0.0 : f32
707  %c0 = arith.constant 0 : index
708  // CHECK: tensor.pad
709  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
710  ^bb0(%arg3: index, %arg4: index, %arg5: index):
711    tensor.yield %f0 : f32
712  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
713  %empty = tensor.empty() : tensor<8x384x384xf32>
714  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
715  %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
716  %1 = tensor.insert_slice %a   into %0   [0, 0, 129]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
717  // Range overlap with %1 at dim#3
718  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
719  return %2: tensor<8x384x384xf32>
720}
721
722// -----
723
724// CHECK-LABEL: func @multi_insert_pad_into_fill_overlap
725func.func @multi_insert_pad_into_fill_overlap(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
726  %f0 = arith.constant 0.0 : f32
727  %c0 = arith.constant 0 : index
728  // CHECK: tensor.pad
729  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
730  ^bb0(%arg3: index, %arg4: index, %arg5: index):
731    tensor.yield %f0 : f32
732  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
733  %empty = tensor.empty() : tensor<8x384x384xf32>
734  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
735  %0 = tensor.insert_slice %a   into %fill[0, 0, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
736  %1 = tensor.insert_slice %a   into %0   [0, 128, 255]    [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
737  // Range overlap with %0 at dim#3
738  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
739  return %2: tensor<8x384x384xf32>
740}
741
742// -----
743
744// CHECK-LABEL: func @multi_insert_pad_into_fill
745func.func @multi_insert_pad_into_fill(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
746  %f0 = arith.constant 0.0 : f32
747  %c0 = arith.constant 0 : index
748  // CHECK-NOT: tensor.pad
749  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
750  ^bb0(%arg3: index, %arg4: index, %arg5: index):
751    tensor.yield %f0 : f32
752  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
753  %empty = tensor.empty() : tensor<8x384x384xf32>
754  %fill = linalg.fill ins(%f0 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
755  // Overlap btween %0 and %1 is fine but not with %2 is fine.
756  // CHECK-COUNT-3: tensor.insert_slice
757  %0 = tensor.insert_slice %a   into %fill[0, 0, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
758  %1 = tensor.insert_slice %a   into %0   [0, 1, %offset]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
759  %2 = tensor.insert_slice %pad into %1   [0, 256, 256]    [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
760  return %2: tensor<8x384x384xf32>
761}
762
763// -----
764
765// CHECK-LABEL: func @multi_insert_pad_into_fill_mismatch
766func.func @multi_insert_pad_into_fill_mismatch(%input: tensor<7x123x124xf32>, %a: tensor<8x128x128xf32>, %offset: index) -> tensor<8x384x384xf32> {
767  %f0 = arith.constant 0.0 : f32
768  %f1 = arith.constant 1.0 : f32
769  %c0 = arith.constant 0 : index
770  // CHECK: tensor.pad
771  %pad = tensor.pad %input low[1, 2, 0] high[0, 3, 4] {
772  ^bb0(%arg3: index, %arg4: index, %arg5: index):
773    tensor.yield %f0 : f32
774  } : tensor<7x123x124xf32> to tensor<8x128x128xf32>
775  %empty = tensor.empty() : tensor<8x384x384xf32>
776  // Different filling value than padding value.
777  %fill = linalg.fill ins(%f1 : f32) outs(%empty : tensor<8x384x384xf32>) -> tensor<8x384x384xf32>
778  %0 = tensor.insert_slice %a   into %fill[%offset, 0, 0]  [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
779  %1 = tensor.insert_slice %a   into %0   [0, 128, %offset][8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
780  %2 = tensor.insert_slice %pad into %1   [0, 0, 256]      [8, 128, 128] [1, 1, 1] : tensor<8x128x128xf32> into tensor<8x384x384xf32>
781  return %2: tensor<8x384x384xf32>
782}
783
784// -----
785
786func.func @fold_linalgop_with_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
787    %arg2 : tensor<?x?xf32>) -> (tensor<4x8xf32>, tensor<?x?xf32>) {
788  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
789      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
790  %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
791  return %1, %0 : tensor<4x8xf32>, tensor<?x?xf32>
792}
793//       CHECK: func @fold_linalgop_with_cast_consumer(
794//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
795//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
796//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
797//   CHECK-DAG:  %[[LHS_CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
798//   CHECK-DAG:  %[[RHS_CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<?x8xf32>
799//   CHECK-DAG:  %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?xf32> to tensor<4x8xf32>
800//       CHECK:  %[[MATMUL:.+]] = linalg.matmul
801//  CHECK-SAME:      ins(%[[LHS_CAST]], %[[RHS_CAST]] :
802//  CHECK-SAME:      outs(%[[OUT_CAST]] :
803//       CHECK:  %[[RESULT_CAST:.+]] = tensor.cast %[[MATMUL]]
804//       CHECK:  return %[[MATMUL]], %[[RESULT_CAST]]
805
806// -----
807
808func.func private @some_use(%0 : tensor<4x8xf32>)
809
810func.func @linalgop_with_cond_cast_consumer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
811    %arg2 : tensor<?x?xf32>, %arg3 : i1) -> tensor<?x?xf32> {
812  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
813      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
814  scf.if %arg3 {
815    %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x8xf32>
816    func.call @some_use(%1) : (tensor<4x8xf32>) -> ()
817  }
818  return %0 : tensor<?x?xf32>
819}
820
821// Check conditionally reachable cast is not folded into producer.
822// CHECK-LABEL: func @linalgop_with_cond_cast_consumer
823//  CHECK-SAME:     (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: tensor<?x?xf32>, %[[ARG3:.*]]: i1)
824//       CHECK: %[[RES:.*]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
825//  CHECK-SAME:      outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
826//       CHECK: scf.if %[[ARG3]] {
827//       CHECK:   %[[CAST:.*]] = tensor.cast %[[RES]] : tensor<?x?xf32> to tensor<4x8xf32>
828//       CHECK:   func.call @some_use(%[[CAST]]) : (tensor<4x8xf32>) -> ()
829//       CHECK: }
830//       CHECK: return %[[RES]] : tensor<?x?xf32>
831
832
833// -----
834
835func.func @fold_conv_op_with_cast_consumer(%arg0 : tensor<?x?x?x?xf32>,
836    %arg1 : tensor<?x?x?x?xf32>,  %arg2 : tensor<?x?x?x?xf32>) ->
837    (tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>) {
838  %0 = linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
839      outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
840  %1 = tensor.cast %0 : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
841  return %1, %0 : tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>
842}
843//       CHECK: func @fold_conv_op_with_cast_consumer(
844//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
845//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>
846//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?x?xf32>)
847//       CHECK:  %[[OUT_CAST:.+]] = tensor.cast %[[ARG2]] : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
848//       CHECK:  %[[CONV:.+]] = linalg.conv_2d_nchw_fchw
849//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]] :
850//  CHECK-SAME:      outs(%[[OUT_CAST]] :
851//       CHECK:  %[[RESULT_CAST:.+]] = tensor.cast %[[CONV]]
852//       CHECK:  return %[[CONV]], %[[RESULT_CAST]]
853
854// -----
855
856func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<2x3x4xf32>) {
857  %c0 = arith.constant 0 : index
858  %c1 = arith.constant 1 : index
859  %c2 = arith.constant 2 : index
860  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
861  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
862  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
863  %empty1 = tensor.empty(%d1, %d2, %d0) : tensor<?x?x?xf32>
864  %empty2 = tensor.empty(%d2, %d1, %d0) : tensor<?x?x?xf32>
865  %0:2 = linalg.generic {
866      iterator_types = ["parallel", "parallel", "parallel"],
867      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
868                       affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
869                       affine_map<(d0, d1, d2) -> (d2, d1, d0)>]}
870      ins(%arg0 : tensor<?x?x?xf32>) outs(%empty1, %empty2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
871    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32) :
872      linalg.yield %b0, %b0 : f32, f32
873    } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
874  %1 = tensor.cast %0#1 : tensor<?x?x?xf32> to tensor<2x3x4xf32>
875  return %0#0, %1 : tensor<?x?x?xf32>, tensor<2x3x4xf32>
876}
877//       CHECK: func @fold_multi_use_generic_op_with_consumer
878//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
879//   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<2x3x4xf32>
880//   CHECK-DAG:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x3x2xf32>
881//   CHECK-DAG:   %[[INIT2:.+]] = tensor.empty() : tensor<3x2x4xf32>
882//       CHECK:   %[[GENERIC:.+]]:2 = linalg.generic
883//  CHECK-SAME:       ins(%[[CAST]] :
884//  CHECK-SAME:       outs(%[[INIT2]], %[[INIT1]] :
885//       CHECK:   %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor<?x?x?xf32>
886//       CHECK:   return %[[RETURN_CAST]], %[[GENERIC]]#1
887
888// -----
889
890#map = affine_map<(d0) -> (d0)>
891func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
892  linalg.generic {
893    indexing_maps = [#map, #map],
894    iterator_types = ["parallel"]
895  } ins(%arg0 : memref<?xf32>)
896    outs(%arg1 : memref<?xf32>) {
897  ^bb0(%arg2 : f32, %arg3 : f32):
898    linalg.yield %arg2 : f32
899  }
900  return
901}
902
903// Do not erase ops with buffer semantics.
904// CHECK-LABEL: func @identity_buffer
905//  CHECK-SAME:     (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
906//       CHECK:     linalg.generic {
907//  CHECK-SAME:    indexing_maps = [#map, #map],
908//  CHECK-SAME:    iterator_types = ["parallel"]
909//  CHECK-SAME:  } ins(%[[ARG1]] : memref<?xf32>)
910//  CHECK-SAME:    outs(%[[ARG2]] : memref<?xf32>) {
911
912// -----
913
914#map = affine_map<(d0, d1) -> (d1, d0)>
915func.func @erase_non_identity_noop(%arg0 : tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
916  %0 = linalg.generic {
917    indexing_maps = [#map, #map],
918    iterator_types = ["parallel", "parallel"]
919  } ins(%arg0 : tensor<?x?xf32>)
920    outs(%arg1 : tensor<?x?xf32>) {
921  ^bb0(%in: f32, %out: f32):
922    linalg.yield %in: f32
923  } -> tensor<?x?xf32>
924  return %0 : tensor<?x?xf32>
925}
926
927// Do not erase ops with buffer semantics.
928// CHECK-LABEL: func @erase_non_identity_noop
929//  CHECK-SAME:   (%[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
930//       CHECK:   return %[[ARG0]] : tensor<?x?xf32>
931
932// -----
933
934// Just make sure that we don't crash.
935
936// CHECK-LABEL: func @dedeplicate_regression_test
937func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) {
938  %36 = linalg.generic
939    {indexing_maps = [affine_map<(d0) -> (d0)>,
940                      affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
941     iterator_types = ["parallel"]}
942    ins(%1, %1 : tensor<4xf32>, tensor<4xf32>)
943    outs(%0 : tensor<4xf32>) {
944  ^bb0(%in: f32, %in_24: f32, %out: f32):
945    linalg.yield %in : f32
946  } -> tensor<4xf32>
947  %53 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
948                        iterator_types = ["parallel"]}
949                        outs(%36 : tensor<4xf32>) {
950  ^bb0(%out: f32):
951    linalg.yield %out : f32
952  } -> tensor<4xf32>
953  return
954}
955
956// -----
957
958// CHECK-LABEL: dead_softmax
959func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
960  %0 = tensor.empty() : tensor<16x64x256xf32>
961  // CHECK-NOT: linalg.softmax
962  %1 = linalg.softmax dimension(1)
963    ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
964  return %arg0 : tensor<16x64x256xf32>
965}
966
967// -----
968
969// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op
970//       CHECK: tensor.dim
971//       CHECK: tensor.dim
972//   CHECK-NOT: tensor.dim
973//       CHECK: return
974func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
975  %c0 = arith.constant 0 : index
976  %c1 = arith.constant 1 : index
977  %dim0_0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
978  %dim1_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
979  %0 = tensor.empty(%dim0_0, %dim1_0) : tensor<?x?xf32>
980  %1 = linalg.copy ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
981  %dim0_1 = tensor.dim %1, %c0 : tensor<?x?xf32>
982  %dim1_1 = tensor.dim %1, %c1 : tensor<?x?xf32>
983  %2 = tensor.empty(%dim0_1, %dim1_1) : tensor<?x?xf32>
984  %3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
985  return %3: tensor<?x?xf32>
986}
987// -----
988
989// CHECK-LABEL: func @canonicalize_fill_to_copy_input(
990//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
991//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
992//       CHECK:   %[[ZERO:.+]] = arith.constant 0.0
993//       CHECK:   linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
994func.func @canonicalize_fill_to_copy_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
995  %c0 = arith.constant 0.0 : f32
996  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
997  %copy = linalg.copy ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
998  return %copy : tensor<?x?xf32>
999}
1000
1001// -----
1002
1003// CHECK-LABEL: func @canonicalize_fill_to_copy_dest(
1004//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
1005//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
1006//       CHECK:   linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>)
1007func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
1008  %c0 = arith.constant 0.0 : f32
1009  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
1010  %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
1011  return %copy : tensor<?x?xf32>
1012}
1013
1014// -----
1015
1016// CHECK-LABEL: func @canonicalize_fill_to_transpose_input(
1017//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
1018//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
1019//       CHECK:   %[[ZERO:.+]] = arith.constant 0.0
1020//       CHECK:   linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor<?x?xf32>)
1021func.func @canonicalize_fill_to_transpose_input(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
1022  %c0 = arith.constant 0.0 : f32
1023  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
1024  %transpose = linalg.transpose ins(%fill : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
1025  return %transpose : tensor<?x?xf32>
1026}
1027
1028// -----
1029
1030// CHECK-LABEL: func @broadcast_same_shape(
1031//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<2x3xf32>
1032//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x3xf32>)
1033//       CHECK-NOT:   linalg.broadcast
1034//       CHECK:       return %[[ARG0]] : tensor<2x3xf32>
1035func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) -> tensor<2x3xf32> {
1036  %0 = linalg.broadcast ins(%input: tensor<2x3xf32>) outs(%init: tensor<2x3xf32>) dimensions = []
1037  return %0 : tensor<2x3xf32>
1038}
1039
1040// -----
1041
1042func.func @transpose_1d(%input: tensor<16xf32>,
1043                        %init: tensor<16xf32>) -> tensor<16xf32> {
1044  %transpose = linalg.transpose
1045      ins(%input:tensor<16xf32>)
1046      outs(%init:tensor<16xf32>)
1047      permutation = [0]
1048  func.return %transpose : tensor<16xf32>
1049}
1050
1051// CHECK-LABEL: func @transpose_1d(
1052//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16xf32>,
1053//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16xf32>)
1054//   CHECK-NOT:   linalg.transpose
1055//       CHECK:   return %[[INPUT]] : tensor<16xf32>
1056
1057// -----
1058
1059func.func @transpose_identity_perm(%input: tensor<16x32x64xf32>,
1060                                   %init: tensor<16x32x64xf32>) -> tensor<16x32x64xf32> {
1061  %transpose = linalg.transpose
1062      ins(%input:tensor<16x32x64xf32>)
1063      outs(%init:tensor<16x32x64xf32>)
1064      permutation = [0, 1, 2]
1065  func.return %transpose : tensor<16x32x64xf32>
1066}
1067
1068// CHECK-LABEL: func @transpose_identity_perm(
1069//  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>,
1070//  CHECK-SAME:     %[[INIT:[a-zA-Z0-9]+]]: tensor<16x32x64xf32>)
1071//   CHECK-NOT:   linalg.transpose
1072//       CHECK:   return %[[INPUT]] : tensor<16x32x64xf32>
1073
1074// -----
1075
1076func.func @transpose_transpose_cancel(%input: tensor<5x4x3xf32>,
1077                                      %init1: tensor<4x3x5xf32>,
1078                                      %init2: tensor<5x4x3xf32>) -> tensor<5x4x3xf32> {
1079  // CHECK-LABEL: @transpose_transpose_cancel
1080  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1081  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
1082  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1083  //   CHECK-NOT:   linalg.transpose
1084  //       CHECK:   return %[[INPUT]] : tensor<5x4x3xf32>
1085  %transpose1 = linalg.transpose
1086      ins(%input:tensor<5x4x3xf32>)
1087      outs(%init1:tensor<4x3x5xf32>)
1088      permutation = [1, 2, 0]
1089  %transpose2 = linalg.transpose
1090      ins(%transpose1:tensor<4x3x5xf32>)
1091      outs(%init2:tensor<5x4x3xf32>)
1092      permutation = [2, 0, 1]
1093  func.return %transpose2 : tensor<5x4x3xf32>
1094}
1095
1096// -----
1097
1098func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
1099                                    %init1: tensor<4x3x5xf32>,
1100                                    %init2: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
1101  // CHECK-LABEL: @transpose_transpose_fold
1102  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<5x4x3xf32>
1103  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<4x3x5xf32>
1104  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<3x4x5xf32>
1105  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<5x4x3xf32>) outs(%[[INIT2]] : tensor<3x4x5xf32>) permutation = [2, 1, 0]
1106  //   CHECK-NOT:   linalg.transpose
1107  //       CHECK:   return %[[TRANSPOSE]] : tensor<3x4x5xf32>
1108  %transpose1 = linalg.transpose
1109      ins(%input:tensor<5x4x3xf32>)
1110      outs(%init1:tensor<4x3x5xf32>)
1111      permutation = [1, 2, 0]
1112  %transpose2 = linalg.transpose
1113      ins(%transpose1:tensor<4x3x5xf32>)
1114      outs(%init2:tensor<3x4x5xf32>)
1115      permutation = [1, 0, 2]
1116  func.return %transpose2 : tensor<3x4x5xf32>
1117}
1118
1119// -----
1120
1121func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
1122                                    %init1: tensor<1x2x3x4x5x6xf32>,
1123                                    %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
1124  // CHECK-LABEL: @broadcast_transpose_fold
1125  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
1126  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
1127  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
1128  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
1129  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
1130  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
1131  //       CHECK:   return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
1132  %broadcast = linalg.broadcast
1133      ins(%input : tensor<2x4x5xf32>)
1134      outs(%init1 : tensor<1x2x3x4x5x6xf32>)
1135      dimensions = [0, 2, 5]
1136  %transpose = linalg.transpose
1137      ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
1138      outs(%init2 : tensor<1x6x2x3x5x4xf32>)
1139      permutation = [0, 5, 1, 2, 4, 3]
1140  func.return %transpose : tensor<1x6x2x3x5x4xf32>
1141}
1142
1143// -----
1144
1145func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
1146                                            %init1: tensor<1x?x3x?x5x6xf32>,
1147                                            %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
1148  // CHECK-LABEL: @broadcast_transpose_fold_dynamic
1149  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
1150  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
1151  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
1152  //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
1153  //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
1154  //       CHECK:   %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
1155  //       CHECK:   %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
1156  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
1157  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
1158  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
1159  //       CHECK:   return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
1160  %broadcast = linalg.broadcast
1161      ins(%input : tensor<?x?x5xf32>)
1162      outs(%init1 : tensor<1x?x3x?x5x6xf32>)
1163      dimensions = [0, 2, 5]
1164  %transpose = linalg.transpose
1165      ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
1166      outs(%init2 : tensor<1x3x?x6x5x?xf32>)
1167      permutation = [0, 2, 3, 5, 4, 1]
1168  func.return %transpose : tensor<1x3x?x6x5x?xf32>
1169}
1170
1171// -----
1172
1173func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
1174                                         %init1: tensor<2x4xf32>,
1175                                         %init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
1176  // CHECK-LABEL: @broadcast_transpose_fold_2dim
1177  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1178  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
1179  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
1180  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
1181  //       CHECK:   return %[[BROADCAST]] : tensor<4x2xf32>
1182  %broadcast = linalg.broadcast
1183      ins(%input : tensor<2xf32>)
1184      outs(%init1 : tensor<2x4xf32>)
1185      dimensions = [1]
1186  %transpose = linalg.transpose
1187      ins(%broadcast : tensor<2x4xf32>)
1188      outs(%init2 : tensor<4x2xf32>)
1189      permutation = [1, 0]
1190  func.return %transpose : tensor<4x2xf32>
1191}
1192
1193// -----
1194
1195func.func @concats_of_fill(
1196    %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
1197    -> tensor<5x?x?xf32>
1198{
1199  %cst0 = arith.constant 0.0 : f32
1200  %cst1 = arith.constant 0.0 : f32
1201  %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
1202  %1 = linalg.fill ins(%cst0 : f32) outs(%0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
1203  %2 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
1204  %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
1205  %4 = tensor.concat dim(1) %1, %3 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
1206  return %4 : tensor<5x?x?xf32>
1207}
1208//       CHECK: func @concats_of_fill(
1209//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
1210//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
1211//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
1212//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index)
1213//   CHECK-DAG:   %[[CST:.+]] = arith.constant 0.0
1214//   CHECK-DAG:   %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
1215//   CHECK-DAG:   %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
1216//       CHECK:   %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]]
1217//       CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] :
1218//       CHECK:   return %[[FILL]]
1219
1220// -----
1221
1222func.func @transpose_buffer(%input: memref<?xf32>,
1223                            %init: memref<?xf32>) {
1224  linalg.transpose ins(%input:memref<?xf32>)
1225                   outs(%init:memref<?xf32>)
1226                   permutation = [0]
1227  func.return
1228}
1229
1230// CHECK-LABEL:   func.func @transpose_buffer(
1231//  CHECK-SAME:            %[[VAL_0:.*]]: memref<?xf32>,
1232//  CHECK-SAME:            %[[VAL_1:.*]]: memref<?xf32>) {
1233//       CHECK:     linalg.transpose ins(%[[VAL_0]] : memref<?xf32>)
1234//  CHECK-SAME:       outs(%[[VAL_1]] : memref<?xf32>) permutation = [0]
1235
1236// -----
1237
1238// This test checks linalg op has a recursive memory effect. Otherwise
1239// linalg.map without a user would be DCEd.
1240func.func @recursive_effect(%arg : tensor<1xf32>) {
1241  %init = arith.constant dense<0.0> : tensor<1xf32>
1242  %mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
1243            (%in : f32) {
1244              vector.print %in : f32
1245              linalg.yield %in : f32
1246            }
1247  func.return
1248}
1249
1250// CHECK-LABEL: @recursive_effect
1251//       CHECK: linalg.map
1252