1 //===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass 2 //-===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements a pass for optimizing allocation liveness. 11 // The pass moves the deallocation operation after the last user of the 12 // allocated buffer. 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" 16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/Operation.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "optimize-allocation-liveness" 23 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 24 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 25 26 namespace mlir { 27 namespace bufferization { 28 #define GEN_PASS_DEF_OPTIMIZEALLOCATIONLIVENESS 29 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 30 } // namespace bufferization 31 } // namespace mlir 32 33 using namespace mlir; 34 35 namespace { 36 37 //===----------------------------------------------------------------------===// 38 // Helper functions 39 //===----------------------------------------------------------------------===// 40 41 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 42 /// properly dominates `b` and `b` is not inside `a`. 43 static bool happensBefore(Operation *a, Operation *b) { 44 do { 45 if (a->isProperAncestor(b)) 46 return false; 47 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { 48 return a->isBeforeInBlock(bAncestor); 49 } 50 } while ((a = a->getParentOp())); 51 return false; 52 } 53 54 /// This method searches for a user of value that is a dealloc operation. 55 /// If multiple users with free effect are found, return nullptr. 56 Operation *findUserWithFreeSideEffect(Value value) { 57 Operation *freeOpUser = nullptr; 58 for (Operation *user : value.getUsers()) { 59 if (MemoryEffectOpInterface memEffectOp = 60 dyn_cast<MemoryEffectOpInterface>(user)) { 61 SmallVector<MemoryEffects::EffectInstance, 2> effects; 62 memEffectOp.getEffects(effects); 63 64 for (const auto &effect : effects) { 65 if (isa<MemoryEffects::Free>(effect.getEffect())) { 66 if (freeOpUser) { 67 LDBG("Multiple users with free effect found: " << *freeOpUser 68 << " and " << *user); 69 return nullptr; 70 } 71 freeOpUser = user; 72 } 73 } 74 } 75 } 76 return freeOpUser; 77 } 78 79 /// Checks if the given op allocates memory. 80 static bool hasMemoryAllocEffect(MemoryEffectOpInterface memEffectOp) { 81 SmallVector<MemoryEffects::EffectInstance, 2> effects; 82 memEffectOp.getEffects(effects); 83 for (const auto &effect : effects) { 84 if (isa<MemoryEffects::Allocate>(effect.getEffect())) { 85 return true; 86 } 87 } 88 return false; 89 } 90 91 struct OptimizeAllocationLiveness 92 : public bufferization::impl::OptimizeAllocationLivenessBase< 93 OptimizeAllocationLiveness> { 94 public: 95 OptimizeAllocationLiveness() = default; 96 97 void runOnOperation() override { 98 func::FuncOp func = getOperation(); 99 100 if (func.isExternal()) 101 return; 102 103 BufferViewFlowAnalysis analysis = BufferViewFlowAnalysis(func); 104 105 func.walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult { 106 if (!hasMemoryAllocEffect(memEffectOp)) 107 return WalkResult::advance(); 108 109 auto allocOp = memEffectOp; 110 LDBG("Checking alloc op: " << allocOp); 111 112 auto deallocOp = findUserWithFreeSideEffect(allocOp->getResult(0)); 113 if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) { 114 // The pass handles allocations that have a single dealloc op in the 115 // same block. We also should not hoist the dealloc op out of 116 // conditionals. 117 return WalkResult::advance(); 118 } 119 120 Operation *lastUser = nullptr; 121 const BufferViewFlowAnalysis::ValueSetT &deps = 122 analysis.resolve(allocOp->getResult(0)); 123 for (auto dep : llvm::make_early_inc_range(deps)) { 124 for (auto user : dep.getUsers()) { 125 // We are looking for a non dealloc op user. 126 // check if user is the dealloc op itself. 127 if (user == deallocOp) 128 continue; 129 130 // find the ancestor of user that is in the same block as the allocOp. 131 auto topUser = allocOp->getBlock()->findAncestorOpInBlock(*user); 132 if (!lastUser || happensBefore(lastUser, topUser)) { 133 lastUser = topUser; 134 } 135 } 136 } 137 if (lastUser == nullptr) { 138 return WalkResult::advance(); 139 } 140 LDBG("Last user found: " << *lastUser); 141 assert(lastUser->getBlock() == allocOp->getBlock()); 142 assert(lastUser->getBlock() == deallocOp->getBlock()); 143 // Move the dealloc op after the last user. 144 deallocOp->moveAfter(lastUser); 145 LDBG("Moved dealloc op after: " << *lastUser); 146 147 return WalkResult::advance(); 148 }); 149 } 150 }; 151 152 } // end anonymous namespace 153 154 //===----------------------------------------------------------------------===// 155 // OptimizeAllocatinliveness construction 156 //===----------------------------------------------------------------------===// 157 158 std::unique_ptr<Pass> 159 mlir::bufferization::createOptimizeAllocationLivenessPass() { 160 return std::make_unique<OptimizeAllocationLiveness>(); 161 } 162