xref: /llvm-project/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp (revision b5cdb039712d0c24b0d10c96b6b6d52456088f84)
1 //===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This pass transforms all Call or Invoke instructions that are annotated
11 // "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12 // The frame of the callee coroutine is allocated inside the caller. A pointer
13 // to the allocated frame will be passed into the `.noalloc` ramp function.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18 
19 #include "llvm/Analysis/LazyCallGraph.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/IR/Analysis.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/InstIterator.h"
24 #include "llvm/IR/Instruction.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/PassManager.h"
27 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
28 
29 #include <cassert>
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "coro-annotation-elide"
34 
35 static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
36   for (Instruction &I : F->getEntryBlock())
37     if (!isa<AllocaInst>(&I))
38       return &I;
39   llvm_unreachable("no terminator in the entry block");
40 }
41 
42 // Create an alloca in the caller, using FrameSize and FrameAlign as the callee
43 // coroutine's activation frame.
44 static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
45                                     Align FrameAlign) {
46   LLVMContext &C = Caller->getContext();
47   BasicBlock::iterator InsertPt =
48       getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
49   const DataLayout &DL = Caller->getDataLayout();
50   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
51   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
52   Frame->setAlignment(FrameAlign);
53   return Frame;
54 }
55 
56 // Given a call or invoke instruction to the elide safe coroutine, this function
57 // does the following:
58 //  - Allocate a frame for the callee coroutine in the caller using alloca.
59 //  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
60 //    pointer to the frame as an additional argument to NewCallee.
61 static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
62                         uint64_t FrameSize, Align FrameAlign) {
63   // TODO: generate the lifetime intrinsics for the new frame. This will require
64   // introduction of two pesudo lifetime intrinsics in the frontend around the
65   // `co_await` expression and convert them to real lifetime intrinsics here.
66   auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
67   auto NewCBInsertPt = CB->getIterator();
68   llvm::CallBase *NewCB = nullptr;
69   SmallVector<Value *, 4> NewArgs;
70   NewArgs.append(CB->arg_begin(), CB->arg_end());
71   NewArgs.push_back(FramePtr);
72 
73   if (auto *CI = dyn_cast<CallInst>(CB)) {
74     auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
75                                    NewArgs, "", NewCBInsertPt);
76     NewCI->setTailCallKind(CI->getTailCallKind());
77     NewCB = NewCI;
78   } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
79     NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
80                                II->getNormalDest(), II->getUnwindDest(),
81                                NewArgs, {}, "", NewCBInsertPt);
82   } else {
83     llvm_unreachable("CallBase should either be Call or Invoke!");
84   }
85 
86   NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
87   NewCB->setCallingConv(CB->getCallingConv());
88   NewCB->setAttributes(CB->getAttributes());
89   NewCB->setDebugLoc(CB->getDebugLoc());
90   std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
91             NewCB->bundle_op_info_begin());
92 
93   NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
94   CB->replaceAllUsesWith(NewCB);
95   CB->eraseFromParent();
96 }
97 
98 PreservedAnalyses CoroAnnotationElidePass::run(Function &F,
99                                                FunctionAnalysisManager &FAM) {
100   bool Changed = false;
101 
102   Function *NewCallee =
103       F.getParent()->getFunction((F.getName() + ".noalloc").str());
104 
105   if (!NewCallee)
106     return PreservedAnalyses::all();
107 
108   auto FramePtrArgPosition = NewCallee->arg_size() - 1;
109   auto FrameSize = NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
110   auto FrameAlign = NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
111 
112   SmallVector<CallBase *, 4> Users;
113   for (auto *U : F.users()) {
114     if (auto *CB = dyn_cast<CallBase>(U)) {
115       if (CB->getCalledFunction() == &F)
116         Users.push_back(CB);
117     }
118   }
119 
120   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
121 
122   for (auto *CB : Users) {
123     auto *Caller = CB->getFunction();
124     if (!Caller)
125       continue;
126 
127     bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
128     bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
129     if (IsCallerPresplitCoroutine && HasAttr) {
130       processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
131 
132       ORE.emit([&]() {
133         return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
134                << "'" << ore::NV("callee", F.getName()) << "' elided in '"
135                << ore::NV("caller", Caller->getName()) << "'";
136       });
137 
138       FAM.invalidate(*Caller, PreservedAnalyses::none());
139       Changed = true;
140     } else {
141       ORE.emit([&]() {
142         return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
143                                         Caller)
144                << "'" << ore::NV("callee", F.getName()) << "' not elided in '"
145                << ore::NV("caller", Caller->getName()) << "' (caller_presplit="
146                << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
147                << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
148                << ")";
149       });
150     }
151   }
152 
153   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
154 }
155