1 //===- CoroElide.cpp - Coroutine Frame Allocation Elision Pass ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Coroutines/CoroElide.h" 10 #include "CoroInternal.h" 11 #include "llvm/ADT/DenseMap.h" 12 #include "llvm/ADT/Statistic.h" 13 #include "llvm/Analysis/AliasAnalysis.h" 14 #include "llvm/Analysis/InstructionSimplify.h" 15 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 16 #include "llvm/IR/Dominators.h" 17 #include "llvm/IR/InstIterator.h" 18 #include "llvm/Support/ErrorHandling.h" 19 #include "llvm/Support/FileSystem.h" 20 #include <optional> 21 22 using namespace llvm; 23 24 #define DEBUG_TYPE "coro-elide" 25 26 STATISTIC(NumOfCoroElided, "The # of coroutine get elided."); 27 28 #ifndef NDEBUG 29 static cl::opt<std::string> CoroElideInfoOutputFilename( 30 "coro-elide-info-output-file", cl::value_desc("filename"), 31 cl::desc("File to record the coroutines got elided"), cl::Hidden); 32 #endif 33 34 namespace { 35 // Created on demand if the coro-elide pass has work to do. 36 class FunctionElideInfo { 37 public: 38 FunctionElideInfo(Function *F) : ContainingFunction(F) { 39 this->collectPostSplitCoroIds(); 40 } 41 42 bool hasCoroIds() const { return !CoroIds.empty(); } 43 44 const SmallVectorImpl<CoroIdInst *> &getCoroIds() const { return CoroIds; } 45 46 private: 47 Function *ContainingFunction; 48 SmallVector<CoroIdInst *, 4> CoroIds; 49 // Used in canCoroBeginEscape to distinguish coro.suspend switchs. 50 SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches; 51 52 void collectPostSplitCoroIds(); 53 friend class CoroIdElider; 54 }; 55 56 class CoroIdElider { 57 public: 58 CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA, 59 DominatorTree &DT, OptimizationRemarkEmitter &ORE); 60 void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign); 61 bool lifetimeEligibleForElide() const; 62 bool attemptElide(); 63 bool canCoroBeginEscape(const CoroBeginInst *, 64 const SmallPtrSetImpl<BasicBlock *> &) const; 65 66 private: 67 CoroIdInst *CoroId; 68 FunctionElideInfo &FEI; 69 AAResults &AA; 70 DominatorTree &DT; 71 OptimizationRemarkEmitter &ORE; 72 73 SmallVector<CoroBeginInst *, 1> CoroBegins; 74 SmallVector<CoroAllocInst *, 1> CoroAllocs; 75 SmallVector<CoroSubFnInst *, 4> ResumeAddr; 76 DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr; 77 }; 78 } // end anonymous namespace 79 80 // Go through the list of coro.subfn.addr intrinsics and replace them with the 81 // provided constant. 82 static void replaceWithConstant(Constant *Value, 83 SmallVectorImpl<CoroSubFnInst *> &Users) { 84 if (Users.empty()) 85 return; 86 87 // See if we need to bitcast the constant to match the type of the intrinsic 88 // being replaced. Note: All coro.subfn.addr intrinsics return the same type, 89 // so we only need to examine the type of the first one in the list. 90 Type *IntrTy = Users.front()->getType(); 91 Type *ValueTy = Value->getType(); 92 if (ValueTy != IntrTy) { 93 // May need to tweak the function type to match the type expected at the 94 // use site. 95 assert(ValueTy->isPointerTy() && IntrTy->isPointerTy()); 96 Value = ConstantExpr::getBitCast(Value, IntrTy); 97 } 98 99 // Now the value type matches the type of the intrinsic. Replace them all! 100 for (CoroSubFnInst *I : Users) 101 replaceAndRecursivelySimplify(I, Value); 102 } 103 104 // See if any operand of the call instruction references the coroutine frame. 105 static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) { 106 for (Value *Op : CI->operand_values()) 107 if (!AA.isNoAlias(Op, Frame)) 108 return true; 109 return false; 110 } 111 112 // Look for any tail calls referencing the coroutine frame and remove tail 113 // attribute from them, since now coroutine frame resides on the stack and tail 114 // call implies that the function does not references anything on the stack. 115 // However if it's a musttail call, we cannot remove the tailcall attribute. 116 // It's safe to keep it there as the musttail call is for symmetric transfer, 117 // and by that point the frame should have been destroyed and hence not 118 // interfering with operands. 119 static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { 120 Function &F = *Frame->getFunction(); 121 for (Instruction &I : instructions(F)) 122 if (auto *Call = dyn_cast<CallInst>(&I)) 123 if (Call->isTailCall() && operandReferences(Call, Frame, AA) && 124 !Call->isMustTailCall()) 125 Call->setTailCall(false); 126 } 127 128 // Given a resume function @f.resume(%f.frame* %frame), returns the size 129 // and expected alignment of %f.frame type. 130 static std::optional<std::pair<uint64_t, Align>> 131 getFrameLayout(Function *Resume) { 132 // Pull information from the function attributes. 133 auto Size = Resume->getParamDereferenceableBytes(0); 134 if (!Size) 135 return std::nullopt; 136 return std::make_pair(Size, Resume->getParamAlign(0).valueOrOne()); 137 } 138 139 // Finds first non alloca instruction in the entry block of a function. 140 static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { 141 for (Instruction &I : F->getEntryBlock()) 142 if (!isa<AllocaInst>(&I)) 143 return &I; 144 llvm_unreachable("no terminator in the entry block"); 145 } 146 147 #ifndef NDEBUG 148 static std::unique_ptr<raw_fd_ostream> getOrCreateLogFile() { 149 assert(!CoroElideInfoOutputFilename.empty() && 150 "coro-elide-info-output-file shouldn't be empty"); 151 std::error_code EC; 152 auto Result = std::make_unique<raw_fd_ostream>(CoroElideInfoOutputFilename, 153 EC, sys::fs::OF_Append); 154 if (!EC) 155 return Result; 156 llvm::errs() << "Error opening coro-elide-info-output-file '" 157 << CoroElideInfoOutputFilename << " for appending!\n"; 158 return std::make_unique<raw_fd_ostream>(2, false); // stderr. 159 } 160 #endif 161 162 void FunctionElideInfo::collectPostSplitCoroIds() { 163 for (auto &I : instructions(this->ContainingFunction)) { 164 if (auto *CII = dyn_cast<CoroIdInst>(&I)) 165 if (CII->getInfo().isPostSplit()) 166 // If it is the coroutine itself, don't touch it. 167 if (CII->getCoroutine() != CII->getFunction()) 168 CoroIds.push_back(CII); 169 170 // Consider case like: 171 // %0 = call i8 @llvm.coro.suspend(...) 172 // switch i8 %0, label %suspend [i8 0, label %resume 173 // i8 1, label %cleanup] 174 // and collect the SwitchInsts which are used by escape analysis later. 175 if (auto *CSI = dyn_cast<CoroSuspendInst>(&I)) 176 if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) { 177 SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser()); 178 if (SWI->getNumCases() == 2) 179 CoroSuspendSwitches.insert(SWI); 180 } 181 } 182 } 183 184 CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, 185 AAResults &AA, DominatorTree &DT, 186 OptimizationRemarkEmitter &ORE) 187 : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), ORE(ORE) { 188 // Collect all coro.begin and coro.allocs associated with this coro.id. 189 for (User *U : CoroId->users()) { 190 if (auto *CB = dyn_cast<CoroBeginInst>(U)) 191 CoroBegins.push_back(CB); 192 else if (auto *CA = dyn_cast<CoroAllocInst>(U)) 193 CoroAllocs.push_back(CA); 194 } 195 196 // Collect all coro.subfn.addrs associated with coro.begin. 197 // Note, we only devirtualize the calls if their coro.subfn.addr refers to 198 // coro.begin directly. If we run into cases where this check is too 199 // conservative, we can consider relaxing the check. 200 for (CoroBeginInst *CB : CoroBegins) { 201 for (User *U : CB->users()) 202 if (auto *II = dyn_cast<CoroSubFnInst>(U)) 203 switch (II->getIndex()) { 204 case CoroSubFnInst::ResumeIndex: 205 ResumeAddr.push_back(II); 206 break; 207 case CoroSubFnInst::DestroyIndex: 208 DestroyAddr[CB].push_back(II); 209 break; 210 default: 211 llvm_unreachable("unexpected coro.subfn.addr constant"); 212 } 213 } 214 } 215 216 // To elide heap allocations we need to suppress code blocks guarded by 217 // llvm.coro.alloc and llvm.coro.free instructions. 218 void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) { 219 LLVMContext &C = FEI.ContainingFunction->getContext(); 220 BasicBlock::iterator InsertPt = 221 getFirstNonAllocaInTheEntryBlock(FEI.ContainingFunction)->getIterator(); 222 223 // Replacing llvm.coro.alloc with false will suppress dynamic 224 // allocation as it is expected for the frontend to generate the code that 225 // looks like: 226 // id = coro.id(...) 227 // mem = coro.alloc(id) ? malloc(coro.size()) : 0; 228 // coro.begin(id, mem) 229 auto *False = ConstantInt::getFalse(C); 230 for (auto *CA : CoroAllocs) { 231 CA->replaceAllUsesWith(False); 232 CA->eraseFromParent(); 233 } 234 235 // FIXME: Design how to transmit alignment information for every alloca that 236 // is spilled into the coroutine frame and recreate the alignment information 237 // here. Possibly we will need to do a mini SROA here and break the coroutine 238 // frame into individual AllocaInst recreating the original alignment. 239 const DataLayout &DL = FEI.ContainingFunction->getDataLayout(); 240 auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); 241 auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); 242 Frame->setAlignment(FrameAlign); 243 auto *FrameVoidPtr = 244 new BitCastInst(Frame, PointerType::getUnqual(C), "vFrame", InsertPt); 245 246 for (auto *CB : CoroBegins) { 247 CB->replaceAllUsesWith(FrameVoidPtr); 248 CB->eraseFromParent(); 249 } 250 251 // Since now coroutine frame lives on the stack we need to make sure that 252 // any tail call referencing it, must be made non-tail call. 253 removeTailCallAttribute(Frame, AA); 254 } 255 256 bool CoroIdElider::canCoroBeginEscape( 257 const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const { 258 const auto &It = DestroyAddr.find(CB); 259 assert(It != DestroyAddr.end()); 260 261 // Limit the number of blocks we visit. 262 unsigned Limit = 32 * (1 + It->second.size()); 263 264 SmallVector<const BasicBlock *, 32> Worklist; 265 Worklist.push_back(CB->getParent()); 266 267 SmallPtrSet<const BasicBlock *, 32> Visited; 268 // Consider basicblock of coro.destroy as visited one, so that we 269 // skip the path pass through coro.destroy. 270 for (auto *DA : It->second) 271 Visited.insert(DA->getParent()); 272 273 SmallPtrSet<const BasicBlock *, 32> EscapingBBs; 274 for (auto *U : CB->users()) { 275 // The use from coroutine intrinsics are not a problem. 276 if (isa<CoroFreeInst, CoroSubFnInst, CoroSaveInst>(U)) 277 continue; 278 279 // Think all other usages may be an escaping candidate conservatively. 280 // 281 // Note that the major user of switch ABI coroutine (the C++) will store 282 // resume.fn, destroy.fn and the index to the coroutine frame immediately. 283 // So the parent of the coro.begin in C++ will be always escaping. 284 // Then we can't get any performance benefits for C++ by improving the 285 // precision of the method. 286 // 287 // The reason why we still judge it is we want to make LLVM Coroutine in 288 // switch ABIs to be self contained as much as possible instead of a 289 // by-product of C++20 Coroutines. 290 EscapingBBs.insert(cast<Instruction>(U)->getParent()); 291 } 292 293 bool PotentiallyEscaped = false; 294 295 do { 296 const auto *BB = Worklist.pop_back_val(); 297 if (!Visited.insert(BB).second) 298 continue; 299 300 // A Path insensitive marker to test whether the coro.begin escapes. 301 // It is intentional to make it path insensitive while it may not be 302 // precise since we don't want the process to be too slow. 303 PotentiallyEscaped |= EscapingBBs.count(BB); 304 305 if (TIs.count(BB)) { 306 if (isa<ReturnInst>(BB->getTerminator()) || PotentiallyEscaped) 307 return true; 308 309 // If the function ends with the exceptional terminator, the memory used 310 // by the coroutine frame can be released by stack unwinding 311 // automatically. So we can think the coro.begin doesn't escape if it 312 // exits the function by exceptional terminator. 313 314 continue; 315 } 316 317 // Conservatively say that there is potentially a path. 318 if (!--Limit) 319 return true; 320 321 auto TI = BB->getTerminator(); 322 // Although the default dest of coro.suspend switches is suspend pointer 323 // which means a escape path to normal terminator, it is reasonable to skip 324 // it since coroutine frame doesn't change outside the coroutine body. 325 if (isa<SwitchInst>(TI) && 326 FEI.CoroSuspendSwitches.count(cast<SwitchInst>(TI))) { 327 Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1)); 328 Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2)); 329 } else 330 Worklist.append(succ_begin(BB), succ_end(BB)); 331 332 } while (!Worklist.empty()); 333 334 // We have exhausted all possible paths and are certain that coro.begin can 335 // not reach to any of terminators. 336 return false; 337 } 338 339 bool CoroIdElider::lifetimeEligibleForElide() const { 340 // If no CoroAllocs, we cannot suppress allocation, so elision is not 341 // possible. 342 if (CoroAllocs.empty()) 343 return false; 344 345 // Check that for every coro.begin there is at least one coro.destroy directly 346 // referencing the SSA value of that coro.begin along each 347 // non-exceptional path. 348 // 349 // If the value escaped, then coro.destroy would have been referencing a 350 // memory location storing that value and not the virtual register. 351 352 SmallPtrSet<BasicBlock *, 8> Terminators; 353 // First gather all of the terminators for the function. 354 // Consider the final coro.suspend as the real terminator when the current 355 // function is a coroutine. 356 for (BasicBlock &B : *FEI.ContainingFunction) { 357 auto *TI = B.getTerminator(); 358 359 if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI)) 360 continue; 361 362 Terminators.insert(&B); 363 } 364 365 // Filter out the coro.destroy that lie along exceptional paths. 366 for (const auto *CB : CoroBegins) { 367 auto It = DestroyAddr.find(CB); 368 369 // FIXME: If we have not found any destroys for this coro.begin, we 370 // disqualify this elide. 371 if (It == DestroyAddr.end()) 372 return false; 373 374 const auto &CorrespondingDestroyAddrs = It->second; 375 376 // If every terminators is dominated by coro.destroy, we could know the 377 // corresponding coro.begin wouldn't escape. 378 auto DominatesTerminator = [&](auto *TI) { 379 return llvm::any_of(CorrespondingDestroyAddrs, [&](auto *Destroy) { 380 return DT.dominates(Destroy, TI->getTerminator()); 381 }); 382 }; 383 384 if (llvm::all_of(Terminators, DominatesTerminator)) 385 continue; 386 387 // Otherwise canCoroBeginEscape would decide whether there is any paths from 388 // coro.begin to Terminators which not pass through any of the 389 // coro.destroys. This is a slower analysis. 390 // 391 // canCoroBeginEscape is relatively slow, so we avoid to run it as much as 392 // possible. 393 if (canCoroBeginEscape(CB, Terminators)) 394 return false; 395 } 396 397 // We have checked all CoroBegins and their paths to the terminators without 398 // finding disqualifying code patterns, so we can perform heap allocations. 399 return true; 400 } 401 402 bool CoroIdElider::attemptElide() { 403 // PostSplit coro.id refers to an array of subfunctions in its Info 404 // argument. 405 ConstantArray *Resumers = CoroId->getInfo().Resumers; 406 assert(Resumers && "PostSplit coro.id Info argument must refer to an array" 407 "of coroutine subfunctions"); 408 auto *ResumeAddrConstant = 409 Resumers->getAggregateElement(CoroSubFnInst::ResumeIndex); 410 411 replaceWithConstant(ResumeAddrConstant, ResumeAddr); 412 413 bool EligibleForElide = lifetimeEligibleForElide(); 414 415 auto *DestroyAddrConstant = Resumers->getAggregateElement( 416 EligibleForElide ? CoroSubFnInst::CleanupIndex 417 : CoroSubFnInst::DestroyIndex); 418 419 for (auto &It : DestroyAddr) 420 replaceWithConstant(DestroyAddrConstant, It.second); 421 422 auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant)); 423 424 auto CallerFunctionName = FEI.ContainingFunction->getName(); 425 auto CalleeCoroutineName = CoroId->getCoroutine()->getName(); 426 427 if (EligibleForElide && FrameSizeAndAlign) { 428 elideHeapAllocations(FrameSizeAndAlign->first, FrameSizeAndAlign->second); 429 coro::replaceCoroFree(CoroId, /*Elide=*/true); 430 NumOfCoroElided++; 431 432 #ifndef NDEBUG 433 if (!CoroElideInfoOutputFilename.empty()) 434 *getOrCreateLogFile() << "Elide " << CalleeCoroutineName << " in " 435 << FEI.ContainingFunction->getName() << "\n"; 436 #endif 437 438 ORE.emit([&]() { 439 return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId) 440 << "'" << ore::NV("callee", CalleeCoroutineName) 441 << "' elided in '" << ore::NV("caller", CallerFunctionName) 442 << "' (frame_size=" 443 << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align=" 444 << ore::NV("align", FrameSizeAndAlign->second.value()) << ")"; 445 }); 446 } else { 447 ORE.emit([&]() { 448 auto Remark = OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) 449 << "'" << ore::NV("callee", CalleeCoroutineName) 450 << "' not elided in '" 451 << ore::NV("caller", CallerFunctionName); 452 453 if (FrameSizeAndAlign) 454 return Remark << "' (frame_size=" 455 << ore::NV("frame_size", FrameSizeAndAlign->first) 456 << ", align=" 457 << ore::NV("align", FrameSizeAndAlign->second.value()) 458 << ")"; 459 else 460 return Remark << "' (frame_size=unknown, align=unknown)"; 461 }); 462 } 463 464 return true; 465 } 466 467 PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { 468 auto &M = *F.getParent(); 469 if (!coro::declaresIntrinsics(M, {"llvm.coro.id"})) 470 return PreservedAnalyses::all(); 471 472 FunctionElideInfo FEI{&F}; 473 // Elide is not necessary if there's no coro.id within the function. 474 if (!FEI.hasCoroIds()) 475 return PreservedAnalyses::all(); 476 477 AAResults &AA = AM.getResult<AAManager>(F); 478 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); 479 auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); 480 481 bool Changed = false; 482 for (auto *CII : FEI.getCoroIds()) { 483 CoroIdElider CIE(CII, FEI, AA, DT, ORE); 484 Changed |= CIE.attemptElide(); 485 } 486 487 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); 488 } 489