1a6628e59SEugene Zhulenev //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// 2a6628e59SEugene Zhulenev // 3a6628e59SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a6628e59SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information. 5a6628e59SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a6628e59SEugene Zhulenev // 7a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 8a6628e59SEugene Zhulenev // 9a6628e59SEugene Zhulenev // This file implements automatic reference counting for Async runtime 10a6628e59SEugene Zhulenev // operations and types. 11a6628e59SEugene Zhulenev // 12a6628e59SEugene Zhulenev //===----------------------------------------------------------------------===// 13a6628e59SEugene Zhulenev 14*039b969bSMichele Scuttari #include "PassDetail.h" 15a6628e59SEugene Zhulenev #include "mlir/Analysis/Liveness.h" 16a6628e59SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h" 17*039b969bSMichele Scuttari #include "mlir/Dialect/Async/Passes.h" 18ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 1923aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 20a6628e59SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h" 21a6628e59SEugene Zhulenev #include "mlir/IR/PatternMatch.h" 22a6628e59SEugene Zhulenev #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23a6628e59SEugene Zhulenev #include "llvm/ADT/SmallSet.h" 24a6628e59SEugene Zhulenev 252be8af8fSMichele Scuttari using namespace mlir; 262be8af8fSMichele Scuttari using namespace mlir::async; 272be8af8fSMichele Scuttari 28*039b969bSMichele Scuttari #define DEBUG_TYPE "async-runtime-ref-counting" 29*039b969bSMichele Scuttari 30f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 31f57b2420SEugene Zhulenev // Utility functions shared by reference counting passes. 32f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 33f57b2420SEugene Zhulenev 34f57b2420SEugene Zhulenev // Drop the reference count immediately if the value has no uses. 35f57b2420SEugene Zhulenev static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { 36f57b2420SEugene Zhulenev if (!value.getUses().empty()) 37f57b2420SEugene Zhulenev return failure(); 38f57b2420SEugene Zhulenev 39f57b2420SEugene Zhulenev OpBuilder b(value.getContext()); 40f57b2420SEugene Zhulenev 41f57b2420SEugene Zhulenev // Set insertion point after the operation producing a value, or at the 42f57b2420SEugene Zhulenev // beginning of the block if the value defined by the block argument. 43f57b2420SEugene Zhulenev if (Operation *op = value.getDefiningOp()) 44f57b2420SEugene Zhulenev b.setInsertionPointAfter(op); 45f57b2420SEugene Zhulenev else 46f57b2420SEugene Zhulenev b.setInsertionPointToStart(value.getParentBlock()); 47f57b2420SEugene Zhulenev 4892db09cdSEugene Zhulenev b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); 49f57b2420SEugene Zhulenev return success(); 50f57b2420SEugene Zhulenev } 51f57b2420SEugene Zhulenev 52f57b2420SEugene Zhulenev // Calls `addRefCounting` for every reference counted value defined by the 53f57b2420SEugene Zhulenev // operation `op` (block arguments and values defined in nested regions). 54f57b2420SEugene Zhulenev static LogicalResult walkReferenceCountedValues( 55f57b2420SEugene Zhulenev Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { 56f57b2420SEugene Zhulenev // Check that we do not have high level async operations in the IR because 57f57b2420SEugene Zhulenev // otherwise reference counting will produce incorrect results after high 58f57b2420SEugene Zhulenev // level async operations will be lowered to `async.runtime` 59f57b2420SEugene Zhulenev WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { 60f57b2420SEugene Zhulenev if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) 61f57b2420SEugene Zhulenev return WalkResult::advance(); 62f57b2420SEugene Zhulenev 63f57b2420SEugene Zhulenev return op->emitError() 64f57b2420SEugene Zhulenev << "async operations must be lowered to async runtime operations"; 65f57b2420SEugene Zhulenev }); 66f57b2420SEugene Zhulenev 67f57b2420SEugene Zhulenev if (checkNoAsyncWalk.wasInterrupted()) 68f57b2420SEugene Zhulenev return failure(); 69f57b2420SEugene Zhulenev 70f57b2420SEugene Zhulenev // Add reference counting to block arguments. 71f57b2420SEugene Zhulenev WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { 72f57b2420SEugene Zhulenev for (BlockArgument arg : block->getArguments()) 73f57b2420SEugene Zhulenev if (isRefCounted(arg.getType())) 74f57b2420SEugene Zhulenev if (failed(addRefCounting(arg))) 75f57b2420SEugene Zhulenev return WalkResult::interrupt(); 76f57b2420SEugene Zhulenev 77f57b2420SEugene Zhulenev return WalkResult::advance(); 78f57b2420SEugene Zhulenev }); 79f57b2420SEugene Zhulenev 80f57b2420SEugene Zhulenev if (blockWalk.wasInterrupted()) 81f57b2420SEugene Zhulenev return failure(); 82f57b2420SEugene Zhulenev 83f57b2420SEugene Zhulenev // Add reference counting to operation results. 84f57b2420SEugene Zhulenev WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { 85f57b2420SEugene Zhulenev for (unsigned i = 0; i < op->getNumResults(); ++i) 86f57b2420SEugene Zhulenev if (isRefCounted(op->getResultTypes()[i])) 87f57b2420SEugene Zhulenev if (failed(addRefCounting(op->getResult(i)))) 88f57b2420SEugene Zhulenev return WalkResult::interrupt(); 89f57b2420SEugene Zhulenev 90f57b2420SEugene Zhulenev return WalkResult::advance(); 91f57b2420SEugene Zhulenev }); 92f57b2420SEugene Zhulenev 93f57b2420SEugene Zhulenev if (opWalk.wasInterrupted()) 94f57b2420SEugene Zhulenev return failure(); 95f57b2420SEugene Zhulenev 96f57b2420SEugene Zhulenev return success(); 97f57b2420SEugene Zhulenev } 98f57b2420SEugene Zhulenev 99f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 100f57b2420SEugene Zhulenev // Automatic reference counting based on the liveness analysis. 101f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 102f57b2420SEugene Zhulenev 103a6628e59SEugene Zhulenev namespace { 104a6628e59SEugene Zhulenev 105a6628e59SEugene Zhulenev class AsyncRuntimeRefCountingPass 106*039b969bSMichele Scuttari : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { 107a6628e59SEugene Zhulenev public: 108a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass() = default; 1098a316b00SEugene Zhulenev void runOnOperation() override; 110a6628e59SEugene Zhulenev 111a6628e59SEugene Zhulenev private: 112a6628e59SEugene Zhulenev /// Adds an automatic reference counting to the `value`. 113a6628e59SEugene Zhulenev /// 114a6628e59SEugene Zhulenev /// All values (token, group or value) are semantically created with a 115a6628e59SEugene Zhulenev /// reference count of +1 and it is the responsibility of the async value user 116a6628e59SEugene Zhulenev /// to place the `add_ref` and `drop_ref` operations to ensure that the value 117a6628e59SEugene Zhulenev /// is destroyed after the last use. 118a6628e59SEugene Zhulenev /// 119a6628e59SEugene Zhulenev /// The function returns failure if it can't deduce the locations where 120a6628e59SEugene Zhulenev /// to place the reference counting operations. 121a6628e59SEugene Zhulenev /// 122a6628e59SEugene Zhulenev /// Async values "semantically created" when: 123a6628e59SEugene Zhulenev /// 1. Operation returns async result (e.g. `async.runtime.create`) 124a6628e59SEugene Zhulenev /// 2. Async value passed in as a block argument (or function argument, 125a6628e59SEugene Zhulenev /// because function arguments are just entry block arguments) 126a6628e59SEugene Zhulenev /// 127a6628e59SEugene Zhulenev /// Passing async value as a function argument (or block argument) does not 128a6628e59SEugene Zhulenev /// really mean that a new async value is created, it only means that the 129a6628e59SEugene Zhulenev /// caller of a function transfered ownership of `+1` reference to the callee. 130a6628e59SEugene Zhulenev /// It is convenient to think that from the callee perspective async value was 131a6628e59SEugene Zhulenev /// "created" with `+1` reference by the block argument. 132a6628e59SEugene Zhulenev /// 133a6628e59SEugene Zhulenev /// Automatic reference counting algorithm outline: 134a6628e59SEugene Zhulenev /// 135a6628e59SEugene Zhulenev /// #1 Insert `drop_ref` operations after last use of the `value`. 136a6628e59SEugene Zhulenev /// #2 Insert `add_ref` operations before functions calls with reference 137a6628e59SEugene Zhulenev /// counted `value` operand (newly created `+1` reference will be 138a6628e59SEugene Zhulenev /// transferred to the callee). 139a6628e59SEugene Zhulenev /// #3 Verify that divergent control flow does not lead to leaked reference 140a6628e59SEugene Zhulenev /// counted objects. 141a6628e59SEugene Zhulenev /// 142a6628e59SEugene Zhulenev /// Async runtime reference counting optimization pass will optimize away 143a6628e59SEugene Zhulenev /// some of the redundant `add_ref` and `drop_ref` operations inserted by this 144a6628e59SEugene Zhulenev /// strategy (see `async-runtime-ref-counting-opt`). 145a6628e59SEugene Zhulenev LogicalResult addAutomaticRefCounting(Value value); 146a6628e59SEugene Zhulenev 147a6628e59SEugene Zhulenev /// (#1) Adds the `drop_ref` operation after the last use of the `value` 148a6628e59SEugene Zhulenev /// relying on the liveness analysis. 149a6628e59SEugene Zhulenev /// 150a6628e59SEugene Zhulenev /// If the `value` is in the block `liveIn` set and it is not in the block 151a6628e59SEugene Zhulenev /// `liveOut` set, it means that it "dies" in the block. We find the last 152a6628e59SEugene Zhulenev /// use of the value in such block and: 153a6628e59SEugene Zhulenev /// 154a6628e59SEugene Zhulenev /// 1. If the last user is a `ReturnLike` operation we do nothing, because 155a6628e59SEugene Zhulenev /// it forwards the ownership to the caller. 156a6628e59SEugene Zhulenev /// 2. Otherwise we add a `drop_ref` operation immediately after the last 157a6628e59SEugene Zhulenev /// use. 158a6628e59SEugene Zhulenev LogicalResult addDropRefAfterLastUse(Value value); 159a6628e59SEugene Zhulenev 160a6628e59SEugene Zhulenev /// (#2) Adds the `add_ref` operation before the function call taking `value` 161a6628e59SEugene Zhulenev /// operand to ensure that the value passed to the function entry block 162a6628e59SEugene Zhulenev /// has a `+1` reference count. 163a6628e59SEugene Zhulenev LogicalResult addAddRefBeforeFunctionCall(Value value); 164a6628e59SEugene Zhulenev 165c412979cSEugene Zhulenev /// (#3) Adds the `drop_ref` operation to account for successor blocks with 166c412979cSEugene Zhulenev /// divergent `liveIn` property: `value` is not in the `liveIn` set of all 167c412979cSEugene Zhulenev /// successor blocks. 168a6628e59SEugene Zhulenev /// 169a6628e59SEugene Zhulenev /// Example: 170a6628e59SEugene Zhulenev /// 171a6628e59SEugene Zhulenev /// ^entry: 172a6628e59SEugene Zhulenev /// %token = async.runtime.create : !async.token 173ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1, ^bb2 174a6628e59SEugene Zhulenev /// ^bb1: 175a6628e59SEugene Zhulenev /// async.runtime.await %token 176c412979cSEugene Zhulenev /// async.runtime.drop_ref %token 177ace01605SRiver Riddle /// cf.br ^bb2 178a6628e59SEugene Zhulenev /// ^bb2: 179a6628e59SEugene Zhulenev /// return 180a6628e59SEugene Zhulenev /// 181c412979cSEugene Zhulenev /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have 182c412979cSEugene Zhulenev /// to branch into a special "reference counting block" from the ^entry that 183c412979cSEugene Zhulenev /// will have a `drop_ref` operation, and then branch into the ^bb2. 184c412979cSEugene Zhulenev /// 185c412979cSEugene Zhulenev /// After transformation: 186c412979cSEugene Zhulenev /// 187c412979cSEugene Zhulenev /// ^entry: 188c412979cSEugene Zhulenev /// %token = async.runtime.create : !async.token 189ace01605SRiver Riddle /// cf.cond_br %cond, ^bb1, ^reference_counting 190c412979cSEugene Zhulenev /// ^bb1: 191c412979cSEugene Zhulenev /// async.runtime.await %token 192c412979cSEugene Zhulenev /// async.runtime.drop_ref %token 193ace01605SRiver Riddle /// cf.br ^bb2 194c412979cSEugene Zhulenev /// ^reference_counting: 195c412979cSEugene Zhulenev /// async.runtime.drop_ref %token 196ace01605SRiver Riddle /// cf.br ^bb2 197c412979cSEugene Zhulenev /// ^bb2: 198c412979cSEugene Zhulenev /// return 199a6628e59SEugene Zhulenev /// 200a6628e59SEugene Zhulenev /// An exception to this rule are blocks with `async.coro.suspend` terminator, 201a6628e59SEugene Zhulenev /// because in Async to LLVM lowering it is guaranteed that the control flow 202a6628e59SEugene Zhulenev /// will jump into the resume block, and then follow into the cleanup and 203a6628e59SEugene Zhulenev /// suspend blocks. 204a6628e59SEugene Zhulenev /// 205a6628e59SEugene Zhulenev /// Example: 206a6628e59SEugene Zhulenev /// 207a6628e59SEugene Zhulenev /// ^entry(%value: !async.value<f32>): 208a6628e59SEugene Zhulenev /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> 209a6628e59SEugene Zhulenev /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup 210a6628e59SEugene Zhulenev /// ^resume: 211a6628e59SEugene Zhulenev /// %0 = async.runtime.load %value 212ace01605SRiver Riddle /// cf.br ^cleanup 213a6628e59SEugene Zhulenev /// ^cleanup: 214a6628e59SEugene Zhulenev /// ... 215a6628e59SEugene Zhulenev /// ^suspend: 216a6628e59SEugene Zhulenev /// ... 217a6628e59SEugene Zhulenev /// 218a6628e59SEugene Zhulenev /// Although cleanup and suspend blocks do not have the `value` in the 219a6628e59SEugene Zhulenev /// `liveIn` set, it is guaranteed that execution will eventually continue in 220a6628e59SEugene Zhulenev /// the resume block (we never explicitly destroy coroutines). 221c412979cSEugene Zhulenev LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); 222a6628e59SEugene Zhulenev }; 223a6628e59SEugene Zhulenev 224a6628e59SEugene Zhulenev } // namespace 225a6628e59SEugene Zhulenev 226a6628e59SEugene Zhulenev LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { 227a6628e59SEugene Zhulenev OpBuilder builder(value.getContext()); 228a6628e59SEugene Zhulenev Location loc = value.getLoc(); 229a6628e59SEugene Zhulenev 230a6628e59SEugene Zhulenev // Use liveness analysis to find the placement of `drop_ref`operation. 231a6628e59SEugene Zhulenev auto &liveness = getAnalysis<Liveness>(); 232a6628e59SEugene Zhulenev 233a6628e59SEugene Zhulenev // We analyse only the blocks of the region that defines the `value`, and do 234a6628e59SEugene Zhulenev // not check nested blocks attached to operations. 235a6628e59SEugene Zhulenev // 236a6628e59SEugene Zhulenev // By analyzing only the `definingRegion` CFG we potentially loose an 237a6628e59SEugene Zhulenev // opportunity to drop the reference count earlier and can extend the lifetime 238a6628e59SEugene Zhulenev // of reference counted value longer then it is really required. 239a6628e59SEugene Zhulenev // 240a6628e59SEugene Zhulenev // We also assume that all nested regions finish their execution before the 241a6628e59SEugene Zhulenev // completion of the owner operation. The only exception to this rule is 242a6628e59SEugene Zhulenev // `async.execute` operation, and we verify that they are lowered to the 243a6628e59SEugene Zhulenev // `async.runtime` operations before adding automatic reference counting. 244a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 245a6628e59SEugene Zhulenev 246a6628e59SEugene Zhulenev // Last users of the `value` inside all blocks where the value dies. 247a6628e59SEugene Zhulenev llvm::SmallSet<Operation *, 4> lastUsers; 248a6628e59SEugene Zhulenev 249a6628e59SEugene Zhulenev // Find blocks in the `definingRegion` that have users of the `value` (if 250a6628e59SEugene Zhulenev // there are multiple users in the block, which one will be selected is 251a6628e59SEugene Zhulenev // undefined). User operation might be not the actual user of the value, but 252a6628e59SEugene Zhulenev // the operation in the block that has a "real user" in one of the attached 253a6628e59SEugene Zhulenev // regions. 254a6628e59SEugene Zhulenev llvm::DenseMap<Block *, Operation *> usersInTheBlocks; 255a6628e59SEugene Zhulenev 256a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 257a6628e59SEugene Zhulenev Block *userBlock = user->getBlock(); 258a6628e59SEugene Zhulenev Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); 259a6628e59SEugene Zhulenev usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); 260a6628e59SEugene Zhulenev assert(ancestor && "ancestor block must be not null"); 261a6628e59SEugene Zhulenev assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); 262a6628e59SEugene Zhulenev } 263a6628e59SEugene Zhulenev 264a6628e59SEugene Zhulenev // Find blocks where the `value` dies: the value is in `liveIn` set and not 265a6628e59SEugene Zhulenev // in the `liveOut` set. We place `drop_ref` immediately after the last use 266a6628e59SEugene Zhulenev // of the `value` in such regions (after handling few special cases). 267a6628e59SEugene Zhulenev // 268a6628e59SEugene Zhulenev // We do not traverse all the blocks in the `definingRegion`, because the 269a6628e59SEugene Zhulenev // `value` can be in the live in set only if it has users in the block, or it 270a6628e59SEugene Zhulenev // is defined in the block. 271a6628e59SEugene Zhulenev // 272a6628e59SEugene Zhulenev // Values with zero users (only definition) handled explicitly above. 273a6628e59SEugene Zhulenev for (auto &blockAndUser : usersInTheBlocks) { 274a6628e59SEugene Zhulenev Block *block = blockAndUser.getFirst(); 275a6628e59SEugene Zhulenev Operation *userInTheBlock = blockAndUser.getSecond(); 276a6628e59SEugene Zhulenev 277a6628e59SEugene Zhulenev const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); 278a6628e59SEugene Zhulenev 279a6628e59SEugene Zhulenev // Value must be in the live input set or defined in the block. 280a6628e59SEugene Zhulenev assert(blockLiveness->isLiveIn(value) || 281a6628e59SEugene Zhulenev blockLiveness->getBlock() == value.getParentBlock()); 282a6628e59SEugene Zhulenev 283a6628e59SEugene Zhulenev // If value is in the live out set, it means it doesn't "die" in the block. 284a6628e59SEugene Zhulenev if (blockLiveness->isLiveOut(value)) 285a6628e59SEugene Zhulenev continue; 286a6628e59SEugene Zhulenev 287a6628e59SEugene Zhulenev // At this point we proved that `value` dies in the `block`. Find the last 288a6628e59SEugene Zhulenev // use of the `value` inside the `block`, this is where it "dies". 289a6628e59SEugene Zhulenev Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); 290a6628e59SEugene Zhulenev assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); 291a6628e59SEugene Zhulenev lastUsers.insert(lastUser); 292a6628e59SEugene Zhulenev } 293a6628e59SEugene Zhulenev 294a6628e59SEugene Zhulenev // Process all the last users of the `value` inside each block where the value 295a6628e59SEugene Zhulenev // dies. 296a6628e59SEugene Zhulenev for (Operation *lastUser : lastUsers) { 297a6628e59SEugene Zhulenev // Return like operations forward reference count. 298a6628e59SEugene Zhulenev if (lastUser->hasTrait<OpTrait::ReturnLike>()) 299a6628e59SEugene Zhulenev continue; 300a6628e59SEugene Zhulenev 301a6628e59SEugene Zhulenev // We can't currently handle other types of terminators. 302a6628e59SEugene Zhulenev if (lastUser->hasTrait<OpTrait::IsTerminator>()) 303a6628e59SEugene Zhulenev return lastUser->emitError() << "async reference counting can't handle " 304a6628e59SEugene Zhulenev "terminators that are not ReturnLike"; 305a6628e59SEugene Zhulenev 306a6628e59SEugene Zhulenev // Add a drop_ref immediately after the last user. 307a6628e59SEugene Zhulenev builder.setInsertionPointAfter(lastUser); 30892db09cdSEugene Zhulenev builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); 309a6628e59SEugene Zhulenev } 310a6628e59SEugene Zhulenev 311a6628e59SEugene Zhulenev return success(); 312a6628e59SEugene Zhulenev } 313a6628e59SEugene Zhulenev 314a6628e59SEugene Zhulenev LogicalResult 315a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { 316a6628e59SEugene Zhulenev OpBuilder builder(value.getContext()); 317a6628e59SEugene Zhulenev Location loc = value.getLoc(); 318a6628e59SEugene Zhulenev 319a6628e59SEugene Zhulenev for (Operation *user : value.getUsers()) { 32023aa5a74SRiver Riddle if (!isa<func::CallOp>(user)) 321a6628e59SEugene Zhulenev continue; 322a6628e59SEugene Zhulenev 323a6628e59SEugene Zhulenev // Add a reference before the function call to pass the value at `+1` 324a6628e59SEugene Zhulenev // reference to the function entry block. 325a6628e59SEugene Zhulenev builder.setInsertionPoint(user); 32692db09cdSEugene Zhulenev builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); 327a6628e59SEugene Zhulenev } 328a6628e59SEugene Zhulenev 329a6628e59SEugene Zhulenev return success(); 330a6628e59SEugene Zhulenev } 331a6628e59SEugene Zhulenev 332c412979cSEugene Zhulenev LogicalResult 333c412979cSEugene Zhulenev AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( 334c412979cSEugene Zhulenev Value value) { 335c412979cSEugene Zhulenev using BlockSet = llvm::SmallPtrSet<Block *, 4>; 336c412979cSEugene Zhulenev 337a6628e59SEugene Zhulenev OpBuilder builder(value.getContext()); 338a6628e59SEugene Zhulenev 339c412979cSEugene Zhulenev // If a block has successors with different `liveIn` property of the `value`, 340c412979cSEugene Zhulenev // record block successors that do not thave the `value` in the `liveIn` set. 341c412979cSEugene Zhulenev llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; 342a6628e59SEugene Zhulenev 343a6628e59SEugene Zhulenev // Use liveness analysis to find the placement of `drop_ref`operation. 344a6628e59SEugene Zhulenev auto &liveness = getAnalysis<Liveness>(); 345a6628e59SEugene Zhulenev 346a6628e59SEugene Zhulenev // Because we only add `drop_ref` operations to the region that defines the 347a6628e59SEugene Zhulenev // `value` we can only process CFG for the same region. 348a6628e59SEugene Zhulenev Region *definingRegion = value.getParentRegion(); 349a6628e59SEugene Zhulenev 350a6628e59SEugene Zhulenev // Collect blocks with successors with mismatching `liveIn` sets. 351a6628e59SEugene Zhulenev for (Block &block : definingRegion->getBlocks()) { 352a6628e59SEugene Zhulenev const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); 353a6628e59SEugene Zhulenev 354a6628e59SEugene Zhulenev // Skip the block if value is not in the `liveOut` set. 3559136b7d0SEugene Zhulenev if (!blockLiveness || !blockLiveness->isLiveOut(value)) 356a6628e59SEugene Zhulenev continue; 357a6628e59SEugene Zhulenev 358c412979cSEugene Zhulenev BlockSet liveInSuccessors; // `value` is in `liveIn` set 359c412979cSEugene Zhulenev BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set 360a6628e59SEugene Zhulenev 361a6628e59SEugene Zhulenev // Collect successors that do not have `value` in the `liveIn` set. 362a6628e59SEugene Zhulenev for (Block *successor : block.getSuccessors()) { 363a6628e59SEugene Zhulenev const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); 3649136b7d0SEugene Zhulenev if (succLiveness && succLiveness->isLiveIn(value)) 365a6628e59SEugene Zhulenev liveInSuccessors.insert(successor); 366a6628e59SEugene Zhulenev else 367a6628e59SEugene Zhulenev noLiveInSuccessors.insert(successor); 368a6628e59SEugene Zhulenev } 369a6628e59SEugene Zhulenev 370a6628e59SEugene Zhulenev // Block has successors with different `liveIn` property of the `value`. 371a6628e59SEugene Zhulenev if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) 372c412979cSEugene Zhulenev divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); 373a6628e59SEugene Zhulenev } 374a6628e59SEugene Zhulenev 375c412979cSEugene Zhulenev // Try to insert `dropRef` operations to handle blocks with divergent liveness 376c412979cSEugene Zhulenev // in successors blocks. 377c412979cSEugene Zhulenev for (auto kv : divergentLivenessBlocks) { 378c412979cSEugene Zhulenev Block *block = kv.getFirst(); 379c412979cSEugene Zhulenev BlockSet &successors = kv.getSecond(); 380c412979cSEugene Zhulenev 381c412979cSEugene Zhulenev // Coroutine suspension is a special case terminator for wich we do not 382c412979cSEugene Zhulenev // need to create additional reference counting (see details above). 383a6628e59SEugene Zhulenev Operation *terminator = block->getTerminator(); 384a6628e59SEugene Zhulenev if (isa<CoroSuspendOp>(terminator)) 385a6628e59SEugene Zhulenev continue; 386a6628e59SEugene Zhulenev 387c412979cSEugene Zhulenev // We only support successor blocks with empty block argument list. 388c412979cSEugene Zhulenev auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; 389c412979cSEugene Zhulenev if (llvm::any_of(successors, hasArgs)) 390c412979cSEugene Zhulenev return terminator->emitOpError() 391c412979cSEugene Zhulenev << "successor have different `liveIn` property of the reference " 392c412979cSEugene Zhulenev "counted value"; 393c412979cSEugene Zhulenev 394c412979cSEugene Zhulenev // Make sure that `dropRef` operation is called when branched into the 395c412979cSEugene Zhulenev // successor block without `value` in the `liveIn` set. 396c412979cSEugene Zhulenev for (Block *successor : successors) { 397c412979cSEugene Zhulenev // If successor has a unique predecessor, it is safe to create `dropRef` 398c412979cSEugene Zhulenev // operations directly in the successor block. 399c412979cSEugene Zhulenev // 400c412979cSEugene Zhulenev // Otherwise we need to create a special block for reference counting 401c412979cSEugene Zhulenev // operations, and branch from it to the original successor block. 402c412979cSEugene Zhulenev Block *refCountingBlock = nullptr; 403c412979cSEugene Zhulenev 404c412979cSEugene Zhulenev if (successor->getUniquePredecessor() == block) { 405c412979cSEugene Zhulenev refCountingBlock = successor; 406c412979cSEugene Zhulenev } else { 407c412979cSEugene Zhulenev refCountingBlock = &successor->getParent()->emplaceBlock(); 408c412979cSEugene Zhulenev refCountingBlock->moveBefore(successor); 409c412979cSEugene Zhulenev OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); 410ace01605SRiver Riddle builder.create<cf::BranchOp>(value.getLoc(), successor); 411c412979cSEugene Zhulenev } 412c412979cSEugene Zhulenev 413c412979cSEugene Zhulenev OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); 414c412979cSEugene Zhulenev builder.create<RuntimeDropRefOp>(value.getLoc(), value, 41592db09cdSEugene Zhulenev builder.getI64IntegerAttr(1)); 416c412979cSEugene Zhulenev 417c412979cSEugene Zhulenev // No need to update the terminator operation. 418c412979cSEugene Zhulenev if (successor == refCountingBlock) 419c412979cSEugene Zhulenev continue; 420c412979cSEugene Zhulenev 421c412979cSEugene Zhulenev // Update terminator `successor` block to `refCountingBlock`. 42289de9cc8SMehdi Amini for (const auto &pair : llvm::enumerate(terminator->getSuccessors())) 423c412979cSEugene Zhulenev if (pair.value() == successor) 424c412979cSEugene Zhulenev terminator->setSuccessor(refCountingBlock, pair.index()); 425c412979cSEugene Zhulenev } 426a6628e59SEugene Zhulenev } 427a6628e59SEugene Zhulenev 428a6628e59SEugene Zhulenev return success(); 429a6628e59SEugene Zhulenev } 430a6628e59SEugene Zhulenev 431a6628e59SEugene Zhulenev LogicalResult 432a6628e59SEugene Zhulenev AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { 433f57b2420SEugene Zhulenev // Short-circuit reference counting for values without uses. 434f57b2420SEugene Zhulenev if (succeeded(dropRefIfNoUses(value))) 435a6628e59SEugene Zhulenev return success(); 436a6628e59SEugene Zhulenev 437a6628e59SEugene Zhulenev // Add `drop_ref` operations based on the liveness analysis. 438a6628e59SEugene Zhulenev if (failed(addDropRefAfterLastUse(value))) 439a6628e59SEugene Zhulenev return failure(); 440a6628e59SEugene Zhulenev 441a6628e59SEugene Zhulenev // Add `add_ref` operations before function calls. 442a6628e59SEugene Zhulenev if (failed(addAddRefBeforeFunctionCall(value))) 443a6628e59SEugene Zhulenev return failure(); 444a6628e59SEugene Zhulenev 445c412979cSEugene Zhulenev // Add `drop_ref` operations to successors with divergent `value` liveness. 446c412979cSEugene Zhulenev if (failed(addDropRefInDivergentLivenessSuccessor(value))) 447a6628e59SEugene Zhulenev return failure(); 448a6628e59SEugene Zhulenev 449a6628e59SEugene Zhulenev return success(); 450a6628e59SEugene Zhulenev } 451a6628e59SEugene Zhulenev 4528a316b00SEugene Zhulenev void AsyncRuntimeRefCountingPass::runOnOperation() { 453f57b2420SEugene Zhulenev auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; 454f57b2420SEugene Zhulenev if (failed(walkReferenceCountedValues(getOperation(), functor))) 455a6628e59SEugene Zhulenev signalPassFailure(); 456a6628e59SEugene Zhulenev } 457a6628e59SEugene Zhulenev 458f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 459f57b2420SEugene Zhulenev // Reference counting based on the user defined policy. 460f57b2420SEugene Zhulenev //===----------------------------------------------------------------------===// 461f57b2420SEugene Zhulenev 462f57b2420SEugene Zhulenev namespace { 463f57b2420SEugene Zhulenev 464f57b2420SEugene Zhulenev class AsyncRuntimePolicyBasedRefCountingPass 465*039b969bSMichele Scuttari : public AsyncRuntimePolicyBasedRefCountingBase< 466f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass> { 467f57b2420SEugene Zhulenev public: 468f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } 469f57b2420SEugene Zhulenev 470f57b2420SEugene Zhulenev void runOnOperation() override; 471f57b2420SEugene Zhulenev 472f57b2420SEugene Zhulenev private: 473f57b2420SEugene Zhulenev // Adds a reference counting operations for all uses of the `value` according 474f57b2420SEugene Zhulenev // to the reference counting policy. 475f57b2420SEugene Zhulenev LogicalResult addRefCounting(Value value); 476f57b2420SEugene Zhulenev 477f57b2420SEugene Zhulenev void initializeDefaultPolicy(); 478f57b2420SEugene Zhulenev 479f57b2420SEugene Zhulenev llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; 480f57b2420SEugene Zhulenev }; 481f57b2420SEugene Zhulenev 482f57b2420SEugene Zhulenev } // namespace 483f57b2420SEugene Zhulenev 484f57b2420SEugene Zhulenev LogicalResult 485f57b2420SEugene Zhulenev AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { 486f57b2420SEugene Zhulenev // Short-circuit reference counting for values without uses. 487f57b2420SEugene Zhulenev if (succeeded(dropRefIfNoUses(value))) 488f57b2420SEugene Zhulenev return success(); 489f57b2420SEugene Zhulenev 490f57b2420SEugene Zhulenev OpBuilder b(value.getContext()); 491f57b2420SEugene Zhulenev 492f57b2420SEugene Zhulenev // Consult the user defined policy for every value use. 493f57b2420SEugene Zhulenev for (OpOperand &operand : value.getUses()) { 494f57b2420SEugene Zhulenev Location loc = operand.getOwner()->getLoc(); 495f57b2420SEugene Zhulenev 496f57b2420SEugene Zhulenev for (auto &func : policy) { 497f57b2420SEugene Zhulenev FailureOr<int> refCount = func(operand); 498f57b2420SEugene Zhulenev if (failed(refCount)) 499f57b2420SEugene Zhulenev return failure(); 500f57b2420SEugene Zhulenev 5016d5fc1e3SKazu Hirata int cnt = *refCount; 502f57b2420SEugene Zhulenev 503f57b2420SEugene Zhulenev // Create `add_ref` operation before the operand owner. 504f57b2420SEugene Zhulenev if (cnt > 0) { 505f57b2420SEugene Zhulenev b.setInsertionPoint(operand.getOwner()); 50692db09cdSEugene Zhulenev b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); 507f57b2420SEugene Zhulenev } 508f57b2420SEugene Zhulenev 509f57b2420SEugene Zhulenev // Create `drop_ref` operation after the operand owner. 510f57b2420SEugene Zhulenev if (cnt < 0) { 511f57b2420SEugene Zhulenev b.setInsertionPointAfter(operand.getOwner()); 51292db09cdSEugene Zhulenev b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); 513f57b2420SEugene Zhulenev } 514f57b2420SEugene Zhulenev } 515f57b2420SEugene Zhulenev } 516f57b2420SEugene Zhulenev 517f57b2420SEugene Zhulenev return success(); 518f57b2420SEugene Zhulenev } 519f57b2420SEugene Zhulenev 520f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { 521f57b2420SEugene Zhulenev policy.push_back([](OpOperand &operand) -> FailureOr<int> { 522f57b2420SEugene Zhulenev Operation *op = operand.getOwner(); 523f57b2420SEugene Zhulenev Type type = operand.get().getType(); 524f57b2420SEugene Zhulenev 525f57b2420SEugene Zhulenev bool isToken = type.isa<TokenType>(); 526f57b2420SEugene Zhulenev bool isGroup = type.isa<GroupType>(); 527f57b2420SEugene Zhulenev bool isValue = type.isa<ValueType>(); 528f57b2420SEugene Zhulenev 529f57b2420SEugene Zhulenev // Drop reference after async token or group error check (coro await). 530f57b2420SEugene Zhulenev if (auto await = dyn_cast<RuntimeIsErrorOp>(op)) 531f57b2420SEugene Zhulenev return (isToken || isGroup) ? -1 : 0; 532f57b2420SEugene Zhulenev 533f57b2420SEugene Zhulenev // Drop reference after async value load. 534f57b2420SEugene Zhulenev if (auto load = dyn_cast<RuntimeLoadOp>(op)) 535f57b2420SEugene Zhulenev return isValue ? -1 : 0; 536f57b2420SEugene Zhulenev 537f57b2420SEugene Zhulenev // Drop reference after async token added to the group. 538f57b2420SEugene Zhulenev if (auto add = dyn_cast<RuntimeAddToGroupOp>(op)) 539f57b2420SEugene Zhulenev return isToken ? -1 : 0; 540f57b2420SEugene Zhulenev 541f57b2420SEugene Zhulenev return 0; 542f57b2420SEugene Zhulenev }); 543f57b2420SEugene Zhulenev } 544f57b2420SEugene Zhulenev 545f57b2420SEugene Zhulenev void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { 546f57b2420SEugene Zhulenev auto functor = [&](Value value) { return addRefCounting(value); }; 547f57b2420SEugene Zhulenev if (failed(walkReferenceCountedValues(getOperation(), functor))) 548f57b2420SEugene Zhulenev signalPassFailure(); 549f57b2420SEugene Zhulenev } 550f57b2420SEugene Zhulenev 551f57b2420SEugene Zhulenev //----------------------------------------------------------------------------// 552f57b2420SEugene Zhulenev 5538a316b00SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { 554a6628e59SEugene Zhulenev return std::make_unique<AsyncRuntimeRefCountingPass>(); 555a6628e59SEugene Zhulenev } 556f57b2420SEugene Zhulenev 557f57b2420SEugene Zhulenev std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() { 558f57b2420SEugene Zhulenev return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>(); 559f57b2420SEugene Zhulenev } 560