xref: /llvm-project/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp (revision 86fd4d4b5b95d58844e521cf7319965eea7d8d0b)
1 //===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This pass transforms all Call or Invoke instructions that are annotated
11 // "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12 // The frame of the callee coroutine is allocated inside the caller. A pointer
13 // to the allocated frame will be passed into the `.noalloc` ramp function.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18 
19 #include "llvm/Analysis/CGSCCPassManager.h"
20 #include "llvm/Analysis/LazyCallGraph.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/IR/Analysis.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instruction.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/PassManager.h"
27 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
28 #include "llvm/Transforms/Utils/Cloning.h"
29 
30 #include <cassert>
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "coro-annotation-elide"
35 
36 static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
37   for (Instruction &I : F->getEntryBlock())
38     if (!isa<AllocaInst>(&I))
39       return &I;
40   llvm_unreachable("no terminator in the entry block");
41 }
42 
43 // Create an alloca in the caller, using FrameSize and FrameAlign as the callee
44 // coroutine's activation frame.
45 static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
46                                     Align FrameAlign) {
47   LLVMContext &C = Caller->getContext();
48   BasicBlock::iterator InsertPt =
49       getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
50   const DataLayout &DL = Caller->getDataLayout();
51   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
52   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
53   Frame->setAlignment(FrameAlign);
54   return Frame;
55 }
56 
57 // Given a call or invoke instruction to the elide safe coroutine, this function
58 // does the following:
59 //  - Allocate a frame for the callee coroutine in the caller using alloca.
60 //  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
61 //    pointer to the frame as an additional argument to NewCallee.
62 static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
63                         uint64_t FrameSize, Align FrameAlign) {
64   // TODO: generate the lifetime intrinsics for the new frame. This will require
65   // introduction of two pesudo lifetime intrinsics in the frontend around the
66   // `co_await` expression and convert them to real lifetime intrinsics here.
67   auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
68   auto NewCBInsertPt = CB->getIterator();
69   llvm::CallBase *NewCB = nullptr;
70   SmallVector<Value *, 4> NewArgs;
71   NewArgs.append(CB->arg_begin(), CB->arg_end());
72   NewArgs.push_back(FramePtr);
73 
74   if (auto *CI = dyn_cast<CallInst>(CB)) {
75     auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
76                                    NewArgs, "", NewCBInsertPt);
77     NewCI->setTailCallKind(CI->getTailCallKind());
78     NewCB = NewCI;
79   } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
80     NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
81                                II->getNormalDest(), II->getUnwindDest(),
82                                NewArgs, {}, "", NewCBInsertPt);
83   } else {
84     llvm_unreachable("CallBase should either be Call or Invoke!");
85   }
86 
87   NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
88   NewCB->setCallingConv(CB->getCallingConv());
89   NewCB->setAttributes(CB->getAttributes());
90   NewCB->setDebugLoc(CB->getDebugLoc());
91   std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
92             NewCB->bundle_op_info_begin());
93 
94   NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
95   CB->replaceAllUsesWith(NewCB);
96 
97   InlineFunctionInfo IFI;
98   InlineResult IR = InlineFunction(*NewCB, IFI);
99   if (IR.isSuccess()) {
100     CB->eraseFromParent();
101   } else {
102     NewCB->replaceAllUsesWith(CB);
103     NewCB->eraseFromParent();
104   }
105 }
106 
107 PreservedAnalyses CoroAnnotationElidePass::run(LazyCallGraph::SCC &C,
108                                                CGSCCAnalysisManager &AM,
109                                                LazyCallGraph &CG,
110                                                CGSCCUpdateResult &UR) {
111   bool Changed = false;
112   CallGraphUpdater CGUpdater;
113   CGUpdater.initialize(CG, C, AM, UR);
114 
115   auto &FAM =
116       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
117 
118   for (LazyCallGraph::Node &N : C) {
119     Function *Callee = &N.getFunction();
120     Function *NewCallee = Callee->getParent()->getFunction(
121         (Callee->getName() + ".noalloc").str());
122     if (!NewCallee)
123       continue;
124 
125     SmallVector<CallBase *, 4> Users;
126     for (auto *U : Callee->users()) {
127       if (auto *CB = dyn_cast<CallBase>(U)) {
128         if (CB->getCalledFunction() == Callee)
129           Users.push_back(CB);
130       }
131     }
132     auto FramePtrArgPosition = NewCallee->arg_size() - 1;
133     auto FrameSize =
134         NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
135     auto FrameAlign =
136         NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
137 
138     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(*Callee);
139 
140     for (auto *CB : Users) {
141       auto *Caller = CB->getFunction();
142       if (!Caller)
143         continue;
144 
145       bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
146       bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
147       if (IsCallerPresplitCoroutine && HasAttr) {
148         auto *CallerN = CG.lookup(*Caller);
149         auto *CallerC = CallerN ? CG.lookupSCC(*CallerN) : nullptr;
150         // If CallerC is nullptr, it means LazyCallGraph hasn't visited Caller
151         // yet. Skip the call graph update.
152         auto ShouldUpdateCallGraph = !!CallerC;
153         processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
154 
155         ORE.emit([&]() {
156           return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
157                  << "'" << ore::NV("callee", Callee->getName())
158                  << "' elided in '" << ore::NV("caller", Caller->getName())
159                  << "'";
160         });
161 
162         FAM.invalidate(*Caller, PreservedAnalyses::none());
163         Changed = true;
164         if (ShouldUpdateCallGraph)
165           updateCGAndAnalysisManagerForCGSCCPass(CG, *CallerC, *CallerN, AM, UR,
166                                                  FAM);
167 
168       } else {
169         ORE.emit([&]() {
170           return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
171                                           Caller)
172                  << "'" << ore::NV("callee", Callee->getName())
173                  << "' not elided in '" << ore::NV("caller", Caller->getName())
174                  << "' (caller_presplit="
175                  << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
176                  << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
177                  << ")";
178         });
179       }
180     }
181   }
182 
183   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
184 }
185