xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp (revision 55e3857931d1e187af574b322e026ced379ee1f2)
1 //===- BufferOptimizations.cpp - pre-pass optimizations for bufferization -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic for three optimization passes. The first two
10 // passes try to move alloc nodes out of blocks to reduce the number of
11 // allocations and copies during buffer deallocation. The third pass tries to
12 // convert heap-based allocations to stack-based allocations, if possible.
13 
14 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
17 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
18 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/Interfaces/LoopLikeInterface.h"
23 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 namespace bufferization {
27 #define GEN_PASS_DEF_BUFFERHOISTING
28 #define GEN_PASS_DEF_BUFFERLOOPHOISTING
29 #define GEN_PASS_DEF_PROMOTEBUFFERSTOSTACK
30 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31 } // namespace bufferization
32 } // namespace mlir
33 
34 using namespace mlir;
35 using namespace mlir::bufferization;
36 
37 /// Returns true if the given operation implements a known high-level region-
38 /// based control-flow interface.
39 static bool isKnownControlFlowInterface(Operation *op) {
40   return isa<LoopLikeOpInterface, RegionBranchOpInterface>(op);
41 }
42 
43 /// Returns true if the given operation implements the AllocationOpInterface
44 /// and it supports the dominate block hoisting.
45 static bool allowAllocDominateBlockHoisting(Operation *op) {
46   auto allocOp = dyn_cast<AllocationOpInterface>(op);
47   return allocOp && (static_cast<uint8_t>(allocOp.getHoistingKind()) &
48                      static_cast<uint8_t>(HoistingKind::Block));
49 }
50 
51 /// Returns true if the given operation implements the AllocationOpInterface
52 /// and it supports the loop hoisting.
53 static bool allowAllocLoopHoisting(Operation *op) {
54   auto allocOp = dyn_cast<AllocationOpInterface>(op);
55   return allocOp && (static_cast<uint8_t>(allocOp.getHoistingKind()) &
56                      static_cast<uint8_t>(HoistingKind::Loop));
57 }
58 
59 /// Check if the size of the allocation is less than the given size. The
60 /// transformation is only applied to small buffers since large buffers could
61 /// exceed the stack space.
62 static bool defaultIsSmallAlloc(Value alloc, unsigned maximumSizeInBytes,
63                                 unsigned maxRankOfAllocatedMemRef) {
64   auto type = dyn_cast<ShapedType>(alloc.getType());
65   if (!type || !alloc.getDefiningOp<memref::AllocOp>())
66     return false;
67   if (!type.hasStaticShape()) {
68     // Check if the dynamic shape dimension of the alloc is produced by
69     // `memref.rank`. If this is the case, it is likely to be small.
70     // Furthermore, the dimension is limited to the maximum rank of the
71     // allocated memref to avoid large values by multiplying several small
72     // values.
73     if (type.getRank() <= maxRankOfAllocatedMemRef) {
74       return llvm::all_of(alloc.getDefiningOp()->getOperands(),
75                           [&](Value operand) {
76                             return operand.getDefiningOp<memref::RankOp>();
77                           });
78     }
79     return false;
80   }
81   unsigned bitwidth = mlir::DataLayout::closest(alloc.getDefiningOp())
82                           .getTypeSizeInBits(type.getElementType());
83   return type.getNumElements() * bitwidth <= maximumSizeInBytes * 8;
84 }
85 
86 /// Checks whether the given aliases leave the allocation scope.
87 static bool
88 leavesAllocationScope(Region *parentRegion,
89                       const BufferViewFlowAnalysis::ValueSetT &aliases) {
90   for (Value alias : aliases) {
91     for (auto *use : alias.getUsers()) {
92       // If there is at least one alias that leaves the parent region, we know
93       // that this alias escapes the whole region and hence the associated
94       // allocation leaves allocation scope.
95       if (isa<RegionBranchTerminatorOpInterface>(use) &&
96           use->getParentRegion() == parentRegion)
97         return true;
98     }
99   }
100   return false;
101 }
102 
103 /// Checks, if an automated allocation scope for a given alloc value exists.
104 static bool hasAllocationScope(Value alloc,
105                                const BufferViewFlowAnalysis &aliasAnalysis) {
106   Region *region = alloc.getParentRegion();
107   do {
108     if (Operation *parentOp = region->getParentOp()) {
109       // Check if the operation is an automatic allocation scope and whether an
110       // alias leaves the scope. This means, an allocation yields out of
111       // this scope and can not be transformed in a stack-based allocation.
112       if (parentOp->hasTrait<OpTrait::AutomaticAllocationScope>() &&
113           !leavesAllocationScope(region, aliasAnalysis.resolve(alloc)))
114         return true;
115       // Check if the operation is a known control flow interface and break the
116       // loop to avoid transformation in loops. Furthermore skip transformation
117       // if the operation does not implement a RegionBeanchOpInterface.
118       if (BufferPlacementTransformationBase::isLoop(parentOp) ||
119           !isKnownControlFlowInterface(parentOp))
120         break;
121     }
122   } while ((region = region->getParentRegion()));
123   return false;
124 }
125 
126 namespace {
127 
128 //===----------------------------------------------------------------------===//
129 // BufferAllocationHoisting
130 //===----------------------------------------------------------------------===//
131 
132 /// A base implementation compatible with the `BufferAllocationHoisting` class.
133 struct BufferAllocationHoistingStateBase {
134   /// A pointer to the current dominance info.
135   DominanceInfo *dominators;
136 
137   /// The current allocation value.
138   Value allocValue;
139 
140   /// The current placement block (if any).
141   Block *placementBlock;
142 
143   /// Initializes the state base.
144   BufferAllocationHoistingStateBase(DominanceInfo *dominators, Value allocValue,
145                                     Block *placementBlock)
146       : dominators(dominators), allocValue(allocValue),
147         placementBlock(placementBlock) {}
148 };
149 
150 /// Implements the actual hoisting logic for allocation nodes.
151 template <typename StateT>
152 class BufferAllocationHoisting : public BufferPlacementTransformationBase {
153 public:
154   BufferAllocationHoisting(Operation *op)
155       : BufferPlacementTransformationBase(op), dominators(op),
156         postDominators(op), scopeOp(op) {}
157 
158   /// Moves allocations upwards.
159   void hoist() {
160     SmallVector<Value> allocsAndAllocas;
161     for (BufferPlacementAllocs::AllocEntry &entry : allocs)
162       allocsAndAllocas.push_back(std::get<0>(entry));
163     scopeOp->walk([&](memref::AllocaOp op) {
164       allocsAndAllocas.push_back(op.getMemref());
165     });
166 
167     for (auto allocValue : allocsAndAllocas) {
168       if (!StateT::shouldHoistOpType(allocValue.getDefiningOp()))
169         continue;
170       Operation *definingOp = allocValue.getDefiningOp();
171       assert(definingOp && "No defining op");
172       auto operands = definingOp->getOperands();
173       auto resultAliases = aliases.resolve(allocValue);
174       // Determine the common dominator block of all aliases.
175       Block *dominatorBlock =
176           findCommonDominator(allocValue, resultAliases, dominators);
177       // Init the initial hoisting state.
178       StateT state(&dominators, allocValue, allocValue.getParentBlock());
179       // Check for additional allocation dependencies to compute an upper bound
180       // for hoisting.
181       Block *dependencyBlock = nullptr;
182       // If this node has dependencies, check all dependent nodes. This ensures
183       // that all dependency values have been computed before allocating the
184       // buffer.
185       for (Value depValue : operands) {
186         Block *depBlock = depValue.getParentBlock();
187         if (!dependencyBlock || dominators.dominates(dependencyBlock, depBlock))
188           dependencyBlock = depBlock;
189       }
190 
191       // Find the actual placement block and determine the start operation using
192       // an upper placement-block boundary. The idea is that placement block
193       // cannot be moved any further upwards than the given upper bound.
194       Block *placementBlock = findPlacementBlock(
195           state, state.computeUpperBound(dominatorBlock, dependencyBlock));
196       Operation *startOperation = BufferPlacementAllocs::getStartOperation(
197           allocValue, placementBlock, liveness);
198 
199       // Move the alloc in front of the start operation.
200       Operation *allocOperation = allocValue.getDefiningOp();
201       allocOperation->moveBefore(startOperation);
202     }
203   }
204 
205 private:
206   /// Finds a valid placement block by walking upwards in the CFG until we
207   /// either cannot continue our walk due to constraints (given by the StateT
208   /// implementation) or we have reached the upper-most dominator block.
209   Block *findPlacementBlock(StateT &state, Block *upperBound) {
210     Block *currentBlock = state.placementBlock;
211     // Walk from the innermost regions/loops to the outermost regions/loops and
212     // find an appropriate placement block that satisfies the constraint of the
213     // current StateT implementation. Walk until we reach the upperBound block
214     // (if any).
215 
216     // If we are not able to find a valid parent operation or an associated
217     // parent block, break the walk loop.
218     Operation *parentOp;
219     Block *parentBlock;
220     while ((parentOp = currentBlock->getParentOp()) &&
221            (parentBlock = parentOp->getBlock()) &&
222            (!upperBound ||
223             dominators.properlyDominates(upperBound, currentBlock))) {
224       // Try to find an immediate dominator and check whether the parent block
225       // is above the immediate dominator (if any).
226       DominanceInfoNode *idom = nullptr;
227 
228       // DominanceInfo doesn't support getNode queries for single-block regions.
229       if (!currentBlock->isEntryBlock())
230         idom = dominators.getNode(currentBlock)->getIDom();
231 
232       if (idom && dominators.properlyDominates(parentBlock, idom->getBlock())) {
233         // If the current immediate dominator is below the placement block, move
234         // to the immediate dominator block.
235         currentBlock = idom->getBlock();
236         state.recordMoveToDominator(currentBlock);
237       } else {
238         // We have to move to our parent block since an immediate dominator does
239         // either not exist or is above our parent block. If we cannot move to
240         // our parent operation due to constraints given by the StateT
241         // implementation, break the walk loop. Furthermore, we should not move
242         // allocations out of unknown region-based control-flow operations.
243         if (!isKnownControlFlowInterface(parentOp) ||
244             !state.isLegalPlacement(parentOp))
245           break;
246         // Move to our parent block by notifying the current StateT
247         // implementation.
248         currentBlock = parentBlock;
249         state.recordMoveToParent(currentBlock);
250       }
251     }
252     // Return the finally determined placement block.
253     return state.placementBlock;
254   }
255 
256   /// The dominator info to find the appropriate start operation to move the
257   /// allocs.
258   DominanceInfo dominators;
259 
260   /// The post dominator info to move the dependent allocs in the right
261   /// position.
262   PostDominanceInfo postDominators;
263 
264   /// The map storing the final placement blocks of a given alloc value.
265   llvm::DenseMap<Value, Block *> placementBlocks;
266 
267   /// The operation that this transformation is working on. It is used to also
268   /// gather allocas.
269   Operation *scopeOp;
270 };
271 
272 /// A state implementation compatible with the `BufferAllocationHoisting` class
273 /// that hoists allocations into dominator blocks while keeping them inside of
274 /// loops.
275 struct BufferAllocationHoistingState : BufferAllocationHoistingStateBase {
276   using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
277 
278   /// Computes the upper bound for the placement block search.
279   Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
280     // If we do not have a dependency block, the upper bound is given by the
281     // dominator block.
282     if (!dependencyBlock)
283       return dominatorBlock;
284 
285     // Find the "lower" block of the dominator and the dependency block to
286     // ensure that we do not move allocations above this block.
287     return dominators->properlyDominates(dominatorBlock, dependencyBlock)
288                ? dependencyBlock
289                : dominatorBlock;
290   }
291 
292   /// Returns true if the given operation does not represent a loop.
293   bool isLegalPlacement(Operation *op) {
294     return !BufferPlacementTransformationBase::isLoop(op);
295   }
296 
297   /// Returns true if the given operation should be considered for hoisting.
298   static bool shouldHoistOpType(Operation *op) {
299     return allowAllocDominateBlockHoisting(op);
300   }
301 
302   /// Sets the current placement block to the given block.
303   void recordMoveToDominator(Block *block) { placementBlock = block; }
304 
305   /// Sets the current placement block to the given block.
306   void recordMoveToParent(Block *block) { recordMoveToDominator(block); }
307 };
308 
309 /// A state implementation compatible with the `BufferAllocationHoisting` class
310 /// that hoists allocations out of loops.
311 struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
312   using BufferAllocationHoistingStateBase::BufferAllocationHoistingStateBase;
313 
314   /// Remembers the dominator block of all aliases.
315   Block *aliasDominatorBlock = nullptr;
316 
317   /// Computes the upper bound for the placement block search.
318   Block *computeUpperBound(Block *dominatorBlock, Block *dependencyBlock) {
319     aliasDominatorBlock = dominatorBlock;
320     // If there is a dependency block, we have to use this block as an upper
321     // bound to satisfy all allocation value dependencies.
322     return dependencyBlock ? dependencyBlock : nullptr;
323   }
324 
325   /// Returns true if the given operation represents a loop and one of the
326   /// aliases caused the `aliasDominatorBlock` to be "above" the block of the
327   /// given loop operation. If this is the case, it indicates that the
328   /// allocation is passed via a back edge.
329   bool isLegalPlacement(Operation *op) {
330     return BufferPlacementTransformationBase::isLoop(op) &&
331            !dominators->dominates(aliasDominatorBlock, op->getBlock());
332   }
333 
334   /// Returns true if the given operation should be considered for hoisting.
335   static bool shouldHoistOpType(Operation *op) {
336     return allowAllocLoopHoisting(op);
337   }
338 
339   /// Does not change the internal placement block, as we want to move
340   /// operations out of loops only.
341   void recordMoveToDominator(Block *block) {}
342 
343   /// Sets the current placement block to the given block.
344   void recordMoveToParent(Block *block) { placementBlock = block; }
345 };
346 
347 //===----------------------------------------------------------------------===//
348 // BufferPlacementPromotion
349 //===----------------------------------------------------------------------===//
350 
351 /// Promotes heap-based allocations to stack-based allocations (if possible).
352 class BufferPlacementPromotion : BufferPlacementTransformationBase {
353 public:
354   BufferPlacementPromotion(Operation *op)
355       : BufferPlacementTransformationBase(op) {}
356 
357   /// Promote buffers to stack-based allocations.
358   void promote(function_ref<bool(Value)> isSmallAlloc) {
359     for (BufferPlacementAllocs::AllocEntry &entry : allocs) {
360       Value alloc = std::get<0>(entry);
361       Operation *dealloc = std::get<1>(entry);
362       // Checking several requirements to transform an AllocOp into an AllocaOp.
363       // The transformation is done if the allocation is limited to a given
364       // size. Furthermore, a deallocation must not be defined for this
365       // allocation entry and a parent allocation scope must exist.
366       if (!isSmallAlloc(alloc) || dealloc ||
367           !hasAllocationScope(alloc, aliases))
368         continue;
369 
370       Operation *startOperation = BufferPlacementAllocs::getStartOperation(
371           alloc, alloc.getParentBlock(), liveness);
372       // Build a new alloca that is associated with its parent
373       // `AutomaticAllocationScope` determined during the initialization phase.
374       OpBuilder builder(startOperation);
375       Operation *allocOp = alloc.getDefiningOp();
376       if (auto allocInterface = dyn_cast<AllocationOpInterface>(allocOp)) {
377         Operation *alloca =
378             allocInterface.buildPromotedAlloc(builder, alloc).value();
379         if (!alloca)
380           continue;
381         // Replace the original alloc by a newly created alloca.
382         allocOp->replaceAllUsesWith(alloca);
383         allocOp->erase();
384       }
385     }
386   }
387 };
388 
389 //===----------------------------------------------------------------------===//
390 // BufferOptimizationPasses
391 //===----------------------------------------------------------------------===//
392 
393 /// The buffer hoisting pass that hoists allocation nodes into dominating
394 /// blocks.
395 struct BufferHoistingPass
396     : public bufferization::impl::BufferHoistingBase<BufferHoistingPass> {
397 
398   void runOnOperation() override {
399     // Hoist all allocations into dominator blocks.
400     BufferAllocationHoisting<BufferAllocationHoistingState> optimizer(
401         getOperation());
402     optimizer.hoist();
403   }
404 };
405 
406 /// The buffer loop hoisting pass that hoists allocation nodes out of loops.
407 struct BufferLoopHoistingPass
408     : public bufferization::impl::BufferLoopHoistingBase<
409           BufferLoopHoistingPass> {
410 
411   void runOnOperation() override {
412     // Hoist all allocations out of loops.
413     hoistBuffersFromLoops(getOperation());
414   }
415 };
416 
417 /// The promote buffer to stack pass that tries to convert alloc nodes into
418 /// alloca nodes.
419 class PromoteBuffersToStackPass
420     : public bufferization::impl::PromoteBuffersToStackBase<
421           PromoteBuffersToStackPass> {
422 public:
423   PromoteBuffersToStackPass(unsigned maxAllocSizeInBytes,
424                             unsigned maxRankOfAllocatedMemRef) {
425     this->maxAllocSizeInBytes = maxAllocSizeInBytes;
426     this->maxRankOfAllocatedMemRef = maxRankOfAllocatedMemRef;
427   }
428 
429   explicit PromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc)
430       : isSmallAlloc(std::move(isSmallAlloc)) {}
431 
432   LogicalResult initialize(MLIRContext *context) override {
433     if (isSmallAlloc == nullptr) {
434       isSmallAlloc = [=](Value alloc) {
435         return defaultIsSmallAlloc(alloc, maxAllocSizeInBytes,
436                                    maxRankOfAllocatedMemRef);
437       };
438     }
439     return success();
440   }
441 
442   void runOnOperation() override {
443     // Move all allocation nodes and convert candidates into allocas.
444     BufferPlacementPromotion optimizer(getOperation());
445     optimizer.promote(isSmallAlloc);
446   }
447 
448 private:
449   std::function<bool(Value)> isSmallAlloc;
450 };
451 
452 } // namespace
453 
454 void mlir::bufferization::hoistBuffersFromLoops(Operation *op) {
455   BufferAllocationHoisting<BufferAllocationLoopHoistingState> optimizer(op);
456   optimizer.hoist();
457 }
458 
459 std::unique_ptr<Pass> mlir::bufferization::createBufferHoistingPass() {
460   return std::make_unique<BufferHoistingPass>();
461 }
462 
463 std::unique_ptr<Pass> mlir::bufferization::createBufferLoopHoistingPass() {
464   return std::make_unique<BufferLoopHoistingPass>();
465 }
466 
467 std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
468     unsigned maxAllocSizeInBytes, unsigned maxRankOfAllocatedMemRef) {
469   return std::make_unique<PromoteBuffersToStackPass>(maxAllocSizeInBytes,
470                                                      maxRankOfAllocatedMemRef);
471 }
472 
473 std::unique_ptr<Pass> mlir::bufferization::createPromoteBuffersToStackPass(
474     std::function<bool(Value)> isSmallAlloc) {
475   return std::make_unique<PromoteBuffersToStackPass>(std::move(isSmallAlloc));
476 }
477