xref: /llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp (revision e53b28c8330102bf21ac717ade2bfac9b1f09517)
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInstr.h"
14 #include "CoroInternal.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Analysis/CallGraph.h"
18 #include "llvm/IR/Attributes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include <cassert>
32 #include <cstddef>
33 #include <utility>
34 
35 using namespace llvm;
36 
37 // Construct the lowerer base class and initialize its members.
38 coro::LowererBase::LowererBase(Module &M)
39     : TheModule(M), Context(M.getContext()),
40       Int8Ptr(PointerType::get(Context, 0)),
41       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
42                                      /*isVarArg=*/false)),
43       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
44 
45 // Creates a sequence of instructions to obtain a resume function address using
46 // llvm.coro.subfn.addr. It generates the following sequence:
47 //
48 //    call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
49 //    bitcast i8* %2 to void(i8*)*
50 
51 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
52                                         Instruction *InsertPt) {
53   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
54   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
55 
56   assert(Index >= CoroSubFnInst::IndexFirst &&
57          Index < CoroSubFnInst::IndexLast &&
58          "makeSubFnCall: Index value out of range");
59   auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
60 
61   auto *Bitcast =
62       new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
63   return Bitcast;
64 }
65 
66 // NOTE: Must be sorted!
67 static const char *const CoroIntrinsics[] = {
68     "llvm.coro.align",
69     "llvm.coro.alloc",
70     "llvm.coro.async.context.alloc",
71     "llvm.coro.async.context.dealloc",
72     "llvm.coro.async.resume",
73     "llvm.coro.async.size.replace",
74     "llvm.coro.async.store_resume",
75     "llvm.coro.begin",
76     "llvm.coro.destroy",
77     "llvm.coro.done",
78     "llvm.coro.end",
79     "llvm.coro.end.async",
80     "llvm.coro.frame",
81     "llvm.coro.free",
82     "llvm.coro.id",
83     "llvm.coro.id.async",
84     "llvm.coro.id.retcon",
85     "llvm.coro.id.retcon.once",
86     "llvm.coro.noop",
87     "llvm.coro.prepare.async",
88     "llvm.coro.prepare.retcon",
89     "llvm.coro.promise",
90     "llvm.coro.resume",
91     "llvm.coro.save",
92     "llvm.coro.size",
93     "llvm.coro.subfn.addr",
94     "llvm.coro.suspend",
95     "llvm.coro.suspend.async",
96     "llvm.coro.suspend.retcon",
97 };
98 
99 #ifndef NDEBUG
100 static bool isCoroutineIntrinsicName(StringRef Name) {
101   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
102 }
103 #endif
104 
105 bool coro::declaresAnyIntrinsic(const Module &M) {
106   for (StringRef Name : CoroIntrinsics) {
107     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
108     if (M.getNamedValue(Name))
109       return true;
110   }
111 
112   return false;
113 }
114 
115 // Verifies if a module has named values listed. Also, in debug mode verifies
116 // that names are intrinsic names.
117 bool coro::declaresIntrinsics(const Module &M,
118                               const std::initializer_list<StringRef> List) {
119   for (StringRef Name : List) {
120     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
121     if (M.getNamedValue(Name))
122       return true;
123   }
124 
125   return false;
126 }
127 
128 // Replace all coro.frees associated with the provided CoroId either with 'null'
129 // if Elide is true and with its frame parameter otherwise.
130 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
131   SmallVector<CoroFreeInst *, 4> CoroFrees;
132   for (User *U : CoroId->users())
133     if (auto CF = dyn_cast<CoroFreeInst>(U))
134       CoroFrees.push_back(CF);
135 
136   if (CoroFrees.empty())
137     return;
138 
139   Value *Replacement =
140       Elide
141           ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
142           : CoroFrees.front()->getFrame();
143 
144   for (CoroFreeInst *CF : CoroFrees) {
145     CF->replaceAllUsesWith(Replacement);
146     CF->eraseFromParent();
147   }
148 }
149 
150 static void clear(coro::Shape &Shape) {
151   Shape.CoroBegin = nullptr;
152   Shape.CoroEnds.clear();
153   Shape.CoroSizes.clear();
154   Shape.CoroSuspends.clear();
155 
156   Shape.FrameTy = nullptr;
157   Shape.FramePtr = nullptr;
158   Shape.AllocaSpillBlock = nullptr;
159 }
160 
161 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
162                                     CoroSuspendInst *SuspendInst) {
163   Module *M = SuspendInst->getModule();
164   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
165   auto *SaveInst =
166       cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
167   assert(!SuspendInst->getCoroSave());
168   SuspendInst->setArgOperand(0, SaveInst);
169   return SaveInst;
170 }
171 
172 // Collect "interesting" coroutine intrinsics.
173 void coro::Shape::buildFrom(Function &F) {
174   bool HasFinalSuspend = false;
175   bool HasUnwindCoroEnd = false;
176   size_t FinalSuspendIndex = 0;
177   clear(*this);
178   SmallVector<CoroFrameInst *, 8> CoroFrames;
179   SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
180 
181   for (Instruction &I : instructions(F)) {
182     if (auto II = dyn_cast<IntrinsicInst>(&I)) {
183       switch (II->getIntrinsicID()) {
184       default:
185         continue;
186       case Intrinsic::coro_size:
187         CoroSizes.push_back(cast<CoroSizeInst>(II));
188         break;
189       case Intrinsic::coro_align:
190         CoroAligns.push_back(cast<CoroAlignInst>(II));
191         break;
192       case Intrinsic::coro_frame:
193         CoroFrames.push_back(cast<CoroFrameInst>(II));
194         break;
195       case Intrinsic::coro_save:
196         // After optimizations, coro_suspends using this coro_save might have
197         // been removed, remember orphaned coro_saves to remove them later.
198         if (II->use_empty())
199           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
200         break;
201       case Intrinsic::coro_suspend_async: {
202         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
203         Suspend->checkWellFormed();
204         CoroSuspends.push_back(Suspend);
205         break;
206       }
207       case Intrinsic::coro_suspend_retcon: {
208         auto Suspend = cast<CoroSuspendRetconInst>(II);
209         CoroSuspends.push_back(Suspend);
210         break;
211       }
212       case Intrinsic::coro_suspend: {
213         auto Suspend = cast<CoroSuspendInst>(II);
214         CoroSuspends.push_back(Suspend);
215         if (Suspend->isFinal()) {
216           if (HasFinalSuspend)
217             report_fatal_error(
218               "Only one suspend point can be marked as final");
219           HasFinalSuspend = true;
220           FinalSuspendIndex = CoroSuspends.size() - 1;
221         }
222         break;
223       }
224       case Intrinsic::coro_begin: {
225         auto CB = cast<CoroBeginInst>(II);
226 
227         // Ignore coro id's that aren't pre-split.
228         auto Id = dyn_cast<CoroIdInst>(CB->getId());
229         if (Id && !Id->getInfo().isPreSplit())
230           break;
231 
232         if (CoroBegin)
233           report_fatal_error(
234                 "coroutine should have exactly one defining @llvm.coro.begin");
235         CB->addRetAttr(Attribute::NonNull);
236         CB->addRetAttr(Attribute::NoAlias);
237         CB->removeFnAttr(Attribute::NoDuplicate);
238         CoroBegin = CB;
239         break;
240       }
241       case Intrinsic::coro_end_async:
242       case Intrinsic::coro_end:
243         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
244         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
245           AsyncEnd->checkWellFormed();
246         }
247 
248         if (CoroEnds.back()->isUnwind())
249           HasUnwindCoroEnd = true;
250 
251         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
252           // Make sure that the fallthrough coro.end is the first element in the
253           // CoroEnds vector.
254           // Note: I don't think this is neccessary anymore.
255           if (CoroEnds.size() > 1) {
256             if (CoroEnds.front()->isFallthrough())
257               report_fatal_error(
258                   "Only one coro.end can be marked as fallthrough");
259             std::swap(CoroEnds.front(), CoroEnds.back());
260           }
261         }
262         break;
263       }
264     }
265   }
266 
267   // If for some reason, we were not able to find coro.begin, bailout.
268   if (!CoroBegin) {
269     // Replace coro.frame which are supposed to be lowered to the result of
270     // coro.begin with undef.
271     auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
272     for (CoroFrameInst *CF : CoroFrames) {
273       CF->replaceAllUsesWith(Undef);
274       CF->eraseFromParent();
275     }
276 
277     // Replace all coro.suspend with undef and remove related coro.saves if
278     // present.
279     for (AnyCoroSuspendInst *CS : CoroSuspends) {
280       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
281       CS->eraseFromParent();
282       if (auto *CoroSave = CS->getCoroSave())
283         CoroSave->eraseFromParent();
284     }
285 
286     // Replace all coro.ends with unreachable instruction.
287     for (AnyCoroEndInst *CE : CoroEnds)
288       changeToUnreachable(CE);
289 
290     return;
291   }
292 
293   auto Id = CoroBegin->getId();
294   switch (auto IdIntrinsic = Id->getIntrinsicID()) {
295   case Intrinsic::coro_id: {
296     auto SwitchId = cast<CoroIdInst>(Id);
297     this->ABI = coro::ABI::Switch;
298     this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
299     this->SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
300     this->SwitchLowering.ResumeSwitch = nullptr;
301     this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
302     this->SwitchLowering.ResumeEntryBlock = nullptr;
303 
304     for (auto *AnySuspend : CoroSuspends) {
305       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
306       if (!Suspend) {
307 #ifndef NDEBUG
308         AnySuspend->dump();
309 #endif
310         report_fatal_error("coro.id must be paired with coro.suspend");
311       }
312 
313       if (!Suspend->getCoroSave())
314         createCoroSave(CoroBegin, Suspend);
315     }
316     break;
317   }
318   case Intrinsic::coro_id_async: {
319     auto *AsyncId = cast<CoroIdAsyncInst>(Id);
320     AsyncId->checkWellFormed();
321     this->ABI = coro::ABI::Async;
322     this->AsyncLowering.Context = AsyncId->getStorage();
323     this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
324     this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
325     this->AsyncLowering.ContextAlignment =
326         AsyncId->getStorageAlignment().value();
327     this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
328     this->AsyncLowering.AsyncCC = F.getCallingConv();
329     break;
330   };
331   case Intrinsic::coro_id_retcon:
332   case Intrinsic::coro_id_retcon_once: {
333     auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
334     ContinuationId->checkWellFormed();
335     this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
336                   ? coro::ABI::Retcon
337                   : coro::ABI::RetconOnce);
338     auto Prototype = ContinuationId->getPrototype();
339     this->RetconLowering.ResumePrototype = Prototype;
340     this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
341     this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
342     this->RetconLowering.ReturnBlock = nullptr;
343     this->RetconLowering.IsFrameInlineInStorage = false;
344 
345     // Determine the result value types, and make sure they match up with
346     // the values passed to the suspends.
347     auto ResultTys = getRetconResultTypes();
348     auto ResumeTys = getRetconResumeTypes();
349 
350     for (auto *AnySuspend : CoroSuspends) {
351       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
352       if (!Suspend) {
353 #ifndef NDEBUG
354         AnySuspend->dump();
355 #endif
356         report_fatal_error("coro.id.retcon.* must be paired with "
357                            "coro.suspend.retcon");
358       }
359 
360       // Check that the argument types of the suspend match the results.
361       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
362       auto RI = ResultTys.begin(), RE = ResultTys.end();
363       for (; SI != SE && RI != RE; ++SI, ++RI) {
364         auto SrcTy = (*SI)->getType();
365         if (SrcTy != *RI) {
366           // The optimizer likes to eliminate bitcasts leading into variadic
367           // calls, but that messes with our invariants.  Re-insert the
368           // bitcast and ignore this type mismatch.
369           if (CastInst::isBitCastable(SrcTy, *RI)) {
370             auto BCI = new BitCastInst(*SI, *RI, "", Suspend);
371             SI->set(BCI);
372             continue;
373           }
374 
375 #ifndef NDEBUG
376           Suspend->dump();
377           Prototype->getFunctionType()->dump();
378 #endif
379           report_fatal_error("argument to coro.suspend.retcon does not "
380                              "match corresponding prototype function result");
381         }
382       }
383       if (SI != SE || RI != RE) {
384 #ifndef NDEBUG
385         Suspend->dump();
386         Prototype->getFunctionType()->dump();
387 #endif
388         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
389       }
390 
391       // Check that the result type of the suspend matches the resume types.
392       Type *SResultTy = Suspend->getType();
393       ArrayRef<Type*> SuspendResultTys;
394       if (SResultTy->isVoidTy()) {
395         // leave as empty array
396       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
397         SuspendResultTys = SResultStructTy->elements();
398       } else {
399         // forms an ArrayRef using SResultTy, be careful
400         SuspendResultTys = SResultTy;
401       }
402       if (SuspendResultTys.size() != ResumeTys.size()) {
403 #ifndef NDEBUG
404         Suspend->dump();
405         Prototype->getFunctionType()->dump();
406 #endif
407         report_fatal_error("wrong number of results from coro.suspend.retcon");
408       }
409       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
410         if (SuspendResultTys[I] != ResumeTys[I]) {
411 #ifndef NDEBUG
412           Suspend->dump();
413           Prototype->getFunctionType()->dump();
414 #endif
415           report_fatal_error("result from coro.suspend.retcon does not "
416                              "match corresponding prototype function param");
417         }
418       }
419     }
420     break;
421   }
422 
423   default:
424     llvm_unreachable("coro.begin is not dependent on a coro.id call");
425   }
426 
427   // The coro.free intrinsic is always lowered to the result of coro.begin.
428   for (CoroFrameInst *CF : CoroFrames) {
429     CF->replaceAllUsesWith(CoroBegin);
430     CF->eraseFromParent();
431   }
432 
433   // Move final suspend to be the last element in the CoroSuspends vector.
434   if (ABI == coro::ABI::Switch &&
435       SwitchLowering.HasFinalSuspend &&
436       FinalSuspendIndex != CoroSuspends.size() - 1)
437     std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
438 
439   // Remove orphaned coro.saves.
440   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
441     CoroSave->eraseFromParent();
442 }
443 
444 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
445   Call->setCallingConv(Callee->getCallingConv());
446   // TODO: attributes?
447 }
448 
449 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
450   if (CG)
451     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
452 }
453 
454 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
455                               CallGraph *CG) const {
456   switch (ABI) {
457   case coro::ABI::Switch:
458     llvm_unreachable("can't allocate memory in coro switch-lowering");
459 
460   case coro::ABI::Retcon:
461   case coro::ABI::RetconOnce: {
462     auto Alloc = RetconLowering.Alloc;
463     Size = Builder.CreateIntCast(Size,
464                                  Alloc->getFunctionType()->getParamType(0),
465                                  /*is signed*/ false);
466     auto *Call = Builder.CreateCall(Alloc, Size);
467     propagateCallAttrsFromCallee(Call, Alloc);
468     addCallToCallGraph(CG, Call, Alloc);
469     return Call;
470   }
471   case coro::ABI::Async:
472     llvm_unreachable("can't allocate memory in coro async-lowering");
473   }
474   llvm_unreachable("Unknown coro::ABI enum");
475 }
476 
477 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
478                               CallGraph *CG) const {
479   switch (ABI) {
480   case coro::ABI::Switch:
481     llvm_unreachable("can't allocate memory in coro switch-lowering");
482 
483   case coro::ABI::Retcon:
484   case coro::ABI::RetconOnce: {
485     auto Dealloc = RetconLowering.Dealloc;
486     Ptr = Builder.CreateBitCast(Ptr,
487                                 Dealloc->getFunctionType()->getParamType(0));
488     auto *Call = Builder.CreateCall(Dealloc, Ptr);
489     propagateCallAttrsFromCallee(Call, Dealloc);
490     addCallToCallGraph(CG, Call, Dealloc);
491     return;
492   }
493   case coro::ABI::Async:
494     llvm_unreachable("can't allocate memory in coro async-lowering");
495   }
496   llvm_unreachable("Unknown coro::ABI enum");
497 }
498 
499 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
500                               Value *V) {
501 #ifndef NDEBUG
502   I->dump();
503   if (V) {
504     errs() << "  Value: ";
505     V->printAsOperand(llvm::errs());
506     errs() << '\n';
507   }
508 #endif
509   report_fatal_error(Reason);
510 }
511 
512 /// Check that the given value is a well-formed prototype for the
513 /// llvm.coro.id.retcon.* intrinsics.
514 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
515   auto F = dyn_cast<Function>(V->stripPointerCasts());
516   if (!F)
517     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
518 
519   auto FT = F->getFunctionType();
520 
521   if (isa<CoroIdRetconInst>(I)) {
522     bool ResultOkay;
523     if (FT->getReturnType()->isPointerTy()) {
524       ResultOkay = true;
525     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
526       ResultOkay = (!SRetTy->isOpaque() &&
527                     SRetTy->getNumElements() > 0 &&
528                     SRetTy->getElementType(0)->isPointerTy());
529     } else {
530       ResultOkay = false;
531     }
532     if (!ResultOkay)
533       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
534               "result", F);
535 
536     if (FT->getReturnType() !=
537           I->getFunction()->getFunctionType()->getReturnType())
538       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
539               "current function return type", F);
540   } else {
541     // No meaningful validation to do here for llvm.coro.id.unique.once.
542   }
543 
544   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
545     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
546             "its first parameter", F);
547 }
548 
549 /// Check that the given value is a well-formed allocator.
550 static void checkWFAlloc(const Instruction *I, Value *V) {
551   auto F = dyn_cast<Function>(V->stripPointerCasts());
552   if (!F)
553     fail(I, "llvm.coro.* allocator not a Function", V);
554 
555   auto FT = F->getFunctionType();
556   if (!FT->getReturnType()->isPointerTy())
557     fail(I, "llvm.coro.* allocator must return a pointer", F);
558 
559   if (FT->getNumParams() != 1 ||
560       !FT->getParamType(0)->isIntegerTy())
561     fail(I, "llvm.coro.* allocator must take integer as only param", F);
562 }
563 
564 /// Check that the given value is a well-formed deallocator.
565 static void checkWFDealloc(const Instruction *I, Value *V) {
566   auto F = dyn_cast<Function>(V->stripPointerCasts());
567   if (!F)
568     fail(I, "llvm.coro.* deallocator not a Function", V);
569 
570   auto FT = F->getFunctionType();
571   if (!FT->getReturnType()->isVoidTy())
572     fail(I, "llvm.coro.* deallocator must return void", F);
573 
574   if (FT->getNumParams() != 1 ||
575       !FT->getParamType(0)->isPointerTy())
576     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
577 }
578 
579 static void checkConstantInt(const Instruction *I, Value *V,
580                              const char *Reason) {
581   if (!isa<ConstantInt>(V)) {
582     fail(I, Reason, V);
583   }
584 }
585 
586 void AnyCoroIdRetconInst::checkWellFormed() const {
587   checkConstantInt(this, getArgOperand(SizeArg),
588                    "size argument to coro.id.retcon.* must be constant");
589   checkConstantInt(this, getArgOperand(AlignArg),
590                    "alignment argument to coro.id.retcon.* must be constant");
591   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
592   checkWFAlloc(this, getArgOperand(AllocArg));
593   checkWFDealloc(this, getArgOperand(DeallocArg));
594 }
595 
596 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
597   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
598   if (!AsyncFuncPtrAddr)
599     fail(I, "llvm.coro.id.async async function pointer not a global", V);
600 }
601 
602 void CoroIdAsyncInst::checkWellFormed() const {
603   checkConstantInt(this, getArgOperand(SizeArg),
604                    "size argument to coro.id.async must be constant");
605   checkConstantInt(this, getArgOperand(AlignArg),
606                    "alignment argument to coro.id.async must be constant");
607   checkConstantInt(this, getArgOperand(StorageArg),
608                    "storage argument offset to coro.id.async must be constant");
609   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
610 }
611 
612 static void checkAsyncContextProjectFunction(const Instruction *I,
613                                              Function *F) {
614   auto *FunTy = cast<FunctionType>(F->getValueType());
615   if (!FunTy->getReturnType()->isPointerTy())
616     fail(I,
617          "llvm.coro.suspend.async resume function projection function must "
618          "return a ptr type",
619          F);
620   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
621     fail(I,
622          "llvm.coro.suspend.async resume function projection function must "
623          "take one ptr type as parameter",
624          F);
625 }
626 
627 void CoroSuspendAsyncInst::checkWellFormed() const {
628   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
629 }
630 
631 void CoroAsyncEndInst::checkWellFormed() const {
632   auto *MustTailCallFunc = getMustTailCallFunction();
633   if (!MustTailCallFunc)
634     return;
635   auto *FnTy = MustTailCallFunc->getFunctionType();
636   if (FnTy->getNumParams() != (arg_size() - 3))
637     fail(this,
638          "llvm.coro.end.async must tail call function argument type must "
639          "match the tail arguments",
640          MustTailCallFunc);
641 }
642