xref: /llvm-project/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp (revision a5aa783685c10f1326a6cb0bb93ebab0c5a3e78d)
1 //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Optimize Async dialect reference counting operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Async/Passes.h"
14 
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/Support/Debug.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
22 #include "mlir/Dialect/Async/Passes.h.inc"
23 } // namespace mlir
24 
25 #define DEBUG_TYPE "async-ref-counting"
26 
27 using namespace mlir;
28 using namespace mlir::async;
29 
30 namespace {
31 
32 class AsyncRuntimeRefCountingOptPass
33     : public impl::AsyncRuntimeRefCountingOptBase<
34           AsyncRuntimeRefCountingOptPass> {
35 public:
36   AsyncRuntimeRefCountingOptPass() = default;
37   void runOnOperation() override;
38 
39 private:
40   LogicalResult optimizeReferenceCounting(
41       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
42 };
43 
44 } // namespace
45 
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)46 LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
47     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
48   Region *definingRegion = value.getParentRegion();
49 
50   // Find all users of the `value` inside each block, including operations that
51   // do not use `value` directly, but have a direct use inside nested region(s).
52   //
53   // Example:
54   //
55   //  ^bb1:
56   //    %token = ...
57   //    scf.if %cond {
58   //      ^bb2:
59   //      async.runtime.await %token : !async.token
60   //    }
61   //
62   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
63   // (`scf.if`).
64 
65   struct BlockUsersInfo {
66     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
67     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
68     llvm::SmallVector<Operation *, 4> users;
69   };
70 
71   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72 
73   auto updateBlockUsersInfo = [&](Operation *user) {
74     BlockUsersInfo &info = blockUsers[user->getBlock()];
75     info.users.push_back(user);
76 
77     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
78       info.addRefs.push_back(addRef);
79     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
80       info.dropRefs.push_back(dropRef);
81   };
82 
83   for (Operation *user : value.getUsers()) {
84     while (user->getParentRegion() != definingRegion) {
85       updateBlockUsersInfo(user);
86       user = user->getParentOp();
87       assert(user != nullptr && "value user lies outside of the value region");
88     }
89 
90     updateBlockUsersInfo(user);
91   }
92 
93   // Sort all operations found in the block.
94   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
95     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
96       return a->isBeforeInBlock(b);
97     };
98     llvm::sort(info.addRefs, isBeforeInBlock);
99     llvm::sort(info.dropRefs, isBeforeInBlock);
100     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
101       return isBeforeInBlock(a, b);
102     });
103 
104     return info;
105   };
106 
107   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
108   // blocks that modify the reference count of the `value`.
109   for (auto &kv : blockUsers) {
110     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111 
112     for (RuntimeAddRefOp addRef : info.addRefs) {
113       for (RuntimeDropRefOp dropRef : info.dropRefs) {
114         // `drop_ref` operation after the `add_ref` with matching count.
115         if (dropRef.getCount() != addRef.getCount() ||
116             dropRef->isBeforeInBlock(addRef.getOperation()))
117           continue;
118 
119         // When reference counted value passed to a function as an argument,
120         // function takes ownership of +1 reference and it will drop it before
121         // returning.
122         //
123         // Example:
124         //
125         //   %token = ... : !async.token
126         //
127         //   async.runtime.add_ref %token {count = 1 : i64} : !async.token
128         //   call @pass_token(%token: !async.token, ...)
129         //
130         //   async.await %token : !async.token
131         //   async.runtime.drop_ref %token {count = 1 : i64} : !async.token
132         //
133         // In this example if we'll cancel a pair of reference counting
134         // operations we might end up with a deallocated token when we'll
135         // reach `async.await` operation.
136         Operation *firstFunctionCallUser = nullptr;
137         Operation *lastNonFunctionCallUser = nullptr;
138 
139         for (Operation *user : info.users) {
140           // `user` operation lies after `addRef` ...
141           if (user == addRef || user->isBeforeInBlock(addRef))
142             continue;
143           // ... and before `dropRef`.
144           if (user == dropRef || dropRef->isBeforeInBlock(user))
145             break;
146 
147           // Find the first function call user of the reference counted value.
148           Operation *functionCall = dyn_cast<func::CallOp>(user);
149           if (functionCall &&
150               (!firstFunctionCallUser ||
151                functionCall->isBeforeInBlock(firstFunctionCallUser))) {
152             firstFunctionCallUser = functionCall;
153             continue;
154           }
155 
156           // Find the last regular user of the reference counted value.
157           if (!functionCall &&
158               (!lastNonFunctionCallUser ||
159                lastNonFunctionCallUser->isBeforeInBlock(user))) {
160             lastNonFunctionCallUser = user;
161             continue;
162           }
163         }
164 
165         // Non function call user after the function call user of the reference
166         // counted value.
167         if (firstFunctionCallUser && lastNonFunctionCallUser &&
168             firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
169           continue;
170 
171         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
172         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
173                                                 addRef.getOperation());
174 
175         if (!emplaced.second) // `drop_ref` was already marked for removal
176           continue;           // go to the next `drop_ref`
177 
178         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
179           break;             // go to the next `add_ref`
180       }
181     }
182   }
183 
184   return success();
185 }
186 
runOnOperation()187 void AsyncRuntimeRefCountingOptPass::runOnOperation() {
188   Operation *op = getOperation();
189 
190   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
191   //
192   // Find all cancellable pairs of operation and erase them in the end to keep
193   // all iterators valid while we are walking the function operations.
194   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
195 
196   // Optimize reference counting for values defined by block arguments.
197   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
198     for (BlockArgument arg : block->getArguments())
199       if (isRefCounted(arg.getType()))
200         if (failed(optimizeReferenceCounting(arg, cancellable)))
201           return WalkResult::interrupt();
202 
203     return WalkResult::advance();
204   });
205 
206   if (blockWalk.wasInterrupted())
207     signalPassFailure();
208 
209   // Optimize reference counting for values defined by operation results.
210   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
211     for (unsigned i = 0; i < op->getNumResults(); ++i)
212       if (isRefCounted(op->getResultTypes()[i]))
213         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
214           return WalkResult::interrupt();
215 
216     return WalkResult::advance();
217   });
218 
219   if (opWalk.wasInterrupted())
220     signalPassFailure();
221 
222   LLVM_DEBUG({
223     llvm::dbgs() << "Found " << cancellable.size()
224                  << " cancellable reference counting operations\n";
225   });
226 
227   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
228   for (auto &kv : cancellable) {
229     kv.first->erase();
230     kv.second->erase();
231   }
232 }
233 
createAsyncRuntimeRefCountingOptPass()234 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
235   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
236 }
237