xref: /llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp (revision 6070aeb3b71520287e81e322b19f32f18ec2031a)
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInternal.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Analysis/CallGraph.h"
17 #include "llvm/IR/Attributes.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DerivedTypes.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Transforms/Coroutines/ABI.h"
30 #include "llvm/Transforms/Coroutines/CoroInstr.h"
31 #include "llvm/Transforms/Coroutines/CoroShape.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 #include <cassert>
34 #include <cstddef>
35 #include <utility>
36 
37 using namespace llvm;
38 
39 // Construct the lowerer base class and initialize its members.
40 coro::LowererBase::LowererBase(Module &M)
41     : TheModule(M), Context(M.getContext()),
42       Int8Ptr(PointerType::get(Context, 0)),
43       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
44                                      /*isVarArg=*/false)),
45       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
46 
47 // Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
48 // It generates the following:
49 //
50 //    call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
51 
52 CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
53                                            Instruction *InsertPt) {
54   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
55   auto *Fn =
56       Intrinsic::getOrInsertDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
57 
58   assert(Index >= CoroSubFnInst::IndexFirst &&
59          Index < CoroSubFnInst::IndexLast &&
60          "makeSubFnCall: Index value out of range");
61   return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt->getIterator());
62 }
63 
64 // NOTE: Must be sorted!
65 static const char *const CoroIntrinsics[] = {
66     "llvm.coro.align",
67     "llvm.coro.alloc",
68     "llvm.coro.async.context.alloc",
69     "llvm.coro.async.context.dealloc",
70     "llvm.coro.async.resume",
71     "llvm.coro.async.size.replace",
72     "llvm.coro.async.store_resume",
73     "llvm.coro.await.suspend.bool",
74     "llvm.coro.await.suspend.handle",
75     "llvm.coro.await.suspend.void",
76     "llvm.coro.begin",
77     "llvm.coro.begin.custom.abi",
78     "llvm.coro.destroy",
79     "llvm.coro.done",
80     "llvm.coro.end",
81     "llvm.coro.end.async",
82     "llvm.coro.frame",
83     "llvm.coro.free",
84     "llvm.coro.id",
85     "llvm.coro.id.async",
86     "llvm.coro.id.retcon",
87     "llvm.coro.id.retcon.once",
88     "llvm.coro.noop",
89     "llvm.coro.prepare.async",
90     "llvm.coro.prepare.retcon",
91     "llvm.coro.promise",
92     "llvm.coro.resume",
93     "llvm.coro.save",
94     "llvm.coro.size",
95     "llvm.coro.subfn.addr",
96     "llvm.coro.suspend",
97     "llvm.coro.suspend.async",
98     "llvm.coro.suspend.retcon",
99 };
100 
101 #ifndef NDEBUG
102 static bool isCoroutineIntrinsicName(StringRef Name) {
103   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name, "coro") !=
104          -1;
105 }
106 #endif
107 
108 bool coro::isSuspendBlock(BasicBlock *BB) {
109   return isa<AnyCoroSuspendInst>(BB->front());
110 }
111 
112 bool coro::declaresAnyIntrinsic(const Module &M) {
113   for (StringRef Name : CoroIntrinsics) {
114     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
115     if (M.getNamedValue(Name))
116       return true;
117   }
118 
119   return false;
120 }
121 
122 // Verifies if a module has named values listed. Also, in debug mode verifies
123 // that names are intrinsic names.
124 bool coro::declaresIntrinsics(const Module &M,
125                               const std::initializer_list<StringRef> List) {
126   for (StringRef Name : List) {
127     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
128     if (M.getNamedValue(Name))
129       return true;
130   }
131 
132   return false;
133 }
134 
135 // Replace all coro.frees associated with the provided CoroId either with 'null'
136 // if Elide is true and with its frame parameter otherwise.
137 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
138   SmallVector<CoroFreeInst *, 4> CoroFrees;
139   for (User *U : CoroId->users())
140     if (auto CF = dyn_cast<CoroFreeInst>(U))
141       CoroFrees.push_back(CF);
142 
143   if (CoroFrees.empty())
144     return;
145 
146   Value *Replacement =
147       Elide
148           ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
149           : CoroFrees.front()->getFrame();
150 
151   for (CoroFreeInst *CF : CoroFrees) {
152     CF->replaceAllUsesWith(Replacement);
153     CF->eraseFromParent();
154   }
155 }
156 
157 void coro::suppressCoroAllocs(CoroIdInst *CoroId) {
158   SmallVector<CoroAllocInst *, 4> CoroAllocs;
159   for (User *U : CoroId->users())
160     if (auto *CA = dyn_cast<CoroAllocInst>(U))
161       CoroAllocs.push_back(CA);
162 
163   if (CoroAllocs.empty())
164     return;
165 
166   coro::suppressCoroAllocs(CoroId->getContext(), CoroAllocs);
167 }
168 
169 // Replacing llvm.coro.alloc with false will suppress dynamic
170 // allocation as it is expected for the frontend to generate the code that
171 // looks like:
172 //   id = coro.id(...)
173 //   mem = coro.alloc(id) ? malloc(coro.size()) : 0;
174 //   coro.begin(id, mem)
175 void coro::suppressCoroAllocs(LLVMContext &Context,
176                               ArrayRef<CoroAllocInst *> CoroAllocs) {
177   auto *False = ConstantInt::getFalse(Context);
178   for (auto *CA : CoroAllocs) {
179     CA->replaceAllUsesWith(False);
180     CA->eraseFromParent();
181   }
182 }
183 
184 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
185                                     CoroSuspendInst *SuspendInst) {
186   Module *M = SuspendInst->getModule();
187   auto *Fn = Intrinsic::getOrInsertDeclaration(M, Intrinsic::coro_save);
188   auto *SaveInst = cast<CoroSaveInst>(
189       CallInst::Create(Fn, CoroBegin, "", SuspendInst->getIterator()));
190   assert(!SuspendInst->getCoroSave());
191   SuspendInst->setArgOperand(0, SaveInst);
192   return SaveInst;
193 }
194 
195 // Collect "interesting" coroutine intrinsics.
196 void coro::Shape::analyze(Function &F,
197                           SmallVectorImpl<CoroFrameInst *> &CoroFrames,
198                           SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
199   clear();
200 
201   bool HasFinalSuspend = false;
202   bool HasUnwindCoroEnd = false;
203   size_t FinalSuspendIndex = 0;
204 
205   for (Instruction &I : instructions(F)) {
206     // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
207     // because they might be invoked
208     if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(&I)) {
209       CoroAwaitSuspends.push_back(AWS);
210     } else if (auto II = dyn_cast<IntrinsicInst>(&I)) {
211       switch (II->getIntrinsicID()) {
212       default:
213         continue;
214       case Intrinsic::coro_size:
215         CoroSizes.push_back(cast<CoroSizeInst>(II));
216         break;
217       case Intrinsic::coro_align:
218         CoroAligns.push_back(cast<CoroAlignInst>(II));
219         break;
220       case Intrinsic::coro_frame:
221         CoroFrames.push_back(cast<CoroFrameInst>(II));
222         break;
223       case Intrinsic::coro_save:
224         // After optimizations, coro_suspends using this coro_save might have
225         // been removed, remember orphaned coro_saves to remove them later.
226         if (II->use_empty())
227           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
228         break;
229       case Intrinsic::coro_suspend_async: {
230         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
231         Suspend->checkWellFormed();
232         CoroSuspends.push_back(Suspend);
233         break;
234       }
235       case Intrinsic::coro_suspend_retcon: {
236         auto Suspend = cast<CoroSuspendRetconInst>(II);
237         CoroSuspends.push_back(Suspend);
238         break;
239       }
240       case Intrinsic::coro_suspend: {
241         auto Suspend = cast<CoroSuspendInst>(II);
242         CoroSuspends.push_back(Suspend);
243         if (Suspend->isFinal()) {
244           if (HasFinalSuspend)
245             report_fatal_error(
246               "Only one suspend point can be marked as final");
247           HasFinalSuspend = true;
248           FinalSuspendIndex = CoroSuspends.size() - 1;
249         }
250         break;
251       }
252       case Intrinsic::coro_begin:
253       case Intrinsic::coro_begin_custom_abi: {
254         auto CB = cast<CoroBeginInst>(II);
255 
256         // Ignore coro id's that aren't pre-split.
257         auto Id = dyn_cast<CoroIdInst>(CB->getId());
258         if (Id && !Id->getInfo().isPreSplit())
259           break;
260 
261         if (CoroBegin)
262           report_fatal_error(
263                 "coroutine should have exactly one defining @llvm.coro.begin");
264         CB->addRetAttr(Attribute::NonNull);
265         CB->addRetAttr(Attribute::NoAlias);
266         CB->removeFnAttr(Attribute::NoDuplicate);
267         CoroBegin = CB;
268         break;
269       }
270       case Intrinsic::coro_end_async:
271       case Intrinsic::coro_end:
272         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
273         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
274           AsyncEnd->checkWellFormed();
275         }
276 
277         if (CoroEnds.back()->isUnwind())
278           HasUnwindCoroEnd = true;
279 
280         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
281           // Make sure that the fallthrough coro.end is the first element in the
282           // CoroEnds vector.
283           // Note: I don't think this is neccessary anymore.
284           if (CoroEnds.size() > 1) {
285             if (CoroEnds.front()->isFallthrough())
286               report_fatal_error(
287                   "Only one coro.end can be marked as fallthrough");
288             std::swap(CoroEnds.front(), CoroEnds.back());
289           }
290         }
291         break;
292       }
293     }
294   }
295 
296   // If there is no CoroBegin then this is not a coroutine.
297   if (!CoroBegin)
298     return;
299 
300   // Determination of ABI and initializing lowering info
301   auto Id = CoroBegin->getId();
302   switch (auto IntrID = Id->getIntrinsicID()) {
303   case Intrinsic::coro_id: {
304     ABI = coro::ABI::Switch;
305     SwitchLowering.HasFinalSuspend = HasFinalSuspend;
306     SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
307 
308     auto SwitchId = getSwitchCoroId();
309     SwitchLowering.ResumeSwitch = nullptr;
310     SwitchLowering.PromiseAlloca = SwitchId->getPromise();
311     SwitchLowering.ResumeEntryBlock = nullptr;
312 
313     // Move final suspend to the last element in the CoroSuspends vector.
314     if (SwitchLowering.HasFinalSuspend &&
315         FinalSuspendIndex != CoroSuspends.size() - 1)
316       std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
317     break;
318   }
319   case Intrinsic::coro_id_async: {
320     ABI = coro::ABI::Async;
321     auto *AsyncId = getAsyncCoroId();
322     AsyncId->checkWellFormed();
323     AsyncLowering.Context = AsyncId->getStorage();
324     AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
325     AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
326     AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value();
327     AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
328     AsyncLowering.AsyncCC = F.getCallingConv();
329     break;
330   }
331   case Intrinsic::coro_id_retcon:
332   case Intrinsic::coro_id_retcon_once: {
333     ABI = IntrID == Intrinsic::coro_id_retcon ? coro::ABI::Retcon
334                                               : coro::ABI::RetconOnce;
335     auto ContinuationId = getRetconCoroId();
336     ContinuationId->checkWellFormed();
337     auto Prototype = ContinuationId->getPrototype();
338     RetconLowering.ResumePrototype = Prototype;
339     RetconLowering.Alloc = ContinuationId->getAllocFunction();
340     RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
341     RetconLowering.ReturnBlock = nullptr;
342     RetconLowering.IsFrameInlineInStorage = false;
343     break;
344   }
345   default:
346     llvm_unreachable("coro.begin is not dependent on a coro.id call");
347   }
348 }
349 
350 // If for some reason, we were not able to find coro.begin, bailout.
351 void coro::Shape::invalidateCoroutine(
352     Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames) {
353   assert(!CoroBegin);
354   {
355     // Replace coro.frame which are supposed to be lowered to the result of
356     // coro.begin with poison.
357     auto *Poison = PoisonValue::get(PointerType::get(F.getContext(), 0));
358     for (CoroFrameInst *CF : CoroFrames) {
359       CF->replaceAllUsesWith(Poison);
360       CF->eraseFromParent();
361     }
362     CoroFrames.clear();
363 
364     // Replace all coro.suspend with poison and remove related coro.saves if
365     // present.
366     for (AnyCoroSuspendInst *CS : CoroSuspends) {
367       CS->replaceAllUsesWith(PoisonValue::get(CS->getType()));
368       CS->eraseFromParent();
369       if (auto *CoroSave = CS->getCoroSave())
370         CoroSave->eraseFromParent();
371     }
372     CoroSuspends.clear();
373 
374     // Replace all coro.ends with unreachable instruction.
375     for (AnyCoroEndInst *CE : CoroEnds)
376       changeToUnreachable(CE);
377   }
378 }
379 
380 void coro::SwitchABI::init() {
381   assert(Shape.ABI == coro::ABI::Switch);
382   {
383     for (auto *AnySuspend : Shape.CoroSuspends) {
384       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
385       if (!Suspend) {
386 #ifndef NDEBUG
387         AnySuspend->dump();
388 #endif
389         report_fatal_error("coro.id must be paired with coro.suspend");
390       }
391 
392       if (!Suspend->getCoroSave())
393         createCoroSave(Shape.CoroBegin, Suspend);
394     }
395   }
396 }
397 
398 void coro::AsyncABI::init() { assert(Shape.ABI == coro::ABI::Async); }
399 
400 void coro::AnyRetconABI::init() {
401   assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
402   {
403     // Determine the result value types, and make sure they match up with
404     // the values passed to the suspends.
405     auto ResultTys = Shape.getRetconResultTypes();
406     auto ResumeTys = Shape.getRetconResumeTypes();
407 
408     for (auto *AnySuspend : Shape.CoroSuspends) {
409       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
410       if (!Suspend) {
411 #ifndef NDEBUG
412         AnySuspend->dump();
413 #endif
414         report_fatal_error("coro.id.retcon.* must be paired with "
415                            "coro.suspend.retcon");
416       }
417 
418       // Check that the argument types of the suspend match the results.
419       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
420       auto RI = ResultTys.begin(), RE = ResultTys.end();
421       for (; SI != SE && RI != RE; ++SI, ++RI) {
422         auto SrcTy = (*SI)->getType();
423         if (SrcTy != *RI) {
424           // The optimizer likes to eliminate bitcasts leading into variadic
425           // calls, but that messes with our invariants.  Re-insert the
426           // bitcast and ignore this type mismatch.
427           if (CastInst::isBitCastable(SrcTy, *RI)) {
428             auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
429             SI->set(BCI);
430             continue;
431           }
432 
433 #ifndef NDEBUG
434           Suspend->dump();
435           Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
436 #endif
437           report_fatal_error("argument to coro.suspend.retcon does not "
438                              "match corresponding prototype function result");
439         }
440       }
441       if (SI != SE || RI != RE) {
442 #ifndef NDEBUG
443         Suspend->dump();
444         Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
445 #endif
446         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
447       }
448 
449       // Check that the result type of the suspend matches the resume types.
450       Type *SResultTy = Suspend->getType();
451       ArrayRef<Type *> SuspendResultTys;
452       if (SResultTy->isVoidTy()) {
453         // leave as empty array
454       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
455         SuspendResultTys = SResultStructTy->elements();
456       } else {
457         // forms an ArrayRef using SResultTy, be careful
458         SuspendResultTys = SResultTy;
459       }
460       if (SuspendResultTys.size() != ResumeTys.size()) {
461 #ifndef NDEBUG
462         Suspend->dump();
463         Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
464 #endif
465         report_fatal_error("wrong number of results from coro.suspend.retcon");
466       }
467       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
468         if (SuspendResultTys[I] != ResumeTys[I]) {
469 #ifndef NDEBUG
470           Suspend->dump();
471           Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
472 #endif
473           report_fatal_error("result from coro.suspend.retcon does not "
474                              "match corresponding prototype function param");
475         }
476       }
477     }
478   }
479 }
480 
481 void coro::Shape::cleanCoroutine(
482     SmallVectorImpl<CoroFrameInst *> &CoroFrames,
483     SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
484   // The coro.frame intrinsic is always lowered to the result of coro.begin.
485   for (CoroFrameInst *CF : CoroFrames) {
486     CF->replaceAllUsesWith(CoroBegin);
487     CF->eraseFromParent();
488   }
489   CoroFrames.clear();
490 
491   // Remove orphaned coro.saves.
492   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
493     CoroSave->eraseFromParent();
494   UnusedCoroSaves.clear();
495 }
496 
497 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
498   Call->setCallingConv(Callee->getCallingConv());
499   // TODO: attributes?
500 }
501 
502 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
503   if (CG)
504     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
505 }
506 
507 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
508                               CallGraph *CG) const {
509   switch (ABI) {
510   case coro::ABI::Switch:
511     llvm_unreachable("can't allocate memory in coro switch-lowering");
512 
513   case coro::ABI::Retcon:
514   case coro::ABI::RetconOnce: {
515     auto Alloc = RetconLowering.Alloc;
516     Size = Builder.CreateIntCast(Size,
517                                  Alloc->getFunctionType()->getParamType(0),
518                                  /*is signed*/ false);
519     auto *Call = Builder.CreateCall(Alloc, Size);
520     propagateCallAttrsFromCallee(Call, Alloc);
521     addCallToCallGraph(CG, Call, Alloc);
522     return Call;
523   }
524   case coro::ABI::Async:
525     llvm_unreachable("can't allocate memory in coro async-lowering");
526   }
527   llvm_unreachable("Unknown coro::ABI enum");
528 }
529 
530 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
531                               CallGraph *CG) const {
532   switch (ABI) {
533   case coro::ABI::Switch:
534     llvm_unreachable("can't allocate memory in coro switch-lowering");
535 
536   case coro::ABI::Retcon:
537   case coro::ABI::RetconOnce: {
538     auto Dealloc = RetconLowering.Dealloc;
539     Ptr = Builder.CreateBitCast(Ptr,
540                                 Dealloc->getFunctionType()->getParamType(0));
541     auto *Call = Builder.CreateCall(Dealloc, Ptr);
542     propagateCallAttrsFromCallee(Call, Dealloc);
543     addCallToCallGraph(CG, Call, Dealloc);
544     return;
545   }
546   case coro::ABI::Async:
547     llvm_unreachable("can't allocate memory in coro async-lowering");
548   }
549   llvm_unreachable("Unknown coro::ABI enum");
550 }
551 
552 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
553                               Value *V) {
554 #ifndef NDEBUG
555   I->dump();
556   if (V) {
557     errs() << "  Value: ";
558     V->printAsOperand(llvm::errs());
559     errs() << '\n';
560   }
561 #endif
562   report_fatal_error(Reason);
563 }
564 
565 /// Check that the given value is a well-formed prototype for the
566 /// llvm.coro.id.retcon.* intrinsics.
567 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
568   auto F = dyn_cast<Function>(V->stripPointerCasts());
569   if (!F)
570     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
571 
572   auto FT = F->getFunctionType();
573 
574   if (isa<CoroIdRetconInst>(I)) {
575     bool ResultOkay;
576     if (FT->getReturnType()->isPointerTy()) {
577       ResultOkay = true;
578     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
579       ResultOkay = (!SRetTy->isOpaque() &&
580                     SRetTy->getNumElements() > 0 &&
581                     SRetTy->getElementType(0)->isPointerTy());
582     } else {
583       ResultOkay = false;
584     }
585     if (!ResultOkay)
586       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
587               "result", F);
588 
589     if (FT->getReturnType() !=
590           I->getFunction()->getFunctionType()->getReturnType())
591       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
592               "current function return type", F);
593   } else {
594     // No meaningful validation to do here for llvm.coro.id.unique.once.
595   }
596 
597   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
598     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
599             "its first parameter", F);
600 }
601 
602 /// Check that the given value is a well-formed allocator.
603 static void checkWFAlloc(const Instruction *I, Value *V) {
604   auto F = dyn_cast<Function>(V->stripPointerCasts());
605   if (!F)
606     fail(I, "llvm.coro.* allocator not a Function", V);
607 
608   auto FT = F->getFunctionType();
609   if (!FT->getReturnType()->isPointerTy())
610     fail(I, "llvm.coro.* allocator must return a pointer", F);
611 
612   if (FT->getNumParams() != 1 ||
613       !FT->getParamType(0)->isIntegerTy())
614     fail(I, "llvm.coro.* allocator must take integer as only param", F);
615 }
616 
617 /// Check that the given value is a well-formed deallocator.
618 static void checkWFDealloc(const Instruction *I, Value *V) {
619   auto F = dyn_cast<Function>(V->stripPointerCasts());
620   if (!F)
621     fail(I, "llvm.coro.* deallocator not a Function", V);
622 
623   auto FT = F->getFunctionType();
624   if (!FT->getReturnType()->isVoidTy())
625     fail(I, "llvm.coro.* deallocator must return void", F);
626 
627   if (FT->getNumParams() != 1 ||
628       !FT->getParamType(0)->isPointerTy())
629     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
630 }
631 
632 static void checkConstantInt(const Instruction *I, Value *V,
633                              const char *Reason) {
634   if (!isa<ConstantInt>(V)) {
635     fail(I, Reason, V);
636   }
637 }
638 
639 void AnyCoroIdRetconInst::checkWellFormed() const {
640   checkConstantInt(this, getArgOperand(SizeArg),
641                    "size argument to coro.id.retcon.* must be constant");
642   checkConstantInt(this, getArgOperand(AlignArg),
643                    "alignment argument to coro.id.retcon.* must be constant");
644   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
645   checkWFAlloc(this, getArgOperand(AllocArg));
646   checkWFDealloc(this, getArgOperand(DeallocArg));
647 }
648 
649 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
650   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
651   if (!AsyncFuncPtrAddr)
652     fail(I, "llvm.coro.id.async async function pointer not a global", V);
653 }
654 
655 void CoroIdAsyncInst::checkWellFormed() const {
656   checkConstantInt(this, getArgOperand(SizeArg),
657                    "size argument to coro.id.async must be constant");
658   checkConstantInt(this, getArgOperand(AlignArg),
659                    "alignment argument to coro.id.async must be constant");
660   checkConstantInt(this, getArgOperand(StorageArg),
661                    "storage argument offset to coro.id.async must be constant");
662   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
663 }
664 
665 static void checkAsyncContextProjectFunction(const Instruction *I,
666                                              Function *F) {
667   auto *FunTy = cast<FunctionType>(F->getValueType());
668   if (!FunTy->getReturnType()->isPointerTy())
669     fail(I,
670          "llvm.coro.suspend.async resume function projection function must "
671          "return a ptr type",
672          F);
673   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
674     fail(I,
675          "llvm.coro.suspend.async resume function projection function must "
676          "take one ptr type as parameter",
677          F);
678 }
679 
680 void CoroSuspendAsyncInst::checkWellFormed() const {
681   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
682 }
683 
684 void CoroAsyncEndInst::checkWellFormed() const {
685   auto *MustTailCallFunc = getMustTailCallFunction();
686   if (!MustTailCallFunc)
687     return;
688   auto *FnTy = MustTailCallFunc->getFunctionType();
689   if (FnTy->getNumParams() != (arg_size() - 3))
690     fail(this,
691          "llvm.coro.end.async must tail call function argument type must "
692          "match the tail arguments",
693          MustTailCallFunc);
694 }
695