1// RUN: mlir-opt %s -expand-realloc="emit-deallocs=false" -ownership-based-buffer-deallocation="private-function-dynamic-ownership=true" -canonicalize -buffer-deallocation-simplification | FileCheck %s 2 3// A function that reallocates two buffer inside of a loop. The simplification 4// pass should be able to figure out that the iter_args are always originating 5// from different allocations. IR like this one appears in the sparse compiler. 6 7// CHECK-LABEL: func private @loop_with_realloc( 8func.func private @loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) { 9 // CHECK-DAG: %[[false:.*]] = arith.constant false 10 // CHECK-DAG: %[[true:.*]] = arith.constant true 11 12 // CHECK: %[[m0:.*]] = memref.alloc 13 %m0 = memref.alloc(%s1) : memref<?xf32> 14 // CHECK: %[[m1:.*]] = memref.alloc 15 %m1 = memref.alloc(%s1) : memref<?xf32> 16 17 // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]]) 18 %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) { 19 // CHECK: %[[m2:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) { 20 // CHECK-NEXT: memref.alloc 21 // CHECK-NEXT: memref.subview 22 // CHECK-NEXT: memref.copy 23 // CHECK-NEXT: scf.yield %{{.*}}, %[[true]] 24 // CHECK-NEXT: } else { 25 // CHECK-NEXT: memref.reinterpret_cast 26 // CHECK-NEXT: scf.yield %{{.*}}, %[[false]] 27 // CHECK-NEXT: } 28 %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32> 29 // CHECK: %[[m3:.*]]:2 = scf.if %{{.*}} -> (memref<?xf32>, i1) { 30 // CHECK-NEXT: memref.alloc 31 // CHECK-NEXT: memref.subview 32 // CHECK-NEXT: memref.copy 33 // CHECK-NEXT: scf.yield %{{.*}}, %[[true]] 34 // CHECK-NEXT: } else { 35 // CHECK-NEXT: memref.reinterpret_cast 36 // CHECK-NEXT: scf.yield %{{.*}}, %[[false]] 37 // CHECK-NEXT: } 38 %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32> 39 40 // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg0]] 41 // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[arg1]] 42 // CHECK: %[[d0:.*]] = bufferization.dealloc (%[[base0]] : memref<f32>) if (%[[o0]]) retain (%[[m2]]#0 : memref<?xf32>) 43 // CHECK: %[[d1:.*]] = bufferization.dealloc (%[[base1]] : memref<f32>) if (%[[o1]]) retain (%[[m3]]#0 : memref<?xf32>) 44 // CHECK-DAG: %[[o2:.*]] = arith.ori %[[d0]], %[[m2]]#1 45 // CHECK-DAG: %[[o3:.*]] = arith.ori %[[d1]], %[[m3]]#1 46 // CHECK: scf.yield %[[m2]]#0, %[[m3]]#0, %[[o2]], %[[o3]] 47 scf.yield %m2, %m3 : memref<?xf32>, memref<?xf32> 48 } 49 50 // CHECK: %[[d2:.*]] = bufferization.dealloc (%[[m0]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#0 : memref<?xf32>) 51 // CHECK: %[[d3:.*]] = bufferization.dealloc (%[[m1]] : memref<?xf32>) if (%[[true]]) retain (%[[r]]#1 : memref<?xf32>) 52 // CHECK-DAG: %[[or0:.*]] = arith.ori %[[d2]], %[[r]]#2 53 // CHECK-DAG: %[[or1:.*]] = arith.ori %[[d3]], %[[r]]#3 54 // CHECK: return %[[r]]#0, %[[r]]#1, %[[or0]], %[[or1]] 55 return %r0, %r1 : memref<?xf32>, memref<?xf32> 56} 57 58// ----- 59 60// The yielded values of the loop are swapped. Therefore, the 61// bufferization.dealloc before the func.return can now longer be split, 62// because %r0 could originate from either %m0 and %m1 (same for %r1). 63 64// CHECK-LABEL: func private @swapping_loop_with_realloc( 65func.func private @swapping_loop_with_realloc(%lb: index, %ub: index, %step: index, %c: i1, %s1: index, %s2: index) -> (memref<?xf32>, memref<?xf32>) { 66 // CHECK-DAG: %[[false:.*]] = arith.constant false 67 // CHECK-DAG: %[[true:.*]] = arith.constant true 68 69 // CHECK: %[[m0:.*]] = memref.alloc 70 %m0 = memref.alloc(%s1) : memref<?xf32> 71 // CHECK: %[[m1:.*]] = memref.alloc 72 %m1 = memref.alloc(%s1) : memref<?xf32> 73 74 // CHECK: %[[r:.*]]:4 = scf.for {{.*}} iter_args(%[[arg0:.*]] = %[[m0]], %[[arg1:.*]] = %[[m1]], %[[o0:.*]] = %[[false]], %[[o1:.*]] = %[[false]]) 75 %r0, %r1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %m0, %arg1 = %m1) -> (memref<?xf32>, memref<?xf32>) { 76 %m2 = memref.realloc %arg0(%s2) : memref<?xf32> to memref<?xf32> 77 %m3 = memref.realloc %arg1(%s2) : memref<?xf32> to memref<?xf32> 78 scf.yield %m3, %m2 : memref<?xf32>, memref<?xf32> 79 } 80 81 // CHECK: %[[base0:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#0 82 // CHECK: %[[base1:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[r]]#1 83 // CHECK: %[[d:.*]]:2 = bufferization.dealloc (%[[m0]], %[[m1]], %[[base0]], %[[base1]] : {{.*}}) if (%[[true]], %[[true]], %[[r]]#2, %[[r]]#3) retain (%[[r]]#0, %[[r]]#1 : {{.*}}) 84 // CHECK: return %[[r]]#0, %[[r]]#1, %[[d]]#0, %[[d]]#1 85 return %r0, %r1 : memref<?xf32>, memref<?xf32> 86} 87