1// RUN: mlir-opt %s -expand-realloc="emit-deallocs=false" -ownership-based-buffer-deallocation="private-function-dynamic-ownership=true" -canonicalize -buffer-deallocation-simplification | FileCheck %s
2
3// A function that reallocates two buffer inside of a loop. The simplification
4// pass should be able to figure out that the iter_args are always originating
5// from different allocations. IR like this one appears in the sparse compiler.
6
7// CHECK-LABEL: func private @loop_with_realloc(
8func.func private @loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
9  // CHECK-DAG: %[[false:.*]] = arith.constant false
10  // CHECK-DAG: %[[true:.*]] = arith.constant true
11
12  // CHECK: %[[m0:.*]] = memref.alloc
13  %m0 = memref.alloc(%s1) : memref<?xf32>
14  // CHECK: %[[m1:.*]] = memref.alloc
15  %m1 = memref.alloc(%s1) : memref<?xf32>
16
17  // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
18  %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
19    //      CHECK: %[[m2:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
20    // CHECK-NEXT:   memref.alloc
21    // CHECK-NEXT:   memref.subview
22    // CHECK-NEXT:   memref.copy
23    // CHECK-NEXT:   scf.yield %{{.*}}, %[[true]]
24    // CHECK-NEXT: } else {
25    // CHECK-NEXT:   memref.reinterpret_cast
26    // CHECK-NEXT:   scf.yield %{{.*}}, %[[false]]
27    // CHECK-NEXT: }
28    %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
29    //      CHECK: %[[m3:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) {
30    // CHECK-NEXT:   memref.alloc
31    // CHECK-NEXT:   memref.subview
32    // CHECK-NEXT:   memref.copy
33    // CHECK-NEXT:   scf.yield %{{.*}}, %[[true]]
34    // CHECK-NEXT: } else {
35    // CHECK-NEXT:   memref.reinterpret_cast
36    // CHECK-NEXT:   scf.yield %{{.*}}, %[[false]]
37    // CHECK-NEXT: }
38    %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
39
40    // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg0]]
41    // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg1]]
42    // CHECK: %[[d0:.*]] = bufferization.dealloc (%[[base0]] : memref<f32>) if (%[[o0]]) retain (%[[m2]]#0 : memref<?xf32>)
43    // CHECK: %[[d1:.*]] = bufferization.dealloc (%[[base1]] : memref<f32>) if (%[[o1]]) retain (%[[m3]]#0 : memref<?xf32>)
44    // CHECK-DAG: %[[o2:.*]] = arith.ori %[[d0]], %[[m2]]#1
45    // CHECK-DAG: %[[o3:.*]] = arith.ori %[[d1]], %[[m3]]#1
46    // CHECK: scf.yield %[[m2]]#0, %[[m3]]#0, %[[o2]], %[[o3]]
47    scf.yield %m2, %m3 : memref<?xf32>, memref<?xf32>
48  }
49
50  // CHECK: %[[d2:.*]] = bufferization.dealloc (%[[m0]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#0 : memref<?xf32>)
51  // CHECK: %[[d3:.*]] = bufferization.dealloc (%[[m1]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#1 : memref<?xf32>)
52  // CHECK-DAG: %[[or0:.*]] = arith.ori %[[d2]], %[[r]]#2
53  // CHECK-DAG: %[[or1:.*]] = arith.ori %[[d3]], %[[r]]#3
54  // CHECK: return %[[r]]#0, %[[r]]#1, %[[or0]], %[[or1]]
55  return %r0, %r1 : memref<?xf32>, memref<?xf32>
56}
57
58// -----
59
60// The yielded values of the loop are swapped. Therefore, the
61// bufferization.dealloc before the func.return can now longer be split,
62// because %r0 could originate from either %m0 and %m1 (same for %r1).
63
64// CHECK-LABEL: func private @swapping_loop_with_realloc(
65func.func private @swapping_loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) {
66  // CHECK-DAG: %[[false:.*]] = arith.constant false
67  // CHECK-DAG: %[[true:.*]] = arith.constant true
68
69  // CHECK: %[[m0:.*]] = memref.alloc
70  %m0 = memref.alloc(%s1) : memref<?xf32>
71  // CHECK: %[[m1:.*]] = memref.alloc
72  %m1 = memref.alloc(%s1) : memref<?xf32>
73
74  // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]])
75  %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) {
76    %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32>
77    %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32>
78    scf.yield %m3, %m2 : memref<?xf32>, memref<?xf32>
79  }
80
81  // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#0
82  // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#1
83  // CHECK: %[[d:.*]]:2 = bufferization.dealloc (%[[m0]], %[[m1]], %[[base0]], %[[base1]] : {{.*}}) if (%[[true]], %[[true]], %[[r]]#2, %[[r]]#3) retain (%[[r]]#0, %[[r]]#1 : {{.*}})
84  // CHECK: return %[[r]]#0, %[[r]]#1, %[[d]]#0, %[[d]]#1
85  return %r0, %r1 : memref<?xf32>, memref<?xf32>
86}
87