xref: /llvm-project/mlir/test/Dialect/SCF/one-shot-bufferize.mlir (revision d5cabf8d89a5f5faa5255283821cb080bebbff86)
1// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops bufferize-function-boundaries" -cse -canonicalize -drop-equivalent-buffer-results -split-input-file | FileCheck %s
2
3// Run fuzzer with different seeds.
4// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops analysis-heuristic=fuzzer test-analysis-only analysis-fuzzer-seed=23 bufferize-function-boundaries" -split-input-file -o /dev/null
5// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops analysis-heuristic=fuzzer test-analysis-only analysis-fuzzer-seed=59 bufferize-function-boundaries" -split-input-file -o /dev/null
6// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops analysis-heuristic=fuzzer test-analysis-only analysis-fuzzer-seed=91 bufferize-function-boundaries" -split-input-file -o /dev/null
7
8// Test bufferization using memref types that have no layout map.
9// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="allow-return-allocs-from-loops unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" -split-input-file -o /dev/null
10
11// CHECK-LABEL: func @scf_for_yield_only(
12//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>,
13//  CHECK-SAME:   %[[t:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
14//  CHECK-SAME:   ) -> memref<?xf32> {
15func.func @scf_for_yield_only(
16    %A : tensor<?xf32> {bufferization.writable = false},
17    %B : tensor<?xf32> {bufferization.writable = true},
18    %lb : index, %ub : index, %step : index)
19  -> (tensor<?xf32>, tensor<?xf32>)
20{
21  //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
22  //     CHECK:   memref.copy %[[A]], %[[ALLOC_FOR_A]]
23
24  // The first scf.for remains but just turns into dead code.
25  %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
26    scf.yield %t : tensor<?xf32>
27  }
28
29  // The second scf.for remains but just turns into dead code.
30  %r1 = scf.for %i = %lb to %ub step %step iter_args(%t = %B) -> (tensor<?xf32>) {
31    scf.yield %t : tensor<?xf32>
32  }
33
34  //     CHECK:   return %[[ALLOC_FOR_A]] : memref<?xf32>
35  // CHECK-NOT:   dealloc
36  return %r0, %r1: tensor<?xf32>, tensor<?xf32>
37}
38
39// -----
40
41// CHECK-LABEL: func @scf_for_is_reading(
42//  CHECK-SAME:     %[[A:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[B:.*]]: memref<?xf32, strided<[?], offset: ?>>
43func.func @scf_for_is_reading(%A : tensor<?xf32>, %B : tensor<?xf32>,
44                              %lb : index, %ub : index)
45  -> (f32, f32)
46{
47  %c1 = arith.constant 1 : index
48  %cst = arith.constant 0.0 : f32
49
50  // This is a regression test to make sure that an alloc + copy is emitted.
51
52  // CHECK: %[[alloc:.*]] = memref.alloc
53  // CHECK: memref.copy %[[A]], %[[alloc]]
54  // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[alloc]])
55  %0 = scf.for %iv = %lb to %ub step %c1 iter_args(%1 = %A) -> tensor<?xf32> {
56    %r = linalg.fill ins(%cst : f32) outs(%1 : tensor<?xf32>) -> tensor<?xf32>
57    scf.yield %B : tensor<?xf32>
58  }
59  %1 = tensor.extract %0[%c1] : tensor<?xf32>
60  %2 = tensor.extract %A[%c1] : tensor<?xf32>
61  return %1, %2 : f32, f32
62}
63
64// -----
65
66// Ensure that the function bufferizes without error. This tests pre-order
67// traversal of scf.for loops during bufferization. No need to check the IR,
68// just want to make sure that it does not crash.
69
70// CHECK-LABEL: func @nested_scf_for
71func.func @nested_scf_for(%A : tensor<?xf32> {bufferization.writable = true},
72                          %v : vector<5xf32>) -> tensor<?xf32> {
73  %c0 = arith.constant 0 : index
74  %c1 = arith.constant 1 : index
75  %c10 = arith.constant 10 : index
76  %r1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%B = %A) -> tensor<?xf32> {
77    %r2 = scf.for %j = %c0 to %c10 step %c1 iter_args(%C = %B) -> tensor<?xf32> {
78      %w = vector.transfer_write %v, %C[%c0] : vector<5xf32>, tensor<?xf32>
79      scf.yield %w : tensor<?xf32>
80    }
81    scf.yield %r2 : tensor<?xf32>
82  }
83  return %r1 : tensor<?xf32>
84}
85
86// -----
87
88// CHECK-LABEL: func @scf_for_with_tensor.insert_slice
89//  CHECK-SAME:   %[[A:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
90//  CHECK-SAME:   %[[B:[a-zA-Z0-9]*]]: memref<?xf32, strided<[?], offset: ?>>
91//  CHECK-SAME:   %[[C:[a-zA-Z0-9]*]]: memref<4xf32, strided<[?], offset: ?>>
92func.func @scf_for_with_tensor.insert_slice(
93    %A : tensor<?xf32> {bufferization.writable = false},
94    %B : tensor<?xf32> {bufferization.writable = true},
95    %C : tensor<4xf32> {bufferization.writable = false},
96    %lb : index, %ub : index, %step : index)
97  -> (tensor<?xf32>, tensor<?xf32>)
98{
99  //     CHECK:   %[[ALLOC_FOR_A:.*]] = memref.alloc
100  //     CHECK:   memref.copy %[[A]], %[[ALLOC_FOR_A]]
101
102  //     CHECK:   scf.for {{.*}}
103  // CHECK-NOT: iter_args
104  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
105      -> (tensor<?xf32>, tensor<?xf32>)
106  {
107    // %ttA bufferizes to direct copy of %BUFFER_CAST_C into %svA
108    //     CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1]
109    //     CHECK: memref.copy %[[C]], %[[svA]]
110    %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
111
112    // %ttB bufferizes to direct copy of %BUFFER_CAST_C into %BUFFER_CAST_B
113    //     CHECK: %[[svB:.*]] = memref.subview %[[B]][0] [4] [1]
114    //     CHECK:   memref.copy %[[C]], %[[svB]]
115    %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
116
117    // CHECK-NOT:   scf.yield
118    scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
119  }
120
121  //     CHECK:  return %[[ALLOC_FOR_A]] : memref<?xf32>
122  return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
123}
124
125// -----
126
127// CHECK-LABEL: func @execute_region_with_conflict(
128//  CHECK-SAME:     %[[m1:.*]]: memref<?xf32
129func.func @execute_region_with_conflict(
130    %t1 : tensor<?xf32> {bufferization.writable = true})
131  -> (f32, tensor<?xf32>, f32)
132{
133  %f1 = arith.constant 0.0 : f32
134  %idx = arith.constant 7 : index
135
136  // scf.execute_region is canonicalized away after bufferization. So just the
137  // memref.store is left over.
138
139  // CHECK: %[[alloc:.*]] = memref.alloc
140  // CHECK: memref.copy %[[m1]], %[[alloc]]
141  // CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}]
142  %0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
143    %t2 = tensor.insert %f1 into %t1[%idx] : tensor<?xf32>
144    scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
145  }
146
147  // CHECK: %[[load:.*]] = memref.load %[[m1]]
148  %3 = tensor.extract %t1[%idx] : tensor<?xf32>
149
150  // CHECK: return %{{.*}}, %[[alloc]], %[[load]] : f32, memref<?xf32>, f32
151  return %0, %1, %3 : f32, tensor<?xf32>, f32
152}
153
154// -----
155
156// CHECK-LABEL: func @scf_if_inplace(
157//  CHECK-SAME:     %[[cond:.*]]: i1, %[[t1:.*]]: memref<?xf32{{.*}}>, %[[v:.*]]: vector
158func.func @scf_if_inplace(%cond: i1,
159                          %t1: tensor<?xf32> {bufferization.writable = true},
160                          %v: vector<5xf32>, %idx: index) -> tensor<?xf32> {
161
162  //      CHECK: scf.if %[[cond]] {
163  // CHECK-NEXT: } else {
164  // CHECK-NEXT:   vector.transfer_write %[[v]], %[[t1]]
165  // CHECK-NEXT: }
166  // CHECK-NEXT: return
167  %r = scf.if %cond -> (tensor<?xf32>) {
168    scf.yield %t1 : tensor<?xf32>
169  } else {
170    %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
171    scf.yield %t2 : tensor<?xf32>
172  }
173  return %r : tensor<?xf32>
174}
175
176// -----
177
178// CHECK-LABEL: func @scf_if_inside_scf_for
179//   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
180//   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
181//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : index
182//       CHECK:   scf.for %{{.*}} = %[[c0]] to %[[c10]] step %[[c1]] {
183//       CHECK:     scf.if %{{.*}} {
184//       CHECK:     } else {
185//       CHECK:       vector.transfer_write
186//       CHECK:     }
187//       CHECK:   }
188func.func @scf_if_inside_scf_for(
189    %t1: tensor<?xf32> {bufferization.writable = true},
190    %v: vector<5xf32>, %idx: index,
191    %cond: i1)
192  -> tensor<?xf32>
193{
194  %c0 = arith.constant 0 : index
195  %c1 = arith.constant 1 : index
196  %c10 = arith.constant 10 : index
197  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%bb = %t1) -> (tensor<?xf32>) {
198    %r2 = scf.if %cond -> (tensor<?xf32>) {
199      scf.yield %bb : tensor<?xf32>
200    } else {
201      %t2 = vector.transfer_write %v, %bb[%idx] : vector<5xf32>, tensor<?xf32>
202      scf.yield %t2 : tensor<?xf32>
203    }
204    scf.yield %r2 : tensor<?xf32>
205  }
206  return %r : tensor<?xf32>
207}
208
209// -----
210
211// CHECK-LABEL: func @scf_if_non_equiv_yields(
212//  CHECK-SAME:     %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
213func.func @scf_if_non_equiv_yields(
214    %b : i1,
215    %A : tensor<4xf32> {bufferization.writable = false},
216    %B : tensor<4xf32> {bufferization.writable = false})
217  -> tensor<4xf32>
218{
219  // CHECK: %[[r:.*]] = arith.select %[[cond]], %[[A]], %[[B]]
220  %r = scf.if %b -> (tensor<4xf32>) {
221    scf.yield %A : tensor<4xf32>
222  } else {
223    scf.yield %B : tensor<4xf32>
224  }
225  // CHECK: return %[[r]]
226  return %r: tensor<4xf32>
227}
228
229// -----
230
231// Note: This bufferization is inefficient, but it bufferizes correctly.
232
233// CHECK-LABEL: func @scf_execute_region_yield_non_equivalent(
234//       CHECK:   %[[alloc:.*]] = memref.alloc(%{{.*}})
235//       CHECK:   %[[r:.*]] = memref.load %[[alloc]][%{{.*}}]
236//       CHECK:   return %[[r]]
237func.func @scf_execute_region_yield_non_equivalent(%i: index, %j: index) -> f32 {
238  %r = scf.execute_region -> (tensor<?xf32>) {
239    %t2 = bufferization.alloc_tensor(%i) : tensor<?xf32>
240    scf.yield %t2 : tensor<?xf32>
241  }
242  %f = tensor.extract %r[%j] : tensor<?xf32>
243  return %f : f32
244}
245
246// -----
247
248// Note: This bufferizes to inefficient code, but bufferization should not see
249// such IR in the first place. The iter_arg would canonicalize away. This test
250// case is just to ensure that the bufferization generates correct code.
251
252// CHECK-LABEL: func @scf_for_yield_non_equivalent(
253//  CHECK-SAME:     %[[t:.*]]: memref<?xf32
254//       CHECK:   %[[alloc:.*]] = memref.alloc(%{{.*}})
255//       CHECK:   memref.copy %[[t]], %[[alloc]]
256//       CHECK:   %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[alloc]])
257//   CHECK-DAG:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
258//       CHECK:     memref.copy %[[t]], %[[alloc2]]
259//       CHECK:     scf.yield %[[alloc2]]
260//       CHECK:   return %[[for]]
261func.func @scf_for_yield_non_equivalent(
262    %t: tensor<?xf32>, %lb : index, %ub : index, %step : index) -> tensor<?xf32> {
263  %r = scf.for %i = %lb to %ub step %step iter_args(%a = %t) -> tensor<?xf32> {
264    scf.yield %t : tensor<?xf32>
265  }
266
267  return %r : tensor<?xf32>
268}
269
270// -----
271
272// CHECK-LABEL: func @scf_for_yield_allocation(
273//  CHECK-SAME:     %[[t:.*]]: memref<?xf32
274//       CHECK:   %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[t]])
275//   CHECK-DAG:     %[[alloc:.*]] = memref.alloc(%{{.*}})
276//       CHECK:     %[[casted:.*]] = memref.cast %[[alloc]]
277//       CHECK:     scf.yield %[[casted]]
278//       CHECK:   return %[[for]]
279func.func @scf_for_yield_allocation(%t: tensor<?xf32>, %lb : index, %ub : index,
280                               %step : index) -> tensor<?xf32> {
281  %r = scf.for %i = %lb to %ub step %step iter_args(%a = %t) -> tensor<?xf32> {
282    %t2 = bufferization.alloc_tensor(%i) : tensor<?xf32>
283    scf.yield %t2 : tensor<?xf32>
284  }
285
286  return %r : tensor<?xf32>
287}
288
289// -----
290
291// TODO: The scf.yield could bufferize to 1 alloc and 2 copies (instead of
292// 2 allocs and 2 copies).
293
294// CHECK-LABEL: func @scf_for_swapping_yields(
295//  CHECK-SAME:     %[[A:.*]]: memref<?xf32, strided{{.*}}>, %[[B:.*]]: memref<?xf32, strided{{.*}}>
296func.func @scf_for_swapping_yields(
297    %A : tensor<?xf32>, %B : tensor<?xf32> {bufferization.writable = true},
298    %C : tensor<4xf32>, %lb : index, %ub : index, %step : index)
299  -> (f32, f32)
300{
301//       CHECK:   %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[iter1:.*]] = %[[A]], %[[iter2:.*]] = %[[B]])
302  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
303      -> (tensor<?xf32>, tensor<?xf32>)
304  {
305//       CHECK:     %[[sv1:.*]] = memref.subview %[[iter1]]
306//       CHECK:     memref.copy %{{.*}}, %[[sv1]]
307    %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
308//       CHECK:     %[[sv2:.*]] = memref.subview %[[iter2]]
309//       CHECK:     memref.copy %{{.*}}, %[[sv2]]
310    %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
311
312//       CHECK:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
313//       CHECK:     memref.copy %[[iter2]], %[[alloc2]]
314//       CHECK:     %[[alloc1:.*]] = memref.alloc(%{{.*}})
315//       CHECK:     memref.copy %[[iter1]], %[[alloc1]]
316//       CHECK:     %[[casted2:.*]] = memref.cast %[[alloc2]]
317//       CHECK:     %[[casted1:.*]] = memref.cast %[[alloc1]]
318//       CHECK:     scf.yield %[[casted2]], %[[casted1]]
319    // Yield tensors in different order.
320    scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
321  }
322
323//       CHECK:     %[[r0:.*]] = memref.load %[[for]]#0
324//       CHECK:     %[[r1:.*]] = memref.load %[[for]]#1
325  %f0 = tensor.extract %r0#0[%step] : tensor<?xf32>
326  %f1 = tensor.extract %r0#1[%step] : tensor<?xf32>
327//       CHECK:     return %[[r0]], %[[r1]]
328  return %f0, %f1: f32, f32
329}
330
331// -----
332
333// CHECK-LABEL: func @scf_while(
334//  CHECK-SAME:     %[[arg0:.*]]: memref<?xi1, strided{{.*}}>
335func.func @scf_while(%arg0: tensor<?xi1>, %idx: index) -> tensor<?xi1> {
336  // CHECK: scf.while : () -> () {
337  %res:2 = scf.while (%arg1 = %arg0, %i = %idx) :
338      (tensor<?xi1>, index) -> (tensor<?xi1>, index) {
339    // CHECK: %[[condition:.*]] = memref.load %[[arg0]]
340    // CHECK: scf.condition(%[[condition]])
341    %condition = tensor.extract %arg1[%idx] : tensor<?xi1>
342    scf.condition(%condition) %arg1, %idx : tensor<?xi1>, index
343  } do {
344  ^bb0(%arg2: tensor<?xi1>, %i: index):
345    // CHECK: } do {
346    // CHECK: memref.store %{{.*}}, %[[arg0]]
347    // CHECK: scf.yield
348    // CHECK: }
349    %pos = "dummy.some_op"() : () -> (index)
350    %val = "dummy.another_op"() : () -> (i1)
351    %1 = tensor.insert %val into %arg2[%pos] : tensor<?xi1>
352    scf.yield %1, %i : tensor<?xi1>, index
353  }
354
355  // CHECK: return
356  return %res#0 : tensor<?xi1>
357}
358
359// -----
360
361// The loop condition yields non-equivalent buffers.
362
363// CHECK-LABEL: func @scf_while_non_equiv_condition(
364//  CHECK-SAME:     %[[arg0:.*]]: memref<5xi1, strided{{.*}}>, %[[arg1:.*]]: memref<5xi1, strided{{.*}}>
365func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
366                                         %arg1: tensor<5xi1>,
367                                         %idx: index)
368  -> (tensor<5xi1>, tensor<5xi1>)
369{
370  // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
371  %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
372      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
373    // CHECK: %[[condition:.*]] = memref.load %[[w0]]
374    // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
375    // CHECK: memref.copy %[[w1]], %[[a1]]
376    // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
377    // CHECK: memref.copy %[[w0]], %[[a0]]
378    // CHECK: scf.condition(%[[condition]]) %[[a1]], %[[a0]]
379    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
380    scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
381  } do {
382  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
383    // CHECK: } do {
384    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
385    // CHECK: memref.store %{{.*}}, %[[b0]]
386    // CHECK: %[[casted0:.*]] = memref.cast %[[b0]] : memref<5xi1> to memref<5xi1, strided{{.*}}>
387    // CHECK: %[[casted1:.*]] = memref.cast %[[b1]] : memref<5xi1> to memref<5xi1, strided{{.*}}>
388    // CHECK: scf.yield %[[casted0]], %[[casted1]]
389    // CHECK: }
390    %pos = "dummy.some_op"() : () -> (index)
391    %val = "dummy.another_op"() : () -> (i1)
392    %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
393    scf.yield %1, %b1 : tensor<5xi1>, tensor<5xi1>
394  }
395
396  // CHECK: return %[[loop]]#0, %[[loop]]#1
397  return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
398}
399
400// -----
401
402// Both the loop condition and the loop buffer yield non-equivalent buffers.
403
404// CHECK-LABEL: func @scf_while_non_equiv_condition_and_body(
405//  CHECK-SAME:     %[[arg0:.*]]: memref<5xi1, strided{{.*}}>, %[[arg1:.*]]: memref<5xi1, strided{{.*}}>
406func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
407                                                  %arg1: tensor<5xi1>,
408                                                  %idx: index)
409  -> (tensor<5xi1>, tensor<5xi1>)
410{
411  // CHECK: %[[loop:.*]]:2 = scf.while (%[[w0:.*]] = %[[arg0]], %[[w1:.*]] = %[[arg1]]) {{.*}} {
412  %r0, %r1 = scf.while (%w0 = %arg0, %w1 = %arg1)
413      : (tensor<5xi1>, tensor<5xi1>) -> (tensor<5xi1>, tensor<5xi1>) {
414    // CHECK: %[[condition:.*]] = memref.load %[[w0]]
415    // CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
416    // CHECK: memref.copy %[[w1]], %[[a1]]
417    // CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
418    // CHECK: memref.copy %[[w0]], %[[a0]]
419    // CHECK: scf.condition(%[[condition]]) %[[a1]], %[[a0]]
420    %condition = tensor.extract %w0[%idx] : tensor<5xi1>
421    scf.condition(%condition) %w1, %w0 : tensor<5xi1>, tensor<5xi1>
422  } do {
423  ^bb0(%b0: tensor<5xi1>, %b1: tensor<5xi1>):
424    // CHECK: } do {
425    // CHECK: ^bb0(%[[b0:.*]]: memref<5xi1>, %[[b1:.*]]: memref<5xi1>):
426    // CHECK: memref.store %{{.*}}, %[[b0]]
427    // CHECK: %[[casted1:.*]] = memref.cast %[[b1]]
428    // CHECK: %[[casted0:.*]] = memref.cast %[[b0]]
429    // CHECK: scf.yield %[[casted1]], %[[casted0]]
430    // CHECK: }
431    %pos = "dummy.some_op"() : () -> (index)
432    %val = "dummy.another_op"() : () -> (i1)
433    %1 = tensor.insert %val into %b0[%pos] : tensor<5xi1>
434    scf.yield %b1, %1 : tensor<5xi1>, tensor<5xi1>
435  }
436
437  // CHECK: return %[[loop]]#0, %[[loop]]#1
438  return %r0, %r1 : tensor<5xi1>, tensor<5xi1>
439}
440
441// -----
442
443// CHECK-LABEL: func @scf_while_iter_arg_result_mismatch(
444//  CHECK-SAME:     %[[arg0:.*]]: memref<5xi1, strided{{.*}}>, %[[arg1:.*]]: memref<5xi1, strided{{.*}}>
445//       CHECK:   scf.while (%[[arg3:.*]] = %[[arg1]]) : (memref<5xi1, strided{{.*}}) -> () {
446//   CHECK-DAG:     %[[load:.*]] = memref.load %[[arg0]]
447//       CHECK:     scf.condition(%[[load]])
448//       CHECK:   } do {
449//       CHECK:     %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
450//       CHECK:     memref.copy %[[arg0]], %[[alloc2]]
451//       CHECK:     memref.store %{{.*}}, %[[alloc2]]
452//       CHECK:     %[[casted:.*]] = memref.cast %[[alloc2]] : memref<5xi1> to memref<5xi1, strided{{.*}}>
453//       CHECK:     scf.yield %[[casted]]
454//       CHECK:   }
455func.func @scf_while_iter_arg_result_mismatch(%arg0: tensor<5xi1>,
456                                              %arg1: tensor<5xi1>,
457                                              %arg2: index) {
458  scf.while (%arg3 = %arg1) : (tensor<5xi1>) -> () {
459    %0 = tensor.extract %arg0[%arg2] : tensor<5xi1>
460    %1 = tensor.extract %arg3[%arg2] : tensor<5xi1>
461    "dummy.use"(%1) : (i1) -> ()
462    scf.condition(%0)
463  } do {
464    %0 = "dummy.some_op"() : () -> index
465    %1 = "dummy.another_op"() : () -> i1
466    %2 = tensor.insert %1 into %arg0[%0] : tensor<5xi1>
467    scf.yield %2 : tensor<5xi1>
468  }
469  return
470}
471
472// -----
473
474// CHECK-LABEL: func.func @parallel_insert_slice_no_conflict(
475//  CHECK-SAME:     %[[idx:.*]]: index, %[[idx2:.*]]: index,
476//  CHECK-SAME:     %[[arg1:.*]]: memref<?xf32, strided{{.*}}>,
477//  CHECK-SAME:     %[[arg2:.*]]: memref<?xf32, strided{{.*}}>
478func.func @parallel_insert_slice_no_conflict(
479    %idx: index,
480    %idx2: index,
481    %arg1: tensor<?xf32> {bufferization.writable = true},
482    %arg2: tensor<?xf32> {bufferization.writable = true}) -> (tensor<?xf32>, f32) {
483  %cst = arith.constant 4.200000e+01 : f32
484  %c0 = arith.constant 0 : index
485  %c1 = arith.constant 1 : index
486
487  // CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]])
488  %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
489      // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1]
490      %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
491      // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref<?xf32
492      %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
493      // CHECK-NOT: memref.copy
494
495      // Empty terminator is elided from pretty-printing.
496      // CHECK-NOT: scf.forall.in_parallel
497      // CHECK-NOT: parallel_insert_slice
498      scf.forall.in_parallel {
499        tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
500          tensor<?xf32> into tensor<?xf32>
501      }
502  } {keep_this_attribute}
503  // CHECK: keep_this_attribute
504
505  // CHECK: %[[load:.*]] = memref.load %[[arg2]]
506  %f = tensor.extract %2[%c0] : tensor<?xf32>
507
508  // CHECK: return %[[load]] : f32
509  return %2, %f : tensor<?xf32>, f32
510}
511
512// -----
513
514// CHECK-LABEL: func.func @parallel_insert_slice_with_conflict(
515//  CHECK-SAME:     %[[idx:.*]]: index, %[[idx2:.*]]: index,
516//  CHECK-SAME:     %[[arg1:.*]]: memref<?xf32, strided{{.*}}>,
517//  CHECK-SAME:     %[[arg2:.*]]: memref<?xf32, strided{{.*}}>
518func.func @parallel_insert_slice_with_conflict(
519    %idx: index,
520    %idx2: index,
521    %arg1: tensor<?xf32> {bufferization.writable = true},
522    %arg2: tensor<?xf32> {bufferization.writable = true}) -> (f32, f32)
523{
524  %cst = arith.constant 4.200000e+01 : f32
525  %c0 = arith.constant 0 : index
526  %c1 = arith.constant 1 : index
527
528  // The parallel_insert_slice_op bufferizes out-of-place due to a RAW conflict
529  // on %arg2, so we need an allocation.
530  // CHECK: %[[alloc1:.*]] = memref.alloc
531  // CHECK: memref.copy %[[arg2]], %[[alloc1]]
532
533  // CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]])
534  %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32>) {
535      // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1]
536      %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor<?xf32> to tensor<?xf32>
537
538      // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview1]] : memref<?xf32
539      %8 = linalg.fill ins(%cst : f32) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
540      // CHECK-NOT: memref.copy
541
542      // Empty terminator is elided from pretty-printing.
543      // CHECK-NOT: scf.forall.in_parallel
544      // CHECK-NOT: parallel_insert_slice
545      scf.forall.in_parallel {
546        tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
547          tensor<?xf32> into tensor<?xf32>
548      }
549  }
550
551  // CHECK: %[[load:.*]] = memref.load %[[arg2]]
552  // CHECK: %[[load2:.*]] = memref.load %[[alloc1]]
553  %f = tensor.extract %arg2[%c0] : tensor<?xf32>
554  %f2 = tensor.extract %2[%c0] : tensor<?xf32>
555
556  // CHECK: return %[[load2]], %[[load]] : f32, f32
557  return %f2, %f : f32, f32
558}
559
560// -----
561
562#map0 = affine_map<(d0) -> (d0 * 4)>
563#map1 = affine_map<(d0) -> (d0 * 2)>
564
565// CHECK-LABEL: func.func @matmul
566func.func @matmul(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32> {bufferization.writable = true}) -> tensor<8x8xf32> {
567  %c2 = arith.constant 2 : index
568  %c4 = arith.constant 4 : index
569
570  // CHECK: scf.forall {{.*}}
571  %0 = scf.forall (%arg3, %arg4) in (%c2, %c4) shared_outs(%o = %arg2) -> (tensor<8x8xf32>) {
572    %1 = affine.apply #map0(%arg3)
573    %3 = tensor.extract_slice %arg0[%1, 0] [4, 8] [1, 1] : tensor<8x8xf32> to tensor<4x8xf32>
574    %4 = affine.apply #map1(%arg4)
575    %6 = tensor.extract_slice %arg1[0, %4] [8, 4] [1, 1] : tensor<8x8xf32> to tensor<8x4xf32>
576    %7 = tensor.extract_slice %o[%1, %4] [4, 4] [1, 1] : tensor<8x8xf32> to tensor<4x4xf32>
577
578    //      CHECK: linalg.matmul ins({{.*}}memref<4x8xf32, strided<[?, ?], offset: ?>>, memref<8x4xf32, strided<[?, ?], offset: ?>>) outs({{.*}} : memref<4x4xf32, strided<[?, ?], offset: ?>>)
579    %8 = linalg.matmul ins(%3, %6 : tensor<4x8xf32>, tensor<8x4xf32>) outs(%7 : tensor<4x4xf32>) -> tensor<4x4xf32>
580    scf.forall.in_parallel {
581      tensor.parallel_insert_slice %8 into %o[%1, %4] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<8x8xf32>
582    }
583  }
584  return %0 : tensor<8x8xf32>
585}
586
587// -----
588
589// CHECK-LABEL: func @scf_foreach_private_var(
590//  CHECK-SAME:     %[[t:.*]]: memref<10xf32
591func.func @scf_foreach_private_var(%t: tensor<10xf32>) -> f32 {
592  %c2 = arith.constant 2 : index
593  %c5 = arith.constant 5 : index
594
595  // A copy is inserted for the uses of %t in the loop.
596  // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32>
597  // CHECK: memref.copy %[[t]], %[[t_copy]]
598
599  // CHECK: scf.forall (%{{.*}}) in (2) {
600
601  // Load from the original and store into the copy.
602  // CHECK:   %[[subview:.*]] = memref.subview %[[t_copy]]
603  // CHECK:   memref.load %[[t]]
604  // CHECK:   memref.store %{{.*}}, %[[subview]]
605  %0 = scf.forall (%tid) in (%c2) shared_outs(%o = %t) -> tensor<10xf32> {
606    %offset = arith.muli %c5, %tid : index
607    %slice = tensor.extract_slice %o[%offset] [5] [1]
608        : tensor<10xf32> to tensor<5xf32>
609    %r2 = tensor.extract %t[%tid] : tensor<10xf32>
610    %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32>
611    scf.forall.in_parallel {
612      tensor.parallel_insert_slice %i into %o[%offset] [5] [1]
613          : tensor<5xf32> into tensor<10xf32>
614    }
615  }
616
617  %r = tensor.extract %0[%c2] : tensor<10xf32>
618  return %r : f32
619}
620
621// -----
622
623// CHECK-LABEL: func.func @scf_foreach_privatized_but_not_copied(
624//  CHECK-SAME:     %[[t0:.*]]: memref<10xf32, {{.*}}>, %[[t1:.*]]: memref<10xf32
625func.func @scf_foreach_privatized_but_not_copied(
626    %t0: tensor<10xf32>, %t1: tensor<10xf32>) -> f32 {
627  %c2 = arith.constant 2 : index
628  %c5 = arith.constant 5 : index
629
630  // CHECK-NOT: memref.alloc
631  // CHECK-NOT: memref.copy
632  // CHECK: scf.forall {{.*}} {
633  %0 = scf.forall (%tid) in (%c2) shared_outs(%o = %t0) -> tensor<10xf32> {
634    %offset = arith.muli %c5, %tid : index
635    %slice = tensor.extract_slice %o[%offset] [5] [1]
636        : tensor<10xf32> to tensor<5xf32>
637
638    // %t1 is never written in here, so no copy is needed
639    // CHECK: memref.load %[[t1]]
640    %r2 = tensor.extract %t1[%tid] : tensor<10xf32>
641    %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32>
642    scf.forall.in_parallel {
643      tensor.parallel_insert_slice %i into %o[%offset] [5] [1]
644          : tensor<5xf32> into tensor<10xf32>
645    }
646  }
647
648  %r = tensor.extract %0[%c2] : tensor<10xf32>
649  return %r : f32
650}
651
652// -----
653
654// CHECK-LABEL: func @scf_if_memory_space
655func.func @scf_if_memory_space(%c: i1, %f: f32, %cst: f32) -> (f32, f32)
656{
657  %c0 = arith.constant 0 : index
658  // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
659  %alloc = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
660  // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<5xf32, 1>)
661  %filled = linalg.fill ins(%cst : f32) outs(%alloc : tensor<5xf32>) -> tensor<5xf32>
662  // CHECK: scf.if %{{.*}} -> (memref<5xf32, 1>) {
663  %1 = scf.if %c -> tensor<5xf32> {
664    // CHECK: scf.yield %[[alloc]]
665    scf.yield %filled : tensor<5xf32>
666  } else {
667    // CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1>
668    // CHECK: memref.store %{{.*}}, %[[alloc2]]
669    // CHECK: scf.yield %[[alloc2]]
670    %2 = tensor.insert %f into %filled[%c0] : tensor<5xf32>
671    scf.yield %2 : tensor<5xf32>
672  }
673  %r0 = tensor.extract %filled[%c0] : tensor<5xf32>
674  %r1 = tensor.extract %1[%c0] : tensor<5xf32>
675  return %r0, %r1 : f32, f32
676}
677
678// -----
679
680// CHECK-LABEL: func @scf_execute_region_memory_space
681// CHECK: memref.alloc() {{.*}} : memref<5xf32, 1>
682// CHECK: memref.store
683// CHECK: memref.load
684func.func @scf_execute_region_memory_space(%f: f32) -> f32 {
685  %c0 = arith.constant 0 : index
686  %0 = scf.execute_region -> tensor<5xf32> {
687    %1 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32>
688    %2 = tensor.insert %f into %1[%c0] : tensor<5xf32>
689    scf.yield %2 : tensor<5xf32>
690  }
691  %r = tensor.extract %0[%c0] : tensor<5xf32>
692  return %r : f32
693}
694
695// -----
696
697// Additional allocs are inserted in the loop body. We just check that all
698// allocs have the correct memory space.
699
700// CHECK-LABEL: func @scf_for_swapping_yields_memory_space
701func.func @scf_for_swapping_yields_memory_space(
702    %sz: index, %C : tensor<4xf32>, %lb : index, %ub : index, %step : index)
703  -> (f32, f32)
704{
705  // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref<?xf32, 1>
706  // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref<?xf32, 1>
707  %A = bufferization.alloc_tensor(%sz) {memory_space = 1 : i64} : tensor<?xf32>
708  %B = bufferization.alloc_tensor(%sz) {memory_space = 1 : i64} : tensor<?xf32>
709
710  // CHECK: scf.for {{.*}} {
711  %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
712      -> (tensor<?xf32>, tensor<?xf32>)
713  {
714    // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref<?xf32, 1>
715    // CHECK: memref.alloc(%{{.*}}) {{.*}} : memref<?xf32, 1>
716    %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
717    %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
718    // Yield tensors in different order.
719    scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
720  }
721  // CHECK: }
722  %f0 = tensor.extract %r0#0[%step] : tensor<?xf32>
723  %f1 = tensor.extract %r0#1[%step] : tensor<?xf32>
724  return %f0, %f1: f32, f32
725}
726
727// -----
728
729// CHECK-LABEL: func @scf_for_yield_alias_of_non_equivalent(
730func.func @scf_for_yield_alias_of_non_equivalent(%sz: index) -> tensor<?xf32> {
731  %c0 = arith.constant 0 : index
732  %c1 = arith.constant 1 : index
733  %cst = arith.constant 5.0 : f32
734
735  // CHECK: %[[generate:.*]] = memref.alloc
736  %0 = tensor.generate %sz {
737  ^bb0(%i: index):
738    tensor.yield %cst : f32
739  } : tensor<?xf32>
740
741  // A copy is inserted because %t is used inside the loop.
742  // CHECK: %[[generate_copy:.*]] = memref.alloc
743  // CHECK: memref.copy %[[generate]], %[[generate_copy]]
744  // CHECK: scf.for
745  %r = scf.for %iv = %c0 to %sz step %c1 iter_args(%t = %0) -> tensor<?xf32> {
746    %iv_sub = arith.subi %iv, %c1 : index
747    // CHECK: memref.subview %[[generate]]
748    %ll = tensor.extract_slice %0[%iv_sub][%sz][1] : tensor<?xf32> to tensor<?xf32>
749    %l = tensor.extract %ll[%c0] : tensor<?xf32>
750    %double = arith.mulf %cst, %l : f32
751    // CHECK: memref.store %{{.*}}, %[[generate_copy]]
752    %s = tensor.insert %double into %t[%iv] : tensor<?xf32>
753    scf.yield %s : tensor<?xf32>
754  }
755
756  // CHECK: return %[[generate_copy]]
757  return %r : tensor<?xf32>
758}
759
760// -----
761
762// We just check that this example bufferizes to valid IR.
763
764// CHECK-LABEL: func @scf_for_buffer_type_mismatch
765func.func @scf_for_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
766  %c0 = arith.constant 0 : index
767  %c1 = arith.constant 1 : index
768  %c10 = arith.constant 10 : index
769  %0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
770  %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor<?xf32> to tensor<?xf32>
771  // init_arg and iter_arg have different buffer types. This must be resolved
772  // with casts.
773  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t = %e2) -> tensor<?xf32> {
774    %s = "test.dummy"() : () -> (index)
775    %e = tensor.extract_slice %t[1][%s][1] : tensor<?xf32> to tensor<?xf32>
776    scf.yield %e : tensor<?xf32>
777  }
778  %x = tensor.extract %r[%c1] : tensor<?xf32>
779  return %x : f32
780}
781
782// -----
783
784// We just check that this example bufferizes to valid IR.
785
786// CHECK-LABEL: func @scf_while_buffer_type_mismatch
787func.func @scf_while_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
788  %c0 = arith.constant 0 : index
789  %c1 = arith.constant 1 : index
790  %c10 = arith.constant 10 : index
791  %cst = arith.constant 5.5 : f32
792  %0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
793  %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor<?xf32> to tensor<?xf32>
794  // init_arg and iter_arg have different buffer types. This must be resolved
795  // with casts.
796  %r = scf.while (%t = %e2) : (tensor<?xf32>) -> (tensor<?xf32>) {
797    %c = "test.condition"() : () -> (i1)
798    %s = "test.dummy"() : () -> (index)
799    %e = tensor.extract_slice %t[1][%s][1] : tensor<?xf32> to tensor<?xf32>
800    scf.condition(%c) %e : tensor<?xf32>
801  } do {
802  ^bb0(%b0: tensor<?xf32>):
803    %s2 = "test.dummy"() : () -> (index)
804    %n = tensor.insert %cst into %b0[%s2] : tensor<?xf32>
805    scf.yield %n : tensor<?xf32>
806  }
807  %x = tensor.extract %r[%c1] : tensor<?xf32>
808  return %x : f32
809}
810
811// -----
812
813// CHECK-LABEL: func @non_tensor_for_arg
814func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true})
815    -> tensor<?xf32> {
816  %c0 = arith.constant 0 : index
817  %c1 = arith.constant 1 : index
818  %c2 = arith.constant 2.0 : f32
819  %c10 = arith.constant 10 : index
820  %r1:2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%idx = %c1, %t = %A) -> (index, tensor<?xf32>) {
821    %t2 = tensor.insert %c2 into %t[%idx] : tensor<?xf32>
822    scf.yield %idx, %t2 : index, tensor<?xf32>
823  }
824  return %r1#1 : tensor<?xf32>
825}
826
827// -----
828
829// This is a regression test. Just check that the IR bufferizes.
830
831// CHECK-LABEL: func @buffer_type_of_collapse_shape
832func.func @buffer_type_of_collapse_shape(%arg0: tensor<f64>) {
833  %true = arith.constant true
834  %0 = scf.while (%arg1 = %arg0) : (tensor<f64>) -> (tensor<f64>) {
835    scf.condition(%true) %arg1 : tensor<f64>
836  } do {
837  ^bb0(%_: tensor<f64>):
838    %3 = bufferization.alloc_tensor() : tensor<1xf64>
839    %16 = tensor.collapse_shape %3 [] : tensor<1xf64> into tensor<f64>
840    scf.yield %16 : tensor<f64>
841  }
842  return
843}
844
845// -----
846
847// This is a regression test. Just check that the IR bufferizes.
848
849// CHECK-LABEL: func @non_block_argument_yield
850func.func @non_block_argument_yield() {
851  %true = arith.constant true
852  %0 = bufferization.alloc_tensor() : tensor<i32>
853  %1 = scf.while (%arg0 = %0) : (tensor<i32>) -> (tensor<i32>) {
854    scf.condition(%true) %arg0 : tensor<i32>
855  } do {
856  ^bb0(%arg0: tensor<i32>):
857    %ret = scf.while (%arg1 = %0) : (tensor<i32>) -> (tensor<i32>) {
858      scf.condition(%true) %arg1 : tensor<i32>
859    } do {
860    ^bb0(%arg7: tensor<i32>):
861      scf.yield %0 : tensor<i32>
862    }
863    scf.yield %ret : tensor<i32>
864  }
865  return
866}
867
868// -----
869
870// This is a regression test. Make sure that bufferization succeeds.
871
872// CHECK-LABEL: func @regression_cast_in_loop(
873func.func @regression_cast_in_loop() -> tensor<2xindex> {
874  %false = arith.constant false
875  %c0 = arith.constant 0 : index
876  %0 = bufferization.alloc_tensor() : tensor<2xindex>
877  // CHECK: scf.while (%{{.*}} = %{{.*}}) : (memref<2xindex>) -> memref<2xindex>
878  %1 = scf.while (%arg0 = %0) : (tensor<2xindex>) -> tensor<2xindex> {
879    scf.condition(%false) %arg0 : tensor<2xindex>
880  } do {
881  // CHECK: ^bb0(%{{.*}}: memref<2xindex>):
882  ^bb0(%arg0: tensor<2xindex>):
883    %cast = tensor.cast %0 : tensor<2xindex> to tensor<?xindex>
884    %inserted = tensor.insert %c0 into %cast[%c0] : tensor<?xindex>
885    %cast_0 = tensor.cast %inserted : tensor<?xindex> to tensor<2xindex>
886    scf.yield %cast_0 : tensor<2xindex>
887  }
888  return %1 : tensor<2xindex>
889}
890
891// -----
892
893// This test does not compute anything meaningful but it tests that
894// bufferizesToMemoryWrite is correctly propagated through regions.
895
896// CHECK-LABEL: func @elide_copy_of_non_writing_scf_if(
897func.func @elide_copy_of_non_writing_scf_if(%c: i1, %p1: index, %p2: index, %f: f32)
898  -> (tensor<10xf32>, f32)
899{
900  %r = scf.if %c -> tensor<10xf32> {
901    // CHECK: memref.alloc
902    %t1 = bufferization.alloc_tensor() : tensor<10xf32>
903    scf.yield %t1 : tensor<10xf32>
904  } else {
905    // CHECK: memref.alloc
906    %t2 = bufferization.alloc_tensor() : tensor<10xf32>
907    scf.yield %t2 : tensor<10xf32>
908  }
909
910  // No copy should be inserted because %r does not bufferize to a memory write.
911  // I.e., %r does not have defined contents and the copy can be elided.
912  // CHECK-NOT: memref.alloc
913  // CHECK-NOT: memref.copy
914  %r2 = tensor.insert %f into %r[%p1] : tensor<10xf32>
915  %r3 = tensor.extract %r[%p2] : tensor<10xf32>
916  return %r2, %r3 : tensor<10xf32>, f32
917}
918
919// -----
920
921// CHECK-LABEL: func @index_switch(
922//  CHECK-SAME:     %[[pred:.*]]: index, %[[b:.*]]: memref<{{.*}}>, %[[c:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
923func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> tensor<5xf32> {
924  // Throw in a tensor that bufferizes to a different layout map.
925  // CHECK: %[[a:.*]] = memref.alloc() {{.*}} : memref<5xf32>
926  %a = bufferization.alloc_tensor() : tensor<5xf32>
927
928  // CHECK: %[[r:.*]] = scf.index_switch %[[pred]] -> memref<5xf32, strided<[?], offset: ?>>
929  %0 = scf.index_switch %pred -> tensor<5xf32>
930  // CHECK: case 2 {
931  // CHECK:   %[[cast:.*]] = memref.cast %[[a]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
932  // CHECK:   scf.yield %[[cast]]
933  case 2 {
934    scf.yield %a: tensor<5xf32>
935  }
936  // CHECK: case 5 {
937  // CHECK:   scf.yield %[[b]] : memref<5xf32, strided<[?], offset: ?>>
938  case 5 {
939    scf.yield %b: tensor<5xf32>
940  }
941  // CHECK: default {
942  // CHECK:   scf.yield %[[c]] : memref<5xf32, strided<[?], offset: ?>>
943  default {
944    scf.yield %c: tensor<5xf32>
945  }
946  // CHECK: return %[[r]]
947  return %0 : tensor<5xf32>
948}
949