1// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s 2 3func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) { 4 %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) 5 %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) 6 %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) 7 // multiple must-alias 8 %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32> 9 %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> 10 %alloc = memref.alloc() : memref<2xi32> 11 %5:3 = bufferization.dealloc (%arg0, %4 : memref<2xi32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>) 12 return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1 13} 14 15// CHECK-LABEL: func @dealloc_deallocated_in_retained 16// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) 17// CHECK-NEXT: arith.constant false 18// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) 19// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] 20// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) 21// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the 22// COM: retained memrefs since the list of memrefs to be deallocated becomes empty 23// COM: due to the pattern under test (and thus there is no memref the retain values 24// COM: could alias to) 25// CHECK-NOT: if 26// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]] 27// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]] 28// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] : 29 30// ----- 31 32func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) { 33 %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index 34 %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index 35 %0 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0 : memref<2xi32>) 36 %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref<i32>, memref<i32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) 37 %2:2 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) 38 // multiple must-alias 39 %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32> 40 %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> 41 %alloc = memref.alloc() : memref<2xi32> 42 %5:3 = bufferization.dealloc (%base_buffer, %4 : memref<i32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>) 43 return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1 44} 45 46// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref 47// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) 48// CHECK-NEXT: arith.constant false 49// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] : 50// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] : 51// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) 52// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] 53// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) 54// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the 55// COM: retained memrefs since the list of memrefs to be deallocated becomes empty 56// COM: due to the pattern under test (and thus there is no memref the retain values 57// COM: could alias to) 58// CHECK-NOT: if 59// CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]] 60// CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]] 61// CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] : 62 63// ----- 64 65func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1, memref<2xi32>) { 66 %alloc = memref.alloc() : memref<2xi32> 67 %alloc0 = memref.alloc() : memref<2xi32> 68 %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>) 69 return %0#0, %0#1, %alloc : i1, i1, memref<2xi32> 70} 71 72// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias 73// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>) 74// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false 75// CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc( 76// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]]) 77// CHECK-NOT: retain 78// CHECK-NEXT: return [[FALSE]], [[FALSE]], [[ALLOC]] : 79 80// ----- 81 82func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1) { 83 %alloc = memref.alloc() : memref<2xi32> 84 %alloc0 = memref.alloc() : memref<2xi32> 85 %0 = arith.select %arg0, %alloc, %alloc0 : memref<2xi32> 86 %1:2 = bufferization.dealloc (%alloc, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg0, %arg3) retain (%arg1, %0 : memref<2xi32>, memref<2xi32>) 87 return %1#0, %1#1 : i1, i1 88} 89 90// CHECK-LABEL: func @dealloc_split_when_no_other_aliasing 91// CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1) 92// CHECK-NEXT: [[ALLOC0:%.+]] = memref.alloc( 93// CHECK-NEXT: [[ALLOC1:%.+]] = memref.alloc( 94// CHECK-NEXT: [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] : 95// COM: there is only one value in the retained lists because the 96// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here: 97// COM: - %alloc is guaranteed to not alias with %arg1. 98// COM: - %arg2 is guaranteed to not alias with %0. 99// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>) 100// CHECK-NEXT: [[V2:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]] : memref<2xi32>) 101// CHECK-NEXT: return [[V2]], [[V1]] : 102 103// ----- 104 105func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition( 106 %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) { 107 %true = arith.constant true 108 %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) 109 return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1 110} 111 112// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition 113// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>) 114// CHECK: bufferization.dealloc ([[ARG2]] :{{.*}}) if (%true{{[0-9_]*}}) 115// CHECK-NEXT: return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} : 116 117// ----- 118 119func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition( 120 %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) { 121 %true = arith.constant true 122 %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index 123 %base_buffer_1, %offset_1, %size_1, %stride_1 = memref.extract_strided_metadata %arg1 : memref<2xi32> -> memref<i32>, index, index, index 124 %base_buffer_2, %offset_2, %size_2, %stride_2 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index 125 %0:2 = bufferization.dealloc (%base_buffer, %base_buffer_1, %base_buffer_2 : memref<i32>, memref<i32>, memref<i32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) 126 return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1 127} 128 129// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition 130// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>) 131// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]] 132// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}}) 133// CHECK-NEXT: return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} : 134 135// ----- 136 137func.func @alloc_and_bbarg(%arg0: memref<5xf32>, %arg1: index, %arg2: index, %arg3: index) -> f32 { 138 %true = arith.constant true 139 %false = arith.constant false 140 %0:2 = scf.for %arg4 = %arg1 to %arg2 step %arg3 iter_args(%arg5 = %arg0, %arg6 = %false) -> (memref<5xf32>, i1) { 141 %alloc = memref.alloc() : memref<5xf32> 142 memref.copy %arg5, %alloc : memref<5xf32> to memref<5xf32> 143 %base_buffer_0, %offset_1, %sizes_2, %strides_3 = memref.extract_strided_metadata %arg5 : memref<5xf32> -> memref<f32>, index, index, index 144 %2 = bufferization.dealloc (%base_buffer_0, %alloc : memref<f32>, memref<5xf32>) if (%arg6, %true) retain (%alloc : memref<5xf32>) 145 scf.yield %alloc, %2 : memref<5xf32>, i1 146 } 147 %1 = memref.load %0#0[%arg1] : memref<5xf32> 148 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %0#0 : memref<5xf32> -> memref<f32>, index, index, index 149 bufferization.dealloc (%base_buffer : memref<f32>) if (%0#1) 150 return %1 : f32 151} 152 153// CHECK-LABEL: func @alloc_and_bbarg 154// CHECK: %[[true:.*]] = arith.constant true 155// CHECK: scf.for {{.*}} iter_args(%[[iter:.*]] = %{{.*}}, %{{.*}} = %{{.*}}) 156// CHECK: %[[alloc:.*]] = memref.alloc 157// CHECK: %[[view:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[iter]] 158// CHECK: bufferization.dealloc (%[[view]] : memref<f32>) 159// CHECK-NOT: retain 160// CHECK: scf.yield %[[alloc]], %[[true]] 161 162// ----- 163 164func.func @duplicate_memref(%arg0: memref<5xf32>, %arg1: memref<6xf32>, %c: i1) -> i1 { 165 %0 = bufferization.dealloc (%arg0, %arg0 : memref<5xf32>, memref<5xf32>) if (%c, %c) retain (%arg1 : memref<6xf32>) 166 return %0 : i1 167} 168 169// CHECK-LABEL: func @duplicate_memref( 170// CHECK: %[[r:.*]] = bufferization.dealloc (%{{.*}} : memref<5xf32>) if (%{{.*}}) retain (%{{.*}} : memref<6xf32>) 171// CHECK: return %[[r]] 172