1 //===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // \file 10 // This pass transforms all Call or Invoke instructions that are annotated 11 // "coro_elide_safe" to call the `.noalloc` variant of coroutine instead. 12 // The frame of the callee coroutine is allocated inside the caller. A pointer 13 // to the allocated frame will be passed into the `.noalloc` ramp function. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h" 18 19 #include "llvm/Analysis/CGSCCPassManager.h" 20 #include "llvm/Analysis/LazyCallGraph.h" 21 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 22 #include "llvm/IR/Analysis.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/Instruction.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/IR/PassManager.h" 27 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 28 #include "llvm/Transforms/Utils/Cloning.h" 29 30 #include <cassert> 31 32 using namespace llvm; 33 34 #define DEBUG_TYPE "coro-annotation-elide" 35 36 static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { 37 for (Instruction &I : F->getEntryBlock()) 38 if (!isa<AllocaInst>(&I)) 39 return &I; 40 llvm_unreachable("no terminator in the entry block"); 41 } 42 43 // Create an alloca in the caller, using FrameSize and FrameAlign as the callee 44 // coroutine's activation frame. 45 static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize, 46 Align FrameAlign) { 47 LLVMContext &C = Caller->getContext(); 48 BasicBlock::iterator InsertPt = 49 getFirstNonAllocaInTheEntryBlock(Caller)->getIterator(); 50 const DataLayout &DL = Caller->getDataLayout(); 51 auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); 52 auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); 53 Frame->setAlignment(FrameAlign); 54 return Frame; 55 } 56 57 // Given a call or invoke instruction to the elide safe coroutine, this function 58 // does the following: 59 // - Allocate a frame for the callee coroutine in the caller using alloca. 60 // - Replace the old CB with a new Call or Invoke to `NewCallee`, with the 61 // pointer to the frame as an additional argument to NewCallee. 62 static void processCall(CallBase *CB, Function *Caller, Function *NewCallee, 63 uint64_t FrameSize, Align FrameAlign) { 64 // TODO: generate the lifetime intrinsics for the new frame. This will require 65 // introduction of two pesudo lifetime intrinsics in the frontend around the 66 // `co_await` expression and convert them to real lifetime intrinsics here. 67 auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign); 68 auto NewCBInsertPt = CB->getIterator(); 69 llvm::CallBase *NewCB = nullptr; 70 SmallVector<Value *, 4> NewArgs; 71 NewArgs.append(CB->arg_begin(), CB->arg_end()); 72 NewArgs.push_back(FramePtr); 73 74 if (auto *CI = dyn_cast<CallInst>(CB)) { 75 auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee, 76 NewArgs, "", NewCBInsertPt); 77 NewCI->setTailCallKind(CI->getTailCallKind()); 78 NewCB = NewCI; 79 } else if (auto *II = dyn_cast<InvokeInst>(CB)) { 80 NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee, 81 II->getNormalDest(), II->getUnwindDest(), 82 NewArgs, {}, "", NewCBInsertPt); 83 } else { 84 llvm_unreachable("CallBase should either be Call or Invoke!"); 85 } 86 87 NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee); 88 NewCB->setCallingConv(CB->getCallingConv()); 89 NewCB->setAttributes(CB->getAttributes()); 90 NewCB->setDebugLoc(CB->getDebugLoc()); 91 std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(), 92 NewCB->bundle_op_info_begin()); 93 94 NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe); 95 CB->replaceAllUsesWith(NewCB); 96 97 InlineFunctionInfo IFI; 98 InlineResult IR = InlineFunction(*NewCB, IFI); 99 if (IR.isSuccess()) { 100 CB->eraseFromParent(); 101 } else { 102 NewCB->replaceAllUsesWith(CB); 103 NewCB->eraseFromParent(); 104 } 105 } 106 107 PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C, 108 CGSCCAnalysisManager &AM, 109 LazyCallGraph &CG, 110 CGSCCUpdateResult &UR) { 111 bool Changed = false; 112 CallGraphUpdater CGUpdater; 113 CGUpdater.initialize(CG, C, AM, UR); 114 115 auto &FAM = 116 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); 117 118 for (LazyCallGraph::Node &N : C) { 119 Function *Callee = &N.getFunction(); 120 Function *NewCallee = Callee->getParent()->getFunction( 121 (Callee->getName() + ".noalloc").str()); 122 if (!NewCallee) 123 continue; 124 125 SmallVector<CallBase *, 4> Users; 126 for (auto *U : Callee->users()) { 127 if (auto *CB = dyn_cast<CallBase>(U)) { 128 if (CB->getCalledFunction() == Callee) 129 Users.push_back(CB); 130 } 131 } 132 auto FramePtrArgPosition = NewCallee->arg_size() - 1; 133 auto FrameSize = 134 NewCallee->getParamDereferenceableBytes(FramePtrArgPosition); 135 auto FrameAlign = 136 NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne(); 137 138 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee); 139 140 for (auto *CB : Users) { 141 auto *Caller = CB->getFunction(); 142 if (!Caller) 143 continue; 144 145 bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine(); 146 bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe); 147 if (IsCallerPresplitCoroutine && HasAttr) { 148 auto *CallerN = CG.lookup(*Caller); 149 auto *CallerC = CallerN ? CG.lookupSCC(*CallerN) : nullptr; 150 // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller 151 // yet. Skip the call graph update. 152 auto ShouldUpdateCallGraph = !!CallerC; 153 processCall(CB, Caller, NewCallee, FrameSize, FrameAlign); 154 155 ORE.emit([&]() { 156 return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller) 157 << "'" << ore::NV("callee", Callee->getName()) 158 << "' elided in '" << ore::NV("caller", Caller->getName()) 159 << "'"; 160 }); 161 162 FAM.invalidate(*Caller, PreservedAnalyses::none()); 163 Changed = true; 164 if (ShouldUpdateCallGraph) 165 updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR, 166 FAM); 167 168 } else { 169 ORE.emit([&]() { 170 return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide", 171 Caller) 172 << "'" << ore::NV("callee", Callee->getName()) 173 << "' not elided in '" << ore::NV("caller", Caller->getName()) 174 << "' (caller_presplit=" 175 << ore::NV("caller_presplit", IsCallerPresplitCoroutine) 176 << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr) 177 << ")"; 178 }); 179 } 180 } 181 } 182 183 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 184 } 185