xref: /llvm-project/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp (revision 0a0aff2d24518b1931a1a333c32a13028a0f88f6)
1 //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements automatic reference counting for Async runtime
10 // operations and types.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Async/Passes.h"
15 
16 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/Dialect/Async/IR/Async.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/SmallSet.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING
27 #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING
28 #include "mlir/Dialect/Async/Passes.h.inc"
29 } // namespace mlir
30 
31 #define DEBUG_TYPE "async-runtime-ref-counting"
32 
33 using namespace mlir;
34 using namespace mlir::async;
35 
36 //===----------------------------------------------------------------------===//
37 // Utility functions shared by reference counting passes.
38 //===----------------------------------------------------------------------===//
39 
40 // Drop the reference count immediately if the value has no uses.
dropRefIfNoUses(Value value,unsigned count=1)41 static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) {
42   if (!value.getUses().empty())
43     return failure();
44 
45   OpBuilder b(value.getContext());
46 
47   // Set insertion point after the operation producing a value, or at the
48   // beginning of the block if the value defined by the block argument.
49   if (Operation *op = value.getDefiningOp())
50     b.setInsertionPointAfter(op);
51   else
52     b.setInsertionPointToStart(value.getParentBlock());
53 
54   b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1));
55   return success();
56 }
57 
58 // Calls `addRefCounting` for every reference counted value defined by the
59 // operation `op` (block arguments and values defined in nested regions).
walkReferenceCountedValues(Operation * op,llvm::function_ref<LogicalResult (Value)> addRefCounting)60 static LogicalResult walkReferenceCountedValues(
61     Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) {
62   // Check that we do not have high level async operations in the IR because
63   // otherwise reference counting will produce incorrect results after high
64   // level async operations will be lowered to `async.runtime`
65   WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult {
66     if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op))
67       return WalkResult::advance();
68 
69     return op->emitError()
70            << "async operations must be lowered to async runtime operations";
71   });
72 
73   if (checkNoAsyncWalk.wasInterrupted())
74     return failure();
75 
76   // Add reference counting to block arguments.
77   WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult {
78     for (BlockArgument arg : block->getArguments())
79       if (isRefCounted(arg.getType()))
80         if (failed(addRefCounting(arg)))
81           return WalkResult::interrupt();
82 
83     return WalkResult::advance();
84   });
85 
86   if (blockWalk.wasInterrupted())
87     return failure();
88 
89   // Add reference counting to operation results.
90   WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult {
91     for (unsigned i = 0; i < op->getNumResults(); ++i)
92       if (isRefCounted(op->getResultTypes()[i]))
93         if (failed(addRefCounting(op->getResult(i))))
94           return WalkResult::interrupt();
95 
96     return WalkResult::advance();
97   });
98 
99   if (opWalk.wasInterrupted())
100     return failure();
101 
102   return success();
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // Automatic reference counting based on the liveness analysis.
107 //===----------------------------------------------------------------------===//
108 
109 namespace {
110 
111 class AsyncRuntimeRefCountingPass
112     : public impl::AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> {
113 public:
114   AsyncRuntimeRefCountingPass() = default;
115   void runOnOperation() override;
116 
117 private:
118   /// Adds an automatic reference counting to the `value`.
119   ///
120   /// All values (token, group or value) are semantically created with a
121   /// reference count of +1 and it is the responsibility of the async value user
122   /// to place the `add_ref` and `drop_ref` operations to ensure that the value
123   /// is destroyed after the last use.
124   ///
125   /// The function returns failure if it can't deduce the locations where
126   /// to place the reference counting operations.
127   ///
128   /// Async values "semantically created" when:
129   ///   1. Operation returns async result (e.g. `async.runtime.create`)
130   ///   2. Async value passed in as a block argument (or function argument,
131   ///      because function arguments are just entry block arguments)
132   ///
133   /// Passing async value as a function argument (or block argument) does not
134   /// really mean that a new async value is created, it only means that the
135   /// caller of a function transfered ownership of `+1` reference to the callee.
136   /// It is convenient to think that from the callee perspective async value was
137   /// "created" with `+1` reference by the block argument.
138   ///
139   /// Automatic reference counting algorithm outline:
140   ///
141   /// #1 Insert `drop_ref` operations after last use of the `value`.
142   /// #2 Insert `add_ref` operations before functions calls with reference
143   ///    counted `value` operand (newly created `+1` reference will be
144   ///    transferred to the callee).
145   /// #3 Verify that divergent control flow does not lead to leaked reference
146   ///    counted objects.
147   ///
148   /// Async runtime reference counting optimization pass will optimize away
149   /// some of the redundant `add_ref` and `drop_ref` operations inserted by this
150   /// strategy (see `async-runtime-ref-counting-opt`).
151   LogicalResult addAutomaticRefCounting(Value value);
152 
153   /// (#1) Adds the `drop_ref` operation after the last use of the `value`
154   /// relying on the liveness analysis.
155   ///
156   /// If the `value` is in the block `liveIn` set and it is not in the block
157   /// `liveOut` set, it means that it "dies" in the block. We find the last
158   /// use of the value in such block and:
159   ///
160   ///   1. If the last user is a `ReturnLike` operation we do nothing, because
161   ///      it forwards the ownership to the caller.
162   ///   2. Otherwise we add a `drop_ref` operation immediately after the last
163   ///      use.
164   LogicalResult addDropRefAfterLastUse(Value value);
165 
166   /// (#2) Adds the `add_ref` operation before the function call taking `value`
167   /// operand to ensure that the value passed to the function entry block
168   /// has a `+1` reference count.
169   LogicalResult addAddRefBeforeFunctionCall(Value value);
170 
171   /// (#3) Adds the `drop_ref` operation to account for successor blocks with
172   /// divergent `liveIn` property: `value` is not in the `liveIn` set of all
173   /// successor blocks.
174   ///
175   /// Example:
176   ///
177   ///   ^entry:
178   ///     %token = async.runtime.create : !async.token
179   ///     cf.cond_br %cond, ^bb1, ^bb2
180   ///   ^bb1:
181   ///     async.runtime.await %token
182   ///     async.runtime.drop_ref %token
183   ///     cf.br ^bb2
184   ///   ^bb2:
185   ///     return
186   ///
187   /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have
188   /// to branch into a special "reference counting block" from the ^entry that
189   /// will have a `drop_ref` operation, and then branch into the ^bb2.
190   ///
191   /// After transformation:
192   ///
193   ///   ^entry:
194   ///     %token = async.runtime.create : !async.token
195   ///     cf.cond_br %cond, ^bb1, ^reference_counting
196   ///   ^bb1:
197   ///     async.runtime.await %token
198   ///     async.runtime.drop_ref %token
199   ///     cf.br ^bb2
200   ///   ^reference_counting:
201   ///     async.runtime.drop_ref %token
202   ///     cf.br ^bb2
203   ///   ^bb2:
204   ///     return
205   ///
206   /// An exception to this rule are blocks with `async.coro.suspend` terminator,
207   /// because in Async to LLVM lowering it is guaranteed that the control flow
208   /// will jump into the resume block, and then follow into the cleanup and
209   /// suspend blocks.
210   ///
211   /// Example:
212   ///
213   ///  ^entry(%value: !async.value<f32>):
214   ///     async.runtime.await_and_resume %value, %hdl : !async.value<f32>
215   ///     async.coro.suspend %ret, ^suspend, ^resume, ^cleanup
216   ///   ^resume:
217   ///     %0 = async.runtime.load %value
218   ///     cf.br ^cleanup
219   ///   ^cleanup:
220   ///     ...
221   ///   ^suspend:
222   ///     ...
223   ///
224   /// Although cleanup and suspend blocks do not have the `value` in the
225   /// `liveIn` set, it is guaranteed that execution will eventually continue in
226   /// the resume block (we never explicitly destroy coroutines).
227   LogicalResult addDropRefInDivergentLivenessSuccessor(Value value);
228 };
229 
230 } // namespace
231 
addDropRefAfterLastUse(Value value)232 LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) {
233   OpBuilder builder(value.getContext());
234   Location loc = value.getLoc();
235 
236   // Use liveness analysis to find the placement of `drop_ref`operation.
237   auto &liveness = getAnalysis<Liveness>();
238 
239   // We analyse only the blocks of the region that defines the `value`, and do
240   // not check nested blocks attached to operations.
241   //
242   // By analyzing only the `definingRegion` CFG we potentially loose an
243   // opportunity to drop the reference count earlier and can extend the lifetime
244   // of reference counted value longer then it is really required.
245   //
246   // We also assume that all nested regions finish their execution before the
247   // completion of the owner operation. The only exception to this rule is
248   // `async.execute` operation, and we verify that they are lowered to the
249   // `async.runtime` operations before adding automatic reference counting.
250   Region *definingRegion = value.getParentRegion();
251 
252   // Last users of the `value` inside all blocks where the value dies.
253   llvm::SmallSet<Operation *, 4> lastUsers;
254 
255   // Find blocks in the `definingRegion` that have users of the `value` (if
256   // there are multiple users in the block, which one will be selected is
257   // undefined). User operation might be not the actual user of the value, but
258   // the operation in the block that has a "real user" in one of the attached
259   // regions.
260   llvm::DenseMap<Block *, Operation *> usersInTheBlocks;
261 
262   for (Operation *user : value.getUsers()) {
263     Block *userBlock = user->getBlock();
264     Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock);
265     usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user);
266     assert(ancestor && "ancestor block must be not null");
267     assert(usersInTheBlocks[ancestor] && "ancestor op must be not null");
268   }
269 
270   // Find blocks where the `value` dies: the value is in `liveIn` set and not
271   // in the `liveOut` set. We place `drop_ref` immediately after the last use
272   // of the `value` in such regions (after handling few special cases).
273   //
274   // We do not traverse all the blocks in the `definingRegion`, because the
275   // `value` can be in the live in set only if it has users in the block, or it
276   // is defined in the block.
277   //
278   // Values with zero users (only definition) handled explicitly above.
279   for (auto &blockAndUser : usersInTheBlocks) {
280     Block *block = blockAndUser.getFirst();
281     Operation *userInTheBlock = blockAndUser.getSecond();
282 
283     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block);
284 
285     // Value must be in the live input set or defined in the block.
286     assert(blockLiveness->isLiveIn(value) ||
287            blockLiveness->getBlock() == value.getParentBlock());
288 
289     // If value is in the live out set, it means it doesn't "die" in the block.
290     if (blockLiveness->isLiveOut(value))
291       continue;
292 
293     // At this point we proved that `value` dies in the `block`. Find the last
294     // use of the `value` inside the `block`, this is where it "dies".
295     Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock);
296     assert(lastUsers.count(lastUser) == 0 && "last users must be unique");
297     lastUsers.insert(lastUser);
298   }
299 
300   // Process all the last users of the `value` inside each block where the value
301   // dies.
302   for (Operation *lastUser : lastUsers) {
303     // Return like operations forward reference count.
304     if (lastUser->hasTrait<OpTrait::ReturnLike>())
305       continue;
306 
307     // We can't currently handle other types of terminators.
308     if (lastUser->hasTrait<OpTrait::IsTerminator>())
309       return lastUser->emitError() << "async reference counting can't handle "
310                                       "terminators that are not ReturnLike";
311 
312     // Add a drop_ref immediately after the last user.
313     builder.setInsertionPointAfter(lastUser);
314     builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1));
315   }
316 
317   return success();
318 }
319 
320 LogicalResult
addAddRefBeforeFunctionCall(Value value)321 AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) {
322   OpBuilder builder(value.getContext());
323   Location loc = value.getLoc();
324 
325   for (Operation *user : value.getUsers()) {
326     if (!isa<func::CallOp>(user))
327       continue;
328 
329     // Add a reference before the function call to pass the value at `+1`
330     // reference to the function entry block.
331     builder.setInsertionPoint(user);
332     builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1));
333   }
334 
335   return success();
336 }
337 
338 LogicalResult
addDropRefInDivergentLivenessSuccessor(Value value)339 AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor(
340     Value value) {
341   using BlockSet = llvm::SmallPtrSet<Block *, 4>;
342 
343   OpBuilder builder(value.getContext());
344 
345   // If a block has successors with different `liveIn` property of the `value`,
346   // record block successors that do not thave the `value` in the `liveIn` set.
347   llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks;
348 
349   // Use liveness analysis to find the placement of `drop_ref`operation.
350   auto &liveness = getAnalysis<Liveness>();
351 
352   // Because we only add `drop_ref` operations to the region that defines the
353   // `value` we can only process CFG for the same region.
354   Region *definingRegion = value.getParentRegion();
355 
356   // Collect blocks with successors with mismatching `liveIn` sets.
357   for (Block &block : definingRegion->getBlocks()) {
358     const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block);
359 
360     // Skip the block if value is not in the `liveOut` set.
361     if (!blockLiveness || !blockLiveness->isLiveOut(value))
362       continue;
363 
364     BlockSet liveInSuccessors;   // `value` is in `liveIn` set
365     BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set
366 
367     // Collect successors that do not have `value` in the `liveIn` set.
368     for (Block *successor : block.getSuccessors()) {
369       const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor);
370       if (succLiveness && succLiveness->isLiveIn(value))
371         liveInSuccessors.insert(successor);
372       else
373         noLiveInSuccessors.insert(successor);
374     }
375 
376     // Block has successors with different `liveIn` property of the `value`.
377     if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty())
378       divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors);
379   }
380 
381   // Try to insert `dropRef` operations to handle blocks with divergent liveness
382   // in successors blocks.
383   for (auto kv : divergentLivenessBlocks) {
384     Block *block = kv.getFirst();
385     BlockSet &successors = kv.getSecond();
386 
387     // Coroutine suspension is a special case terminator for wich we do not
388     // need to create additional reference counting (see details above).
389     Operation *terminator = block->getTerminator();
390     if (isa<CoroSuspendOp>(terminator))
391       continue;
392 
393     // We only support successor blocks with empty block argument list.
394     auto hasArgs = [](Block *block) { return !block->getArguments().empty(); };
395     if (llvm::any_of(successors, hasArgs))
396       return terminator->emitOpError()
397              << "successor have different `liveIn` property of the reference "
398                 "counted value";
399 
400     // Make sure that `dropRef` operation is called when branched into the
401     // successor block without `value` in the `liveIn` set.
402     for (Block *successor : successors) {
403       // If successor has a unique predecessor, it is safe to create `dropRef`
404       // operations directly in the successor block.
405       //
406       // Otherwise we need to create a special block for reference counting
407       // operations, and branch from it to the original successor block.
408       Block *refCountingBlock = nullptr;
409 
410       if (successor->getUniquePredecessor() == block) {
411         refCountingBlock = successor;
412       } else {
413         refCountingBlock = &successor->getParent()->emplaceBlock();
414         refCountingBlock->moveBefore(successor);
415         OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock);
416         builder.create<cf::BranchOp>(value.getLoc(), successor);
417       }
418 
419       OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock);
420       builder.create<RuntimeDropRefOp>(value.getLoc(), value,
421                                        builder.getI64IntegerAttr(1));
422 
423       // No need to update the terminator operation.
424       if (successor == refCountingBlock)
425         continue;
426 
427       // Update terminator `successor` block to `refCountingBlock`.
428       for (const auto &pair : llvm::enumerate(terminator->getSuccessors()))
429         if (pair.value() == successor)
430           terminator->setSuccessor(refCountingBlock, pair.index());
431     }
432   }
433 
434   return success();
435 }
436 
437 LogicalResult
addAutomaticRefCounting(Value value)438 AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) {
439   // Short-circuit reference counting for values without uses.
440   if (succeeded(dropRefIfNoUses(value)))
441     return success();
442 
443   // Add `drop_ref` operations based on the liveness analysis.
444   if (failed(addDropRefAfterLastUse(value)))
445     return failure();
446 
447   // Add `add_ref` operations before function calls.
448   if (failed(addAddRefBeforeFunctionCall(value)))
449     return failure();
450 
451   // Add `drop_ref` operations to successors with divergent `value` liveness.
452   if (failed(addDropRefInDivergentLivenessSuccessor(value)))
453     return failure();
454 
455   return success();
456 }
457 
runOnOperation()458 void AsyncRuntimeRefCountingPass::runOnOperation() {
459   auto functor = [&](Value value) { return addAutomaticRefCounting(value); };
460   if (failed(walkReferenceCountedValues(getOperation(), functor)))
461     signalPassFailure();
462 }
463 
464 //===----------------------------------------------------------------------===//
465 // Reference counting based on the user defined policy.
466 //===----------------------------------------------------------------------===//
467 
468 namespace {
469 
470 class AsyncRuntimePolicyBasedRefCountingPass
471     : public impl::AsyncRuntimePolicyBasedRefCountingBase<
472           AsyncRuntimePolicyBasedRefCountingPass> {
473 public:
AsyncRuntimePolicyBasedRefCountingPass()474   AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); }
475 
476   void runOnOperation() override;
477 
478 private:
479   // Adds a reference counting operations for all uses of the `value` according
480   // to the reference counting policy.
481   LogicalResult addRefCounting(Value value);
482 
483   void initializeDefaultPolicy();
484 
485   llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy;
486 };
487 
488 } // namespace
489 
490 LogicalResult
addRefCounting(Value value)491 AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) {
492   // Short-circuit reference counting for values without uses.
493   if (succeeded(dropRefIfNoUses(value)))
494     return success();
495 
496   OpBuilder b(value.getContext());
497 
498   // Consult the user defined policy for every value use.
499   for (OpOperand &operand : value.getUses()) {
500     Location loc = operand.getOwner()->getLoc();
501 
502     for (auto &func : policy) {
503       FailureOr<int> refCount = func(operand);
504       if (failed(refCount))
505         return failure();
506 
507       int cnt = *refCount;
508 
509       // Create `add_ref` operation before the operand owner.
510       if (cnt > 0) {
511         b.setInsertionPoint(operand.getOwner());
512         b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt));
513       }
514 
515       // Create `drop_ref` operation after the operand owner.
516       if (cnt < 0) {
517         b.setInsertionPointAfter(operand.getOwner());
518         b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt));
519       }
520     }
521   }
522 
523   return success();
524 }
525 
initializeDefaultPolicy()526 void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() {
527   policy.push_back([](OpOperand &operand) -> FailureOr<int> {
528     Operation *op = operand.getOwner();
529     Type type = operand.get().getType();
530 
531     bool isToken = isa<TokenType>(type);
532     bool isGroup = isa<GroupType>(type);
533     bool isValue = isa<ValueType>(type);
534 
535     // Drop reference after async token or group error check (coro await).
536     if (dyn_cast<RuntimeIsErrorOp>(op))
537       return (isToken || isGroup) ? -1 : 0;
538 
539     // Drop reference after async value load.
540     if (dyn_cast<RuntimeLoadOp>(op))
541       return isValue ? -1 : 0;
542 
543     // Drop reference after async token added to the group.
544     if (dyn_cast<RuntimeAddToGroupOp>(op))
545       return isToken ? -1 : 0;
546 
547     return 0;
548   });
549 }
550 
runOnOperation()551 void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() {
552   auto functor = [&](Value value) { return addRefCounting(value); };
553   if (failed(walkReferenceCountedValues(getOperation(), functor)))
554     signalPassFailure();
555 }
556 
557 //----------------------------------------------------------------------------//
558 
createAsyncRuntimeRefCountingPass()559 std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() {
560   return std::make_unique<AsyncRuntimeRefCountingPass>();
561 }
562 
createAsyncRuntimePolicyBasedRefCountingPass()563 std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() {
564   return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>();
565 }
566