//===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements automatic reference counting for Async runtime // operations and types. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Async/Passes.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallSet.h" namespace mlir { #define GEN_PASS_DEF_ASYNCRUNTIMEREFCOUNTING #define GEN_PASS_DEF_ASYNCRUNTIMEPOLICYBASEDREFCOUNTING #include "mlir/Dialect/Async/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "async-runtime-ref-counting" using namespace mlir; using namespace mlir::async; //===----------------------------------------------------------------------===// // Utility functions shared by reference counting passes. //===----------------------------------------------------------------------===// // Drop the reference count immediately if the value has no uses. static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { if (!value.getUses().empty()) return failure(); OpBuilder b(value.getContext()); // Set insertion point after the operation producing a value, or at the // beginning of the block if the value defined by the block argument. if (Operation *op = value.getDefiningOp()) b.setInsertionPointAfter(op); else b.setInsertionPointToStart(value.getParentBlock()); b.create(value.getLoc(), value, b.getI64IntegerAttr(1)); return success(); } // Calls `addRefCounting` for every reference counted value defined by the // operation `op` (block arguments and values defined in nested regions). static LogicalResult walkReferenceCountedValues( Operation *op, llvm::function_ref addRefCounting) { // Check that we do not have high level async operations in the IR because // otherwise reference counting will produce incorrect results after high // level async operations will be lowered to `async.runtime` WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { if (!isa(op)) return WalkResult::advance(); return op->emitError() << "async operations must be lowered to async runtime operations"; }); if (checkNoAsyncWalk.wasInterrupted()) return failure(); // Add reference counting to block arguments. WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { for (BlockArgument arg : block->getArguments()) if (isRefCounted(arg.getType())) if (failed(addRefCounting(arg))) return WalkResult::interrupt(); return WalkResult::advance(); }); if (blockWalk.wasInterrupted()) return failure(); // Add reference counting to operation results. WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { for (unsigned i = 0; i < op->getNumResults(); ++i) if (isRefCounted(op->getResultTypes()[i])) if (failed(addRefCounting(op->getResult(i)))) return WalkResult::interrupt(); return WalkResult::advance(); }); if (opWalk.wasInterrupted()) return failure(); return success(); } //===----------------------------------------------------------------------===// // Automatic reference counting based on the liveness analysis. //===----------------------------------------------------------------------===// namespace { class AsyncRuntimeRefCountingPass : public impl::AsyncRuntimeRefCountingBase { public: AsyncRuntimeRefCountingPass() = default; void runOnOperation() override; private: /// Adds an automatic reference counting to the `value`. /// /// All values (token, group or value) are semantically created with a /// reference count of +1 and it is the responsibility of the async value user /// to place the `add_ref` and `drop_ref` operations to ensure that the value /// is destroyed after the last use. /// /// The function returns failure if it can't deduce the locations where /// to place the reference counting operations. /// /// Async values "semantically created" when: /// 1. Operation returns async result (e.g. `async.runtime.create`) /// 2. Async value passed in as a block argument (or function argument, /// because function arguments are just entry block arguments) /// /// Passing async value as a function argument (or block argument) does not /// really mean that a new async value is created, it only means that the /// caller of a function transfered ownership of `+1` reference to the callee. /// It is convenient to think that from the callee perspective async value was /// "created" with `+1` reference by the block argument. /// /// Automatic reference counting algorithm outline: /// /// #1 Insert `drop_ref` operations after last use of the `value`. /// #2 Insert `add_ref` operations before functions calls with reference /// counted `value` operand (newly created `+1` reference will be /// transferred to the callee). /// #3 Verify that divergent control flow does not lead to leaked reference /// counted objects. /// /// Async runtime reference counting optimization pass will optimize away /// some of the redundant `add_ref` and `drop_ref` operations inserted by this /// strategy (see `async-runtime-ref-counting-opt`). LogicalResult addAutomaticRefCounting(Value value); /// (#1) Adds the `drop_ref` operation after the last use of the `value` /// relying on the liveness analysis. /// /// If the `value` is in the block `liveIn` set and it is not in the block /// `liveOut` set, it means that it "dies" in the block. We find the last /// use of the value in such block and: /// /// 1. If the last user is a `ReturnLike` operation we do nothing, because /// it forwards the ownership to the caller. /// 2. Otherwise we add a `drop_ref` operation immediately after the last /// use. LogicalResult addDropRefAfterLastUse(Value value); /// (#2) Adds the `add_ref` operation before the function call taking `value` /// operand to ensure that the value passed to the function entry block /// has a `+1` reference count. LogicalResult addAddRefBeforeFunctionCall(Value value); /// (#3) Adds the `drop_ref` operation to account for successor blocks with /// divergent `liveIn` property: `value` is not in the `liveIn` set of all /// successor blocks. /// /// Example: /// /// ^entry: /// %token = async.runtime.create : !async.token /// cf.cond_br %cond, ^bb1, ^bb2 /// ^bb1: /// async.runtime.await %token /// async.runtime.drop_ref %token /// cf.br ^bb2 /// ^bb2: /// return /// /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have /// to branch into a special "reference counting block" from the ^entry that /// will have a `drop_ref` operation, and then branch into the ^bb2. /// /// After transformation: /// /// ^entry: /// %token = async.runtime.create : !async.token /// cf.cond_br %cond, ^bb1, ^reference_counting /// ^bb1: /// async.runtime.await %token /// async.runtime.drop_ref %token /// cf.br ^bb2 /// ^reference_counting: /// async.runtime.drop_ref %token /// cf.br ^bb2 /// ^bb2: /// return /// /// An exception to this rule are blocks with `async.coro.suspend` terminator, /// because in Async to LLVM lowering it is guaranteed that the control flow /// will jump into the resume block, and then follow into the cleanup and /// suspend blocks. /// /// Example: /// /// ^entry(%value: !async.value): /// async.runtime.await_and_resume %value, %hdl : !async.value /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup /// ^resume: /// %0 = async.runtime.load %value /// cf.br ^cleanup /// ^cleanup: /// ... /// ^suspend: /// ... /// /// Although cleanup and suspend blocks do not have the `value` in the /// `liveIn` set, it is guaranteed that execution will eventually continue in /// the resume block (we never explicitly destroy coroutines). LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); }; } // namespace LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { OpBuilder builder(value.getContext()); Location loc = value.getLoc(); // Use liveness analysis to find the placement of `drop_ref`operation. auto &liveness = getAnalysis(); // We analyse only the blocks of the region that defines the `value`, and do // not check nested blocks attached to operations. // // By analyzing only the `definingRegion` CFG we potentially loose an // opportunity to drop the reference count earlier and can extend the lifetime // of reference counted value longer then it is really required. // // We also assume that all nested regions finish their execution before the // completion of the owner operation. The only exception to this rule is // `async.execute` operation, and we verify that they are lowered to the // `async.runtime` operations before adding automatic reference counting. Region *definingRegion = value.getParentRegion(); // Last users of the `value` inside all blocks where the value dies. llvm::SmallSet lastUsers; // Find blocks in the `definingRegion` that have users of the `value` (if // there are multiple users in the block, which one will be selected is // undefined). User operation might be not the actual user of the value, but // the operation in the block that has a "real user" in one of the attached // regions. llvm::DenseMap usersInTheBlocks; for (Operation *user : value.getUsers()) { Block *userBlock = user->getBlock(); Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); assert(ancestor && "ancestor block must be not null"); assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); } // Find blocks where the `value` dies: the value is in `liveIn` set and not // in the `liveOut` set. We place `drop_ref` immediately after the last use // of the `value` in such regions (after handling few special cases). // // We do not traverse all the blocks in the `definingRegion`, because the // `value` can be in the live in set only if it has users in the block, or it // is defined in the block. // // Values with zero users (only definition) handled explicitly above. for (auto &blockAndUser : usersInTheBlocks) { Block *block = blockAndUser.getFirst(); Operation *userInTheBlock = blockAndUser.getSecond(); const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); // Value must be in the live input set or defined in the block. assert(blockLiveness->isLiveIn(value) || blockLiveness->getBlock() == value.getParentBlock()); // If value is in the live out set, it means it doesn't "die" in the block. if (blockLiveness->isLiveOut(value)) continue; // At this point we proved that `value` dies in the `block`. Find the last // use of the `value` inside the `block`, this is where it "dies". Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); lastUsers.insert(lastUser); } // Process all the last users of the `value` inside each block where the value // dies. for (Operation *lastUser : lastUsers) { // Return like operations forward reference count. if (lastUser->hasTrait()) continue; // We can't currently handle other types of terminators. if (lastUser->hasTrait()) return lastUser->emitError() << "async reference counting can't handle " "terminators that are not ReturnLike"; // Add a drop_ref immediately after the last user. builder.setInsertionPointAfter(lastUser); builder.create(loc, value, builder.getI64IntegerAttr(1)); } return success(); } LogicalResult AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { OpBuilder builder(value.getContext()); Location loc = value.getLoc(); for (Operation *user : value.getUsers()) { if (!isa(user)) continue; // Add a reference before the function call to pass the value at `+1` // reference to the function entry block. builder.setInsertionPoint(user); builder.create(loc, value, builder.getI64IntegerAttr(1)); } return success(); } LogicalResult AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( Value value) { using BlockSet = llvm::SmallPtrSet; OpBuilder builder(value.getContext()); // If a block has successors with different `liveIn` property of the `value`, // record block successors that do not thave the `value` in the `liveIn` set. llvm::SmallDenseMap divergentLivenessBlocks; // Use liveness analysis to find the placement of `drop_ref`operation. auto &liveness = getAnalysis(); // Because we only add `drop_ref` operations to the region that defines the // `value` we can only process CFG for the same region. Region *definingRegion = value.getParentRegion(); // Collect blocks with successors with mismatching `liveIn` sets. for (Block &block : definingRegion->getBlocks()) { const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); // Skip the block if value is not in the `liveOut` set. if (!blockLiveness || !blockLiveness->isLiveOut(value)) continue; BlockSet liveInSuccessors; // `value` is in `liveIn` set BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set // Collect successors that do not have `value` in the `liveIn` set. for (Block *successor : block.getSuccessors()) { const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); if (succLiveness && succLiveness->isLiveIn(value)) liveInSuccessors.insert(successor); else noLiveInSuccessors.insert(successor); } // Block has successors with different `liveIn` property of the `value`. if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); } // Try to insert `dropRef` operations to handle blocks with divergent liveness // in successors blocks. for (auto kv : divergentLivenessBlocks) { Block *block = kv.getFirst(); BlockSet &successors = kv.getSecond(); // Coroutine suspension is a special case terminator for wich we do not // need to create additional reference counting (see details above). Operation *terminator = block->getTerminator(); if (isa(terminator)) continue; // We only support successor blocks with empty block argument list. auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; if (llvm::any_of(successors, hasArgs)) return terminator->emitOpError() << "successor have different `liveIn` property of the reference " "counted value"; // Make sure that `dropRef` operation is called when branched into the // successor block without `value` in the `liveIn` set. for (Block *successor : successors) { // If successor has a unique predecessor, it is safe to create `dropRef` // operations directly in the successor block. // // Otherwise we need to create a special block for reference counting // operations, and branch from it to the original successor block. Block *refCountingBlock = nullptr; if (successor->getUniquePredecessor() == block) { refCountingBlock = successor; } else { refCountingBlock = &successor->getParent()->emplaceBlock(); refCountingBlock->moveBefore(successor); OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); builder.create(value.getLoc(), successor); } OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); builder.create(value.getLoc(), value, builder.getI64IntegerAttr(1)); // No need to update the terminator operation. if (successor == refCountingBlock) continue; // Update terminator `successor` block to `refCountingBlock`. for (const auto &pair : llvm::enumerate(terminator->getSuccessors())) if (pair.value() == successor) terminator->setSuccessor(refCountingBlock, pair.index()); } } return success(); } LogicalResult AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { // Short-circuit reference counting for values without uses. if (succeeded(dropRefIfNoUses(value))) return success(); // Add `drop_ref` operations based on the liveness analysis. if (failed(addDropRefAfterLastUse(value))) return failure(); // Add `add_ref` operations before function calls. if (failed(addAddRefBeforeFunctionCall(value))) return failure(); // Add `drop_ref` operations to successors with divergent `value` liveness. if (failed(addDropRefInDivergentLivenessSuccessor(value))) return failure(); return success(); } void AsyncRuntimeRefCountingPass::runOnOperation() { auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; if (failed(walkReferenceCountedValues(getOperation(), functor))) signalPassFailure(); } //===----------------------------------------------------------------------===// // Reference counting based on the user defined policy. //===----------------------------------------------------------------------===// namespace { class AsyncRuntimePolicyBasedRefCountingPass : public impl::AsyncRuntimePolicyBasedRefCountingBase< AsyncRuntimePolicyBasedRefCountingPass> { public: AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } void runOnOperation() override; private: // Adds a reference counting operations for all uses of the `value` according // to the reference counting policy. LogicalResult addRefCounting(Value value); void initializeDefaultPolicy(); llvm::SmallVector(OpOperand &)>> policy; }; } // namespace LogicalResult AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { // Short-circuit reference counting for values without uses. if (succeeded(dropRefIfNoUses(value))) return success(); OpBuilder b(value.getContext()); // Consult the user defined policy for every value use. for (OpOperand &operand : value.getUses()) { Location loc = operand.getOwner()->getLoc(); for (auto &func : policy) { FailureOr refCount = func(operand); if (failed(refCount)) return failure(); int cnt = *refCount; // Create `add_ref` operation before the operand owner. if (cnt > 0) { b.setInsertionPoint(operand.getOwner()); b.create(loc, value, b.getI64IntegerAttr(cnt)); } // Create `drop_ref` operation after the operand owner. if (cnt < 0) { b.setInsertionPointAfter(operand.getOwner()); b.create(loc, value, b.getI64IntegerAttr(-cnt)); } } } return success(); } void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { policy.push_back([](OpOperand &operand) -> FailureOr { Operation *op = operand.getOwner(); Type type = operand.get().getType(); bool isToken = isa(type); bool isGroup = isa(type); bool isValue = isa(type); // Drop reference after async token or group error check (coro await). if (dyn_cast(op)) return (isToken || isGroup) ? -1 : 0; // Drop reference after async value load. if (dyn_cast(op)) return isValue ? -1 : 0; // Drop reference after async token added to the group. if (dyn_cast(op)) return isToken ? -1 : 0; return 0; }); } void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { auto functor = [&](Value value) { return addRefCounting(value); }; if (failed(walkReferenceCountedValues(getOperation(), functor))) signalPassFailure(); } //----------------------------------------------------------------------------// std::unique_ptr mlir::createAsyncRuntimeRefCountingPass() { return std::make_unique(); } std::unique_ptr mlir::createAsyncRuntimePolicyBasedRefCountingPass() { return std::make_unique(); }