xref: /llvm-project/mlir/test/Dialect/SCF/buffer-deallocation.mlir (revision 10056c821a56a19cef732129e4e0c5883ae1ee49)
1// RUN: mlir-opt -verify-diagnostics -ownership-based-buffer-deallocation \
2// RUN:   -buffer-deallocation-simplification -split-input-file %s | FileCheck %s
3
4func.func @parallel_insert_slice(%arg0: index) {
5  %c0 = arith.constant 0 : index
6  %alloc = memref.alloc() : memref<2xf32>
7  scf.forall (%arg1) in (%arg0) {
8    %alloc0 = memref.alloc() : memref<2xf32>
9    %0 = memref.load %alloc[%c0] : memref<2xf32>
10    linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>)
11  }
12  return
13}
14
15// CHECK-LABEL: func @parallel_insert_slice
16//  CHECK-SAME: (%arg0: index)
17//       CHECK: [[ALLOC0:%.+]] = memref.alloc(
18//       CHECK: scf.forall
19//       CHECK:   [[ALLOC1:%.+]] = memref.alloc(
20//       CHECK:   bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true
21//   CHECK-NOT: retain
22//       CHECK: }
23//       CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true
24//   CHECK-NOT: retain
25
26// -----
27
28func.func @reduce(%buffer: memref<100xf32>) {
29  %init = arith.constant 0.0 : f32
30  %c0 = arith.constant 0 : index
31  %c1 = arith.constant 1 : index
32  scf.parallel (%iv) = (%c0) to (%c1) step (%c1) init (%init) -> f32 {
33    %elem_to_reduce = memref.load %buffer[%iv] : memref<100xf32>
34    scf.reduce(%elem_to_reduce : f32) {
35      ^bb0(%lhs : f32, %rhs: f32):
36        %alloc = memref.alloc() : memref<2xf32>
37        memref.store %lhs, %alloc [%c0] : memref<2xf32>
38        memref.store %rhs, %alloc [%c1] : memref<2xf32>
39        %0 = memref.load %alloc[%c0] : memref<2xf32>
40        %1 = memref.load %alloc[%c1] : memref<2xf32>
41        %res = arith.addf %0, %1 : f32
42        scf.reduce.return %res : f32
43    }
44  }
45  func.return
46}
47
48// CHECK-LABEL: func @reduce
49//       CHECK: scf.reduce
50//       CHECK:   [[ALLOC:%.+]] = memref.alloc(
51//       CHECK:   bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true
52//       CHECK:   scf.reduce.return
53