xref: /llvm-project/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (revision 8906b7be918be653d3c5f2ef3dbd923561603969)
1 #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
2 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
3 
4 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
5 #include "mlir/Dialect/MemRef/IR/MemRef.h"
6 #include "mlir/Pass/Pass.h"
7 
8 namespace mlir {
9 class FunctionOpInterface;
10 class MemRefType;
11 class ModuleOp;
12 class RewritePatternSet;
13 class OpBuilder;
14 class SymbolTable;
15 
16 namespace func {
17 class FuncOp;
18 } // namespace func
19 
20 namespace bufferization {
21 struct OneShotBufferizationOptions;
22 
23 /// Maps from symbol table to its corresponding dealloc helper function.
24 using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>;
25 
26 //===----------------------------------------------------------------------===//
27 // Passes
28 //===----------------------------------------------------------------------===//
29 
30 #define GEN_PASS_DECL
31 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
32 
33 /// Creates an instance of the BufferDeallocation pass to free all allocated
34 /// buffers.
35 std::unique_ptr<Pass> createBufferDeallocationPass();
36 
37 /// Creates an instance of the OwnershipBasedBufferDeallocation pass to free all
38 /// allocated buffers.
39 std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
40     DeallocationOptions options = DeallocationOptions());
41 
42 /// Creates a pass that finds all temporary allocations
43 /// and attempts to move the deallocation after the last user/dependency
44 /// of the allocation, thereby optimizing allocation liveness.
45 std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
46 
47 /// Creates a pass that optimizes `bufferization.dealloc` operations. For
48 /// example, it reduces the number of alias checks needed at runtime using
49 /// static alias analysis.
50 std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
51 
52 /// Creates an instance of the LowerDeallocations pass to lower
53 /// `bufferization.dealloc` operations to the `memref` dialect.
54 std::unique_ptr<Pass> createLowerDeallocationsPass();
55 
56 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the
57 /// given pattern set for use in other transformation passes.
58 void populateBufferizationDeallocLoweringPattern(
59     RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap);
60 
61 /// Construct the library function needed for the fully generic
62 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
63 /// The function can then be called at bufferization dealloc sites to determine
64 /// aliasing and ownership.
65 ///
66 /// The generated function takes two memrefs of indices and three memrefs of
67 /// booleans as arguments:
68 ///   * The first argument A should contain the result of the
69 ///     extract_aligned_pointer_as_index operation applied to the memrefs to be
70 ///     deallocated
71 ///   * The second argument B should contain the result of the
72 ///     extract_aligned_pointer_as_index operation applied to the memrefs to be
73 ///     retained
74 ///   * The third argument C should contain the conditions as passed directly
75 ///     to the deallocation operation.
76 ///   * The fourth argument D is used to pass results to the caller. Those
77 ///     represent the condition under which the memref at the corresponding
78 ///     position in A should be deallocated.
79 ///   * The fifth argument E is used to pass results to the caller. It
80 ///     provides the ownership value corresponding the the memref at the same
81 ///     position in B
82 ///
83 /// This helper function is supposed to be called once for each
84 /// `bufferization.dealloc` operation to determine the deallocation need and new
85 /// ownership indicator for the retained values, but does not perform the
86 /// deallocation itself.
87 ///
88 /// Generated code:
89 /// ```
90 /// func.func @dealloc_helper(
91 ///     %dyn_dealloc_base_pointer_list: memref<?xindex>,
92 ///     %dyn_retain_base_pointer_list: memref<?xindex>,
93 ///     %dyn_cond_list: memref<?xi1>,
94 ///     %dyn_dealloc_cond_out: memref<?xi1>,
95 ///     %dyn_ownership_out: memref<?xi1>) {
96 ///   %c0 = arith.constant 0 : index
97 ///   %c1 = arith.constant 1 : index
98 ///   %true = arith.constant true
99 ///   %false = arith.constant false
100 ///   %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
101 ///   %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
102 ///   // Zero initialize result buffer.
103 ///   scf.for %i = %c0 to %num_retain_memrefs step %c1 {
104 ///     memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
105 ///   }
106 ///   scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
107 ///     %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
108 ///     %cond = memref.load %dyn_cond_list[%i]
109 ///     // Check for aliasing with retained memrefs.
110 ///     %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
111 ///         step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
112 ///       %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
113 ///       %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
114 ///       scf.if %does_alias {
115 ///         %curr_ownership = memref.load %dyn_ownership_out[%j]
116 ///         %updated_ownership = arith.ori %curr_ownership, %cond : i1
117 ///         memref.store %updated_ownership, %dyn_ownership_out[%j]
118 ///       }
119 ///       %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
120 ///       %updated_aggregate = arith.andi %does_not_alias_aggregated,
121 ///                                       %does_not_alias : i1
122 ///       scf.yield %updated_aggregate : i1
123 ///     }
124 ///     // Check for aliasing with dealloc memrefs in the list before the
125 ///     // current one, i.e.,
126 ///     // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
127 ///     // %dyn_dealloc_base_pointer[i])`
128 ///     %does_not_alias_any = scf.for %j = %c0 to %i step %c1
129 ///        iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
130 ///       %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
131 ///       %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
132 ///       %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
133 ///       scf.yield %updated_alias_agg : i1
134 ///     }
135 ///     %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
136 ///     memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
137 ///   }
138 ///   return
139 /// }
140 /// ```
141 func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
142                                               SymbolTable &symbolTable);
143 
144 /// Run buffer deallocation.
145 LogicalResult deallocateBuffers(Operation *op);
146 
147 /// Run the ownership-based buffer deallocation.
148 LogicalResult deallocateBuffersOwnershipBased(FunctionOpInterface op,
149                                               DeallocationOptions options);
150 
151 /// Creates a pass that moves allocations upwards to reduce the number of
152 /// required copies that are inserted during the BufferDeallocation pass.
153 std::unique_ptr<Pass> createBufferHoistingPass();
154 
155 /// Creates a pass that moves allocations upwards out of loops. This avoids
156 /// reallocations inside of loops.
157 std::unique_ptr<Pass> createBufferLoopHoistingPass();
158 
159 // Options struct for BufferResultsToOutParams pass.
160 // Note: defined only here, not in tablegen.
161 struct BufferResultsToOutParamsOpts {
162   /// Allocator function: Generate a memref allocation with the given type.
163   /// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
164   /// results, we don't allow passing a range of values for dynamic dims.
165   using AllocationFn =
166       std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
167 
168   /// Memcpy function: Generate a memcpy between two memrefs.
169   using MemCpyFn =
170       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
171 
172   // Filter function; returns true if the function should be converted.
173   // Defaults to true, i.e. all functions are converted.
174   std::function<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
175     return true;
176   };
177 
178   /// Allocation function; used to allocate a memref.
179   /// Default memref.alloc is used
180   AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
181                                  MemRefType type) {
182     return builder.create<memref::AllocOp>(loc, type).getResult();
183   };
184 
185   /// Memcpy function; used to create a copy between two memrefs.
186   /// Default memref.copy is used.
187   MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
188                          Value to) {
189     builder.create<memref::CopyOp>(loc, from, to);
190     return success();
191   };
192 
193   /// If true, the pass adds a "bufferize.result" attribute to each output
194   /// parameter.
195   bool addResultAttribute = false;
196 
197   /// If true, the pass eliminates the memref.alloc and memcpy if the returned
198   /// memref is allocated in the current function.
199   bool hoistStaticAllocs = false;
200 };
201 
202 /// Creates a pass that converts memref function results to out-params.
203 std::unique_ptr<Pass> createBufferResultsToOutParamsPass(
204     const BufferResultsToOutParamsOpts &options = {});
205 
206 /// Replace buffers that are returned from a function with an out parameter.
207 /// Also update all call sites.
208 LogicalResult
209 promoteBufferResultsToOutParams(ModuleOp module,
210                                 const BufferResultsToOutParamsOpts &options);
211 
212 /// Creates a pass that drops memref function results that are equivalent to a
213 /// function argument.
214 std::unique_ptr<Pass> createDropEquivalentBufferResultsPass();
215 
216 /// Create a pass that rewrites tensor.empty to bufferization.alloc_tensor.
217 std::unique_ptr<Pass> createEmptyTensorToAllocTensorPass();
218 
219 /// Drop all memref function results that are equivalent to a function argument.
220 LogicalResult dropEquivalentBufferResults(ModuleOp module);
221 
222 /// Create a pass that bufferizes all ops that implement BufferizableOpInterface
223 /// with One-Shot Bufferize.
224 std::unique_ptr<Pass> createOneShotBufferizePass();
225 
226 /// Create a pass that bufferizes all ops that implement BufferizableOpInterface
227 /// with One-Shot Bufferize and the specified bufferization options.
228 std::unique_ptr<Pass>
229 createOneShotBufferizePass(const OneShotBufferizationOptions &options);
230 
231 /// Creates a pass that promotes heap-based allocations to stack-based ones.
232 /// Only buffers smaller than the provided size are promoted.
233 /// Dynamic shaped buffers are promoted up to the given rank.
234 std::unique_ptr<Pass>
235 createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
236                                 unsigned maxRankOfAllocatedMemRef = 1);
237 
238 /// Creates a pass that promotes heap-based allocations to stack-based ones.
239 /// Only buffers smaller with `isSmallAlloc(alloc) == true` are promoted.
240 std::unique_ptr<Pass>
241 createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
242 
243 /// Create a pass that tries to eliminate tensor.empty ops that are anchored on
244 /// insert_slice ops.
245 std::unique_ptr<Pass> createEmptyTensorEliminationPass();
246 
247 //===----------------------------------------------------------------------===//
248 // Registration
249 //===----------------------------------------------------------------------===//
250 
251 /// Generate the code for registering passes.
252 #define GEN_PASS_REGISTRATION
253 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
254 
255 } // namespace bufferization
256 } // namespace mlir
257 
258 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
259