1// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries" -cse -canonicalize -split-input-file | FileCheck %s
2// RUN: mlir-opt %s -eliminate-empty-tensors | FileCheck %s --check-prefix=CHECK-ELIM
3
4//      CHECK: func @buffer_forwarding_conflict(
5// CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
6// CHECK-SAME:   %[[sz:[0-9a-zA-Z]*]]: index
7func.func @buffer_forwarding_conflict(
8  %t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
9  %sz: index)
10    -> (tensor<?xf32>, tensor<?xf32>)
11{
12  %f0 = arith.constant 0.0: f32
13
14  //     CHECK: %[[EXTRACT_SLICE_ALLOC:.*]] = memref.alloc(%[[sz]])
15  //     CHECK: linalg.fill ins({{.*}} : f32) outs(%[[EXTRACT_SLICE_ALLOC]] : memref<?xf32>)
16  // Alloc is needed for the **first** insert_slice (due to backward traversal during analysis).
17  //     CHECK: %[[DIM:.*]] = memref.dim %[[FUNC_ARG]]
18  // This allocs the whole dim to allow for a full clone of t.
19  //     CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]])
20  // tensor.empty itself does not alloc but forwards to the **second**
21  // insert_slice. The pass replaces the tensor.empty with an out-of-place
22  // extract_slice.
23  %a = tensor.empty(%sz) : tensor<?xf32>
24  %f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
25
26  //     CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
27  //     CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32, strided<[1]>>
28  //     CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32, strided<[1]>>
29  %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor<?xf32> into tensor<?xf32>
30
31  //     CHECK: %[[T_SUBVIEW:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
32  //     CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]]
33  %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
34
35  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
36}
37
38// -----
39
40//      CHECK: func @buffer_forwarding_no_conflict(
41// CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
42// CHECK-SAME:   %[[sz:[0-9a-zA-Z]*]]: index
43func.func @buffer_forwarding_no_conflict(
44  %t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
45  %sz: index)
46    -> (tensor<?xf32>)
47{
48  %f0 = arith.constant 0.0: f32
49
50  // tensor.empty itself does not alloc but forwards to the insert_slice.
51  // EmptyTensorOpElimination replaces the tensor.empty with an inplace
52  // extract_slice.
53  // CHECK: %[[T_SUBVIEW:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
54  %a = tensor.empty(%sz) : tensor<?xf32>
55
56  // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW]] : memref<?xf32
57  %f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
58
59  // Self-copy canonicalizes away later.
60  %r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
61
62  return %r1: tensor<?xf32>
63}
64
65// -----
66
67//      CHECK: func @insertion_point_inside_loop(
68// CHECK-SAME:     %[[t:.*]]: memref<?xf32, strided{{.*}}>, %[[sz:.*]]: index)
69func.func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tensor<?xf32>) {
70  %c0 = arith.constant 0 : index
71  %c1 = arith.constant 1 : index
72  %c5 = arith.constant 5 : index
73
74  // CHECK-NOT: memref.alloc
75  %blank = tensor.empty() : tensor<5xf32>
76
77  // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
78  %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
79    // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[iv]]] [5] [1]
80    %iv_i32 = arith.index_cast %iv : index to i32
81    %f = arith.sitofp %iv_i32 : i32 to f32
82
83    // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[subview]]
84    %filled = linalg.fill ins(%f : f32) outs(%blank : tensor<5xf32>) -> tensor<5xf32>
85
86    // CHECK-NOT: memref.copy
87    %inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
88    scf.yield %inserted : tensor<?xf32>
89  }
90
91  return %r : tensor<?xf32>
92}
93
94// -----
95
96//      CHECK: func @insertion_point_outside_loop(
97// CHECK-SAME:     %[[t:.*]]: memref<?xf32, strided{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
98func.func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
99                                        %idx : index) -> (tensor<?xf32>) {
100  %c0 = arith.constant 0 : index
101  %c1 = arith.constant 1 : index
102  %c5 = arith.constant 5 : index
103
104  // CHECK-NOT: memref.alloc
105  %blank = tensor.empty() : tensor<5xf32>
106
107  // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
108  %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
109    %iv_i32 = arith.index_cast %iv : index to i32
110    %f = arith.sitofp %iv_i32 : i32 to f32
111
112    // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[idx]]] [5] [1]
113    // CHECK: linalg.fill ins(%{{.*}}{{.*}}outs(%[[subview]]
114    %filled = linalg.fill ins(%f : f32) outs(%blank : tensor<5xf32>) -> tensor<5xf32>
115
116    // CHECK-NOT: memref.copy
117    %inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor<?xf32>
118    scf.yield %inserted : tensor<?xf32>
119  }
120
121  return %r : tensor<?xf32>
122}
123
124// -----
125
126// EmptyTensorElimination does currently not apply to chains where the type is
127// changing. (Casts are supported.) This test just ensures that we do not crash
128// or generate IR that does not verify.
129
130// CHECK-LABEL: func @shape_mismatch
131func.func @shape_mismatch(%t: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
132  %cst = arith.constant 8.0 : f32
133  %0 = tensor.empty() : tensor<128xf32>
134  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
135  %2 = tensor.expand_shape %1 [[0, 1, 2]] output_shape [1, 1, 128]
136      : tensor<128xf32> into tensor<1x1x128xf32>
137  %3 = tensor.insert_slice %2 into %t[2, 3, 0][1, 1, 128][1, 1, 1]
138      : tensor<1x1x128xf32> into tensor<5x6x128xf32>
139  return %3 : tensor<5x6x128xf32>
140}
141
142// -----
143
144// CHECK-LABEL: func @cast(
145//  CHECK-SAME:     %[[t:.*]]: memref<256xf32,
146//       CHECK:   %[[sv:.*]] = memref.subview %[[t]]
147//       CHECK:   linalg.fill {{.*}} outs(%[[sv]]
148//       CHECK:   return %[[t]]
149func.func @cast(%t: tensor<256xf32>) -> tensor<256xf32> {
150  %cst = arith.constant 8.0 : f32
151  %c128 = arith.constant 128 : index
152  %0 = tensor.empty(%c128) : tensor<?xf32>
153  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
154  %2 = tensor.cast %1 : tensor<?xf32> to tensor<128xf32>
155  %3 = tensor.insert_slice %2 into %t[2][128][1]
156      : tensor<128xf32> into tensor<256xf32>
157  return %3 : tensor<256xf32>
158}
159
160// -----
161
162//      CHECK: func @parallel_insert_slice(
163// CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
164// CHECK-SAME:   %[[sz:[0-9a-zA-Z]*]]: index
165func.func @parallel_insert_slice(
166  %t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
167  %sz: index)
168    -> (tensor<?xf32>)
169{
170  %f0 = arith.constant 0.0: f32
171  %c512 = arith.constant 512 : index
172
173  %r1 = scf.forall (%iv) in (%c512) shared_outs(%o = %t) -> (tensor<?xf32>) {
174    // tensor.empty itself does not alloc but forwards to the insert_slice.
175    // EmptyTensorOpElimination replaces the tensor.empty with an inplace
176    // extract_slice.
177    // CHECK: %[[T_SUBVIEW:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
178    %a = tensor.empty(%sz) : tensor<?xf32>
179
180    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW]] : memref<?xf32
181    %f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
182
183    // Self-copy canonicalizes away later.
184    scf.forall.in_parallel {
185      tensor.parallel_insert_slice %f into %o[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
186    }
187  }
188
189  return %r1: tensor<?xf32>
190}
191
192// -----
193
194// CHECK-LABEL: func @eleminate_multiple_ops(
195//  CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
196//  CHECK-SAME:   %[[sz:[0-9a-zA-Z]*]]: index
197func.func @eleminate_multiple_ops(%t: tensor<?xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>}, %sz: index, %c: i1)
198    -> (tensor<?xf32>)
199{
200  %cst1 = arith.constant 0.0: f32
201  %cst2 = arith.constant 1.0: f32
202
203  // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref
204  %if = scf.if %c -> tensor<?xf32> {
205    // CHECK: %[[T_SUBVIEW_1:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
206    %a1 = tensor.empty(%sz) : tensor<?xf32>
207    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref<?xf32
208    %f1 = linalg.fill ins(%cst1 : f32) outs(%a1 : tensor<?xf32>) -> tensor<?xf32>
209    // CHECK: scf.yield %[[T_SUBVIEW_1]]
210    scf.yield %f1 : tensor<?xf32>
211  } else {
212      // CHECK: %[[T_SUBVIEW_2:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
213    %a2 = tensor.empty(%sz) : tensor<?xf32>
214    // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref<?xf32
215    %f2 = linalg.fill ins(%cst2 : f32) outs(%a2 : tensor<?xf32>) -> tensor<?xf32>
216    // CHECK: scf.yield %[[T_SUBVIEW_2]]
217    scf.yield %f2 : tensor<?xf32>
218  }
219
220  // Self-copy could canonicalize away later.
221  // CHECK: %[[T_SUBVIEW_3:.*]] =  memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
222  // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]]
223  %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
224  return %r1: tensor<?xf32>
225}
226
227// -----
228
229// This is a regression test. Make sure that the tensor.extract_slice is not
230// eliminated.
231
232// CHECK-LABEL: func.func @regression_do_not_eliminate_non_empty(
233//       CHECK:   memref.subview
234//       CHECK:   memref.subview
235//       CHECK:   memref.copy
236func.func @regression_do_not_eliminate_non_empty(
237    %t: tensor<10xf32>, %t2: tensor<10xf32>) -> tensor<10xf32> {
238  %1 = tensor.extract_slice %t[0] [5] [1] : tensor<10xf32> to tensor<5xf32>
239  %2 = tensor.insert_slice %1 into %t2[1] [5] [1]
240      : tensor<5xf32> into tensor<10xf32>
241  return %2 : tensor<10xf32>
242}
243
244// -----
245
246// This is a regression test. Make sure that there is no crash.
247
248// CHECK-LABEL: func.func @regression_insert_of_bbarg(
249func.func @regression_insert_of_bbarg(%t0: tensor<5xf32>, %t1: tensor<10xf32>) -> tensor<10xf32> {
250  %0 = tensor.insert_slice %t0 into %t1 [2] [5] [1] : tensor<5xf32> into tensor<10xf32>
251  return %0 : tensor<10xf32>
252}
253
254// -----
255
256// This is a regression test. Make sure that there is no crash.
257
258// CHECK-LABEL: func.func @regression_eliminate_equivalent_only(
259func.func @regression_eliminate_equivalent_only(%sz: index, %p: index, %t0: tensor<?x16xi8>) -> tensor<?x16xi8> {
260  %c0 = arith.constant 0 : index
261  %c8 = arith.constant 8 : index
262  %c16 = arith.constant 16 : index
263  %27 = tensor.empty(%sz) : tensor<?x8xi32>
264  %extracted_slice = tensor.extract_slice %27[0, 0] [%p, 8] [1, 1] : tensor<?x8xi32> to tensor<?x8xi32>
265  %28 = scf.for %arg4 = %c0 to %c16 step %c8 iter_args(%arg5 = %t0) -> (tensor<?x16xi8>) {
266    %inserted_slice = tensor.insert_slice %extracted_slice into %27[0, 0] [%sz, 8] [1, 1] : tensor<?x8xi32> into tensor<?x8xi32>
267    %extracted_slice_2 = tensor.extract_slice %arg5[%p, %p] [%sz, 8] [1, 1] : tensor<?x16xi8> to tensor<?x8xi8>
268    %32 = linalg.generic
269        {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
270        iterator_types = ["parallel", "parallel"]}
271        ins(%inserted_slice : tensor<?x8xi32>) outs(%extracted_slice_2 : tensor<?x8xi8>) {
272    ^bb0(%in: i32, %out: i8):
273      %tr = arith.trunci %in : i32 to i8
274      linalg.yield %tr : i8
275    } -> tensor<?x8xi8>
276    %inserted_slice_3 = tensor.insert_slice %32 into %arg5[%p, %arg4] [%sz, 8] [1, 1] : tensor<?x8xi8> into tensor<?x16xi8>
277    scf.yield %inserted_slice_3 : tensor<?x16xi8>
278  }
279  func.return %28 : tensor<?x16xi8>
280}
281
282// -----
283
284// CHECK-LABEL: func.func @regression_multiple_insertion_points(
285//   CHECK-NOT:   memref.alloc
286func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<?x?xf32> {
287  %empty = tensor.empty() : tensor<2x5xf32>
288  %f0 = arith.constant 5.5 : f32
289  %0 = "test.foo"() : () -> (index)
290  %1 = "test.bar"() : () -> (index)
291  %filled = linalg.fill ins(%f0 : f32) outs(%empty : tensor<2x5xf32>) -> tensor<2x5xf32>
292  %2 = tensor.insert_slice %filled into %t1 [%0, %1] [2, 5] [1, 1] : tensor<2x5xf32> into tensor<?x?xf32>
293  return %2 : tensor<?x?xf32>
294}
295
296// -----
297
298// CHECK-LABEL: func @materialize_in_destination(
299//  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
300//       CHECK:   linalg.fill {{.*}} outs(%[[m]]
301//       CHECK:   return %[[m]]
302func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
303  %0 = tensor.empty() : tensor<5xf32>
304  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
305  %1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
306  return %1 : tensor<5xf32>
307}
308
309// -----
310
311// CHECK-LABEL: func @materialize_in_destination_buffer(
312//  CHECK-SAME:     %[[m:.*]]: memref<5xf32>,
313//  CHECK-NEXT:   linalg.fill {{.*}} outs(%[[m]]
314//  CHECK-NEXT:   return
315func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
316  %0 = tensor.empty() : tensor<5xf32>
317  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
318  bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
319  return
320}
321
322// -----
323
324// CHECK-LABEL: func @linalg_copy(
325//  CHECK-SAME:     %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
326//       CHECK:   linalg.fill {{.*}} outs(%[[m]]
327//       CHECK:   return %[[m]]
328func.func @linalg_copy(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
329  %0 = tensor.empty() : tensor<5xf32>
330  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
331  %1 = linalg.copy ins(%filled : tensor<5xf32>) outs(%t : tensor<5xf32>) -> tensor<5xf32>
332  return %1 : tensor<5xf32>
333}
334
335// -----
336
337// CHECK-LABEL: func @linalg_copy_empty(
338// CHECK: %[[ret:.*]] = memref.alloc()
339// CHECK-NEXT: return %[[ret]]
340func.func @linalg_copy_empty() -> tensor<26xi32> {
341  %0 = tensor.empty() : tensor<26xi32>
342  %1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
343  return %1 : tensor<26xi32>
344}
345
346// -----
347
348// CHECK-ELIM-LABEL: func @multiple_materialize_in_destination_buffer(
349//  CHECK-ELIM-SAME:     %[[m:.*]]: memref<5xf32>
350//       CHECK-ELIM:   tensor.empty
351//       CHECK-ELIM:   bufferization.to_tensor %[[m]] restrict writable
352//       CHECK-ELIM:   bufferization.materialize_in_destination {{.*}} in writable %[[m]]
353func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32, %f2: f32, %c: i1) {
354  %0 = tensor.empty() : tensor<5xf32>
355  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
356
357  %1 = tensor.empty() : tensor<5xf32>
358  %filled2 = linalg.fill ins(%f2 : f32) outs(%1 : tensor<5xf32>) -> tensor<5xf32>
359
360  %selected = scf.if %c -> tensor<5xf32> {
361    scf.yield %filled : tensor<5xf32>
362  } else {
363    scf.yield %filled2 : tensor<5xf32>
364  }
365  bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
366  return
367}
368
369// -----
370
371// `EmptyTensorElimination` fails to find a valid insertion
372// point for the new injected `SubsetExtraction`.
373// CHECK-LABEL:   func.func @fail_to_eliminate_any_empty_tensors
374func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
375  %cst_1 = arith.constant 1.0 : f32
376  %cst_2 = arith.constant 2.0 : f32
377  // CHECK: memref.alloc
378  // CHECK: memref.alloc
379  // CHECK: memref.alloc
380  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
381  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
382  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
383  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
384  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
385  // CHECK: memref.copy
386  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
387      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
388  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
389      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
390  return %inserted_slice_2 : tensor<5x6x128xf32>
391}
392
393// -----
394
395// CHECK-LABEL:   func.func @succeed_to_eliminate_one_empty_tensor
396func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397  %cst_1 = arith.constant 1.0 : f32
398  %cst_2 = arith.constant 2.0 : f32
399  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
400  // CHECK: memref.alloc
401  // CHECK-NOT: memref.alloc
402  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
403  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
404  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
405  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
406  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
407  // CHECK: memref.copy
408  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
409      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
410  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
411      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
412  return %inserted_slice_2 : tensor<5x6x128xf32>
413}
414
415// -----
416
417// `EmptyTensorElimination` will replace the specific use of the tensor
418// empty with the new injected `SubsetExtraction`, i.e. the specific use
419// which has been tracked.
420
421// CHECK-ELIM-LABEL:   func.func @multi_use_of_the_same_tensor_empty
422// CHECK-LABEL:   func.func @multi_use_of_the_same_tensor_empty
423func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
424  %cst_1 = arith.constant 1.0 : f32
425  %cst_2 = arith.constant 2.0 : f32
426  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
427  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
428  // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429  // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430  // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
431  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
432  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433  // CHECK: memref.copy
434  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
436  // CHECK-NOT: memref.copy
437  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
439  return %inserted_slice_2 : tensor<5x6x128xf32>
440}
441
442// -----
443
444// CHECK-LABEL:   func.func @multi_use_of_the_same_tensor_empty_creates_non_existent_read
445// CHECK-ELIM-LABEL:   func.func @multi_use_of_the_same_tensor_empty_creates_non_existent_read
446func.func @multi_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>)
447    -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
448  %cst_1 = arith.constant 1.0 : f32
449  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
450  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
451  // CHECK-NOT: memref.alloc
452  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
453  %res_2 = linalg.generic{
454    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
455    iterator_types = ["parallel", "parallel", "parallel"]
456  }
457  ins(%empty_1 : tensor<5x6x64xf32>)
458  outs(%arg2 :tensor<5x6x64xf32>) {
459  ^bb0(%in: f32, %out: f32):
460    %res = arith.addf %in, %in : f32
461    linalg.yield %res : f32
462  } -> tensor<5x6x64xf32>
463  // CHECK-NOT: memref.copy
464  %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
465      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
466  return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
467}
468
469// -----
470
471// CHECK-LABEL:   func.func @direct_use_of_tensor_empty
472func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
473  // CHECK-NOT: memref.alloc
474  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
475  %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
476      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
477  return %inserted_slice_1 : tensor<5x6x128xf32>
478}
479