1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2a6628e59SEugene Zhulenev //
3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a6628e59SEugene Zhulenev //
7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
8a6628e59SEugene Zhulenev //
9a6628e59SEugene Zhulenev // This file implements automatic reference counting for Async runtime
10a6628e59SEugene Zhulenev // operations and types.
11a6628e59SEugene Zhulenev //
12a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===//
13a6628e59SEugene Zhulenev
1467d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h"
1567d0d7acSMichele Scuttari
16a6628e59SEugene Zhulenev #include "mlir/Analysis/Liveness.h"
17a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
18ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1923aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
20a6628e59SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
21a6628e59SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
22a6628e59SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h"
24a6628e59SEugene Zhulenev
2567d0d7acSMichele Scuttari namespace mlir {
2667d0d7acSMichele Scuttari #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING
2767d0d7acSMichele Scuttari #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING
2867d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h.inc"
2967d0d7acSMichele Scuttari } // namespace mlir
302be8af8fSMichele Scuttari
31039b969bSMichele Scuttari #define DEBUG_TYPE "async-runtime-ref-counting"
32039b969bSMichele Scuttari
3367d0d7acSMichele Scuttari using namespace mlir;
3467d0d7acSMichele Scuttari using namespace mlir::async;
3567d0d7acSMichele Scuttari
36f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
37f57b2420SEugene Zhulenev // Utility functions shared by reference counting passes.
38f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
39f57b2420SEugene Zhulenev
40f57b2420SEugene Zhulenev // Drop the reference count immediately if the value has no uses.
dropRefIfNoUses(Value value,unsigned count=1)41f57b2420SEugene Zhulenev static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
42f57b2420SEugene Zhulenev if (!value.getUses().empty())
43f57b2420SEugene Zhulenev return failure();
44f57b2420SEugene Zhulenev
45f57b2420SEugene Zhulenev OpBuilder b(value.getContext());
46f57b2420SEugene Zhulenev
47f57b2420SEugene Zhulenev // Set insertion point after the operation producing a value, or at the
48f57b2420SEugene Zhulenev // beginning of the block if the value defined by the block argument.
49f57b2420SEugene Zhulenev if (Operation *op = value.getDefiningOp())
50f57b2420SEugene Zhulenev b.setInsertionPointAfter(op);
51f57b2420SEugene Zhulenev else
52f57b2420SEugene Zhulenev b.setInsertionPointToStart(value.getParentBlock());
53f57b2420SEugene Zhulenev
5492db09cdSEugene Zhulenev b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
55f57b2420SEugene Zhulenev return success();
56f57b2420SEugene Zhulenev }
57f57b2420SEugene Zhulenev
58f57b2420SEugene Zhulenev // Calls `addRefCounting` for every reference counted value defined by the
59f57b2420SEugene Zhulenev // operation `op` (block arguments and values defined in nested regions).
walkReferenceCountedValues(Operation * op,llvm::function_ref<LogicalResult (Value)> addRefCounting)60f57b2420SEugene Zhulenev static LogicalResult walkReferenceCountedValues(
61f57b2420SEugene Zhulenev Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
62f57b2420SEugene Zhulenev // Check that we do not have high level async operations in the IR because
63f57b2420SEugene Zhulenev // otherwise reference counting will produce incorrect results after high
64f57b2420SEugene Zhulenev // level async operations will be lowered to `async.runtime`
65f57b2420SEugene Zhulenev WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
66f57b2420SEugene Zhulenev if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
67f57b2420SEugene Zhulenev return WalkResult::advance();
68f57b2420SEugene Zhulenev
69f57b2420SEugene Zhulenev return op->emitError()
70f57b2420SEugene Zhulenev << "async operations must be lowered to async runtime operations";
71f57b2420SEugene Zhulenev });
72f57b2420SEugene Zhulenev
73f57b2420SEugene Zhulenev if (checkNoAsyncWalk.wasInterrupted())
74f57b2420SEugene Zhulenev return failure();
75f57b2420SEugene Zhulenev
76f57b2420SEugene Zhulenev // Add reference counting to block arguments.
77f57b2420SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
78f57b2420SEugene Zhulenev for (BlockArgument arg : block->getArguments())
79f57b2420SEugene Zhulenev if (isRefCounted(arg.getType()))
80f57b2420SEugene Zhulenev if (failed(addRefCounting(arg)))
81f57b2420SEugene Zhulenev return WalkResult::interrupt();
82f57b2420SEugene Zhulenev
83f57b2420SEugene Zhulenev return WalkResult::advance();
84f57b2420SEugene Zhulenev });
85f57b2420SEugene Zhulenev
86f57b2420SEugene Zhulenev if (blockWalk.wasInterrupted())
87f57b2420SEugene Zhulenev return failure();
88f57b2420SEugene Zhulenev
89f57b2420SEugene Zhulenev // Add reference counting to operation results.
90f57b2420SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
91f57b2420SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i)
92f57b2420SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i]))
93f57b2420SEugene Zhulenev if (failed(addRefCounting(op->getResult(i))))
94f57b2420SEugene Zhulenev return WalkResult::interrupt();
95f57b2420SEugene Zhulenev
96f57b2420SEugene Zhulenev return WalkResult::advance();
97f57b2420SEugene Zhulenev });
98f57b2420SEugene Zhulenev
99f57b2420SEugene Zhulenev if (opWalk.wasInterrupted())
100f57b2420SEugene Zhulenev return failure();
101f57b2420SEugene Zhulenev
102f57b2420SEugene Zhulenev return success();
103f57b2420SEugene Zhulenev }
104f57b2420SEugene Zhulenev
105f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
106f57b2420SEugene Zhulenev // Automatic reference counting based on the liveness analysis.
107f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
108f57b2420SEugene Zhulenev
109a6628e59SEugene Zhulenev namespace {
110a6628e59SEugene Zhulenev
111a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingPass
11267d0d7acSMichele Scuttari : public impl::AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
113a6628e59SEugene Zhulenev public:
114a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass() = default;
1158a316b00SEugene Zhulenev void runOnOperation() override;
116a6628e59SEugene Zhulenev
117a6628e59SEugene Zhulenev private:
118a6628e59SEugene Zhulenev /// Adds an automatic reference counting to the `value`.
119a6628e59SEugene Zhulenev ///
120a6628e59SEugene Zhulenev /// All values (token, group or value) are semantically created with a
121a6628e59SEugene Zhulenev /// reference count of +1 and it is the responsibility of the async value user
122a6628e59SEugene Zhulenev /// to place the `add_ref` and `drop_ref` operations to ensure that the value
123a6628e59SEugene Zhulenev /// is destroyed after the last use.
124a6628e59SEugene Zhulenev ///
125a6628e59SEugene Zhulenev /// The function returns failure if it can't deduce the locations where
126a6628e59SEugene Zhulenev /// to place the reference counting operations.
127a6628e59SEugene Zhulenev ///
128a6628e59SEugene Zhulenev /// Async values "semantically created" when:
129a6628e59SEugene Zhulenev /// 1. Operation returns async result (e.g. `async.runtime.create`)
130a6628e59SEugene Zhulenev /// 2. Async value passed in as a block argument (or function argument,
131a6628e59SEugene Zhulenev /// because function arguments are just entry block arguments)
132a6628e59SEugene Zhulenev ///
133a6628e59SEugene Zhulenev /// Passing async value as a function argument (or block argument) does not
134a6628e59SEugene Zhulenev /// really mean that a new async value is created, it only means that the
135a6628e59SEugene Zhulenev /// caller of a function transfered ownership of `+1` reference to the callee.
136a6628e59SEugene Zhulenev /// It is convenient to think that from the callee perspective async value was
137a6628e59SEugene Zhulenev /// "created" with `+1` reference by the block argument.
138a6628e59SEugene Zhulenev ///
139a6628e59SEugene Zhulenev /// Automatic reference counting algorithm outline:
140a6628e59SEugene Zhulenev ///
141a6628e59SEugene Zhulenev /// #1 Insert `drop_ref` operations after last use of the `value`.
142a6628e59SEugene Zhulenev /// #2 Insert `add_ref` operations before functions calls with reference
143a6628e59SEugene Zhulenev /// counted `value` operand (newly created `+1` reference will be
144a6628e59SEugene Zhulenev /// transferred to the callee).
145a6628e59SEugene Zhulenev /// #3 Verify that divergent control flow does not lead to leaked reference
146a6628e59SEugene Zhulenev /// counted objects.
147a6628e59SEugene Zhulenev ///
148a6628e59SEugene Zhulenev /// Async runtime reference counting optimization pass will optimize away
149a6628e59SEugene Zhulenev /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
150a6628e59SEugene Zhulenev /// strategy (see `async-runtime-ref-counting-opt`).
151a6628e59SEugene Zhulenev LogicalResult addAutomaticRefCounting(Value value);
152a6628e59SEugene Zhulenev
153a6628e59SEugene Zhulenev /// (#1) Adds the `drop_ref` operation after the last use of the `value`
154a6628e59SEugene Zhulenev /// relying on the liveness analysis.
155a6628e59SEugene Zhulenev ///
156a6628e59SEugene Zhulenev /// If the `value` is in the block `liveIn` set and it is not in the block
157a6628e59SEugene Zhulenev /// `liveOut` set, it means that it "dies" in the block. We find the last
158a6628e59SEugene Zhulenev /// use of the value in such block and:
159a6628e59SEugene Zhulenev ///
160a6628e59SEugene Zhulenev /// 1. If the last user is a `ReturnLike` operation we do nothing, because
161a6628e59SEugene Zhulenev /// it forwards the ownership to the caller.
162a6628e59SEugene Zhulenev /// 2. Otherwise we add a `drop_ref` operation immediately after the last
163a6628e59SEugene Zhulenev /// use.
164a6628e59SEugene Zhulenev LogicalResult addDropRefAfterLastUse(Value value);
165a6628e59SEugene Zhulenev
166a6628e59SEugene Zhulenev /// (#2) Adds the `add_ref` operation before the function call taking `value`
167a6628e59SEugene Zhulenev /// operand to ensure that the value passed to the function entry block
168a6628e59SEugene Zhulenev /// has a `+1` reference count.
169a6628e59SEugene Zhulenev LogicalResult addAddRefBeforeFunctionCall(Value value);
170a6628e59SEugene Zhulenev
171c412979cSEugene Zhulenev /// (#3) Adds the `drop_ref` operation to account for successor blocks with
172c412979cSEugene Zhulenev /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
173c412979cSEugene Zhulenev /// successor blocks.
174a6628e59SEugene Zhulenev ///
175a6628e59SEugene Zhulenev /// Example:
176a6628e59SEugene Zhulenev ///
177a6628e59SEugene Zhulenev /// ^entry:
178a6628e59SEugene Zhulenev /// %token = async.runtime.create : !async.token
179ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1, ^bb2
180a6628e59SEugene Zhulenev /// ^bb1:
181a6628e59SEugene Zhulenev /// async.runtime.await %token
182c412979cSEugene Zhulenev /// async.runtime.drop_ref %token
183ace01605SRiver Riddle /// cf.br ^bb2
184a6628e59SEugene Zhulenev /// ^bb2:
185a6628e59SEugene Zhulenev /// return
186a6628e59SEugene Zhulenev ///
187c412979cSEugene Zhulenev /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
188c412979cSEugene Zhulenev /// to branch into a special "reference counting block" from the ^entry that
189c412979cSEugene Zhulenev /// will have a `drop_ref` operation, and then branch into the ^bb2.
190c412979cSEugene Zhulenev ///
191c412979cSEugene Zhulenev /// After transformation:
192c412979cSEugene Zhulenev ///
193c412979cSEugene Zhulenev /// ^entry:
194c412979cSEugene Zhulenev /// %token = async.runtime.create : !async.token
195ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1, ^reference_counting
196c412979cSEugene Zhulenev /// ^bb1:
197c412979cSEugene Zhulenev /// async.runtime.await %token
198c412979cSEugene Zhulenev /// async.runtime.drop_ref %token
199ace01605SRiver Riddle /// cf.br ^bb2
200c412979cSEugene Zhulenev /// ^reference_counting:
201c412979cSEugene Zhulenev /// async.runtime.drop_ref %token
202ace01605SRiver Riddle /// cf.br ^bb2
203c412979cSEugene Zhulenev /// ^bb2:
204c412979cSEugene Zhulenev /// return
205a6628e59SEugene Zhulenev ///
206a6628e59SEugene Zhulenev /// An exception to this rule are blocks with `async.coro.suspend` terminator,
207a6628e59SEugene Zhulenev /// because in Async to LLVM lowering it is guaranteed that the control flow
208a6628e59SEugene Zhulenev /// will jump into the resume block, and then follow into the cleanup and
209a6628e59SEugene Zhulenev /// suspend blocks.
210a6628e59SEugene Zhulenev ///
211a6628e59SEugene Zhulenev /// Example:
212a6628e59SEugene Zhulenev ///
213a6628e59SEugene Zhulenev /// ^entry(%value: !async.value<f32>):
214a6628e59SEugene Zhulenev /// async.runtime.await_and_resume %value, %hdl : !async.value<f32>
215a6628e59SEugene Zhulenev /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
216a6628e59SEugene Zhulenev /// ^resume:
217a6628e59SEugene Zhulenev /// %0 = async.runtime.load %value
218ace01605SRiver Riddle /// cf.br ^cleanup
219a6628e59SEugene Zhulenev /// ^cleanup:
220a6628e59SEugene Zhulenev /// ...
221a6628e59SEugene Zhulenev /// ^suspend:
222a6628e59SEugene Zhulenev /// ...
223a6628e59SEugene Zhulenev ///
224a6628e59SEugene Zhulenev /// Although cleanup and suspend blocks do not have the `value` in the
225a6628e59SEugene Zhulenev /// `liveIn` set, it is guaranteed that execution will eventually continue in
226a6628e59SEugene Zhulenev /// the resume block (we never explicitly destroy coroutines).
227c412979cSEugene Zhulenev LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
228a6628e59SEugene Zhulenev };
229a6628e59SEugene Zhulenev
230a6628e59SEugene Zhulenev } // namespace
231a6628e59SEugene Zhulenev
addDropRefAfterLastUse(Value value)232a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
233a6628e59SEugene Zhulenev OpBuilder builder(value.getContext());
234a6628e59SEugene Zhulenev Location loc = value.getLoc();
235a6628e59SEugene Zhulenev
236a6628e59SEugene Zhulenev // Use liveness analysis to find the placement of `drop_ref`operation.
237a6628e59SEugene Zhulenev auto &liveness = getAnalysis<Liveness>();
238a6628e59SEugene Zhulenev
239a6628e59SEugene Zhulenev // We analyse only the blocks of the region that defines the `value`, and do
240a6628e59SEugene Zhulenev // not check nested blocks attached to operations.
241a6628e59SEugene Zhulenev //
242a6628e59SEugene Zhulenev // By analyzing only the `definingRegion` CFG we potentially loose an
243a6628e59SEugene Zhulenev // opportunity to drop the reference count earlier and can extend the lifetime
244a6628e59SEugene Zhulenev // of reference counted value longer then it is really required.
245a6628e59SEugene Zhulenev //
246a6628e59SEugene Zhulenev // We also assume that all nested regions finish their execution before the
247a6628e59SEugene Zhulenev // completion of the owner operation. The only exception to this rule is
248a6628e59SEugene Zhulenev // `async.execute` operation, and we verify that they are lowered to the
249a6628e59SEugene Zhulenev // `async.runtime` operations before adding automatic reference counting.
250a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion();
251a6628e59SEugene Zhulenev
252a6628e59SEugene Zhulenev // Last users of the `value` inside all blocks where the value dies.
253a6628e59SEugene Zhulenev llvm::SmallSet<Operation *, 4> lastUsers;
254a6628e59SEugene Zhulenev
255a6628e59SEugene Zhulenev // Find blocks in the `definingRegion` that have users of the `value` (if
256a6628e59SEugene Zhulenev // there are multiple users in the block, which one will be selected is
257a6628e59SEugene Zhulenev // undefined). User operation might be not the actual user of the value, but
258a6628e59SEugene Zhulenev // the operation in the block that has a "real user" in one of the attached
259a6628e59SEugene Zhulenev // regions.
260a6628e59SEugene Zhulenev llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
261a6628e59SEugene Zhulenev
262a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) {
263a6628e59SEugene Zhulenev Block *userBlock = user->getBlock();
264a6628e59SEugene Zhulenev Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
265a6628e59SEugene Zhulenev usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
266a6628e59SEugene Zhulenev assert(ancestor && "ancestor block must be not null");
267a6628e59SEugene Zhulenev assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
268a6628e59SEugene Zhulenev }
269a6628e59SEugene Zhulenev
270a6628e59SEugene Zhulenev // Find blocks where the `value` dies: the value is in `liveIn` set and not
271a6628e59SEugene Zhulenev // in the `liveOut` set. We place `drop_ref` immediately after the last use
272a6628e59SEugene Zhulenev // of the `value` in such regions (after handling few special cases).
273a6628e59SEugene Zhulenev //
274a6628e59SEugene Zhulenev // We do not traverse all the blocks in the `definingRegion`, because the
275a6628e59SEugene Zhulenev // `value` can be in the live in set only if it has users in the block, or it
276a6628e59SEugene Zhulenev // is defined in the block.
277a6628e59SEugene Zhulenev //
278a6628e59SEugene Zhulenev // Values with zero users (only definition) handled explicitly above.
279a6628e59SEugene Zhulenev for (auto &blockAndUser : usersInTheBlocks) {
280a6628e59SEugene Zhulenev Block *block = blockAndUser.getFirst();
281a6628e59SEugene Zhulenev Operation *userInTheBlock = blockAndUser.getSecond();
282a6628e59SEugene Zhulenev
283a6628e59SEugene Zhulenev const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
284a6628e59SEugene Zhulenev
285a6628e59SEugene Zhulenev // Value must be in the live input set or defined in the block.
286a6628e59SEugene Zhulenev assert(blockLiveness->isLiveIn(value) ||
287a6628e59SEugene Zhulenev blockLiveness->getBlock() == value.getParentBlock());
288a6628e59SEugene Zhulenev
289a6628e59SEugene Zhulenev // If value is in the live out set, it means it doesn't "die" in the block.
290a6628e59SEugene Zhulenev if (blockLiveness->isLiveOut(value))
291a6628e59SEugene Zhulenev continue;
292a6628e59SEugene Zhulenev
293a6628e59SEugene Zhulenev // At this point we proved that `value` dies in the `block`. Find the last
294a6628e59SEugene Zhulenev // use of the `value` inside the `block`, this is where it "dies".
295a6628e59SEugene Zhulenev Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
296a6628e59SEugene Zhulenev assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
297a6628e59SEugene Zhulenev lastUsers.insert(lastUser);
298a6628e59SEugene Zhulenev }
299a6628e59SEugene Zhulenev
300a6628e59SEugene Zhulenev // Process all the last users of the `value` inside each block where the value
301a6628e59SEugene Zhulenev // dies.
302a6628e59SEugene Zhulenev for (Operation *lastUser : lastUsers) {
303a6628e59SEugene Zhulenev // Return like operations forward reference count.
304a6628e59SEugene Zhulenev if (lastUser->hasTrait<OpTrait::ReturnLike>())
305a6628e59SEugene Zhulenev continue;
306a6628e59SEugene Zhulenev
307a6628e59SEugene Zhulenev // We can't currently handle other types of terminators.
308a6628e59SEugene Zhulenev if (lastUser->hasTrait<OpTrait::IsTerminator>())
309a6628e59SEugene Zhulenev return lastUser->emitError() << "async reference counting can't handle "
310a6628e59SEugene Zhulenev "terminators that are not ReturnLike";
311a6628e59SEugene Zhulenev
312a6628e59SEugene Zhulenev // Add a drop_ref immediately after the last user.
313a6628e59SEugene Zhulenev builder.setInsertionPointAfter(lastUser);
31492db09cdSEugene Zhulenev builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
315a6628e59SEugene Zhulenev }
316a6628e59SEugene Zhulenev
317a6628e59SEugene Zhulenev return success();
318a6628e59SEugene Zhulenev }
319a6628e59SEugene Zhulenev
320a6628e59SEugene Zhulenev LogicalResult
addAddRefBeforeFunctionCall(Value value)321a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
322a6628e59SEugene Zhulenev OpBuilder builder(value.getContext());
323a6628e59SEugene Zhulenev Location loc = value.getLoc();
324a6628e59SEugene Zhulenev
325a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) {
32623aa5a74SRiver Riddle if (!isa<func::CallOp>(user))
327a6628e59SEugene Zhulenev continue;
328a6628e59SEugene Zhulenev
329a6628e59SEugene Zhulenev // Add a reference before the function call to pass the value at `+1`
330a6628e59SEugene Zhulenev // reference to the function entry block.
331a6628e59SEugene Zhulenev builder.setInsertionPoint(user);
33292db09cdSEugene Zhulenev builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
333a6628e59SEugene Zhulenev }
334a6628e59SEugene Zhulenev
335a6628e59SEugene Zhulenev return success();
336a6628e59SEugene Zhulenev }
337a6628e59SEugene Zhulenev
338c412979cSEugene Zhulenev LogicalResult
addDropRefInDivergentLivenessSuccessor(Value value)339c412979cSEugene Zhulenev AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
340c412979cSEugene Zhulenev Value value) {
341c412979cSEugene Zhulenev using BlockSet = llvm::SmallPtrSet<Block *, 4>;
342c412979cSEugene Zhulenev
343a6628e59SEugene Zhulenev OpBuilder builder(value.getContext());
344a6628e59SEugene Zhulenev
345c412979cSEugene Zhulenev // If a block has successors with different `liveIn` property of the `value`,
346c412979cSEugene Zhulenev // record block successors that do not thave the `value` in the `liveIn` set.
347c412979cSEugene Zhulenev llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
348a6628e59SEugene Zhulenev
349a6628e59SEugene Zhulenev // Use liveness analysis to find the placement of `drop_ref`operation.
350a6628e59SEugene Zhulenev auto &liveness = getAnalysis<Liveness>();
351a6628e59SEugene Zhulenev
352a6628e59SEugene Zhulenev // Because we only add `drop_ref` operations to the region that defines the
353a6628e59SEugene Zhulenev // `value` we can only process CFG for the same region.
354a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion();
355a6628e59SEugene Zhulenev
356a6628e59SEugene Zhulenev // Collect blocks with successors with mismatching `liveIn` sets.
357a6628e59SEugene Zhulenev for (Block &block : definingRegion->getBlocks()) {
358a6628e59SEugene Zhulenev const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
359a6628e59SEugene Zhulenev
360a6628e59SEugene Zhulenev // Skip the block if value is not in the `liveOut` set.
3619136b7d0SEugene Zhulenev if (!blockLiveness || !blockLiveness->isLiveOut(value))
362a6628e59SEugene Zhulenev continue;
363a6628e59SEugene Zhulenev
364c412979cSEugene Zhulenev BlockSet liveInSuccessors; // `value` is in `liveIn` set
365c412979cSEugene Zhulenev BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
366a6628e59SEugene Zhulenev
367a6628e59SEugene Zhulenev // Collect successors that do not have `value` in the `liveIn` set.
368a6628e59SEugene Zhulenev for (Block *successor : block.getSuccessors()) {
369a6628e59SEugene Zhulenev const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
3709136b7d0SEugene Zhulenev if (succLiveness && succLiveness->isLiveIn(value))
371a6628e59SEugene Zhulenev liveInSuccessors.insert(successor);
372a6628e59SEugene Zhulenev else
373a6628e59SEugene Zhulenev noLiveInSuccessors.insert(successor);
374a6628e59SEugene Zhulenev }
375a6628e59SEugene Zhulenev
376a6628e59SEugene Zhulenev // Block has successors with different `liveIn` property of the `value`.
377a6628e59SEugene Zhulenev if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
378c412979cSEugene Zhulenev divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
379a6628e59SEugene Zhulenev }
380a6628e59SEugene Zhulenev
381c412979cSEugene Zhulenev // Try to insert `dropRef` operations to handle blocks with divergent liveness
382c412979cSEugene Zhulenev // in successors blocks.
383c412979cSEugene Zhulenev for (auto kv : divergentLivenessBlocks) {
384c412979cSEugene Zhulenev Block *block = kv.getFirst();
385c412979cSEugene Zhulenev BlockSet &successors = kv.getSecond();
386c412979cSEugene Zhulenev
387c412979cSEugene Zhulenev // Coroutine suspension is a special case terminator for wich we do not
388c412979cSEugene Zhulenev // need to create additional reference counting (see details above).
389a6628e59SEugene Zhulenev Operation *terminator = block->getTerminator();
390a6628e59SEugene Zhulenev if (isa<CoroSuspendOp>(terminator))
391a6628e59SEugene Zhulenev continue;
392a6628e59SEugene Zhulenev
393c412979cSEugene Zhulenev // We only support successor blocks with empty block argument list.
394c412979cSEugene Zhulenev auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
395c412979cSEugene Zhulenev if (llvm::any_of(successors, hasArgs))
396c412979cSEugene Zhulenev return terminator->emitOpError()
397c412979cSEugene Zhulenev << "successor have different `liveIn` property of the reference "
398c412979cSEugene Zhulenev "counted value";
399c412979cSEugene Zhulenev
400c412979cSEugene Zhulenev // Make sure that `dropRef` operation is called when branched into the
401c412979cSEugene Zhulenev // successor block without `value` in the `liveIn` set.
402c412979cSEugene Zhulenev for (Block *successor : successors) {
403c412979cSEugene Zhulenev // If successor has a unique predecessor, it is safe to create `dropRef`
404c412979cSEugene Zhulenev // operations directly in the successor block.
405c412979cSEugene Zhulenev //
406c412979cSEugene Zhulenev // Otherwise we need to create a special block for reference counting
407c412979cSEugene Zhulenev // operations, and branch from it to the original successor block.
408c412979cSEugene Zhulenev Block *refCountingBlock = nullptr;
409c412979cSEugene Zhulenev
410c412979cSEugene Zhulenev if (successor->getUniquePredecessor() == block) {
411c412979cSEugene Zhulenev refCountingBlock = successor;
412c412979cSEugene Zhulenev } else {
413c412979cSEugene Zhulenev refCountingBlock = &successor->getParent()->emplaceBlock();
414c412979cSEugene Zhulenev refCountingBlock->moveBefore(successor);
415c412979cSEugene Zhulenev OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
416ace01605SRiver Riddle builder.create<cf::BranchOp>(value.getLoc(), successor);
417c412979cSEugene Zhulenev }
418c412979cSEugene Zhulenev
419c412979cSEugene Zhulenev OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
420c412979cSEugene Zhulenev builder.create<RuntimeDropRefOp>(value.getLoc(), value,
42192db09cdSEugene Zhulenev builder.getI64IntegerAttr(1));
422c412979cSEugene Zhulenev
423c412979cSEugene Zhulenev // No need to update the terminator operation.
424c412979cSEugene Zhulenev if (successor == refCountingBlock)
425c412979cSEugene Zhulenev continue;
426c412979cSEugene Zhulenev
427c412979cSEugene Zhulenev // Update terminator `successor` block to `refCountingBlock`.
42889de9cc8SMehdi Amini for (const auto &pair : llvm::enumerate(terminator->getSuccessors()))
429c412979cSEugene Zhulenev if (pair.value() == successor)
430c412979cSEugene Zhulenev terminator->setSuccessor(refCountingBlock, pair.index());
431c412979cSEugene Zhulenev }
432a6628e59SEugene Zhulenev }
433a6628e59SEugene Zhulenev
434a6628e59SEugene Zhulenev return success();
435a6628e59SEugene Zhulenev }
436a6628e59SEugene Zhulenev
437a6628e59SEugene Zhulenev LogicalResult
addAutomaticRefCounting(Value value)438a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
439f57b2420SEugene Zhulenev // Short-circuit reference counting for values without uses.
440f57b2420SEugene Zhulenev if (succeeded(dropRefIfNoUses(value)))
441a6628e59SEugene Zhulenev return success();
442a6628e59SEugene Zhulenev
443a6628e59SEugene Zhulenev // Add `drop_ref` operations based on the liveness analysis.
444a6628e59SEugene Zhulenev if (failed(addDropRefAfterLastUse(value)))
445a6628e59SEugene Zhulenev return failure();
446a6628e59SEugene Zhulenev
447a6628e59SEugene Zhulenev // Add `add_ref` operations before function calls.
448a6628e59SEugene Zhulenev if (failed(addAddRefBeforeFunctionCall(value)))
449a6628e59SEugene Zhulenev return failure();
450a6628e59SEugene Zhulenev
451c412979cSEugene Zhulenev // Add `drop_ref` operations to successors with divergent `value` liveness.
452c412979cSEugene Zhulenev if (failed(addDropRefInDivergentLivenessSuccessor(value)))
453a6628e59SEugene Zhulenev return failure();
454a6628e59SEugene Zhulenev
455a6628e59SEugene Zhulenev return success();
456a6628e59SEugene Zhulenev }
457a6628e59SEugene Zhulenev
runOnOperation()4588a316b00SEugene Zhulenev void AsyncRuntimeRefCountingPass::runOnOperation() {
459f57b2420SEugene Zhulenev auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
460f57b2420SEugene Zhulenev if (failed(walkReferenceCountedValues(getOperation(), functor)))
461a6628e59SEugene Zhulenev signalPassFailure();
462a6628e59SEugene Zhulenev }
463a6628e59SEugene Zhulenev
464f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
465f57b2420SEugene Zhulenev // Reference counting based on the user defined policy.
466f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===//
467f57b2420SEugene Zhulenev
468f57b2420SEugene Zhulenev namespace {
469f57b2420SEugene Zhulenev
470f57b2420SEugene Zhulenev class AsyncRuntimePolicyBasedRefCountingPass
47167d0d7acSMichele Scuttari : public impl::AsyncRuntimePolicyBasedRefCountingBase<
472f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass> {
473f57b2420SEugene Zhulenev public:
AsyncRuntimePolicyBasedRefCountingPass()474f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
475f57b2420SEugene Zhulenev
476f57b2420SEugene Zhulenev void runOnOperation() override;
477f57b2420SEugene Zhulenev
478f57b2420SEugene Zhulenev private:
479f57b2420SEugene Zhulenev // Adds a reference counting operations for all uses of the `value` according
480f57b2420SEugene Zhulenev // to the reference counting policy.
481f57b2420SEugene Zhulenev LogicalResult addRefCounting(Value value);
482f57b2420SEugene Zhulenev
483f57b2420SEugene Zhulenev void initializeDefaultPolicy();
484f57b2420SEugene Zhulenev
485f57b2420SEugene Zhulenev llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
486f57b2420SEugene Zhulenev };
487f57b2420SEugene Zhulenev
488f57b2420SEugene Zhulenev } // namespace
489f57b2420SEugene Zhulenev
490f57b2420SEugene Zhulenev LogicalResult
addRefCounting(Value value)491f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
492f57b2420SEugene Zhulenev // Short-circuit reference counting for values without uses.
493f57b2420SEugene Zhulenev if (succeeded(dropRefIfNoUses(value)))
494f57b2420SEugene Zhulenev return success();
495f57b2420SEugene Zhulenev
496f57b2420SEugene Zhulenev OpBuilder b(value.getContext());
497f57b2420SEugene Zhulenev
498f57b2420SEugene Zhulenev // Consult the user defined policy for every value use.
499f57b2420SEugene Zhulenev for (OpOperand &operand : value.getUses()) {
500f57b2420SEugene Zhulenev Location loc = operand.getOwner()->getLoc();
501f57b2420SEugene Zhulenev
502f57b2420SEugene Zhulenev for (auto &func : policy) {
503f57b2420SEugene Zhulenev FailureOr<int> refCount = func(operand);
504f57b2420SEugene Zhulenev if (failed(refCount))
505f57b2420SEugene Zhulenev return failure();
506f57b2420SEugene Zhulenev
5076d5fc1e3SKazu Hirata int cnt = *refCount;
508f57b2420SEugene Zhulenev
509f57b2420SEugene Zhulenev // Create `add_ref` operation before the operand owner.
510f57b2420SEugene Zhulenev if (cnt > 0) {
511f57b2420SEugene Zhulenev b.setInsertionPoint(operand.getOwner());
51292db09cdSEugene Zhulenev b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
513f57b2420SEugene Zhulenev }
514f57b2420SEugene Zhulenev
515f57b2420SEugene Zhulenev // Create `drop_ref` operation after the operand owner.
516f57b2420SEugene Zhulenev if (cnt < 0) {
517f57b2420SEugene Zhulenev b.setInsertionPointAfter(operand.getOwner());
51892db09cdSEugene Zhulenev b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
519f57b2420SEugene Zhulenev }
520f57b2420SEugene Zhulenev }
521f57b2420SEugene Zhulenev }
522f57b2420SEugene Zhulenev
523f57b2420SEugene Zhulenev return success();
524f57b2420SEugene Zhulenev }
525f57b2420SEugene Zhulenev
initializeDefaultPolicy()526f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
527f57b2420SEugene Zhulenev policy.push_back([](OpOperand &operand) -> FailureOr<int> {
528f57b2420SEugene Zhulenev Operation *op = operand.getOwner();
529f57b2420SEugene Zhulenev Type type = operand.get().getType();
530f57b2420SEugene Zhulenev
5315550c821STres Popp bool isToken = isa<TokenType>(type);
5325550c821STres Popp bool isGroup = isa<GroupType>(type);
5335550c821STres Popp bool isValue = isa<ValueType>(type);
534f57b2420SEugene Zhulenev
535f57b2420SEugene Zhulenev // Drop reference after async token or group error check (coro await).
536*0a0aff2dSMikhail Goncharov if (dyn_cast<RuntimeIsErrorOp>(op))
537f57b2420SEugene Zhulenev return (isToken || isGroup) ? -1 : 0;
538f57b2420SEugene Zhulenev
539f57b2420SEugene Zhulenev // Drop reference after async value load.
540*0a0aff2dSMikhail Goncharov if (dyn_cast<RuntimeLoadOp>(op))
541f57b2420SEugene Zhulenev return isValue ? -1 : 0;
542f57b2420SEugene Zhulenev
543f57b2420SEugene Zhulenev // Drop reference after async token added to the group.
544*0a0aff2dSMikhail Goncharov if (dyn_cast<RuntimeAddToGroupOp>(op))
545f57b2420SEugene Zhulenev return isToken ? -1 : 0;
546f57b2420SEugene Zhulenev
547f57b2420SEugene Zhulenev return 0;
548f57b2420SEugene Zhulenev });
549f57b2420SEugene Zhulenev }
550f57b2420SEugene Zhulenev
runOnOperation()551f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
552f57b2420SEugene Zhulenev auto functor = [&](Value value) { return addRefCounting(value); };
553f57b2420SEugene Zhulenev if (failed(walkReferenceCountedValues(getOperation(), functor)))
554f57b2420SEugene Zhulenev signalPassFailure();
555f57b2420SEugene Zhulenev }
556f57b2420SEugene Zhulenev
557f57b2420SEugene Zhulenev //----------------------------------------------------------------------------//
558f57b2420SEugene Zhulenev
createAsyncRuntimeRefCountingPass()5598a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
560a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingPass>();
561a6628e59SEugene Zhulenev }
562f57b2420SEugene Zhulenev
createAsyncRuntimePolicyBasedRefCountingPass()563f57b2420SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
564f57b2420SEugene Zhulenev return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
565f57b2420SEugene Zhulenev }
566