xref: /llvm-project/llvm/lib/Transforms/Coroutines/CoroAnnotationElide.cpp (revision b02e5bc5b1be9d94689ebe1cf1244b7da540fb19)
1 //===- CoroAnnotationElide.cpp - Elide attributed safe coroutine calls ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This pass transforms all Call or Invoke instructions that are annotated
11 // "coro_elide_safe" to call the `.noalloc` variant of coroutine instead.
12 // The frame of the callee coroutine is allocated inside the caller. A pointer
13 // to the allocated frame will be passed into the `.noalloc` ramp function.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/Coroutines/CoroAnnotationElide.h"
18 
19 #include "llvm/Analysis/LazyCallGraph.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/IR/Analysis.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Instruction.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 
27 #include <cassert>
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "coro-annotation-elide"
32 
33 static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) {
34   for (Instruction &I : F->getEntryBlock())
35     if (!isa<AllocaInst>(&I))
36       return &I;
37   llvm_unreachable("no terminator in the entry block");
38 }
39 
40 // Create an alloca in the caller, using FrameSize and FrameAlign as the callee
41 // coroutine's activation frame.
42 static Value *allocateFrameInCaller(Function *Caller, uint64_t FrameSize,
43                                     Align FrameAlign) {
44   LLVMContext &C = Caller->getContext();
45   BasicBlock::iterator InsertPt =
46       getFirstNonAllocaInTheEntryBlock(Caller)->getIterator();
47   const DataLayout &DL = Caller->getDataLayout();
48   auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize);
49   auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt);
50   Frame->setAlignment(FrameAlign);
51   return Frame;
52 }
53 
54 // Given a call or invoke instruction to the elide safe coroutine, this function
55 // does the following:
56 //  - Allocate a frame for the callee coroutine in the caller using alloca.
57 //  - Replace the old CB with a new Call or Invoke to `NewCallee`, with the
58 //    pointer to the frame as an additional argument to NewCallee.
59 static void processCall(CallBase *CB, Function *Caller, Function *NewCallee,
60                         uint64_t FrameSize, Align FrameAlign) {
61   // TODO: generate the lifetime intrinsics for the new frame. This will require
62   // introduction of two pesudo lifetime intrinsics in the frontend around the
63   // `co_await` expression and convert them to real lifetime intrinsics here.
64   auto *FramePtr = allocateFrameInCaller(Caller, FrameSize, FrameAlign);
65   auto NewCBInsertPt = CB->getIterator();
66   llvm::CallBase *NewCB = nullptr;
67   SmallVector<Value *, 4> NewArgs;
68   NewArgs.append(CB->arg_begin(), CB->arg_end());
69   NewArgs.push_back(FramePtr);
70 
71   if (auto *CI = dyn_cast<CallInst>(CB)) {
72     auto *NewCI = CallInst::Create(NewCallee->getFunctionType(), NewCallee,
73                                    NewArgs, "", NewCBInsertPt);
74     NewCI->setTailCallKind(CI->getTailCallKind());
75     NewCB = NewCI;
76   } else if (auto *II = dyn_cast<InvokeInst>(CB)) {
77     NewCB = InvokeInst::Create(NewCallee->getFunctionType(), NewCallee,
78                                II->getNormalDest(), II->getUnwindDest(),
79                                NewArgs, {}, "", NewCBInsertPt);
80   } else {
81     llvm_unreachable("CallBase should either be Call or Invoke!");
82   }
83 
84   NewCB->setCalledFunction(NewCallee->getFunctionType(), NewCallee);
85   NewCB->setCallingConv(CB->getCallingConv());
86   NewCB->setAttributes(CB->getAttributes());
87   NewCB->setDebugLoc(CB->getDebugLoc());
88   std::copy(CB->bundle_op_info_begin(), CB->bundle_op_info_end(),
89             NewCB->bundle_op_info_begin());
90 
91   NewCB->removeFnAttr(llvm::Attribute::CoroElideSafe);
92   CB->replaceAllUsesWith(NewCB);
93   CB->eraseFromParent();
94 }
95 
96 PreservedAnalyses CoroAnnotationElidePass::run(Function &F,
97                                                FunctionAnalysisManager &FAM) {
98   bool Changed = false;
99 
100   Function *NewCallee =
101       F.getParent()->getFunction((F.getName() + ".noalloc").str());
102 
103   if (!NewCallee)
104     return PreservedAnalyses::all();
105 
106   auto FramePtrArgPosition = NewCallee->arg_size() - 1;
107   auto FrameSize = NewCallee->getParamDereferenceableBytes(FramePtrArgPosition);
108   auto FrameAlign = NewCallee->getParamAlign(FramePtrArgPosition).valueOrOne();
109 
110   SmallVector<CallBase *, 4> Users;
111   for (auto *U : F.users()) {
112     if (auto *CB = dyn_cast<CallBase>(U)) {
113       if (CB->getCalledFunction() == &F)
114         Users.push_back(CB);
115     }
116   }
117 
118   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
119 
120   for (auto *CB : Users) {
121     auto *Caller = CB->getFunction();
122     if (!Caller)
123       continue;
124 
125     bool IsCallerPresplitCoroutine = Caller->isPresplitCoroutine();
126     bool HasAttr = CB->hasFnAttr(llvm::Attribute::CoroElideSafe);
127     if (IsCallerPresplitCoroutine && HasAttr) {
128       processCall(CB, Caller, NewCallee, FrameSize, FrameAlign);
129 
130       ORE.emit([&]() {
131         return OptimizationRemark(DEBUG_TYPE, "CoroAnnotationElide", Caller)
132                << "'" << ore::NV("callee", F.getName()) << "' elided in '"
133                << ore::NV("caller", Caller->getName()) << "'";
134       });
135 
136       FAM.invalidate(*Caller, PreservedAnalyses::none());
137       Changed = true;
138     } else {
139       ORE.emit([&]() {
140         return OptimizationRemarkMissed(DEBUG_TYPE, "CoroAnnotationElide",
141                                         Caller)
142                << "'" << ore::NV("callee", F.getName()) << "' not elided in '"
143                << ore::NV("caller", Caller->getName()) << "' (caller_presplit="
144                << ore::NV("caller_presplit", IsCallerPresplitCoroutine)
145                << ", elide_safe_attr=" << ore::NV("elide_safe_attr", HasAttr)
146                << ")";
147       });
148     }
149   }
150 
151   return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
152 }
153