xref: /llvm-project/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp (revision 86fd4d4b5b95d58844e521cf7319965eea7d8d0b)
1a416267aSYuxuan Chen //===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2a416267aSYuxuan Chen //
3a416267aSYuxuan Chen // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a416267aSYuxuan Chen // See https://llvm.org/LICENSE.txt for license information.
5a416267aSYuxuan Chen // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a416267aSYuxuan Chen //
7a416267aSYuxuan Chen //===----------------------------------------------------------------------===//
8a416267aSYuxuan Chen //
9a416267aSYuxuan Chen // \file
10a416267aSYuxuan Chen // This pass transforms all Call or Invoke instructions that are annotated
11a416267aSYuxuan Chen // "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12a416267aSYuxuan Chen // The frame of the callee coroutine is allocated inside the caller. A pointer
13a416267aSYuxuan Chen // to the allocated frame will be passed into the `.noalloc` ramp function.
14a416267aSYuxuan Chen //
15a416267aSYuxuan Chen //===----------------------------------------------------------------------===//
16a416267aSYuxuan Chen 
17a416267aSYuxuan Chen #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18a416267aSYuxuan Chen 
19c6414970SYuxuan Chen #include "llvm/Analysis/CGSCCPassManager.h"
20a416267aSYuxuan Chen #include "llvm/Analysis/LazyCallGraph.h"
21a416267aSYuxuan Chen #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22a416267aSYuxuan Chen #include "llvm/IR/Analysis.h"
23a416267aSYuxuan Chen #include "llvm/IR/IRBuilder.h"
24a416267aSYuxuan Chen #include "llvm/IR/Instruction.h"
25a416267aSYuxuan Chen #include "llvm/IR/Module.h"
26a416267aSYuxuan Chen #include "llvm/IR/PassManager.h"
27c6414970SYuxuan Chen #include "llvm/Transforms/Utils/CallGraphUpdater.h"
28c6414970SYuxuan Chen #include "llvm/Transforms/Utils/Cloning.h"
29a416267aSYuxuan Chen 
30a416267aSYuxuan Chen #include <cassert>
31a416267aSYuxuan Chen 
32a416267aSYuxuan Chen using namespace llvm;
33a416267aSYuxuan Chen 
34a416267aSYuxuan Chen #define DEBUG_TYPE "coro-annotation-elide"
35a416267aSYuxuan Chen 
36a416267aSYuxuan Chen static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
37a416267aSYuxuan Chen   for (Instruction &I : F->getEntryBlock())
38a416267aSYuxuan Chen     if (!isa<AllocaInst>(&I))
39a416267aSYuxuan Chen       return &I;
40a416267aSYuxuan Chen   llvm_unreachable("no terminator in the entry block");
41a416267aSYuxuan Chen }
42a416267aSYuxuan Chen 
43a416267aSYuxuan Chen // Create an alloca in the caller, using FrameSize and FrameAlign as the callee
44a416267aSYuxuan Chen // coroutine's activation frame.
45a416267aSYuxuan Chen static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
46a416267aSYuxuan Chen                                     Align FrameAlign) {
47a416267aSYuxuan Chen   LLVMContext &C = Caller->getContext();
48a416267aSYuxuan Chen   BasicBlock::iterator InsertPt =
49a416267aSYuxuan Chen       getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
50a416267aSYuxuan Chen   const DataLayout &DL = Caller->getDataLayout();
51a416267aSYuxuan Chen   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
52a416267aSYuxuan Chen   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
53a416267aSYuxuan Chen   Frame->setAlignment(FrameAlign);
54a416267aSYuxuan Chen   return Frame;
55a416267aSYuxuan Chen }
56a416267aSYuxuan Chen 
57a416267aSYuxuan Chen // Given a call or invoke instruction to the elide safe coroutine, this function
58a416267aSYuxuan Chen // does the following:
59a416267aSYuxuan Chen //  - Allocate a frame for the callee coroutine in the caller using alloca.
60a416267aSYuxuan Chen //  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
61a416267aSYuxuan Chen //    pointer to the frame as an additional argument to NewCallee.
62a416267aSYuxuan Chen static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
63a416267aSYuxuan Chen                         uint64_t FrameSize, Align FrameAlign) {
64a416267aSYuxuan Chen   // TODO: generate the lifetime intrinsics for the new frame. This will require
65a416267aSYuxuan Chen   // introduction of two pesudo lifetime intrinsics in the frontend around the
66a416267aSYuxuan Chen   // `co_await` expression and convert them to real lifetime intrinsics here.
67a416267aSYuxuan Chen   auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
68a416267aSYuxuan Chen   auto NewCBInsertPt = CB->getIterator();
69a416267aSYuxuan Chen   llvm::CallBase *NewCB = nullptr;
70a416267aSYuxuan Chen   SmallVector<Value *, 4> NewArgs;
71a416267aSYuxuan Chen   NewArgs.append(CB->arg_begin(), CB->arg_end());
72a416267aSYuxuan Chen   NewArgs.push_back(FramePtr);
73a416267aSYuxuan Chen 
74a416267aSYuxuan Chen   if (auto *CI = dyn_cast<CallInst>(CB)) {
75a416267aSYuxuan Chen     auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
76a416267aSYuxuan Chen                                    NewArgs, "", NewCBInsertPt);
77a416267aSYuxuan Chen     NewCI->setTailCallKind(CI->getTailCallKind());
78a416267aSYuxuan Chen     NewCB = NewCI;
79a416267aSYuxuan Chen   } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
80a416267aSYuxuan Chen     NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
81a416267aSYuxuan Chen                                II->getNormalDest(), II->getUnwindDest(),
82e03f4271SJay Foad                                NewArgs, {}, "", NewCBInsertPt);
83a416267aSYuxuan Chen   } else {
84a416267aSYuxuan Chen     llvm_unreachable("CallBase should either be Call or Invoke!");
85a416267aSYuxuan Chen   }
86a416267aSYuxuan Chen 
87a416267aSYuxuan Chen   NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
88a416267aSYuxuan Chen   NewCB->setCallingConv(CB->getCallingConv());
89a416267aSYuxuan Chen   NewCB->setAttributes(CB->getAttributes());
90a416267aSYuxuan Chen   NewCB->setDebugLoc(CB->getDebugLoc());
91a416267aSYuxuan Chen   std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
92a416267aSYuxuan Chen             NewCB->bundle_op_info_begin());
93a416267aSYuxuan Chen 
94a416267aSYuxuan Chen   NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
95a416267aSYuxuan Chen   CB->replaceAllUsesWith(NewCB);
96c6414970SYuxuan Chen 
97c6414970SYuxuan Chen   InlineFunctionInfo IFI;
98c6414970SYuxuan Chen   InlineResult IR = InlineFunction(*NewCB, IFI);
99c6414970SYuxuan Chen   if (IR.isSuccess()) {
100a416267aSYuxuan Chen     CB->eraseFromParent();
101c6414970SYuxuan Chen   } else {
102c6414970SYuxuan Chen     NewCB->replaceAllUsesWith(CB);
103c6414970SYuxuan Chen     NewCB->eraseFromParent();
104c6414970SYuxuan Chen   }
105a416267aSYuxuan Chen }
106a416267aSYuxuan Chen 
107c6414970SYuxuan Chen PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
108c6414970SYuxuan Chen                                                CGSCCAnalysisManager &AM,
109c6414970SYuxuan Chen                                                LazyCallGraph &CG,
110c6414970SYuxuan Chen                                                CGSCCUpdateResult &UR) {
111a416267aSYuxuan Chen   bool Changed = false;
112c6414970SYuxuan Chen   CallGraphUpdater CGUpdater;
113c6414970SYuxuan Chen   CGUpdater.initialize(CG, C, AM, UR);
114a416267aSYuxuan Chen 
115c6414970SYuxuan Chen   auto &FAM =
116c6414970SYuxuan Chen       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
117b5cdb039SYuxuan Chen 
118c6414970SYuxuan Chen   for (LazyCallGraph::Node &N : C) {
119c6414970SYuxuan Chen     Function *Callee = &N.getFunction();
120c6414970SYuxuan Chen     Function *NewCallee = Callee->getParent()->getFunction(
121c6414970SYuxuan Chen         (Callee->getName() + ".noalloc").str());
122b5cdb039SYuxuan Chen     if (!NewCallee)
123c6414970SYuxuan Chen       continue;
124a416267aSYuxuan Chen 
125a416267aSYuxuan Chen     SmallVector<CallBase *, 4> Users;
126c6414970SYuxuan Chen     for (auto *U : Callee->users()) {
127a416267aSYuxuan Chen       if (auto *CB = dyn_cast<CallBase>(U)) {
128c6414970SYuxuan Chen         if (CB->getCalledFunction() == Callee)
129a416267aSYuxuan Chen           Users.push_back(CB);
130a416267aSYuxuan Chen       }
131a416267aSYuxuan Chen     }
132c6414970SYuxuan Chen     auto FramePtrArgPosition = NewCallee->arg_size() - 1;
133c6414970SYuxuan Chen     auto FrameSize =
134c6414970SYuxuan Chen         NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
135c6414970SYuxuan Chen     auto FrameAlign =
136c6414970SYuxuan Chen         NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
137a416267aSYuxuan Chen 
138c6414970SYuxuan Chen     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
139a416267aSYuxuan Chen 
140a416267aSYuxuan Chen     for (auto *CB : Users) {
141a416267aSYuxuan Chen       auto *Caller = CB->getFunction();
142b5cdb039SYuxuan Chen       if (!Caller)
143b5cdb039SYuxuan Chen         continue;
144a416267aSYuxuan Chen 
145b5cdb039SYuxuan Chen       bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
146b5cdb039SYuxuan Chen       bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
147b5cdb039SYuxuan Chen       if (IsCallerPresplitCoroutine && HasAttr) {
148c6414970SYuxuan Chen         auto *CallerN = CG.lookup(*Caller);
149*86fd4d4bSYuxuan Chen         auto *CallerC = CallerN ? CG.lookupSCC(*CallerN) : nullptr;
150*86fd4d4bSYuxuan Chen         // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller
151*86fd4d4bSYuxuan Chen         // yet. Skip the call graph update.
152*86fd4d4bSYuxuan Chen         auto ShouldUpdateCallGraph = !!CallerC;
153a416267aSYuxuan Chen         processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
154a416267aSYuxuan Chen 
155a416267aSYuxuan Chen         ORE.emit([&]() {
156a416267aSYuxuan Chen           return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
157c6414970SYuxuan Chen                  << "'" << ore::NV("callee", Callee->getName())
158c6414970SYuxuan Chen                  << "' elided in '" << ore::NV("caller", Caller->getName())
159c6414970SYuxuan Chen                  << "'";
160a416267aSYuxuan Chen         });
161b5cdb039SYuxuan Chen 
162761bf333SYuxuan Chen         FAM.invalidate(*Caller, PreservedAnalyses::none());
163a416267aSYuxuan Chen         Changed = true;
164*86fd4d4bSYuxuan Chen         if (ShouldUpdateCallGraph)
165c6414970SYuxuan Chen           updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
166c6414970SYuxuan Chen                                                  FAM);
167c6414970SYuxuan Chen 
168b5cdb039SYuxuan Chen       } else {
169b5cdb039SYuxuan Chen         ORE.emit([&]() {
170b5cdb039SYuxuan Chen           return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
171b5cdb039SYuxuan Chen                                           Caller)
172c6414970SYuxuan Chen                  << "'" << ore::NV("callee", Callee->getName())
173c6414970SYuxuan Chen                  << "' not elided in '" << ore::NV("caller", Caller->getName())
174c6414970SYuxuan Chen                  << "' (caller_presplit="
175b5cdb039SYuxuan Chen                  << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
176b5cdb039SYuxuan Chen                  << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
177b5cdb039SYuxuan Chen                  << ")";
178b5cdb039SYuxuan Chen         });
179a416267aSYuxuan Chen       }
180a416267aSYuxuan Chen     }
181c6414970SYuxuan Chen   }
182761bf333SYuxuan Chen 
183a416267aSYuxuan Chen   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
184a416267aSYuxuan Chen }
185