1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCountingOpt.cpp - Async Ref Counting --------------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // Optimize Async dialect reference counting operations.
10a6628e59SEugene Zhulenev //
11a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
12a6628e59SEugene Zhulenev
13039b969bSMichele Scuttari #include "mlir/Dialect/Async/Passes.h"
1467d0d7acSMichele Scuttari
1567d0d7acSMichele Scuttari #include "mlir/Dialect/Async/IR/Async.h"
1623aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
17a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
18297a5b7cSNico Weber #include "llvm/Support/Debug.h"
19a6628e59SEugene Zhulenev
2067d0d7acSMichele Scuttari namespace mlir {
2167d0d7acSMichele Scuttari #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTINGOPT
2267d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h.inc"
2367d0d7acSMichele Scuttari } // namespace mlir
242be8af8fSMichele Scuttari
25039b969bSMichele Scuttari #define DEBUG_TYPE "async-ref-counting"
26039b969bSMichele Scuttari
2767d0d7acSMichele Scuttari using namespace mlir;
2867d0d7acSMichele Scuttari using namespace mlir::async;
2967d0d7acSMichele Scuttari
30a6628e59SEugene Zhulenev namespace {
31a6628e59SEugene Zhulenev
32a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingOptPass
3367d0d7acSMichele Scuttari : public impl::AsyncRuntimeRefCountingOptBase<
3467d0d7acSMichele Scuttari AsyncRuntimeRefCountingOptPass> {
35a6628e59SEugene Zhulenev public:
36a6628e59SEugene Zhulenev AsyncRuntimeRefCountingOptPass() = default;
378a316b00SEugene Zhulenev void runOnOperation() override;
38a6628e59SEugene Zhulenev
39a6628e59SEugene Zhulenev private:
40a6628e59SEugene Zhulenev LogicalResult optimizeReferenceCounting(
41a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable);
42a6628e59SEugene Zhulenev };
43a6628e59SEugene Zhulenev
44a6628e59SEugene Zhulenev } // namespace
45a6628e59SEugene Zhulenev
optimizeReferenceCounting(Value value,llvm::SmallDenseMap<Operation *,Operation * > & cancellable)46a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingOptPass::optimizeReferenceCounting(
47a6628e59SEugene Zhulenev Value value, llvm::SmallDenseMap<Operation *, Operation *> &cancellable) {
48a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion();
49a6628e59SEugene Zhulenev
50a6628e59SEugene Zhulenev // Find all users of the `value` inside each block, including operations that
51a6628e59SEugene Zhulenev // do not use `value` directly, but have a direct use inside nested region(s).
52a6628e59SEugene Zhulenev //
53a6628e59SEugene Zhulenev // Example:
54a6628e59SEugene Zhulenev //
55a6628e59SEugene Zhulenev // ^bb1:
56a6628e59SEugene Zhulenev // %token = ...
57a6628e59SEugene Zhulenev // scf.if %cond {
58a6628e59SEugene Zhulenev // ^bb2:
59a6628e59SEugene Zhulenev // async.runtime.await %token : !async.token
60a6628e59SEugene Zhulenev // }
61a6628e59SEugene Zhulenev //
62a6628e59SEugene Zhulenev // %token has a use inside ^bb2 (`async.runtime.await`) and inside ^bb1
63a6628e59SEugene Zhulenev // (`scf.if`).
64a6628e59SEugene Zhulenev
65a6628e59SEugene Zhulenev struct BlockUsersInfo {
66a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeAddRefOp, 4> addRefs;
67a6628e59SEugene Zhulenev llvm::SmallVector<RuntimeDropRefOp, 4> dropRefs;
68a6628e59SEugene Zhulenev llvm::SmallVector<Operation *, 4> users;
69a6628e59SEugene Zhulenev };
70a6628e59SEugene Zhulenev
71a6628e59SEugene Zhulenev llvm::DenseMap<Block *, BlockUsersInfo> blockUsers;
72a6628e59SEugene Zhulenev
73a6628e59SEugene Zhulenev auto updateBlockUsersInfo = [&](Operation *user) {
74a6628e59SEugene Zhulenev BlockUsersInfo &info = blockUsers[user->getBlock()];
75a6628e59SEugene Zhulenev info.users.push_back(user);
76a6628e59SEugene Zhulenev
77a6628e59SEugene Zhulenev if (auto addRef = dyn_cast<RuntimeAddRefOp>(user))
78a6628e59SEugene Zhulenev info.addRefs.push_back(addRef);
79a6628e59SEugene Zhulenev if (auto dropRef = dyn_cast<RuntimeDropRefOp>(user))
80a6628e59SEugene Zhulenev info.dropRefs.push_back(dropRef);
81a6628e59SEugene Zhulenev };
82a6628e59SEugene Zhulenev
83a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) {
84a6628e59SEugene Zhulenev while (user->getParentRegion() != definingRegion) {
85a6628e59SEugene Zhulenev updateBlockUsersInfo(user);
86a6628e59SEugene Zhulenev user = user->getParentOp();
87a6628e59SEugene Zhulenev assert(user != nullptr && "value user lies outside of the value region");
88a6628e59SEugene Zhulenev }
89a6628e59SEugene Zhulenev
90a6628e59SEugene Zhulenev updateBlockUsersInfo(user);
91a6628e59SEugene Zhulenev }
92a6628e59SEugene Zhulenev
93a6628e59SEugene Zhulenev // Sort all operations found in the block.
94a6628e59SEugene Zhulenev auto preprocessBlockUsersInfo = [](BlockUsersInfo &info) -> BlockUsersInfo & {
95a6628e59SEugene Zhulenev auto isBeforeInBlock = [](Operation *a, Operation *b) -> bool {
96a6628e59SEugene Zhulenev return a->isBeforeInBlock(b);
97a6628e59SEugene Zhulenev };
98a6628e59SEugene Zhulenev llvm::sort(info.addRefs, isBeforeInBlock);
99a6628e59SEugene Zhulenev llvm::sort(info.dropRefs, isBeforeInBlock);
100a6628e59SEugene Zhulenev llvm::sort(info.users, [&](Operation *a, Operation *b) -> bool {
101a6628e59SEugene Zhulenev return isBeforeInBlock(a, b);
102a6628e59SEugene Zhulenev });
103a6628e59SEugene Zhulenev
104a6628e59SEugene Zhulenev return info;
105a6628e59SEugene Zhulenev };
106a6628e59SEugene Zhulenev
107a6628e59SEugene Zhulenev // Find and erase matching pairs of `add_ref` / `drop_ref` operations in the
108a6628e59SEugene Zhulenev // blocks that modify the reference count of the `value`.
109a6628e59SEugene Zhulenev for (auto &kv : blockUsers) {
110a6628e59SEugene Zhulenev BlockUsersInfo &info = preprocessBlockUsersInfo(kv.second);
111a6628e59SEugene Zhulenev
112a6628e59SEugene Zhulenev for (RuntimeAddRefOp addRef : info.addRefs) {
113a6628e59SEugene Zhulenev for (RuntimeDropRefOp dropRef : info.dropRefs) {
114a6628e59SEugene Zhulenev // `drop_ref` operation after the `add_ref` with matching count.
115*a5aa7836SRiver Riddle if (dropRef.getCount() != addRef.getCount() ||
116a6628e59SEugene Zhulenev dropRef->isBeforeInBlock(addRef.getOperation()))
117a6628e59SEugene Zhulenev continue;
118a6628e59SEugene Zhulenev
1199ccdaac8SEugene Zhulenev // When reference counted value passed to a function as an argument,
1209ccdaac8SEugene Zhulenev // function takes ownership of +1 reference and it will drop it before
1219ccdaac8SEugene Zhulenev // returning.
1229ccdaac8SEugene Zhulenev //
1239ccdaac8SEugene Zhulenev // Example:
1249ccdaac8SEugene Zhulenev //
1259ccdaac8SEugene Zhulenev // %token = ... : !async.token
1269ccdaac8SEugene Zhulenev //
12792db09cdSEugene Zhulenev // async.runtime.add_ref %token {count = 1 : i64} : !async.token
1289ccdaac8SEugene Zhulenev // call @pass_token(%token: !async.token, ...)
1299ccdaac8SEugene Zhulenev //
1309ccdaac8SEugene Zhulenev // async.await %token : !async.token
13192db09cdSEugene Zhulenev // async.runtime.drop_ref %token {count = 1 : i64} : !async.token
1329ccdaac8SEugene Zhulenev //
1339ccdaac8SEugene Zhulenev // In this example if we'll cancel a pair of reference counting
1349ccdaac8SEugene Zhulenev // operations we might end up with a deallocated token when we'll
1359ccdaac8SEugene Zhulenev // reach `async.await` operation.
1369ccdaac8SEugene Zhulenev Operation *firstFunctionCallUser = nullptr;
1379ccdaac8SEugene Zhulenev Operation *lastNonFunctionCallUser = nullptr;
1389ccdaac8SEugene Zhulenev
1399ccdaac8SEugene Zhulenev for (Operation *user : info.users) {
1409ccdaac8SEugene Zhulenev // `user` operation lies after `addRef` ...
1419ccdaac8SEugene Zhulenev if (user == addRef || user->isBeforeInBlock(addRef))
1429ccdaac8SEugene Zhulenev continue;
1439ccdaac8SEugene Zhulenev // ... and before `dropRef`.
1449ccdaac8SEugene Zhulenev if (user == dropRef || dropRef->isBeforeInBlock(user))
1459ccdaac8SEugene Zhulenev break;
1469ccdaac8SEugene Zhulenev
1479ccdaac8SEugene Zhulenev // Find the first function call user of the reference counted value.
14823aa5a74SRiver Riddle Operation *functionCall = dyn_cast<func::CallOp>(user);
1499ccdaac8SEugene Zhulenev if (functionCall &&
1509ccdaac8SEugene Zhulenev (!firstFunctionCallUser ||
1519ccdaac8SEugene Zhulenev functionCall->isBeforeInBlock(firstFunctionCallUser))) {
1529ccdaac8SEugene Zhulenev firstFunctionCallUser = functionCall;
1539ccdaac8SEugene Zhulenev continue;
1549ccdaac8SEugene Zhulenev }
1559ccdaac8SEugene Zhulenev
1569ccdaac8SEugene Zhulenev // Find the last regular user of the reference counted value.
1579ccdaac8SEugene Zhulenev if (!functionCall &&
1589ccdaac8SEugene Zhulenev (!lastNonFunctionCallUser ||
1599ccdaac8SEugene Zhulenev lastNonFunctionCallUser->isBeforeInBlock(user))) {
1609ccdaac8SEugene Zhulenev lastNonFunctionCallUser = user;
1619ccdaac8SEugene Zhulenev continue;
1629ccdaac8SEugene Zhulenev }
1639ccdaac8SEugene Zhulenev }
1649ccdaac8SEugene Zhulenev
1659ccdaac8SEugene Zhulenev // Non function call user after the function call user of the reference
1669ccdaac8SEugene Zhulenev // counted value.
1679ccdaac8SEugene Zhulenev if (firstFunctionCallUser && lastNonFunctionCallUser &&
1689ccdaac8SEugene Zhulenev firstFunctionCallUser->isBeforeInBlock(lastNonFunctionCallUser))
1699ccdaac8SEugene Zhulenev continue;
1709ccdaac8SEugene Zhulenev
171a6628e59SEugene Zhulenev // Try to cancel the pair of `add_ref` and `drop_ref` operations.
172a6628e59SEugene Zhulenev auto emplaced = cancellable.try_emplace(dropRef.getOperation(),
173a6628e59SEugene Zhulenev addRef.getOperation());
174a6628e59SEugene Zhulenev
175a6628e59SEugene Zhulenev if (!emplaced.second) // `drop_ref` was already marked for removal
176a6628e59SEugene Zhulenev continue; // go to the next `drop_ref`
177a6628e59SEugene Zhulenev
178a6628e59SEugene Zhulenev if (emplaced.second) // successfully cancelled `add_ref` <-> `drop_ref`
179a6628e59SEugene Zhulenev break; // go to the next `add_ref`
180a6628e59SEugene Zhulenev }
181a6628e59SEugene Zhulenev }
182a6628e59SEugene Zhulenev }
183a6628e59SEugene Zhulenev
184a6628e59SEugene Zhulenev return success();
185a6628e59SEugene Zhulenev }
186a6628e59SEugene Zhulenev
runOnOperation()1878a316b00SEugene Zhulenev void AsyncRuntimeRefCountingOptPass::runOnOperation() {
1888a316b00SEugene Zhulenev Operation *op = getOperation();
189a6628e59SEugene Zhulenev
190a6628e59SEugene Zhulenev // Mapping from `dropRef.getOperation()` to `addRef.getOperation()`.
191a6628e59SEugene Zhulenev //
192a6628e59SEugene Zhulenev // Find all cancellable pairs of operation and erase them in the end to keep
193a6628e59SEugene Zhulenev // all iterators valid while we are walking the function operations.
194a6628e59SEugene Zhulenev llvm::SmallDenseMap<Operation *, Operation *> cancellable;
195a6628e59SEugene Zhulenev
196a6628e59SEugene Zhulenev // Optimize reference counting for values defined by block arguments.
1978a316b00SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
198a6628e59SEugene Zhulenev for (BlockArgument arg : block->getArguments())
199a6628e59SEugene Zhulenev if (isRefCounted(arg.getType()))
200a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(arg, cancellable)))
201a6628e59SEugene Zhulenev return WalkResult::interrupt();
202a6628e59SEugene Zhulenev
203a6628e59SEugene Zhulenev return WalkResult::advance();
204a6628e59SEugene Zhulenev });
205a6628e59SEugene Zhulenev
206a6628e59SEugene Zhulenev if (blockWalk.wasInterrupted())
207a6628e59SEugene Zhulenev signalPassFailure();
208a6628e59SEugene Zhulenev
209a6628e59SEugene Zhulenev // Optimize reference counting for values defined by operation results.
2108a316b00SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
211a6628e59SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i)
212a6628e59SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i]))
213a6628e59SEugene Zhulenev if (failed(optimizeReferenceCounting(op->getResult(i), cancellable)))
214a6628e59SEugene Zhulenev return WalkResult::interrupt();
215a6628e59SEugene Zhulenev
216a6628e59SEugene Zhulenev return WalkResult::advance();
217a6628e59SEugene Zhulenev });
218a6628e59SEugene Zhulenev
219a6628e59SEugene Zhulenev if (opWalk.wasInterrupted())
220a6628e59SEugene Zhulenev signalPassFailure();
221a6628e59SEugene Zhulenev
222a6628e59SEugene Zhulenev LLVM_DEBUG({
223a6628e59SEugene Zhulenev llvm::dbgs() << "Found " << cancellable.size()
224a6628e59SEugene Zhulenev << " cancellable reference counting operations\n";
225a6628e59SEugene Zhulenev });
226a6628e59SEugene Zhulenev
227a6628e59SEugene Zhulenev // Erase all cancellable `add_ref <-> drop_ref` operation pairs.
228a6628e59SEugene Zhulenev for (auto &kv : cancellable) {
229a6628e59SEugene Zhulenev kv.first->erase();
230a6628e59SEugene Zhulenev kv.second->erase();
231a6628e59SEugene Zhulenev }
232a6628e59SEugene Zhulenev }
233a6628e59SEugene Zhulenev
createAsyncRuntimeRefCountingOptPass()2348a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingOptPass() {
235a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingOptPass>();
236a6628e59SEugene Zhulenev }
237