1 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H 2 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H 3 4 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" 5 #include "mlir/Dialect/MemRef/IR/MemRef.h" 6 #include "mlir/Pass/Pass.h" 7 8 namespace mlir { 9 class FunctionOpInterface; 10 class MemRefType; 11 class ModuleOp; 12 class RewritePatternSet; 13 class OpBuilder; 14 class SymbolTable; 15 16 namespace func { 17 class FuncOp; 18 } // namespace func 19 20 namespace bufferization { 21 struct OneShotBufferizationOptions; 22 23 /// Maps from symbol table to its corresponding dealloc helper function. 24 using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>; 25 26 //===----------------------------------------------------------------------===// 27 // Passes 28 //===----------------------------------------------------------------------===// 29 30 #define GEN_PASS_DECL 31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 32 33 /// Creates an instance of the BufferDeallocation pass to free all allocated 34 /// buffers. 35 std::unique_ptr<Pass> createBufferDeallocationPass(); 36 37 /// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all 38 /// allocated buffers. 39 std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass( 40 DeallocationOptions options = DeallocationOptions()); 41 42 /// Creates a pass that finds all temporary allocations 43 /// and attempts to move the deallocation after the last user/dependency 44 /// of the allocation, thereby optimizing allocation liveness. 45 std::unique_ptr<Pass> createOptimizeAllocationLivenessPass(); 46 47 /// Creates a pass that optimizes `bufferization.dealloc` operations. For 48 /// example, it reduces the number of alias checks needed at runtime using 49 /// static alias analysis. 50 std::unique_ptr<Pass> createBufferDeallocationSimplificationPass(); 51 52 /// Creates an instance of the LowerDeallocations pass to lower 53 /// `bufferization.dealloc` operations to the `memref` dialect. 54 std::unique_ptr<Pass> createLowerDeallocationsPass(); 55 56 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the 57 /// given pattern set for use in other transformation passes. 58 void populateBufferizationDeallocLoweringPattern( 59 RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap); 60 61 /// Construct the library function needed for the fully generic 62 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass. 63 /// The function can then be called at bufferization dealloc sites to determine 64 /// aliasing and ownership. 65 /// 66 /// The generated function takes two memrefs of indices and three memrefs of 67 /// booleans as arguments: 68 /// * The first argument A should contain the result of the 69 /// extract_aligned_pointer_as_index operation applied to the memrefs to be 70 /// deallocated 71 /// * The second argument B should contain the result of the 72 /// extract_aligned_pointer_as_index operation applied to the memrefs to be 73 /// retained 74 /// * The third argument C should contain the conditions as passed directly 75 /// to the deallocation operation. 76 /// * The fourth argument D is used to pass results to the caller. Those 77 /// represent the condition under which the memref at the corresponding 78 /// position in A should be deallocated. 79 /// * The fifth argument E is used to pass results to the caller. It 80 /// provides the ownership value corresponding the the memref at the same 81 /// position in B 82 /// 83 /// This helper function is supposed to be called once for each 84 /// `bufferization.dealloc` operation to determine the deallocation need and new 85 /// ownership indicator for the retained values, but does not perform the 86 /// deallocation itself. 87 /// 88 /// Generated code: 89 /// ``` 90 /// func.func @dealloc_helper( 91 /// %dyn_dealloc_base_pointer_list: memref<?xindex>, 92 /// %dyn_retain_base_pointer_list: memref<?xindex>, 93 /// %dyn_cond_list: memref<?xi1>, 94 /// %dyn_dealloc_cond_out: memref<?xi1>, 95 /// %dyn_ownership_out: memref<?xi1>) { 96 /// %c0 = arith.constant 0 : index 97 /// %c1 = arith.constant 1 : index 98 /// %true = arith.constant true 99 /// %false = arith.constant false 100 /// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0 101 /// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0 102 /// // Zero initialize result buffer. 103 /// scf.for %i = %c0 to %num_retain_memrefs step %c1 { 104 /// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1> 105 /// } 106 /// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 { 107 /// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i] 108 /// %cond = memref.load %dyn_cond_list[%i] 109 /// // Check for aliasing with retained memrefs. 110 /// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs 111 /// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) { 112 /// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j] 113 /// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index 114 /// scf.if %does_alias { 115 /// %curr_ownership = memref.load %dyn_ownership_out[%j] 116 /// %updated_ownership = arith.ori %curr_ownership, %cond : i1 117 /// memref.store %updated_ownership, %dyn_ownership_out[%j] 118 /// } 119 /// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index 120 /// %updated_aggregate = arith.andi %does_not_alias_aggregated, 121 /// %does_not_alias : i1 122 /// scf.yield %updated_aggregate : i1 123 /// } 124 /// // Check for aliasing with dealloc memrefs in the list before the 125 /// // current one, i.e., 126 /// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j], 127 /// // %dyn_dealloc_base_pointer[i])` 128 /// %does_not_alias_any = scf.for %j = %c0 to %i step %c1 129 /// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) { 130 /// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j] 131 /// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp 132 /// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias 133 /// scf.yield %updated_alias_agg : i1 134 /// } 135 /// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1 136 /// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1> 137 /// } 138 /// return 139 /// } 140 /// ``` 141 func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, 142 SymbolTable &symbolTable); 143 144 /// Run buffer deallocation. 145 LogicalResult deallocateBuffers(Operation *op); 146 147 /// Run the ownership-based buffer deallocation. 148 LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op, 149 DeallocationOptions options); 150 151 /// Creates a pass that moves allocations upwards to reduce the number of 152 /// required copies that are inserted during the BufferDeallocation pass. 153 std::unique_ptr<Pass> createBufferHoistingPass(); 154 155 /// Creates a pass that moves allocations upwards out of loops. This avoids 156 /// reallocations inside of loops. 157 std::unique_ptr<Pass> createBufferLoopHoistingPass(); 158 159 // Options struct for BufferResultsToOutParams pass. 160 // Note: defined only here, not in tablegen. 161 struct BufferResultsToOutParamsOpts { 162 /// Allocator function: Generate a memref allocation with the given type. 163 /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped 164 /// results, we don't allow passing a range of values for dynamic dims. 165 using AllocationFn = 166 std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>; 167 168 /// Memcpy function: Generate a memcpy between two memrefs. 169 using MemCpyFn = 170 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>; 171 172 // Filter function; returns true if the function should be converted. 173 // Defaults to true, i.e. all functions are converted. 174 std::function<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) { 175 return true; 176 }; 177 178 /// Allocation function; used to allocate a memref. 179 /// Default memref.alloc is used 180 AllocationFn allocationFn = [](OpBuilder &builder, Location loc, 181 MemRefType type) { 182 return builder.create<memref::AllocOp>(loc, type).getResult(); 183 }; 184 185 /// Memcpy function; used to create a copy between two memrefs. 186 /// Default memref.copy is used. 187 MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from, 188 Value to) { 189 builder.create<memref::CopyOp>(loc, from, to); 190 return success(); 191 }; 192 193 /// If true, the pass adds a "bufferize.result" attribute to each output 194 /// parameter. 195 bool addResultAttribute = false; 196 197 /// If true, the pass eliminates the memref.alloc and memcpy if the returned 198 /// memref is allocated in the current function. 199 bool hoistStaticAllocs = false; 200 }; 201 202 /// Creates a pass that converts memref function results to out-params. 203 std::unique_ptr<Pass> createBufferResultsToOutParamsPass( 204 const BufferResultsToOutParamsOpts &options = {}); 205 206 /// Replace buffers that are returned from a function with an out parameter. 207 /// Also update all call sites. 208 LogicalResult 209 promoteBufferResultsToOutParams(ModuleOp module, 210 const BufferResultsToOutParamsOpts &options); 211 212 /// Creates a pass that drops memref function results that are equivalent to a 213 /// function argument. 214 std::unique_ptr<Pass> createDropEquivalentBufferResultsPass(); 215 216 /// Create a pass that rewrites tensor.empty to bufferization.alloc_tensor. 217 std::unique_ptr<Pass> createEmptyTensorToAllocTensorPass(); 218 219 /// Drop all memref function results that are equivalent to a function argument. 220 LogicalResult dropEquivalentBufferResults(ModuleOp module); 221 222 /// Create a pass that bufferizes all ops that implement BufferizableOpInterface 223 /// with One-Shot Bufferize. 224 std::unique_ptr<Pass> createOneShotBufferizePass(); 225 226 /// Create a pass that bufferizes all ops that implement BufferizableOpInterface 227 /// with One-Shot Bufferize and the specified bufferization options. 228 std::unique_ptr<Pass> 229 createOneShotBufferizePass(const OneShotBufferizationOptions &options); 230 231 /// Creates a pass that promotes heap-based allocations to stack-based ones. 232 /// Only buffers smaller than the provided size are promoted. 233 /// Dynamic shaped buffers are promoted up to the given rank. 234 std::unique_ptr<Pass> 235 createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024, 236 unsigned maxRankOfAllocatedMemRef = 1); 237 238 /// Creates a pass that promotes heap-based allocations to stack-based ones. 239 /// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted. 240 std::unique_ptr<Pass> 241 createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc); 242 243 /// Create a pass that tries to eliminate tensor.empty ops that are anchored on 244 /// insert_slice ops. 245 std::unique_ptr<Pass> createEmptyTensorEliminationPass(); 246 247 //===----------------------------------------------------------------------===// 248 // Registration 249 //===----------------------------------------------------------------------===// 250 251 /// Generate the code for registering passes. 252 #define GEN_PASS_REGISTRATION 253 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 254 255 } // namespace bufferization 256 } // namespace mlir 257 258 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H 259