xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp (revision 6de04e6fe8b1520ef3e4073ff222e623b7dc9cb9)
1 //===- OptimizeAllocationLiveness.cpp - impl. optimize allocation liveness pass
2 //-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a pass for optimizing allocation liveness.
11 // The pass moves the deallocation operation after the last user of the
12 // allocated buffer.
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
16 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/Operation.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "optimize-allocation-liveness"
23 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
24 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
25 
26 namespace mlir {
27 namespace bufferization {
28 #define GEN_PASS_DEF_OPTIMIZEALLOCATIONLIVENESS
29 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
30 } // namespace bufferization
31 } // namespace mlir
32 
33 using namespace mlir;
34 
35 namespace {
36 
37 //===----------------------------------------------------------------------===//
38 // Helper functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
42 /// properly dominates `b` and `b` is not inside `a`.
43 static bool happensBefore(Operation *a, Operation *b) {
44   do {
45     if (a->isProperAncestor(b))
46       return false;
47     if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
48       return a->isBeforeInBlock(bAncestor);
49     }
50   } while ((a = a->getParentOp()));
51   return false;
52 }
53 
54 /// This method searches for a user of value that is a dealloc operation.
55 /// If multiple users with free effect are found, return nullptr.
56 Operation *findUserWithFreeSideEffect(Value value) {
57   Operation *freeOpUser = nullptr;
58   for (Operation *user : value.getUsers()) {
59     if (MemoryEffectOpInterface memEffectOp =
60             dyn_cast<MemoryEffectOpInterface>(user)) {
61       SmallVector<MemoryEffects::EffectInstance, 2> effects;
62       memEffectOp.getEffects(effects);
63 
64       for (const auto &effect : effects) {
65         if (isa<MemoryEffects::Free>(effect.getEffect())) {
66           if (freeOpUser) {
67             LDBG("Multiple users with free effect found: " << *freeOpUser
68                                                            << " and " << *user);
69             return nullptr;
70           }
71           freeOpUser = user;
72         }
73       }
74     }
75   }
76   return freeOpUser;
77 }
78 
79 /// Checks if the given op allocates memory.
80 static bool hasMemoryAllocEffect(MemoryEffectOpInterface memEffectOp) {
81   SmallVector<MemoryEffects::EffectInstance, 2> effects;
82   memEffectOp.getEffects(effects);
83   for (const auto &effect : effects) {
84     if (isa<MemoryEffects::Allocate>(effect.getEffect())) {
85       return true;
86     }
87   }
88   return false;
89 }
90 
91 struct OptimizeAllocationLiveness
92     : public bufferization::impl::OptimizeAllocationLivenessBase<
93           OptimizeAllocationLiveness> {
94 public:
95   OptimizeAllocationLiveness() = default;
96 
97   void runOnOperation() override {
98     func::FuncOp func = getOperation();
99 
100     if (func.isExternal())
101       return;
102 
103     BufferViewFlowAnalysis analysis = BufferViewFlowAnalysis(func);
104 
105     func.walk([&](MemoryEffectOpInterface memEffectOp) -> WalkResult {
106       if (!hasMemoryAllocEffect(memEffectOp))
107         return WalkResult::advance();
108 
109       auto allocOp = memEffectOp;
110       LDBG("Checking alloc op: " << allocOp);
111 
112       auto deallocOp = findUserWithFreeSideEffect(allocOp->getResult(0));
113       if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) {
114         // The pass handles allocations that have a single dealloc op in the
115         // same block. We also should not hoist the dealloc op out of
116         // conditionals.
117         return WalkResult::advance();
118       }
119 
120       Operation *lastUser = nullptr;
121       const BufferViewFlowAnalysis::ValueSetT &deps =
122           analysis.resolve(allocOp->getResult(0));
123       for (auto dep : llvm::make_early_inc_range(deps)) {
124         for (auto user : dep.getUsers()) {
125           // We are looking for a non dealloc op user.
126           // check if user is the dealloc op itself.
127           if (user == deallocOp)
128             continue;
129 
130           // find the ancestor of user that is in the same block as the allocOp.
131           auto topUser = allocOp->getBlock()->findAncestorOpInBlock(*user);
132           if (!lastUser || happensBefore(lastUser, topUser)) {
133             lastUser = topUser;
134           }
135         }
136       }
137       if (lastUser == nullptr) {
138         return WalkResult::advance();
139       }
140       LDBG("Last user found: " << *lastUser);
141       assert(lastUser->getBlock() == allocOp->getBlock());
142       assert(lastUser->getBlock() == deallocOp->getBlock());
143       // Move the dealloc op after the last user.
144       deallocOp->moveAfter(lastUser);
145       LDBG("Moved dealloc op after: " << *lastUser);
146 
147       return WalkResult::advance();
148     });
149   }
150 };
151 
152 } // end anonymous namespace
153 
154 //===----------------------------------------------------------------------===//
155 // OptimizeAllocatinliveness construction
156 //===----------------------------------------------------------------------===//
157 
158 std::unique_ptr<Pass>
159 mlir::bufferization::createOptimizeAllocationLivenessPass() {
160   return std::make_unique<OptimizeAllocationLiveness>();
161 }
162