1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "llvm/Transforms/Coroutines.h"
14 #include "CoroInstr.h"
15 #include "CoroInternal.h"
16 #include "llvm-c/Transforms/Coroutines.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/IR/Attributes.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/LegacyPassManager.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Support/Casting.h"
34 #include "llvm/Support/ErrorHandling.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
37 #include "llvm/Transforms/Utils/Local.h"
38 #include <cassert>
39 #include <cstddef>
40 #include <utility>
41
42 using namespace llvm;
43
initializeCoroutines(PassRegistry & Registry)44 void llvm::initializeCoroutines(PassRegistry &Registry) {
45 initializeCoroEarlyLegacyPass(Registry);
46 initializeCoroSplitLegacyPass(Registry);
47 initializeCoroElideLegacyPass(Registry);
48 initializeCoroCleanupLegacyPass(Registry);
49 }
50
addCoroutineOpt0Passes(const PassManagerBuilder & Builder,legacy::PassManagerBase & PM)51 static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder,
52 legacy::PassManagerBase &PM) {
53 PM.add(createCoroSplitLegacyPass());
54 PM.add(createCoroElideLegacyPass());
55
56 PM.add(createBarrierNoopPass());
57 PM.add(createCoroCleanupLegacyPass());
58 }
59
addCoroutineEarlyPasses(const PassManagerBuilder & Builder,legacy::PassManagerBase & PM)60 static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder,
61 legacy::PassManagerBase &PM) {
62 PM.add(createCoroEarlyLegacyPass());
63 }
64
addCoroutineScalarOptimizerPasses(const PassManagerBuilder & Builder,legacy::PassManagerBase & PM)65 static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder,
66 legacy::PassManagerBase &PM) {
67 PM.add(createCoroElideLegacyPass());
68 }
69
addCoroutineSCCPasses(const PassManagerBuilder & Builder,legacy::PassManagerBase & PM)70 static void addCoroutineSCCPasses(const PassManagerBuilder &Builder,
71 legacy::PassManagerBase &PM) {
72 PM.add(createCoroSplitLegacyPass(Builder.OptLevel != 0));
73 }
74
addCoroutineOptimizerLastPasses(const PassManagerBuilder & Builder,legacy::PassManagerBase & PM)75 static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder,
76 legacy::PassManagerBase &PM) {
77 PM.add(createCoroCleanupLegacyPass());
78 }
79
addCoroutinePassesToExtensionPoints(PassManagerBuilder & Builder)80 void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
81 Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible,
82 addCoroutineEarlyPasses);
83 Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
84 addCoroutineOpt0Passes);
85 Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate,
86 addCoroutineSCCPasses);
87 Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate,
88 addCoroutineScalarOptimizerPasses);
89 Builder.addExtension(PassManagerBuilder::EP_OptimizerLast,
90 addCoroutineOptimizerLastPasses);
91 }
92
93 // Construct the lowerer base class and initialize its members.
LowererBase(Module & M)94 coro::LowererBase::LowererBase(Module &M)
95 : TheModule(M), Context(M.getContext()),
96 Int8Ptr(Type::getInt8PtrTy(Context)),
97 ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
98 /*isVarArg=*/false)),
99 NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
100
101 // Creates a sequence of instructions to obtain a resume function address using
102 // llvm.coro.subfn.addr. It generates the following sequence:
103 //
104 // call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
105 // bitcast i8* %2 to void(i8*)*
106
makeSubFnCall(Value * Arg,int Index,Instruction * InsertPt)107 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
108 Instruction *InsertPt) {
109 auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
110 auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
111
112 assert(Index >= CoroSubFnInst::IndexFirst &&
113 Index < CoroSubFnInst::IndexLast &&
114 "makeSubFnCall: Index value out of range");
115 auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
116
117 auto *Bitcast =
118 new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
119 return Bitcast;
120 }
121
122 #ifndef NDEBUG
isCoroutineIntrinsicName(StringRef Name)123 static bool isCoroutineIntrinsicName(StringRef Name) {
124 // NOTE: Must be sorted!
125 static const char *const CoroIntrinsics[] = {
126 "llvm.coro.alloc",
127 "llvm.coro.async.context.alloc",
128 "llvm.coro.async.context.dealloc",
129 "llvm.coro.async.size.replace",
130 "llvm.coro.async.store_resume",
131 "llvm.coro.begin",
132 "llvm.coro.destroy",
133 "llvm.coro.done",
134 "llvm.coro.end",
135 "llvm.coro.end.async",
136 "llvm.coro.frame",
137 "llvm.coro.free",
138 "llvm.coro.id",
139 "llvm.coro.id.async",
140 "llvm.coro.id.retcon",
141 "llvm.coro.id.retcon.once",
142 "llvm.coro.noop",
143 "llvm.coro.param",
144 "llvm.coro.prepare.async",
145 "llvm.coro.prepare.retcon",
146 "llvm.coro.promise",
147 "llvm.coro.resume",
148 "llvm.coro.save",
149 "llvm.coro.size",
150 "llvm.coro.subfn.addr",
151 "llvm.coro.suspend",
152 "llvm.coro.suspend.async",
153 "llvm.coro.suspend.retcon",
154 };
155 return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
156 }
157 #endif
158
159 // Verifies if a module has named values listed. Also, in debug mode verifies
160 // that names are intrinsic names.
declaresIntrinsics(const Module & M,const std::initializer_list<StringRef> List)161 bool coro::declaresIntrinsics(const Module &M,
162 const std::initializer_list<StringRef> List) {
163 for (StringRef Name : List) {
164 assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
165 if (M.getNamedValue(Name))
166 return true;
167 }
168
169 return false;
170 }
171
172 // Replace all coro.frees associated with the provided CoroId either with 'null'
173 // if Elide is true and with its frame parameter otherwise.
replaceCoroFree(CoroIdInst * CoroId,bool Elide)174 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
175 SmallVector<CoroFreeInst *, 4> CoroFrees;
176 for (User *U : CoroId->users())
177 if (auto CF = dyn_cast<CoroFreeInst>(U))
178 CoroFrees.push_back(CF);
179
180 if (CoroFrees.empty())
181 return;
182
183 Value *Replacement =
184 Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
185 : CoroFrees.front()->getFrame();
186
187 for (CoroFreeInst *CF : CoroFrees) {
188 CF->replaceAllUsesWith(Replacement);
189 CF->eraseFromParent();
190 }
191 }
192
193 // FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
194 // happens to be private. It is better for this functionality exposed by the
195 // CallGraph.
buildCGN(CallGraph & CG,CallGraphNode * Node)196 static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
197 Function *F = Node->getFunction();
198
199 // Look for calls by this function.
200 for (Instruction &I : instructions(F))
201 if (auto *Call = dyn_cast<CallBase>(&I)) {
202 const Function *Callee = Call->getCalledFunction();
203 if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
204 // Indirect calls of intrinsics are not allowed so no need to check.
205 // We can be more precise here by using TargetArg returned by
206 // Intrinsic::isLeaf.
207 Node->addCalledFunction(Call, CG.getCallsExternalNode());
208 else if (!Callee->isIntrinsic())
209 Node->addCalledFunction(Call, CG.getOrInsertFunction(Callee));
210 }
211 }
212
213 // Rebuild CGN after we extracted parts of the code from ParentFunc into
214 // NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
updateCallGraph(Function & ParentFunc,ArrayRef<Function * > NewFuncs,CallGraph & CG,CallGraphSCC & SCC)215 void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
216 CallGraph &CG, CallGraphSCC &SCC) {
217 // Rebuild CGN from scratch for the ParentFunc
218 auto *ParentNode = CG[&ParentFunc];
219 ParentNode->removeAllCalledFunctions();
220 buildCGN(CG, ParentNode);
221
222 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
223
224 for (Function *F : NewFuncs) {
225 CallGraphNode *Callee = CG.getOrInsertFunction(F);
226 Nodes.push_back(Callee);
227 buildCGN(CG, Callee);
228 }
229
230 SCC.initialize(Nodes);
231 }
232
clear(coro::Shape & Shape)233 static void clear(coro::Shape &Shape) {
234 Shape.CoroBegin = nullptr;
235 Shape.CoroEnds.clear();
236 Shape.CoroSizes.clear();
237 Shape.CoroSuspends.clear();
238
239 Shape.FrameTy = nullptr;
240 Shape.FramePtr = nullptr;
241 Shape.AllocaSpillBlock = nullptr;
242 }
243
createCoroSave(CoroBeginInst * CoroBegin,CoroSuspendInst * SuspendInst)244 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
245 CoroSuspendInst *SuspendInst) {
246 Module *M = SuspendInst->getModule();
247 auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
248 auto *SaveInst =
249 cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
250 assert(!SuspendInst->getCoroSave());
251 SuspendInst->setArgOperand(0, SaveInst);
252 return SaveInst;
253 }
254
255 // Collect "interesting" coroutine intrinsics.
buildFrom(Function & F)256 void coro::Shape::buildFrom(Function &F) {
257 bool HasFinalSuspend = false;
258 size_t FinalSuspendIndex = 0;
259 clear(*this);
260 SmallVector<CoroFrameInst *, 8> CoroFrames;
261 SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
262
263 for (Instruction &I : instructions(F)) {
264 if (auto II = dyn_cast<IntrinsicInst>(&I)) {
265 switch (II->getIntrinsicID()) {
266 default:
267 continue;
268 case Intrinsic::coro_size:
269 CoroSizes.push_back(cast<CoroSizeInst>(II));
270 break;
271 case Intrinsic::coro_frame:
272 CoroFrames.push_back(cast<CoroFrameInst>(II));
273 break;
274 case Intrinsic::coro_save:
275 // After optimizations, coro_suspends using this coro_save might have
276 // been removed, remember orphaned coro_saves to remove them later.
277 if (II->use_empty())
278 UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
279 break;
280 case Intrinsic::coro_suspend_async: {
281 auto *Suspend = cast<CoroSuspendAsyncInst>(II);
282 Suspend->checkWellFormed();
283 CoroSuspends.push_back(Suspend);
284 break;
285 }
286 case Intrinsic::coro_suspend_retcon: {
287 auto Suspend = cast<CoroSuspendRetconInst>(II);
288 CoroSuspends.push_back(Suspend);
289 break;
290 }
291 case Intrinsic::coro_suspend: {
292 auto Suspend = cast<CoroSuspendInst>(II);
293 CoroSuspends.push_back(Suspend);
294 if (Suspend->isFinal()) {
295 if (HasFinalSuspend)
296 report_fatal_error(
297 "Only one suspend point can be marked as final");
298 HasFinalSuspend = true;
299 FinalSuspendIndex = CoroSuspends.size() - 1;
300 }
301 break;
302 }
303 case Intrinsic::coro_begin: {
304 auto CB = cast<CoroBeginInst>(II);
305
306 // Ignore coro id's that aren't pre-split.
307 auto Id = dyn_cast<CoroIdInst>(CB->getId());
308 if (Id && !Id->getInfo().isPreSplit())
309 break;
310
311 if (CoroBegin)
312 report_fatal_error(
313 "coroutine should have exactly one defining @llvm.coro.begin");
314 CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
315 CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
316 CB->removeAttribute(AttributeList::FunctionIndex,
317 Attribute::NoDuplicate);
318 CoroBegin = CB;
319 break;
320 }
321 case Intrinsic::coro_end_async:
322 case Intrinsic::coro_end:
323 CoroEnds.push_back(cast<AnyCoroEndInst>(II));
324 if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
325 AsyncEnd->checkWellFormed();
326 }
327 if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
328 // Make sure that the fallthrough coro.end is the first element in the
329 // CoroEnds vector.
330 // Note: I don't think this is neccessary anymore.
331 if (CoroEnds.size() > 1) {
332 if (CoroEnds.front()->isFallthrough())
333 report_fatal_error(
334 "Only one coro.end can be marked as fallthrough");
335 std::swap(CoroEnds.front(), CoroEnds.back());
336 }
337 }
338 break;
339 }
340 }
341 }
342
343 // If for some reason, we were not able to find coro.begin, bailout.
344 if (!CoroBegin) {
345 // Replace coro.frame which are supposed to be lowered to the result of
346 // coro.begin with undef.
347 auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
348 for (CoroFrameInst *CF : CoroFrames) {
349 CF->replaceAllUsesWith(Undef);
350 CF->eraseFromParent();
351 }
352
353 // Replace all coro.suspend with undef and remove related coro.saves if
354 // present.
355 for (AnyCoroSuspendInst *CS : CoroSuspends) {
356 CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
357 CS->eraseFromParent();
358 if (auto *CoroSave = CS->getCoroSave())
359 CoroSave->eraseFromParent();
360 }
361
362 // Replace all coro.ends with unreachable instruction.
363 for (AnyCoroEndInst *CE : CoroEnds)
364 changeToUnreachable(CE, /*UseLLVMTrap=*/false);
365
366 return;
367 }
368
369 auto Id = CoroBegin->getId();
370 switch (auto IdIntrinsic = Id->getIntrinsicID()) {
371 case Intrinsic::coro_id: {
372 auto SwitchId = cast<CoroIdInst>(Id);
373 this->ABI = coro::ABI::Switch;
374 this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
375 this->SwitchLowering.ResumeSwitch = nullptr;
376 this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
377 this->SwitchLowering.ResumeEntryBlock = nullptr;
378
379 for (auto AnySuspend : CoroSuspends) {
380 auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
381 if (!Suspend) {
382 #ifndef NDEBUG
383 AnySuspend->dump();
384 #endif
385 report_fatal_error("coro.id must be paired with coro.suspend");
386 }
387
388 if (!Suspend->getCoroSave())
389 createCoroSave(CoroBegin, Suspend);
390 }
391 break;
392 }
393 case Intrinsic::coro_id_async: {
394 auto *AsyncId = cast<CoroIdAsyncInst>(Id);
395 AsyncId->checkWellFormed();
396 this->ABI = coro::ABI::Async;
397 this->AsyncLowering.Context = AsyncId->getStorage();
398 this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
399 this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
400 this->AsyncLowering.ContextAlignment =
401 AsyncId->getStorageAlignment().value();
402 this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
403 this->AsyncLowering.AsyncCC = F.getCallingConv();
404 break;
405 };
406 case Intrinsic::coro_id_retcon:
407 case Intrinsic::coro_id_retcon_once: {
408 auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
409 ContinuationId->checkWellFormed();
410 this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
411 ? coro::ABI::Retcon
412 : coro::ABI::RetconOnce);
413 auto Prototype = ContinuationId->getPrototype();
414 this->RetconLowering.ResumePrototype = Prototype;
415 this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
416 this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
417 this->RetconLowering.ReturnBlock = nullptr;
418 this->RetconLowering.IsFrameInlineInStorage = false;
419
420 // Determine the result value types, and make sure they match up with
421 // the values passed to the suspends.
422 auto ResultTys = getRetconResultTypes();
423 auto ResumeTys = getRetconResumeTypes();
424
425 for (auto AnySuspend : CoroSuspends) {
426 auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
427 if (!Suspend) {
428 #ifndef NDEBUG
429 AnySuspend->dump();
430 #endif
431 report_fatal_error("coro.id.retcon.* must be paired with "
432 "coro.suspend.retcon");
433 }
434
435 // Check that the argument types of the suspend match the results.
436 auto SI = Suspend->value_begin(), SE = Suspend->value_end();
437 auto RI = ResultTys.begin(), RE = ResultTys.end();
438 for (; SI != SE && RI != RE; ++SI, ++RI) {
439 auto SrcTy = (*SI)->getType();
440 if (SrcTy != *RI) {
441 // The optimizer likes to eliminate bitcasts leading into variadic
442 // calls, but that messes with our invariants. Re-insert the
443 // bitcast and ignore this type mismatch.
444 if (CastInst::isBitCastable(SrcTy, *RI)) {
445 auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
446 SI->set(BCI);
447 continue;
448 }
449
450 #ifndef NDEBUG
451 Suspend->dump();
452 Prototype->getFunctionType()->dump();
453 #endif
454 report_fatal_error("argument to coro.suspend.retcon does not "
455 "match corresponding prototype function result");
456 }
457 }
458 if (SI != SE || RI != RE) {
459 #ifndef NDEBUG
460 Suspend->dump();
461 Prototype->getFunctionType()->dump();
462 #endif
463 report_fatal_error("wrong number of arguments to coro.suspend.retcon");
464 }
465
466 // Check that the result type of the suspend matches the resume types.
467 Type *SResultTy = Suspend->getType();
468 ArrayRef<Type*> SuspendResultTys;
469 if (SResultTy->isVoidTy()) {
470 // leave as empty array
471 } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
472 SuspendResultTys = SResultStructTy->elements();
473 } else {
474 // forms an ArrayRef using SResultTy, be careful
475 SuspendResultTys = SResultTy;
476 }
477 if (SuspendResultTys.size() != ResumeTys.size()) {
478 #ifndef NDEBUG
479 Suspend->dump();
480 Prototype->getFunctionType()->dump();
481 #endif
482 report_fatal_error("wrong number of results from coro.suspend.retcon");
483 }
484 for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
485 if (SuspendResultTys[I] != ResumeTys[I]) {
486 #ifndef NDEBUG
487 Suspend->dump();
488 Prototype->getFunctionType()->dump();
489 #endif
490 report_fatal_error("result from coro.suspend.retcon does not "
491 "match corresponding prototype function param");
492 }
493 }
494 }
495 break;
496 }
497
498 default:
499 llvm_unreachable("coro.begin is not dependent on a coro.id call");
500 }
501
502 // The coro.free intrinsic is always lowered to the result of coro.begin.
503 for (CoroFrameInst *CF : CoroFrames) {
504 CF->replaceAllUsesWith(CoroBegin);
505 CF->eraseFromParent();
506 }
507
508 // Move final suspend to be the last element in the CoroSuspends vector.
509 if (ABI == coro::ABI::Switch &&
510 SwitchLowering.HasFinalSuspend &&
511 FinalSuspendIndex != CoroSuspends.size() - 1)
512 std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
513
514 // Remove orphaned coro.saves.
515 for (CoroSaveInst *CoroSave : UnusedCoroSaves)
516 CoroSave->eraseFromParent();
517 }
518
propagateCallAttrsFromCallee(CallInst * Call,Function * Callee)519 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
520 Call->setCallingConv(Callee->getCallingConv());
521 // TODO: attributes?
522 }
523
addCallToCallGraph(CallGraph * CG,CallInst * Call,Function * Callee)524 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
525 if (CG)
526 (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
527 }
528
emitAlloc(IRBuilder<> & Builder,Value * Size,CallGraph * CG) const529 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
530 CallGraph *CG) const {
531 switch (ABI) {
532 case coro::ABI::Switch:
533 llvm_unreachable("can't allocate memory in coro switch-lowering");
534
535 case coro::ABI::Retcon:
536 case coro::ABI::RetconOnce: {
537 auto Alloc = RetconLowering.Alloc;
538 Size = Builder.CreateIntCast(Size,
539 Alloc->getFunctionType()->getParamType(0),
540 /*is signed*/ false);
541 auto *Call = Builder.CreateCall(Alloc, Size);
542 propagateCallAttrsFromCallee(Call, Alloc);
543 addCallToCallGraph(CG, Call, Alloc);
544 return Call;
545 }
546 case coro::ABI::Async:
547 llvm_unreachable("can't allocate memory in coro async-lowering");
548 }
549 llvm_unreachable("Unknown coro::ABI enum");
550 }
551
emitDealloc(IRBuilder<> & Builder,Value * Ptr,CallGraph * CG) const552 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
553 CallGraph *CG) const {
554 switch (ABI) {
555 case coro::ABI::Switch:
556 llvm_unreachable("can't allocate memory in coro switch-lowering");
557
558 case coro::ABI::Retcon:
559 case coro::ABI::RetconOnce: {
560 auto Dealloc = RetconLowering.Dealloc;
561 Ptr = Builder.CreateBitCast(Ptr,
562 Dealloc->getFunctionType()->getParamType(0));
563 auto *Call = Builder.CreateCall(Dealloc, Ptr);
564 propagateCallAttrsFromCallee(Call, Dealloc);
565 addCallToCallGraph(CG, Call, Dealloc);
566 return;
567 }
568 case coro::ABI::Async:
569 llvm_unreachable("can't allocate memory in coro async-lowering");
570 }
571 llvm_unreachable("Unknown coro::ABI enum");
572 }
573
574 LLVM_ATTRIBUTE_NORETURN
fail(const Instruction * I,const char * Reason,Value * V)575 static void fail(const Instruction *I, const char *Reason, Value *V) {
576 #ifndef NDEBUG
577 I->dump();
578 if (V) {
579 errs() << " Value: ";
580 V->printAsOperand(llvm::errs());
581 errs() << '\n';
582 }
583 #endif
584 report_fatal_error(Reason);
585 }
586
587 /// Check that the given value is a well-formed prototype for the
588 /// llvm.coro.id.retcon.* intrinsics.
checkWFRetconPrototype(const AnyCoroIdRetconInst * I,Value * V)589 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
590 auto F = dyn_cast<Function>(V->stripPointerCasts());
591 if (!F)
592 fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
593
594 auto FT = F->getFunctionType();
595
596 if (isa<CoroIdRetconInst>(I)) {
597 bool ResultOkay;
598 if (FT->getReturnType()->isPointerTy()) {
599 ResultOkay = true;
600 } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
601 ResultOkay = (!SRetTy->isOpaque() &&
602 SRetTy->getNumElements() > 0 &&
603 SRetTy->getElementType(0)->isPointerTy());
604 } else {
605 ResultOkay = false;
606 }
607 if (!ResultOkay)
608 fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
609 "result", F);
610
611 if (FT->getReturnType() !=
612 I->getFunction()->getFunctionType()->getReturnType())
613 fail(I, "llvm.coro.id.retcon prototype return type must be same as"
614 "current function return type", F);
615 } else {
616 // No meaningful validation to do here for llvm.coro.id.unique.once.
617 }
618
619 if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
620 fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
621 "its first parameter", F);
622 }
623
624 /// Check that the given value is a well-formed allocator.
checkWFAlloc(const Instruction * I,Value * V)625 static void checkWFAlloc(const Instruction *I, Value *V) {
626 auto F = dyn_cast<Function>(V->stripPointerCasts());
627 if (!F)
628 fail(I, "llvm.coro.* allocator not a Function", V);
629
630 auto FT = F->getFunctionType();
631 if (!FT->getReturnType()->isPointerTy())
632 fail(I, "llvm.coro.* allocator must return a pointer", F);
633
634 if (FT->getNumParams() != 1 ||
635 !FT->getParamType(0)->isIntegerTy())
636 fail(I, "llvm.coro.* allocator must take integer as only param", F);
637 }
638
639 /// Check that the given value is a well-formed deallocator.
checkWFDealloc(const Instruction * I,Value * V)640 static void checkWFDealloc(const Instruction *I, Value *V) {
641 auto F = dyn_cast<Function>(V->stripPointerCasts());
642 if (!F)
643 fail(I, "llvm.coro.* deallocator not a Function", V);
644
645 auto FT = F->getFunctionType();
646 if (!FT->getReturnType()->isVoidTy())
647 fail(I, "llvm.coro.* deallocator must return void", F);
648
649 if (FT->getNumParams() != 1 ||
650 !FT->getParamType(0)->isPointerTy())
651 fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
652 }
653
checkConstantInt(const Instruction * I,Value * V,const char * Reason)654 static void checkConstantInt(const Instruction *I, Value *V,
655 const char *Reason) {
656 if (!isa<ConstantInt>(V)) {
657 fail(I, Reason, V);
658 }
659 }
660
checkWellFormed() const661 void AnyCoroIdRetconInst::checkWellFormed() const {
662 checkConstantInt(this, getArgOperand(SizeArg),
663 "size argument to coro.id.retcon.* must be constant");
664 checkConstantInt(this, getArgOperand(AlignArg),
665 "alignment argument to coro.id.retcon.* must be constant");
666 checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
667 checkWFAlloc(this, getArgOperand(AllocArg));
668 checkWFDealloc(this, getArgOperand(DeallocArg));
669 }
670
checkAsyncFuncPointer(const Instruction * I,Value * V)671 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
672 auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
673 if (!AsyncFuncPtrAddr)
674 fail(I, "llvm.coro.id.async async function pointer not a global", V);
675
676 auto *StructTy =
677 cast<StructType>(AsyncFuncPtrAddr->getType()->getPointerElementType());
678 if (StructTy->isOpaque() || !StructTy->isPacked() ||
679 StructTy->getNumElements() != 2 ||
680 !StructTy->getElementType(0)->isIntegerTy(32) ||
681 !StructTy->getElementType(1)->isIntegerTy(32))
682 fail(I,
683 "llvm.coro.id.async async function pointer argument's type is not "
684 "<{i32, i32}>",
685 V);
686 }
687
checkWellFormed() const688 void CoroIdAsyncInst::checkWellFormed() const {
689 checkConstantInt(this, getArgOperand(SizeArg),
690 "size argument to coro.id.async must be constant");
691 checkConstantInt(this, getArgOperand(AlignArg),
692 "alignment argument to coro.id.async must be constant");
693 checkConstantInt(this, getArgOperand(StorageArg),
694 "storage argument offset to coro.id.async must be constant");
695 checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
696 }
697
checkAsyncContextProjectFunction(const Instruction * I,Function * F)698 static void checkAsyncContextProjectFunction(const Instruction *I,
699 Function *F) {
700 auto *FunTy = cast<FunctionType>(F->getType()->getPointerElementType());
701 if (!FunTy->getReturnType()->isPointerTy() ||
702 !FunTy->getReturnType()->getPointerElementType()->isIntegerTy(8))
703 fail(I,
704 "llvm.coro.suspend.async resume function projection function must "
705 "return an i8* type",
706 F);
707 if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy() ||
708 !FunTy->getParamType(0)->getPointerElementType()->isIntegerTy(8))
709 fail(I,
710 "llvm.coro.suspend.async resume function projection function must "
711 "take one i8* type as parameter",
712 F);
713 }
714
checkWellFormed() const715 void CoroSuspendAsyncInst::checkWellFormed() const {
716 checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
717 }
718
checkWellFormed() const719 void CoroAsyncEndInst::checkWellFormed() const {
720 auto *MustTailCallFunc = getMustTailCallFunction();
721 if (!MustTailCallFunc)
722 return;
723 auto *FnTy =
724 cast<FunctionType>(MustTailCallFunc->getType()->getPointerElementType());
725 if (FnTy->getNumParams() != (getNumArgOperands() - 3))
726 fail(this,
727 "llvm.coro.end.async must tail call function argument type must "
728 "match the tail arguments",
729 MustTailCallFunc);
730 }
731
LLVMAddCoroEarlyPass(LLVMPassManagerRef PM)732 void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
733 unwrap(PM)->add(createCoroEarlyLegacyPass());
734 }
735
LLVMAddCoroSplitPass(LLVMPassManagerRef PM)736 void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) {
737 unwrap(PM)->add(createCoroSplitLegacyPass());
738 }
739
LLVMAddCoroElidePass(LLVMPassManagerRef PM)740 void LLVMAddCoroElidePass(LLVMPassManagerRef PM) {
741 unwrap(PM)->add(createCoroElideLegacyPass());
742 }
743
LLVMAddCoroCleanupPass(LLVMPassManagerRef PM)744 void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) {
745 unwrap(PM)->add(createCoroCleanupLegacyPass());
746 }
747
748 void
LLVMPassManagerBuilderAddCoroutinePassesToExtensionPoints(LLVMPassManagerBuilderRef PMB)749 LLVMPassManagerBuilderAddCoroutinePassesToExtensionPoints(LLVMPassManagerBuilderRef PMB) {
750 PassManagerBuilder *Builder = unwrap(PMB);
751 addCoroutinePassesToExtensionPoints(*Builder);
752 }
753