xref: /llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp (revision 2fe81edef6f0b35abffbbc59b649b30ea9c15a62)
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInstr.h"
14 #include "CoroInternal.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Analysis/CallGraph.h"
18 #include "llvm/IR/Attributes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include <cassert>
32 #include <cstddef>
33 #include <utility>
34 
35 using namespace llvm;
36 
37 // Construct the lowerer base class and initialize its members.
38 coro::LowererBase::LowererBase(Module &M)
39     : TheModule(M), Context(M.getContext()),
40       Int8Ptr(PointerType::get(Context, 0)),
41       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
42                                      /*isVarArg=*/false)),
43       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
44 
45 // Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
46 // It generates the following:
47 //
48 //    call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
49 
50 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
51                                         Instruction *InsertPt) {
52   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
53   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
54 
55   assert(Index >= CoroSubFnInst::IndexFirst &&
56          Index < CoroSubFnInst::IndexLast &&
57          "makeSubFnCall: Index value out of range");
58   return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt->getIterator());
59 }
60 
61 // NOTE: Must be sorted!
62 static const char *const CoroIntrinsics[] = {
63     "llvm.coro.align",
64     "llvm.coro.alloc",
65     "llvm.coro.async.context.alloc",
66     "llvm.coro.async.context.dealloc",
67     "llvm.coro.async.resume",
68     "llvm.coro.async.size.replace",
69     "llvm.coro.async.store_resume",
70     "llvm.coro.begin",
71     "llvm.coro.destroy",
72     "llvm.coro.done",
73     "llvm.coro.end",
74     "llvm.coro.end.async",
75     "llvm.coro.frame",
76     "llvm.coro.free",
77     "llvm.coro.id",
78     "llvm.coro.id.async",
79     "llvm.coro.id.retcon",
80     "llvm.coro.id.retcon.once",
81     "llvm.coro.noop",
82     "llvm.coro.prepare.async",
83     "llvm.coro.prepare.retcon",
84     "llvm.coro.promise",
85     "llvm.coro.resume",
86     "llvm.coro.save",
87     "llvm.coro.size",
88     "llvm.coro.subfn.addr",
89     "llvm.coro.suspend",
90     "llvm.coro.suspend.async",
91     "llvm.coro.suspend.retcon",
92 };
93 
94 #ifndef NDEBUG
95 static bool isCoroutineIntrinsicName(StringRef Name) {
96   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
97 }
98 #endif
99 
100 bool coro::declaresAnyIntrinsic(const Module &M) {
101   for (StringRef Name : CoroIntrinsics) {
102     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
103     if (M.getNamedValue(Name))
104       return true;
105   }
106 
107   return false;
108 }
109 
110 // Verifies if a module has named values listed. Also, in debug mode verifies
111 // that names are intrinsic names.
112 bool coro::declaresIntrinsics(const Module &M,
113                               const std::initializer_list<StringRef> List) {
114   for (StringRef Name : List) {
115     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
116     if (M.getNamedValue(Name))
117       return true;
118   }
119 
120   return false;
121 }
122 
123 // Replace all coro.frees associated with the provided CoroId either with 'null'
124 // if Elide is true and with its frame parameter otherwise.
125 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
126   SmallVector<CoroFreeInst *, 4> CoroFrees;
127   for (User *U : CoroId->users())
128     if (auto CF = dyn_cast<CoroFreeInst>(U))
129       CoroFrees.push_back(CF);
130 
131   if (CoroFrees.empty())
132     return;
133 
134   Value *Replacement =
135       Elide
136           ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
137           : CoroFrees.front()->getFrame();
138 
139   for (CoroFreeInst *CF : CoroFrees) {
140     CF->replaceAllUsesWith(Replacement);
141     CF->eraseFromParent();
142   }
143 }
144 
145 static void clear(coro::Shape &Shape) {
146   Shape.CoroBegin = nullptr;
147   Shape.CoroEnds.clear();
148   Shape.CoroSizes.clear();
149   Shape.CoroSuspends.clear();
150 
151   Shape.FrameTy = nullptr;
152   Shape.FramePtr = nullptr;
153   Shape.AllocaSpillBlock = nullptr;
154 }
155 
156 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
157                                     CoroSuspendInst *SuspendInst) {
158   Module *M = SuspendInst->getModule();
159   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
160   auto *SaveInst = cast<CoroSaveInst>(
161       CallInst::Create(Fn, CoroBegin, "", SuspendInst->getIterator()));
162   assert(!SuspendInst->getCoroSave());
163   SuspendInst->setArgOperand(0, SaveInst);
164   return SaveInst;
165 }
166 
167 // Collect "interesting" coroutine intrinsics.
168 void coro::Shape::buildFrom(Function &F) {
169   bool HasFinalSuspend = false;
170   bool HasUnwindCoroEnd = false;
171   size_t FinalSuspendIndex = 0;
172   clear(*this);
173   SmallVector<CoroFrameInst *, 8> CoroFrames;
174   SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
175 
176   for (Instruction &I : instructions(F)) {
177     if (auto II = dyn_cast<IntrinsicInst>(&I)) {
178       switch (II->getIntrinsicID()) {
179       default:
180         continue;
181       case Intrinsic::coro_size:
182         CoroSizes.push_back(cast<CoroSizeInst>(II));
183         break;
184       case Intrinsic::coro_align:
185         CoroAligns.push_back(cast<CoroAlignInst>(II));
186         break;
187       case Intrinsic::coro_frame:
188         CoroFrames.push_back(cast<CoroFrameInst>(II));
189         break;
190       case Intrinsic::coro_save:
191         // After optimizations, coro_suspends using this coro_save might have
192         // been removed, remember orphaned coro_saves to remove them later.
193         if (II->use_empty())
194           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
195         break;
196       case Intrinsic::coro_suspend_async: {
197         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
198         Suspend->checkWellFormed();
199         CoroSuspends.push_back(Suspend);
200         break;
201       }
202       case Intrinsic::coro_suspend_retcon: {
203         auto Suspend = cast<CoroSuspendRetconInst>(II);
204         CoroSuspends.push_back(Suspend);
205         break;
206       }
207       case Intrinsic::coro_suspend: {
208         auto Suspend = cast<CoroSuspendInst>(II);
209         CoroSuspends.push_back(Suspend);
210         if (Suspend->isFinal()) {
211           if (HasFinalSuspend)
212             report_fatal_error(
213               "Only one suspend point can be marked as final");
214           HasFinalSuspend = true;
215           FinalSuspendIndex = CoroSuspends.size() - 1;
216         }
217         break;
218       }
219       case Intrinsic::coro_begin: {
220         auto CB = cast<CoroBeginInst>(II);
221 
222         // Ignore coro id's that aren't pre-split.
223         auto Id = dyn_cast<CoroIdInst>(CB->getId());
224         if (Id && !Id->getInfo().isPreSplit())
225           break;
226 
227         if (CoroBegin)
228           report_fatal_error(
229                 "coroutine should have exactly one defining @llvm.coro.begin");
230         CB->addRetAttr(Attribute::NonNull);
231         CB->addRetAttr(Attribute::NoAlias);
232         CB->removeFnAttr(Attribute::NoDuplicate);
233         CoroBegin = CB;
234         break;
235       }
236       case Intrinsic::coro_end_async:
237       case Intrinsic::coro_end:
238         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
239         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
240           AsyncEnd->checkWellFormed();
241         }
242 
243         if (CoroEnds.back()->isUnwind())
244           HasUnwindCoroEnd = true;
245 
246         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
247           // Make sure that the fallthrough coro.end is the first element in the
248           // CoroEnds vector.
249           // Note: I don't think this is neccessary anymore.
250           if (CoroEnds.size() > 1) {
251             if (CoroEnds.front()->isFallthrough())
252               report_fatal_error(
253                   "Only one coro.end can be marked as fallthrough");
254             std::swap(CoroEnds.front(), CoroEnds.back());
255           }
256         }
257         break;
258       }
259     }
260   }
261 
262   // If for some reason, we were not able to find coro.begin, bailout.
263   if (!CoroBegin) {
264     // Replace coro.frame which are supposed to be lowered to the result of
265     // coro.begin with undef.
266     auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
267     for (CoroFrameInst *CF : CoroFrames) {
268       CF->replaceAllUsesWith(Undef);
269       CF->eraseFromParent();
270     }
271 
272     // Replace all coro.suspend with undef and remove related coro.saves if
273     // present.
274     for (AnyCoroSuspendInst *CS : CoroSuspends) {
275       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
276       CS->eraseFromParent();
277       if (auto *CoroSave = CS->getCoroSave())
278         CoroSave->eraseFromParent();
279     }
280 
281     // Replace all coro.ends with unreachable instruction.
282     for (AnyCoroEndInst *CE : CoroEnds)
283       changeToUnreachable(CE);
284 
285     return;
286   }
287 
288   auto Id = CoroBegin->getId();
289   switch (auto IdIntrinsic = Id->getIntrinsicID()) {
290   case Intrinsic::coro_id: {
291     auto SwitchId = cast<CoroIdInst>(Id);
292     this->ABI = coro::ABI::Switch;
293     this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
294     this->SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
295     this->SwitchLowering.ResumeSwitch = nullptr;
296     this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
297     this->SwitchLowering.ResumeEntryBlock = nullptr;
298 
299     for (auto *AnySuspend : CoroSuspends) {
300       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
301       if (!Suspend) {
302 #ifndef NDEBUG
303         AnySuspend->dump();
304 #endif
305         report_fatal_error("coro.id must be paired with coro.suspend");
306       }
307 
308       if (!Suspend->getCoroSave())
309         createCoroSave(CoroBegin, Suspend);
310     }
311     break;
312   }
313   case Intrinsic::coro_id_async: {
314     auto *AsyncId = cast<CoroIdAsyncInst>(Id);
315     AsyncId->checkWellFormed();
316     this->ABI = coro::ABI::Async;
317     this->AsyncLowering.Context = AsyncId->getStorage();
318     this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
319     this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
320     this->AsyncLowering.ContextAlignment =
321         AsyncId->getStorageAlignment().value();
322     this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
323     this->AsyncLowering.AsyncCC = F.getCallingConv();
324     break;
325   };
326   case Intrinsic::coro_id_retcon:
327   case Intrinsic::coro_id_retcon_once: {
328     auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
329     ContinuationId->checkWellFormed();
330     this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
331                   ? coro::ABI::Retcon
332                   : coro::ABI::RetconOnce);
333     auto Prototype = ContinuationId->getPrototype();
334     this->RetconLowering.ResumePrototype = Prototype;
335     this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
336     this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
337     this->RetconLowering.ReturnBlock = nullptr;
338     this->RetconLowering.IsFrameInlineInStorage = false;
339 
340     // Determine the result value types, and make sure they match up with
341     // the values passed to the suspends.
342     auto ResultTys = getRetconResultTypes();
343     auto ResumeTys = getRetconResumeTypes();
344 
345     for (auto *AnySuspend : CoroSuspends) {
346       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
347       if (!Suspend) {
348 #ifndef NDEBUG
349         AnySuspend->dump();
350 #endif
351         report_fatal_error("coro.id.retcon.* must be paired with "
352                            "coro.suspend.retcon");
353       }
354 
355       // Check that the argument types of the suspend match the results.
356       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
357       auto RI = ResultTys.begin(), RE = ResultTys.end();
358       for (; SI != SE && RI != RE; ++SI, ++RI) {
359         auto SrcTy = (*SI)->getType();
360         if (SrcTy != *RI) {
361           // The optimizer likes to eliminate bitcasts leading into variadic
362           // calls, but that messes with our invariants.  Re-insert the
363           // bitcast and ignore this type mismatch.
364           if (CastInst::isBitCastable(SrcTy, *RI)) {
365             auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
366             SI->set(BCI);
367             continue;
368           }
369 
370 #ifndef NDEBUG
371           Suspend->dump();
372           Prototype->getFunctionType()->dump();
373 #endif
374           report_fatal_error("argument to coro.suspend.retcon does not "
375                              "match corresponding prototype function result");
376         }
377       }
378       if (SI != SE || RI != RE) {
379 #ifndef NDEBUG
380         Suspend->dump();
381         Prototype->getFunctionType()->dump();
382 #endif
383         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
384       }
385 
386       // Check that the result type of the suspend matches the resume types.
387       Type *SResultTy = Suspend->getType();
388       ArrayRef<Type*> SuspendResultTys;
389       if (SResultTy->isVoidTy()) {
390         // leave as empty array
391       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
392         SuspendResultTys = SResultStructTy->elements();
393       } else {
394         // forms an ArrayRef using SResultTy, be careful
395         SuspendResultTys = SResultTy;
396       }
397       if (SuspendResultTys.size() != ResumeTys.size()) {
398 #ifndef NDEBUG
399         Suspend->dump();
400         Prototype->getFunctionType()->dump();
401 #endif
402         report_fatal_error("wrong number of results from coro.suspend.retcon");
403       }
404       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
405         if (SuspendResultTys[I] != ResumeTys[I]) {
406 #ifndef NDEBUG
407           Suspend->dump();
408           Prototype->getFunctionType()->dump();
409 #endif
410           report_fatal_error("result from coro.suspend.retcon does not "
411                              "match corresponding prototype function param");
412         }
413       }
414     }
415     break;
416   }
417 
418   default:
419     llvm_unreachable("coro.begin is not dependent on a coro.id call");
420   }
421 
422   // The coro.free intrinsic is always lowered to the result of coro.begin.
423   for (CoroFrameInst *CF : CoroFrames) {
424     CF->replaceAllUsesWith(CoroBegin);
425     CF->eraseFromParent();
426   }
427 
428   // Move final suspend to be the last element in the CoroSuspends vector.
429   if (ABI == coro::ABI::Switch &&
430       SwitchLowering.HasFinalSuspend &&
431       FinalSuspendIndex != CoroSuspends.size() - 1)
432     std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
433 
434   // Remove orphaned coro.saves.
435   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
436     CoroSave->eraseFromParent();
437 }
438 
439 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
440   Call->setCallingConv(Callee->getCallingConv());
441   // TODO: attributes?
442 }
443 
444 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
445   if (CG)
446     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
447 }
448 
449 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
450                               CallGraph *CG) const {
451   switch (ABI) {
452   case coro::ABI::Switch:
453     llvm_unreachable("can't allocate memory in coro switch-lowering");
454 
455   case coro::ABI::Retcon:
456   case coro::ABI::RetconOnce: {
457     auto Alloc = RetconLowering.Alloc;
458     Size = Builder.CreateIntCast(Size,
459                                  Alloc->getFunctionType()->getParamType(0),
460                                  /*is signed*/ false);
461     auto *Call = Builder.CreateCall(Alloc, Size);
462     propagateCallAttrsFromCallee(Call, Alloc);
463     addCallToCallGraph(CG, Call, Alloc);
464     return Call;
465   }
466   case coro::ABI::Async:
467     llvm_unreachable("can't allocate memory in coro async-lowering");
468   }
469   llvm_unreachable("Unknown coro::ABI enum");
470 }
471 
472 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
473                               CallGraph *CG) const {
474   switch (ABI) {
475   case coro::ABI::Switch:
476     llvm_unreachable("can't allocate memory in coro switch-lowering");
477 
478   case coro::ABI::Retcon:
479   case coro::ABI::RetconOnce: {
480     auto Dealloc = RetconLowering.Dealloc;
481     Ptr = Builder.CreateBitCast(Ptr,
482                                 Dealloc->getFunctionType()->getParamType(0));
483     auto *Call = Builder.CreateCall(Dealloc, Ptr);
484     propagateCallAttrsFromCallee(Call, Dealloc);
485     addCallToCallGraph(CG, Call, Dealloc);
486     return;
487   }
488   case coro::ABI::Async:
489     llvm_unreachable("can't allocate memory in coro async-lowering");
490   }
491   llvm_unreachable("Unknown coro::ABI enum");
492 }
493 
494 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
495                               Value *V) {
496 #ifndef NDEBUG
497   I->dump();
498   if (V) {
499     errs() << "  Value: ";
500     V->printAsOperand(llvm::errs());
501     errs() << '\n';
502   }
503 #endif
504   report_fatal_error(Reason);
505 }
506 
507 /// Check that the given value is a well-formed prototype for the
508 /// llvm.coro.id.retcon.* intrinsics.
509 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
510   auto F = dyn_cast<Function>(V->stripPointerCasts());
511   if (!F)
512     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
513 
514   auto FT = F->getFunctionType();
515 
516   if (isa<CoroIdRetconInst>(I)) {
517     bool ResultOkay;
518     if (FT->getReturnType()->isPointerTy()) {
519       ResultOkay = true;
520     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
521       ResultOkay = (!SRetTy->isOpaque() &&
522                     SRetTy->getNumElements() > 0 &&
523                     SRetTy->getElementType(0)->isPointerTy());
524     } else {
525       ResultOkay = false;
526     }
527     if (!ResultOkay)
528       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
529               "result", F);
530 
531     if (FT->getReturnType() !=
532           I->getFunction()->getFunctionType()->getReturnType())
533       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
534               "current function return type", F);
535   } else {
536     // No meaningful validation to do here for llvm.coro.id.unique.once.
537   }
538 
539   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
540     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
541             "its first parameter", F);
542 }
543 
544 /// Check that the given value is a well-formed allocator.
545 static void checkWFAlloc(const Instruction *I, Value *V) {
546   auto F = dyn_cast<Function>(V->stripPointerCasts());
547   if (!F)
548     fail(I, "llvm.coro.* allocator not a Function", V);
549 
550   auto FT = F->getFunctionType();
551   if (!FT->getReturnType()->isPointerTy())
552     fail(I, "llvm.coro.* allocator must return a pointer", F);
553 
554   if (FT->getNumParams() != 1 ||
555       !FT->getParamType(0)->isIntegerTy())
556     fail(I, "llvm.coro.* allocator must take integer as only param", F);
557 }
558 
559 /// Check that the given value is a well-formed deallocator.
560 static void checkWFDealloc(const Instruction *I, Value *V) {
561   auto F = dyn_cast<Function>(V->stripPointerCasts());
562   if (!F)
563     fail(I, "llvm.coro.* deallocator not a Function", V);
564 
565   auto FT = F->getFunctionType();
566   if (!FT->getReturnType()->isVoidTy())
567     fail(I, "llvm.coro.* deallocator must return void", F);
568 
569   if (FT->getNumParams() != 1 ||
570       !FT->getParamType(0)->isPointerTy())
571     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
572 }
573 
574 static void checkConstantInt(const Instruction *I, Value *V,
575                              const char *Reason) {
576   if (!isa<ConstantInt>(V)) {
577     fail(I, Reason, V);
578   }
579 }
580 
581 void AnyCoroIdRetconInst::checkWellFormed() const {
582   checkConstantInt(this, getArgOperand(SizeArg),
583                    "size argument to coro.id.retcon.* must be constant");
584   checkConstantInt(this, getArgOperand(AlignArg),
585                    "alignment argument to coro.id.retcon.* must be constant");
586   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
587   checkWFAlloc(this, getArgOperand(AllocArg));
588   checkWFDealloc(this, getArgOperand(DeallocArg));
589 }
590 
591 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
592   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
593   if (!AsyncFuncPtrAddr)
594     fail(I, "llvm.coro.id.async async function pointer not a global", V);
595 }
596 
597 void CoroIdAsyncInst::checkWellFormed() const {
598   checkConstantInt(this, getArgOperand(SizeArg),
599                    "size argument to coro.id.async must be constant");
600   checkConstantInt(this, getArgOperand(AlignArg),
601                    "alignment argument to coro.id.async must be constant");
602   checkConstantInt(this, getArgOperand(StorageArg),
603                    "storage argument offset to coro.id.async must be constant");
604   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
605 }
606 
607 static void checkAsyncContextProjectFunction(const Instruction *I,
608                                              Function *F) {
609   auto *FunTy = cast<FunctionType>(F->getValueType());
610   if (!FunTy->getReturnType()->isPointerTy())
611     fail(I,
612          "llvm.coro.suspend.async resume function projection function must "
613          "return a ptr type",
614          F);
615   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
616     fail(I,
617          "llvm.coro.suspend.async resume function projection function must "
618          "take one ptr type as parameter",
619          F);
620 }
621 
622 void CoroSuspendAsyncInst::checkWellFormed() const {
623   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
624 }
625 
626 void CoroAsyncEndInst::checkWellFormed() const {
627   auto *MustTailCallFunc = getMustTailCallFunction();
628   if (!MustTailCallFunc)
629     return;
630   auto *FnTy = MustTailCallFunc->getFunctionType();
631   if (FnTy->getNumParams() != (arg_size() - 3))
632     fail(this,
633          "llvm.coro.end.async must tail call function argument type must "
634          "match the tail arguments",
635          MustTailCallFunc);
636 }
637