1a416267aSYuxuan Chen //===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===// 2a416267aSYuxuan Chen // 3a416267aSYuxuan Chen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4a416267aSYuxuan Chen // See https://llvm.org/LICENSE.txt for license information. 5a416267aSYuxuan Chen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6a416267aSYuxuan Chen // 7a416267aSYuxuan Chen //===----------------------------------------------------------------------===// 8a416267aSYuxuan Chen // 9a416267aSYuxuan Chen // \file 10a416267aSYuxuan Chen // This pass transforms all Call or Invoke instructions that are annotated 11a416267aSYuxuan Chen // "coro_elide_safe" to call the `.noalloc` variant of coroutine instead. 12a416267aSYuxuan Chen // The frame of the callee coroutine is allocated inside the caller. A pointer 13a416267aSYuxuan Chen // to the allocated frame will be passed into the `.noalloc` ramp function. 14a416267aSYuxuan Chen // 15a416267aSYuxuan Chen //===----------------------------------------------------------------------===// 16a416267aSYuxuan Chen 17a416267aSYuxuan Chen #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" 18a416267aSYuxuan Chen 19c6414970SYuxuan Chen #include "llvm/Analysis/CGSCCPassManager.h" 20a416267aSYuxuan Chen #include "llvm/Analysis/LazyCallGraph.h" 21a416267aSYuxuan Chen #include "llvm/Analysis/OptimizationRemarkEmitter.h" 22a416267aSYuxuan Chen #include "llvm/IR/Analysis.h" 23a416267aSYuxuan Chen #include "llvm/IR/IRBuilder.h" 24a416267aSYuxuan Chen #include "llvm/IR/Instruction.h" 25a416267aSYuxuan Chen #include "llvm/IR/Module.h" 26a416267aSYuxuan Chen #include "llvm/IR/PassManager.h" 27c6414970SYuxuan Chen #include "llvm/Transforms/Utils/CallGraphUpdater.h" 28c6414970SYuxuan Chen #include "llvm/Transforms/Utils/Cloning.h" 29a416267aSYuxuan Chen 30a416267aSYuxuan Chen #include <cassert> 31a416267aSYuxuan Chen 32a416267aSYuxuan Chen using namespace llvm; 33a416267aSYuxuan Chen 34a416267aSYuxuan Chen #define DEBUG_TYPE "coro-annotation-elide" 35a416267aSYuxuan Chen 36a416267aSYuxuan Chen static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { 37a416267aSYuxuan Chen for (Instruction &I : F->getEntryBlock()) 38a416267aSYuxuan Chen if (!isa<AllocaInst>(&I)) 39a416267aSYuxuan Chen return &I; 40a416267aSYuxuan Chen llvm_unreachable("no terminator in the entry block"); 41a416267aSYuxuan Chen } 42a416267aSYuxuan Chen 43a416267aSYuxuan Chen // Create an alloca in the caller, using FrameSize and FrameAlign as the callee 44a416267aSYuxuan Chen // coroutine's activation frame. 45a416267aSYuxuan Chen static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize, 46a416267aSYuxuan Chen Align FrameAlign) { 47a416267aSYuxuan Chen LLVMContext &C = Caller->getContext(); 48a416267aSYuxuan Chen BasicBlock::iterator InsertPt = 49a416267aSYuxuan Chen getFirstNonAllocaInTheEntryBlock(Caller)->getIterator(); 50a416267aSYuxuan Chen const DataLayout &DL = Caller->getDataLayout(); 51a416267aSYuxuan Chen auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); 52a416267aSYuxuan Chen auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); 53a416267aSYuxuan Chen Frame->setAlignment(FrameAlign); 54a416267aSYuxuan Chen return Frame; 55a416267aSYuxuan Chen } 56a416267aSYuxuan Chen 57a416267aSYuxuan Chen // Given a call or invoke instruction to the elide safe coroutine, this function 58a416267aSYuxuan Chen // does the following: 59a416267aSYuxuan Chen // - Allocate a frame for the callee coroutine in the caller using alloca. 60a416267aSYuxuan Chen // - Replace the old CB with a new Call or Invoke to `NewCallee`, with the 61a416267aSYuxuan Chen // pointer to the frame as an additional argument to NewCallee. 62a416267aSYuxuan Chen static void processCall(CallBase *CB, Function *Caller, Function *NewCallee, 63a416267aSYuxuan Chen uint64_t FrameSize, Align FrameAlign) { 64a416267aSYuxuan Chen // TODO: generate the lifetime intrinsics for the new frame. This will require 65a416267aSYuxuan Chen // introduction of two pesudo lifetime intrinsics in the frontend around the 66a416267aSYuxuan Chen // `co_await` expression and convert them to real lifetime intrinsics here. 67a416267aSYuxuan Chen auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign); 68a416267aSYuxuan Chen auto NewCBInsertPt = CB->getIterator(); 69a416267aSYuxuan Chen llvm::CallBase *NewCB = nullptr; 70a416267aSYuxuan Chen SmallVector<Value *, 4> NewArgs; 71a416267aSYuxuan Chen NewArgs.append(CB->arg_begin(), CB->arg_end()); 72a416267aSYuxuan Chen NewArgs.push_back(FramePtr); 73a416267aSYuxuan Chen 74a416267aSYuxuan Chen if (auto *CI = dyn_cast<CallInst>(CB)) { 75a416267aSYuxuan Chen auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee, 76a416267aSYuxuan Chen NewArgs, "", NewCBInsertPt); 77a416267aSYuxuan Chen NewCI->setTailCallKind(CI->getTailCallKind()); 78a416267aSYuxuan Chen NewCB = NewCI; 79a416267aSYuxuan Chen } else if (auto *II = dyn_cast<InvokeInst>(CB)) { 80a416267aSYuxuan Chen NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee, 81a416267aSYuxuan Chen II->getNormalDest(), II->getUnwindDest(), 82e03f4271SJay Foad NewArgs, {}, "", NewCBInsertPt); 83a416267aSYuxuan Chen } else { 84a416267aSYuxuan Chen llvm_unreachable("CallBase should either be Call or Invoke!"); 85a416267aSYuxuan Chen } 86a416267aSYuxuan Chen 87a416267aSYuxuan Chen NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee); 88a416267aSYuxuan Chen NewCB->setCallingConv(CB->getCallingConv()); 89a416267aSYuxuan Chen NewCB->setAttributes(CB->getAttributes()); 90a416267aSYuxuan Chen NewCB->setDebugLoc(CB->getDebugLoc()); 91a416267aSYuxuan Chen std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(), 92a416267aSYuxuan Chen NewCB->bundle_op_info_begin()); 93a416267aSYuxuan Chen 94a416267aSYuxuan Chen NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe); 95a416267aSYuxuan Chen CB->replaceAllUsesWith(NewCB); 96c6414970SYuxuan Chen 97c6414970SYuxuan Chen InlineFunctionInfo IFI; 98c6414970SYuxuan Chen InlineResult IR = InlineFunction(*NewCB, IFI); 99c6414970SYuxuan Chen if (IR.isSuccess()) { 100a416267aSYuxuan Chen CB->eraseFromParent(); 101c6414970SYuxuan Chen } else { 102c6414970SYuxuan Chen NewCB->replaceAllUsesWith(CB); 103c6414970SYuxuan Chen NewCB->eraseFromParent(); 104c6414970SYuxuan Chen } 105a416267aSYuxuan Chen } 106a416267aSYuxuan Chen 107c6414970SYuxuan Chen PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C, 108c6414970SYuxuan Chen CGSCCAnalysisManager &AM, 109c6414970SYuxuan Chen LazyCallGraph &CG, 110c6414970SYuxuan Chen CGSCCUpdateResult &UR) { 111a416267aSYuxuan Chen bool Changed = false; 112c6414970SYuxuan Chen CallGraphUpdater CGUpdater; 113c6414970SYuxuan Chen CGUpdater.initialize(CG, C, AM, UR); 114a416267aSYuxuan Chen 115c6414970SYuxuan Chen auto &FAM = 116c6414970SYuxuan Chen AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 117b5cdb039SYuxuan Chen 118c6414970SYuxuan Chen for (LazyCallGraph::Node &N : C) { 119c6414970SYuxuan Chen Function *Callee = &N.getFunction(); 120c6414970SYuxuan Chen Function *NewCallee = Callee->getParent()->getFunction( 121c6414970SYuxuan Chen (Callee->getName() + ".noalloc").str()); 122b5cdb039SYuxuan Chen if (!NewCallee) 123c6414970SYuxuan Chen continue; 124a416267aSYuxuan Chen 125a416267aSYuxuan Chen SmallVector<CallBase *, 4> Users; 126c6414970SYuxuan Chen for (auto *U : Callee->users()) { 127a416267aSYuxuan Chen if (auto *CB = dyn_cast<CallBase>(U)) { 128c6414970SYuxuan Chen if (CB->getCalledFunction() == Callee) 129a416267aSYuxuan Chen Users.push_back(CB); 130a416267aSYuxuan Chen } 131a416267aSYuxuan Chen } 132c6414970SYuxuan Chen auto FramePtrArgPosition = NewCallee->arg_size() - 1; 133c6414970SYuxuan Chen auto FrameSize = 134c6414970SYuxuan Chen NewCallee->getParamDereferenceableBytes(FramePtrArgPosition); 135c6414970SYuxuan Chen auto FrameAlign = 136c6414970SYuxuan Chen NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne(); 137a416267aSYuxuan Chen 138c6414970SYuxuan Chen auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee); 139a416267aSYuxuan Chen 140a416267aSYuxuan Chen for (auto *CB : Users) { 141a416267aSYuxuan Chen auto *Caller = CB->getFunction(); 142b5cdb039SYuxuan Chen if (!Caller) 143b5cdb039SYuxuan Chen continue; 144a416267aSYuxuan Chen 145b5cdb039SYuxuan Chen bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine(); 146b5cdb039SYuxuan Chen bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe); 147b5cdb039SYuxuan Chen if (IsCallerPresplitCoroutine && HasAttr) { 148c6414970SYuxuan Chen auto *CallerN = CG.lookup(*Caller); 149*86fd4d4bSYuxuan Chen auto *CallerC = CallerN ? CG.lookupSCC(*CallerN) : nullptr; 150*86fd4d4bSYuxuan Chen // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller 151*86fd4d4bSYuxuan Chen // yet. Skip the call graph update. 152*86fd4d4bSYuxuan Chen auto ShouldUpdateCallGraph = !!CallerC; 153a416267aSYuxuan Chen processCall(CB, Caller, NewCallee, FrameSize, FrameAlign); 154a416267aSYuxuan Chen 155a416267aSYuxuan Chen ORE.emit([&]() { 156a416267aSYuxuan Chen return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller) 157c6414970SYuxuan Chen << "'" << ore::NV("callee", Callee->getName()) 158c6414970SYuxuan Chen << "' elided in '" << ore::NV("caller", Caller->getName()) 159c6414970SYuxuan Chen << "'"; 160a416267aSYuxuan Chen }); 161b5cdb039SYuxuan Chen 162761bf333SYuxuan Chen FAM.invalidate(*Caller, PreservedAnalyses::none()); 163a416267aSYuxuan Chen Changed = true; 164*86fd4d4bSYuxuan Chen if (ShouldUpdateCallGraph) 165c6414970SYuxuan Chen updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR, 166c6414970SYuxuan Chen FAM); 167c6414970SYuxuan Chen 168b5cdb039SYuxuan Chen } else { 169b5cdb039SYuxuan Chen ORE.emit([&]() { 170b5cdb039SYuxuan Chen return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide", 171b5cdb039SYuxuan Chen Caller) 172c6414970SYuxuan Chen << "'" << ore::NV("callee", Callee->getName()) 173c6414970SYuxuan Chen << "' not elided in '" << ore::NV("caller", Caller->getName()) 174c6414970SYuxuan Chen << "' (caller_presplit=" 175b5cdb039SYuxuan Chen << ore::NV("caller_presplit", IsCallerPresplitCoroutine) 176b5cdb039SYuxuan Chen << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr) 177b5cdb039SYuxuan Chen << ")"; 178b5cdb039SYuxuan Chen }); 179a416267aSYuxuan Chen } 180a416267aSYuxuan Chen } 181c6414970SYuxuan Chen } 182761bf333SYuxuan Chen 183a416267aSYuxuan Chen return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 184a416267aSYuxuan Chen } 185