xref: /llvm-project/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCountingOpt.cpp (revision a5aa783685c10f1326a6cb0bb93ebab0c5a3e78d)
1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations.
10a6628e59SEugene Zhulenev //
11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
12a6628e59SEugene Zhulenev 
13039b969bSMichele Scuttari #include "mlir/Dialect/Async/Passes.h"
1467d0d7acSMichele Scuttari 
1567d0d7acSMichele Scuttari #include "mlir/Dialect/Async/IR/Async.h"
1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
17a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
18297a5b7cSNico Weber #include "llvm/Support/Debug.h"
19a6628e59SEugene Zhulenev 
2067d0d7acSMichele Scuttari namespace mlir {
2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
2267d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h.inc"
2367d0d7acSMichele Scuttari } // namespace mlir
242be8af8fSMichele Scuttari 
25039b969bSMichele Scuttari #define DEBUG_TYPE "async-ref-counting"
26039b969bSMichele Scuttari 
2767d0d7acSMichele Scuttari using namespace mlir;
2867d0d7acSMichele Scuttari using namespace mlir::async;
2967d0d7acSMichele Scuttari 
30a6628e59SEugene Zhulenev namespace {
31a6628e59SEugene Zhulenev 
32a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass
3367d0d7acSMichele Scuttari     : public impl::AsyncRuntimeRefCountingOptBase<
3467d0d7acSMichele Scuttari           AsyncRuntimeRefCountingOptPass> {
35a6628e59SEugene Zhulenev public:
36a6628e59SEugene Zhulenev   AsyncRuntimeRefCountingOptPass() = default;
378a316b00SEugene Zhulenev   void runOnOperation() override;
38a6628e59SEugene Zhulenev 
39a6628e59SEugene Zhulenev private:
40a6628e59SEugene Zhulenev   LogicalResult optimizeReferenceCounting(
41a6628e59SEugene Zhulenev       Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
42a6628e59SEugene Zhulenev };
43a6628e59SEugene Zhulenev 
44a6628e59SEugene Zhulenev } // namespace
45a6628e59SEugene Zhulenev 
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)46a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
47a6628e59SEugene Zhulenev     Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
48a6628e59SEugene Zhulenev   Region *definingRegion = value.getParentRegion();
49a6628e59SEugene Zhulenev 
50a6628e59SEugene Zhulenev   // Find all users of the `value` inside each block, including operations that
51a6628e59SEugene Zhulenev   // do not use `value` directly, but have a direct use inside nested region(s).
52a6628e59SEugene Zhulenev   //
53a6628e59SEugene Zhulenev   // Example:
54a6628e59SEugene Zhulenev   //
55a6628e59SEugene Zhulenev   //  ^bb1:
56a6628e59SEugene Zhulenev   //    %token = ...
57a6628e59SEugene Zhulenev   //    scf.if %cond {
58a6628e59SEugene Zhulenev   //      ^bb2:
59a6628e59SEugene Zhulenev   //      async.runtime.await %token : !async.token
60a6628e59SEugene Zhulenev   //    }
61a6628e59SEugene Zhulenev   //
62a6628e59SEugene Zhulenev   // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
63a6628e59SEugene Zhulenev   // (`scf.if`).
64a6628e59SEugene Zhulenev 
65a6628e59SEugene Zhulenev   struct BlockUsersInfo {
66a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
67a6628e59SEugene Zhulenev     llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
68a6628e59SEugene Zhulenev     llvm::SmallVector<Operation *, 4> users;
69a6628e59SEugene Zhulenev   };
70a6628e59SEugene Zhulenev 
71a6628e59SEugene Zhulenev   llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72a6628e59SEugene Zhulenev 
73a6628e59SEugene Zhulenev   auto updateBlockUsersInfo = [&](Operation *user) {
74a6628e59SEugene Zhulenev     BlockUsersInfo &info = blockUsers[user->getBlock()];
75a6628e59SEugene Zhulenev     info.users.push_back(user);
76a6628e59SEugene Zhulenev 
77a6628e59SEugene Zhulenev     if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
78a6628e59SEugene Zhulenev       info.addRefs.push_back(addRef);
79a6628e59SEugene Zhulenev     if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
80a6628e59SEugene Zhulenev       info.dropRefs.push_back(dropRef);
81a6628e59SEugene Zhulenev   };
82a6628e59SEugene Zhulenev 
83a6628e59SEugene Zhulenev   for (Operation *user : value.getUsers()) {
84a6628e59SEugene Zhulenev     while (user->getParentRegion() != definingRegion) {
85a6628e59SEugene Zhulenev       updateBlockUsersInfo(user);
86a6628e59SEugene Zhulenev       user = user->getParentOp();
87a6628e59SEugene Zhulenev       assert(user != nullptr && "value user lies outside of the value region");
88a6628e59SEugene Zhulenev     }
89a6628e59SEugene Zhulenev 
90a6628e59SEugene Zhulenev     updateBlockUsersInfo(user);
91a6628e59SEugene Zhulenev   }
92a6628e59SEugene Zhulenev 
93a6628e59SEugene Zhulenev   // Sort all operations found in the block.
94a6628e59SEugene Zhulenev   auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
95a6628e59SEugene Zhulenev     auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
96a6628e59SEugene Zhulenev       return a->isBeforeInBlock(b);
97a6628e59SEugene Zhulenev     };
98a6628e59SEugene Zhulenev     llvm::sort(info.addRefs, isBeforeInBlock);
99a6628e59SEugene Zhulenev     llvm::sort(info.dropRefs, isBeforeInBlock);
100a6628e59SEugene Zhulenev     llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
101a6628e59SEugene Zhulenev       return isBeforeInBlock(a, b);
102a6628e59SEugene Zhulenev     });
103a6628e59SEugene Zhulenev 
104a6628e59SEugene Zhulenev     return info;
105a6628e59SEugene Zhulenev   };
106a6628e59SEugene Zhulenev 
107a6628e59SEugene Zhulenev   // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
108a6628e59SEugene Zhulenev   // blocks that modify the reference count of the `value`.
109a6628e59SEugene Zhulenev   for (auto &kv : blockUsers) {
110a6628e59SEugene Zhulenev     BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111a6628e59SEugene Zhulenev 
112a6628e59SEugene Zhulenev     for (RuntimeAddRefOp addRef : info.addRefs) {
113a6628e59SEugene Zhulenev       for (RuntimeDropRefOp dropRef : info.dropRefs) {
114a6628e59SEugene Zhulenev         // `drop_ref` operation after the `add_ref` with matching count.
115*a5aa7836SRiver Riddle         if (dropRef.getCount() != addRef.getCount() ||
116a6628e59SEugene Zhulenev             dropRef->isBeforeInBlock(addRef.getOperation()))
117a6628e59SEugene Zhulenev           continue;
118a6628e59SEugene Zhulenev 
1199ccdaac8SEugene Zhulenev         // When reference counted value passed to a function as an argument,
1209ccdaac8SEugene Zhulenev         // function takes ownership of +1 reference and it will drop it before
1219ccdaac8SEugene Zhulenev         // returning.
1229ccdaac8SEugene Zhulenev         //
1239ccdaac8SEugene Zhulenev         // Example:
1249ccdaac8SEugene Zhulenev         //
1259ccdaac8SEugene Zhulenev         //   %token = ... : !async.token
1269ccdaac8SEugene Zhulenev         //
12792db09cdSEugene Zhulenev         //   async.runtime.add_ref %token {count = 1 : i64} : !async.token
1289ccdaac8SEugene Zhulenev         //   call @pass_token(%token: !async.token, ...)
1299ccdaac8SEugene Zhulenev         //
1309ccdaac8SEugene Zhulenev         //   async.await %token : !async.token
13192db09cdSEugene Zhulenev         //   async.runtime.drop_ref %token {count = 1 : i64} : !async.token
1329ccdaac8SEugene Zhulenev         //
1339ccdaac8SEugene Zhulenev         // In this example if we'll cancel a pair of reference counting
1349ccdaac8SEugene Zhulenev         // operations we might end up with a deallocated token when we'll
1359ccdaac8SEugene Zhulenev         // reach `async.await` operation.
1369ccdaac8SEugene Zhulenev         Operation *firstFunctionCallUser = nullptr;
1379ccdaac8SEugene Zhulenev         Operation *lastNonFunctionCallUser = nullptr;
1389ccdaac8SEugene Zhulenev 
1399ccdaac8SEugene Zhulenev         for (Operation *user : info.users) {
1409ccdaac8SEugene Zhulenev           // `user` operation lies after `addRef` ...
1419ccdaac8SEugene Zhulenev           if (user == addRef || user->isBeforeInBlock(addRef))
1429ccdaac8SEugene Zhulenev             continue;
1439ccdaac8SEugene Zhulenev           // ... and before `dropRef`.
1449ccdaac8SEugene Zhulenev           if (user == dropRef || dropRef->isBeforeInBlock(user))
1459ccdaac8SEugene Zhulenev             break;
1469ccdaac8SEugene Zhulenev 
1479ccdaac8SEugene Zhulenev           // Find the first function call user of the reference counted value.
14823aa5a74SRiver Riddle           Operation *functionCall = dyn_cast<func::CallOp>(user);
1499ccdaac8SEugene Zhulenev           if (functionCall &&
1509ccdaac8SEugene Zhulenev               (!firstFunctionCallUser ||
1519ccdaac8SEugene Zhulenev                functionCall->isBeforeInBlock(firstFunctionCallUser))) {
1529ccdaac8SEugene Zhulenev             firstFunctionCallUser = functionCall;
1539ccdaac8SEugene Zhulenev             continue;
1549ccdaac8SEugene Zhulenev           }
1559ccdaac8SEugene Zhulenev 
1569ccdaac8SEugene Zhulenev           // Find the last regular user of the reference counted value.
1579ccdaac8SEugene Zhulenev           if (!functionCall &&
1589ccdaac8SEugene Zhulenev               (!lastNonFunctionCallUser ||
1599ccdaac8SEugene Zhulenev                lastNonFunctionCallUser->isBeforeInBlock(user))) {
1609ccdaac8SEugene Zhulenev             lastNonFunctionCallUser = user;
1619ccdaac8SEugene Zhulenev             continue;
1629ccdaac8SEugene Zhulenev           }
1639ccdaac8SEugene Zhulenev         }
1649ccdaac8SEugene Zhulenev 
1659ccdaac8SEugene Zhulenev         // Non function call user after the function call user of the reference
1669ccdaac8SEugene Zhulenev         // counted value.
1679ccdaac8SEugene Zhulenev         if (firstFunctionCallUser && lastNonFunctionCallUser &&
1689ccdaac8SEugene Zhulenev             firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
1699ccdaac8SEugene Zhulenev           continue;
1709ccdaac8SEugene Zhulenev 
171a6628e59SEugene Zhulenev         // Try to cancel the pair of `add_ref` and `drop_ref` operations.
172a6628e59SEugene Zhulenev         auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
173a6628e59SEugene Zhulenev                                                 addRef.getOperation());
174a6628e59SEugene Zhulenev 
175a6628e59SEugene Zhulenev         if (!emplaced.second) // `drop_ref` was already marked for removal
176a6628e59SEugene Zhulenev           continue;           // go to the next `drop_ref`
177a6628e59SEugene Zhulenev 
178a6628e59SEugene Zhulenev         if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
179a6628e59SEugene Zhulenev           break;             // go to the next `add_ref`
180a6628e59SEugene Zhulenev       }
181a6628e59SEugene Zhulenev     }
182a6628e59SEugene Zhulenev   }
183a6628e59SEugene Zhulenev 
184a6628e59SEugene Zhulenev   return success();
185a6628e59SEugene Zhulenev }
186a6628e59SEugene Zhulenev 
runOnOperation()1878a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() {
1888a316b00SEugene Zhulenev   Operation *op = getOperation();
189a6628e59SEugene Zhulenev 
190a6628e59SEugene Zhulenev   // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
191a6628e59SEugene Zhulenev   //
192a6628e59SEugene Zhulenev   // Find all cancellable pairs of operation and erase them in the end to keep
193a6628e59SEugene Zhulenev   // all iterators valid while we are walking the function operations.
194a6628e59SEugene Zhulenev   llvm::SmallDenseMap<Operation *, Operation *> cancellable;
195a6628e59SEugene Zhulenev 
196a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by block arguments.
1978a316b00SEugene Zhulenev   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
198a6628e59SEugene Zhulenev     for (BlockArgument arg : block->getArguments())
199a6628e59SEugene Zhulenev       if (isRefCounted(arg.getType()))
200a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(arg, cancellable)))
201a6628e59SEugene Zhulenev           return WalkResult::interrupt();
202a6628e59SEugene Zhulenev 
203a6628e59SEugene Zhulenev     return WalkResult::advance();
204a6628e59SEugene Zhulenev   });
205a6628e59SEugene Zhulenev 
206a6628e59SEugene Zhulenev   if (blockWalk.wasInterrupted())
207a6628e59SEugene Zhulenev     signalPassFailure();
208a6628e59SEugene Zhulenev 
209a6628e59SEugene Zhulenev   // Optimize reference counting for values defined by operation results.
2108a316b00SEugene Zhulenev   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
211a6628e59SEugene Zhulenev     for (unsigned i = 0; i < op->getNumResults(); ++i)
212a6628e59SEugene Zhulenev       if (isRefCounted(op->getResultTypes()[i]))
213a6628e59SEugene Zhulenev         if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
214a6628e59SEugene Zhulenev           return WalkResult::interrupt();
215a6628e59SEugene Zhulenev 
216a6628e59SEugene Zhulenev     return WalkResult::advance();
217a6628e59SEugene Zhulenev   });
218a6628e59SEugene Zhulenev 
219a6628e59SEugene Zhulenev   if (opWalk.wasInterrupted())
220a6628e59SEugene Zhulenev     signalPassFailure();
221a6628e59SEugene Zhulenev 
222a6628e59SEugene Zhulenev   LLVM_DEBUG({
223a6628e59SEugene Zhulenev     llvm::dbgs() << "Found " << cancellable.size()
224a6628e59SEugene Zhulenev                  << " cancellable reference counting operations\n";
225a6628e59SEugene Zhulenev   });
226a6628e59SEugene Zhulenev 
227a6628e59SEugene Zhulenev   // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
228a6628e59SEugene Zhulenev   for (auto &kv : cancellable) {
229a6628e59SEugene Zhulenev     kv.first->erase();
230a6628e59SEugene Zhulenev     kv.second->erase();
231a6628e59SEugene Zhulenev   }
232a6628e59SEugene Zhulenev }
233a6628e59SEugene Zhulenev 
createAsyncRuntimeRefCountingOptPass()2348a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
235a6628e59SEugene Zhulenev   return std::make_unique<AsyncRuntimeRefCountingOptPass>();
236a6628e59SEugene Zhulenev }
237