xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp (revision ace6072bca65f946478edeeee018ca1d6a947a3e)
1 //===- BufferOptimizations.cpp - pre-pass optimizations for bufferization -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for three optimization passes. The first two
10 // passes try to move alloc nodes out of blocks to reduce the number of
11 // allocations and copies during buffer deallocation. The third pass tries to
12 // convert heap-based allocations to stack-based allocations, if possible.
13 
14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Interfaces/LoopLikeInterface.h"
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 namespace bufferization {
25 #define GEN_PASS_DEF_BUFFERHOISTING
26 #define GEN_PASS_DEF_BUFFERLOOPHOISTING
27 #define GEN_PASS_DEF_PROMOTEBUFFERSTOSTACK
28 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
29 } // namespace bufferization
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::bufferization;
34 
35 /// Returns true if the given operation implements a known high-level region-
36 /// based control-flow interface.
37 static bool isKnownControlFlowInterface(Operation *op) {
38   return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
39 }
40 
41 /// Check if the size of the allocation is less than the given size. The
42 /// transformation is only applied to small buffers since large buffers could
43 /// exceed the stack space.
44 static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
45                                 unsigned maxRankOfAllocatedMemRef) {
46   auto type = alloc.getType().dyn_cast<ShapedType>();
47   if (!type || !alloc.getDefiningOp<memref::AllocOp>())
48     return false;
49   if (!type.hasStaticShape()) {
50     // Check if the dynamic shape dimension of the alloc is produced by
51     // `memref.rank`. If this is the case, it is likely to be small.
52     // Furthermore, the dimension is limited to the maximum rank of the
53     // allocated memref to avoid large values by multiplying several small
54     // values.
55     if (type.getRank() <= maxRankOfAllocatedMemRef) {
56       return llvm::all_of(alloc.getDefiningOp()->getOperands(),
57                           [&](Value operand) {
58                             return operand.getDefiningOp<memref::RankOp>();
59                           });
60     }
61     return false;
62   }
63   unsigned bitwidth = mlir::DataLayout::closest(alloc.getDefiningOp())
64                           .getTypeSizeInBits(type.getElementType());
65   return type.getNumElements() * bitwidth <= maximumSizeInBytes * 8;
66 }
67 
68 /// Checks whether the given aliases leave the allocation scope.
69 static bool
70 leavesAllocationScope(Region *parentRegion,
71                       const BufferViewFlowAnalysis::ValueSetT &aliases) {
72   for (Value alias : aliases) {
73     for (auto *use : alias.getUsers()) {
74       // If there is at least one alias that leaves the parent region, we know
75       // that this alias escapes the whole region and hence the associated
76       // allocation leaves allocation scope.
77       if (isRegionReturnLike(use) && use->getParentRegion() == parentRegion)
78         return true;
79     }
80   }
81   return false;
82 }
83 
84 /// Checks, if an automated allocation scope for a given alloc value exists.
85 static bool hasAllocationScope(Value alloc,
86                                const BufferViewFlowAnalysis &aliasAnalysis) {
87   Region *region = alloc.getParentRegion();
88   do {
89     if (Operation *parentOp = region->getParentOp()) {
90       // Check if the operation is an automatic allocation scope and whether an
91       // alias leaves the scope. This means, an allocation yields out of
92       // this scope and can not be transformed in a stack-based allocation.
93       if (parentOp->hasTrait<OpTrait::AutomaticAllocationScope>() &&
94           !leavesAllocationScope(region, aliasAnalysis.resolve(alloc)))
95         return true;
96       // Check if the operation is a known control flow interface and break the
97       // loop to avoid transformation in loops. Furthermore skip transformation
98       // if the operation does not implement a RegionBeanchOpInterface.
99       if (BufferPlacementTransformationBase::isLoop(parentOp) ||
100           !isKnownControlFlowInterface(parentOp))
101         break;
102     }
103   } while ((region = region->getParentRegion()));
104   return false;
105 }
106 
107 namespace {
108 
109 //===----------------------------------------------------------------------===//
110 // BufferAllocationHoisting
111 //===----------------------------------------------------------------------===//
112 
113 /// A base implementation compatible with the `BufferAllocationHoisting` class.
114 struct BufferAllocationHoistingStateBase {
115   /// A pointer to the current dominance info.
116   DominanceInfo *dominators;
117 
118   /// The current allocation value.
119   Value allocValue;
120 
121   /// The current placement block (if any).
122   Block *placementBlock;
123 
124   /// Initializes the state base.
125   BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,
126                                     Block *placementBlock)
127       : dominators(dominators), allocValue(allocValue),
128         placementBlock(placementBlock) {}
129 };
130 
131 /// Implements the actual hoisting logic for allocation nodes.
132 template <typename StateT>
133 class BufferAllocationHoisting : public BufferPlacementTransformationBase {
134 public:
135   BufferAllocationHoisting(Operation *op)
136       : BufferPlacementTransformationBase(op), dominators(op),
137         postDominators(op), scopeOp(op) {}
138 
139   /// Moves allocations upwards.
140   void hoist() {
141     SmallVector<Value> allocsAndAllocas;
142     for (BufferPlacementAllocs::AllocEntry &entry : allocs)
143       allocsAndAllocas.push_back(std::get<0>(entry));
144     scopeOp->walk([&](memref::AllocaOp op) {
145       allocsAndAllocas.push_back(op.getMemref());
146     });
147 
148     for (auto allocValue : allocsAndAllocas) {
149       if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
150         continue;
151       Operation *definingOp = allocValue.getDefiningOp();
152       assert(definingOp && "No defining op");
153       auto operands = definingOp->getOperands();
154       auto resultAliases = aliases.resolve(allocValue);
155       // Determine the common dominator block of all aliases.
156       Block *dominatorBlock =
157           findCommonDominator(allocValue, resultAliases, dominators);
158       // Init the initial hoisting state.
159       StateT state(&dominators, allocValue, allocValue.getParentBlock());
160       // Check for additional allocation dependencies to compute an upper bound
161       // for hoisting.
162       Block *dependencyBlock = nullptr;
163       // If this node has dependencies, check all dependent nodes. This ensures
164       // that all dependency values have been computed before allocating the
165       // buffer.
166       for (Value depValue : operands) {
167         Block *depBlock = depValue.getParentBlock();
168         if (!dependencyBlock || dominators.dominates(dependencyBlock, depBlock))
169           dependencyBlock = depBlock;
170       }
171 
172       // Find the actual placement block and determine the start operation using
173       // an upper placement-block boundary. The idea is that placement block
174       // cannot be moved any further upwards than the given upper bound.
175       Block *placementBlock = findPlacementBlock(
176           state, state.computeUpperBound(dominatorBlock, dependencyBlock));
177       Operation *startOperation = BufferPlacementAllocs::getStartOperation(
178           allocValue, placementBlock, liveness);
179 
180       // Move the alloc in front of the start operation.
181       Operation *allocOperation = allocValue.getDefiningOp();
182       allocOperation->moveBefore(startOperation);
183     }
184   }
185 
186 private:
187   /// Finds a valid placement block by walking upwards in the CFG until we
188   /// either cannot continue our walk due to constraints (given by the StateT
189   /// implementation) or we have reached the upper-most dominator block.
190   Block *findPlacementBlock(StateT &state, Block *upperBound) {
191     Block *currentBlock = state.placementBlock;
192     // Walk from the innermost regions/loops to the outermost regions/loops and
193     // find an appropriate placement block that satisfies the constraint of the
194     // current StateT implementation. Walk until we reach the upperBound block
195     // (if any).
196 
197     // If we are not able to find a valid parent operation or an associated
198     // parent block, break the walk loop.
199     Operation *parentOp;
200     Block *parentBlock;
201     while ((parentOp = currentBlock->getParentOp()) &&
202            (parentBlock = parentOp->getBlock()) &&
203            (!upperBound ||
204             dominators.properlyDominates(upperBound, currentBlock))) {
205       // Try to find an immediate dominator and check whether the parent block
206       // is above the immediate dominator (if any).
207       DominanceInfoNode *idom = nullptr;
208 
209       // DominanceInfo doesn't support getNode queries for single-block regions.
210       if (!currentBlock->isEntryBlock())
211         idom = dominators.getNode(currentBlock)->getIDom();
212 
213       if (idom && dominators.properlyDominates(parentBlock, idom->getBlock())) {
214         // If the current immediate dominator is below the placement block, move
215         // to the immediate dominator block.
216         currentBlock = idom->getBlock();
217         state.recordMoveToDominator(currentBlock);
218       } else {
219         // We have to move to our parent block since an immediate dominator does
220         // either not exist or is above our parent block. If we cannot move to
221         // our parent operation due to constraints given by the StateT
222         // implementation, break the walk loop. Furthermore, we should not move
223         // allocations out of unknown region-based control-flow operations.
224         if (!isKnownControlFlowInterface(parentOp) ||
225             !state.isLegalPlacement(parentOp))
226           break;
227         // Move to our parent block by notifying the current StateT
228         // implementation.
229         currentBlock = parentBlock;
230         state.recordMoveToParent(currentBlock);
231       }
232     }
233     // Return the finally determined placement block.
234     return state.placementBlock;
235   }
236 
237   /// The dominator info to find the appropriate start operation to move the
238   /// allocs.
239   DominanceInfo dominators;
240 
241   /// The post dominator info to move the dependent allocs in the right
242   /// position.
243   PostDominanceInfo postDominators;
244 
245   /// The map storing the final placement blocks of a given alloc value.
246   llvm::DenseMap<Value, Block *> placementBlocks;
247 
248   /// The operation that this transformation is working on. It is used to also
249   /// gather allocas.
250   Operation *scopeOp;
251 };
252 
253 /// A state implementation compatible with the `BufferAllocationHoisting` class
254 /// that hoists allocations into dominator blocks while keeping them inside of
255 /// loops.
256 struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
257   using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
258 
259   /// Computes the upper bound for the placement block search.
260   Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
261     // If we do not have a dependency block, the upper bound is given by the
262     // dominator block.
263     if (!dependencyBlock)
264       return dominatorBlock;
265 
266     // Find the "lower" block of the dominator and the dependency block to
267     // ensure that we do not move allocations above this block.
268     return dominators->properlyDominates(dominatorBlock, dependencyBlock)
269                ? dependencyBlock
270                : dominatorBlock;
271   }
272 
273   /// Returns true if the given operation does not represent a loop.
274   bool isLegalPlacement(Operation *op) {
275     return !BufferPlacementTransformationBase::isLoop(op);
276   }
277 
278   /// Returns true if the given operation should be considered for hoisting.
279   static bool shouldHoistOpType(Operation *op) {
280     return llvm::isa<memref::AllocOp>(op);
281   }
282 
283   /// Sets the current placement block to the given block.
284   void recordMoveToDominator(Block *block) { placementBlock = block; }
285 
286   /// Sets the current placement block to the given block.
287   void recordMoveToParent(Block *block) { recordMoveToDominator(block); }
288 };
289 
290 /// A state implementation compatible with the `BufferAllocationHoisting` class
291 /// that hoists allocations out of loops.
292 struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
293   using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
294 
295   /// Remembers the dominator block of all aliases.
296   Block *aliasDominatorBlock = nullptr;
297 
298   /// Computes the upper bound for the placement block search.
299   Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
300     aliasDominatorBlock = dominatorBlock;
301     // If there is a dependency block, we have to use this block as an upper
302     // bound to satisfy all allocation value dependencies.
303     return dependencyBlock ? dependencyBlock : nullptr;
304   }
305 
306   /// Returns true if the given operation represents a loop and one of the
307   /// aliases caused the `aliasDominatorBlock` to be "above" the block of the
308   /// given loop operation. If this is the case, it indicates that the
309   /// allocation is passed via a back edge.
310   bool isLegalPlacement(Operation *op) {
311     return BufferPlacementTransformationBase::isLoop(op) &&
312            !dominators->dominates(aliasDominatorBlock, op->getBlock());
313   }
314 
315   /// Returns true if the given operation should be considered for hoisting.
316   static bool shouldHoistOpType(Operation *op) {
317     return llvm::isa<memref::AllocOp, memref::AllocaOp>(op);
318   }
319 
320   /// Does not change the internal placement block, as we want to move
321   /// operations out of loops only.
322   void recordMoveToDominator(Block *block) {}
323 
324   /// Sets the current placement block to the given block.
325   void recordMoveToParent(Block *block) { placementBlock = block; }
326 };
327 
328 //===----------------------------------------------------------------------===//
329 // BufferPlacementPromotion
330 //===----------------------------------------------------------------------===//
331 
332 /// Promotes heap-based allocations to stack-based allocations (if possible).
333 class BufferPlacementPromotion : BufferPlacementTransformationBase {
334 public:
335   BufferPlacementPromotion(Operation *op)
336       : BufferPlacementTransformationBase(op) {}
337 
338   /// Promote buffers to stack-based allocations.
339   void promote(function_ref<bool(Value)> isSmallAlloc) {
340     for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
341       Value alloc = std::get<0>(entry);
342       Operation *dealloc = std::get<1>(entry);
343       // Checking several requirements to transform an AllocOp into an AllocaOp.
344       // The transformation is done if the allocation is limited to a given
345       // size. Furthermore, a deallocation must not be defined for this
346       // allocation entry and a parent allocation scope must exist.
347       if (!isSmallAlloc(alloc) || dealloc ||
348           !hasAllocationScope(alloc, aliases))
349         continue;
350 
351       Operation *startOperation = BufferPlacementAllocs::getStartOperation(
352           alloc, alloc.getParentBlock(), liveness);
353       // Build a new alloca that is associated with its parent
354       // `AutomaticAllocationScope` determined during the initialization phase.
355       OpBuilder builder(startOperation);
356       Operation *allocOp = alloc.getDefiningOp();
357       Operation *alloca = builder.create<memref::AllocaOp>(
358           alloc.getLoc(), alloc.getType().cast<MemRefType>(),
359           allocOp->getOperands(), allocOp->getAttrs());
360 
361       // Replace the original alloc by a newly created alloca.
362       allocOp->replaceAllUsesWith(alloca);
363       allocOp->erase();
364     }
365   }
366 };
367 
368 //===----------------------------------------------------------------------===//
369 // BufferOptimizationPasses
370 //===----------------------------------------------------------------------===//
371 
372 /// The buffer hoisting pass that hoists allocation nodes into dominating
373 /// blocks.
374 struct BufferHoistingPass
375     : public bufferization::impl::BufferHoistingBase<BufferHoistingPass> {
376 
377   void runOnOperation() override {
378     // Hoist all allocations into dominator blocks.
379     BufferAllocationHoisting<BufferAllocationHoistingState> optimizer(
380         getOperation());
381     optimizer.hoist();
382   }
383 };
384 
385 /// The buffer loop hoisting pass that hoists allocation nodes out of loops.
386 struct BufferLoopHoistingPass
387     : public bufferization::impl::BufferLoopHoistingBase<
388           BufferLoopHoistingPass> {
389 
390   void runOnOperation() override {
391     // Hoist all allocations out of loops.
392     BufferAllocationHoisting<BufferAllocationLoopHoistingState> optimizer(
393         getOperation());
394     optimizer.hoist();
395   }
396 };
397 
398 /// The promote buffer to stack pass that tries to convert alloc nodes into
399 /// alloca nodes.
400 class PromoteBuffersToStackPass
401     : public bufferization::impl::PromoteBuffersToStackBase<
402           PromoteBuffersToStackPass> {
403 public:
404   PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,
405                             unsigned maxRankOfAllocatedMemRef) {
406     this->maxAllocSizeInBytes = maxAllocSizeInBytes;
407     this->maxRankOfAllocatedMemRef = maxRankOfAllocatedMemRef;
408   }
409 
410   explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)
411       : isSmallAlloc(std::move(isSmallAlloc)) {}
412 
413   LogicalResult initialize(MLIRContext *context) override {
414     if (isSmallAlloc == nullptr) {
415       isSmallAlloc = [=](Value alloc) {
416         return defaultIsSmallAlloc(alloc, maxAllocSizeInBytes,
417                                    maxRankOfAllocatedMemRef);
418       };
419     }
420     return success();
421   }
422 
423   void runOnOperation() override {
424     // Move all allocation nodes and convert candidates into allocas.
425     BufferPlacementPromotion optimizer(getOperation());
426     optimizer.promote(isSmallAlloc);
427   }
428 
429 private:
430   std::function<bool(Value)> isSmallAlloc;
431 };
432 
433 } // namespace
434 
435 std::unique_ptr<Pass> mlir::bufferization::createBufferHoistingPass() {
436   return std::make_unique<BufferHoistingPass>();
437 }
438 
439 std::unique_ptr<Pass> mlir::bufferization::createBufferLoopHoistingPass() {
440   return std::make_unique<BufferLoopHoistingPass>();
441 }
442 
443 std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
444     unsigned maxAllocSizeInBytes, unsigned maxRankOfAllocatedMemRef) {
445   return std::make_unique<PromoteBuffersToStackPass>(maxAllocSizeInBytes,
446                                                      maxRankOfAllocatedMemRef);
447 }
448 
449 std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
450     std::function<bool(Value)> isSmallAlloc) {
451   return std::make_unique<PromoteBuffersToStackPass>(std::move(isSmallAlloc));
452 }
453