xref: /llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp (revision 4e36d5b92d78822f9eeef6b62e7b037f5c2cb5b9)
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInternal.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Analysis/CallGraph.h"
17 #include "llvm/IR/Attributes.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DerivedTypes.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Transforms/Coroutines/ABI.h"
30 #include "llvm/Transforms/Coroutines/CoroInstr.h"
31 #include "llvm/Transforms/Coroutines/CoroShape.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 #include <cassert>
34 #include <cstddef>
35 #include <utility>
36 
37 using namespace llvm;
38 
39 // Construct the lowerer base class and initialize its members.
40 coro::LowererBase::LowererBase(Module &M)
41     : TheModule(M), Context(M.getContext()),
42       Int8Ptr(PointerType::get(Context, 0)),
43       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
44                                      /*isVarArg=*/false)),
45       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
46 
47 // Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
48 // It generates the following:
49 //
50 //    call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
51 
52 CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
53                                            Instruction *InsertPt) {
54   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
55   auto *Fn =
56       Intrinsic::getOrInsertDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
57 
58   assert(Index >= CoroSubFnInst::IndexFirst &&
59          Index < CoroSubFnInst::IndexLast &&
60          "makeSubFnCall: Index value out of range");
61   return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt->getIterator());
62 }
63 
64 // NOTE: Must be sorted!
65 static const char *const CoroIntrinsics[] = {
66     "llvm.coro.align",
67     "llvm.coro.alloc",
68     "llvm.coro.async.context.alloc",
69     "llvm.coro.async.context.dealloc",
70     "llvm.coro.async.resume",
71     "llvm.coro.async.size.replace",
72     "llvm.coro.await.suspend.bool",
73     "llvm.coro.await.suspend.handle",
74     "llvm.coro.await.suspend.void",
75     "llvm.coro.begin",
76     "llvm.coro.begin.custom.abi",
77     "llvm.coro.destroy",
78     "llvm.coro.done",
79     "llvm.coro.end",
80     "llvm.coro.end.async",
81     "llvm.coro.frame",
82     "llvm.coro.free",
83     "llvm.coro.id",
84     "llvm.coro.id.async",
85     "llvm.coro.id.retcon",
86     "llvm.coro.id.retcon.once",
87     "llvm.coro.noop",
88     "llvm.coro.prepare.async",
89     "llvm.coro.prepare.retcon",
90     "llvm.coro.promise",
91     "llvm.coro.resume",
92     "llvm.coro.save",
93     "llvm.coro.size",
94     "llvm.coro.subfn.addr",
95     "llvm.coro.suspend",
96     "llvm.coro.suspend.async",
97     "llvm.coro.suspend.retcon",
98 };
99 
100 #ifndef NDEBUG
101 static bool isCoroutineIntrinsicName(StringRef Name) {
102   return llvm::binary_search(CoroIntrinsics, Name);
103 }
104 #endif
105 
106 bool coro::isSuspendBlock(BasicBlock *BB) {
107   return isa<AnyCoroSuspendInst>(BB->front());
108 }
109 
110 bool coro::declaresAnyIntrinsic(const Module &M) {
111   for (StringRef Name : CoroIntrinsics) {
112     if (M.getNamedValue(Name))
113       return true;
114   }
115 
116   return false;
117 }
118 
119 // Verifies if a module has named values listed. Also, in debug mode verifies
120 // that names are intrinsic names.
121 bool coro::declaresIntrinsics(const Module &M,
122                               const std::initializer_list<StringRef> List) {
123   for (StringRef Name : List) {
124     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
125     if (M.getNamedValue(Name))
126       return true;
127   }
128 
129   return false;
130 }
131 
132 // Replace all coro.frees associated with the provided CoroId either with 'null'
133 // if Elide is true and with its frame parameter otherwise.
134 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
135   SmallVector<CoroFreeInst *, 4> CoroFrees;
136   for (User *U : CoroId->users())
137     if (auto CF = dyn_cast<CoroFreeInst>(U))
138       CoroFrees.push_back(CF);
139 
140   if (CoroFrees.empty())
141     return;
142 
143   Value *Replacement =
144       Elide
145           ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
146           : CoroFrees.front()->getFrame();
147 
148   for (CoroFreeInst *CF : CoroFrees) {
149     CF->replaceAllUsesWith(Replacement);
150     CF->eraseFromParent();
151   }
152 }
153 
154 void coro::suppressCoroAllocs(CoroIdInst *CoroId) {
155   SmallVector<CoroAllocInst *, 4> CoroAllocs;
156   for (User *U : CoroId->users())
157     if (auto *CA = dyn_cast<CoroAllocInst>(U))
158       CoroAllocs.push_back(CA);
159 
160   if (CoroAllocs.empty())
161     return;
162 
163   coro::suppressCoroAllocs(CoroId->getContext(), CoroAllocs);
164 }
165 
166 // Replacing llvm.coro.alloc with false will suppress dynamic
167 // allocation as it is expected for the frontend to generate the code that
168 // looks like:
169 //   id = coro.id(...)
170 //   mem = coro.alloc(id) ? malloc(coro.size()) : 0;
171 //   coro.begin(id, mem)
172 void coro::suppressCoroAllocs(LLVMContext &Context,
173                               ArrayRef<CoroAllocInst *> CoroAllocs) {
174   auto *False = ConstantInt::getFalse(Context);
175   for (auto *CA : CoroAllocs) {
176     CA->replaceAllUsesWith(False);
177     CA->eraseFromParent();
178   }
179 }
180 
181 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
182                                     CoroSuspendInst *SuspendInst) {
183   Module *M = SuspendInst->getModule();
184   auto *Fn = Intrinsic::getOrInsertDeclaration(M, Intrinsic::coro_save);
185   auto *SaveInst = cast<CoroSaveInst>(
186       CallInst::Create(Fn, CoroBegin, "", SuspendInst->getIterator()));
187   assert(!SuspendInst->getCoroSave());
188   SuspendInst->setArgOperand(0, SaveInst);
189   return SaveInst;
190 }
191 
192 // Collect "interesting" coroutine intrinsics.
193 void coro::Shape::analyze(Function &F,
194                           SmallVectorImpl<CoroFrameInst *> &CoroFrames,
195                           SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
196   clear();
197 
198   bool HasFinalSuspend = false;
199   bool HasUnwindCoroEnd = false;
200   size_t FinalSuspendIndex = 0;
201 
202   for (Instruction &I : instructions(F)) {
203     // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
204     // because they might be invoked
205     if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(&I)) {
206       CoroAwaitSuspends.push_back(AWS);
207     } else if (auto II = dyn_cast<IntrinsicInst>(&I)) {
208       switch (II->getIntrinsicID()) {
209       default:
210         continue;
211       case Intrinsic::coro_size:
212         CoroSizes.push_back(cast<CoroSizeInst>(II));
213         break;
214       case Intrinsic::coro_align:
215         CoroAligns.push_back(cast<CoroAlignInst>(II));
216         break;
217       case Intrinsic::coro_frame:
218         CoroFrames.push_back(cast<CoroFrameInst>(II));
219         break;
220       case Intrinsic::coro_save:
221         // After optimizations, coro_suspends using this coro_save might have
222         // been removed, remember orphaned coro_saves to remove them later.
223         if (II->use_empty())
224           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
225         break;
226       case Intrinsic::coro_suspend_async: {
227         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
228         Suspend->checkWellFormed();
229         CoroSuspends.push_back(Suspend);
230         break;
231       }
232       case Intrinsic::coro_suspend_retcon: {
233         auto Suspend = cast<CoroSuspendRetconInst>(II);
234         CoroSuspends.push_back(Suspend);
235         break;
236       }
237       case Intrinsic::coro_suspend: {
238         auto Suspend = cast<CoroSuspendInst>(II);
239         CoroSuspends.push_back(Suspend);
240         if (Suspend->isFinal()) {
241           if (HasFinalSuspend)
242             report_fatal_error(
243               "Only one suspend point can be marked as final");
244           HasFinalSuspend = true;
245           FinalSuspendIndex = CoroSuspends.size() - 1;
246         }
247         break;
248       }
249       case Intrinsic::coro_begin:
250       case Intrinsic::coro_begin_custom_abi: {
251         auto CB = cast<CoroBeginInst>(II);
252 
253         // Ignore coro id's that aren't pre-split.
254         auto Id = dyn_cast<CoroIdInst>(CB->getId());
255         if (Id && !Id->getInfo().isPreSplit())
256           break;
257 
258         if (CoroBegin)
259           report_fatal_error(
260                 "coroutine should have exactly one defining @llvm.coro.begin");
261         CB->addRetAttr(Attribute::NonNull);
262         CB->addRetAttr(Attribute::NoAlias);
263         CB->removeFnAttr(Attribute::NoDuplicate);
264         CoroBegin = CB;
265         break;
266       }
267       case Intrinsic::coro_end_async:
268       case Intrinsic::coro_end:
269         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
270         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
271           AsyncEnd->checkWellFormed();
272         }
273 
274         if (CoroEnds.back()->isUnwind())
275           HasUnwindCoroEnd = true;
276 
277         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
278           // Make sure that the fallthrough coro.end is the first element in the
279           // CoroEnds vector.
280           // Note: I don't think this is neccessary anymore.
281           if (CoroEnds.size() > 1) {
282             if (CoroEnds.front()->isFallthrough())
283               report_fatal_error(
284                   "Only one coro.end can be marked as fallthrough");
285             std::swap(CoroEnds.front(), CoroEnds.back());
286           }
287         }
288         break;
289       }
290     }
291   }
292 
293   // If there is no CoroBegin then this is not a coroutine.
294   if (!CoroBegin)
295     return;
296 
297   // Determination of ABI and initializing lowering info
298   auto Id = CoroBegin->getId();
299   switch (auto IntrID = Id->getIntrinsicID()) {
300   case Intrinsic::coro_id: {
301     ABI = coro::ABI::Switch;
302     SwitchLowering.HasFinalSuspend = HasFinalSuspend;
303     SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
304 
305     auto SwitchId = getSwitchCoroId();
306     SwitchLowering.ResumeSwitch = nullptr;
307     SwitchLowering.PromiseAlloca = SwitchId->getPromise();
308     SwitchLowering.ResumeEntryBlock = nullptr;
309 
310     // Move final suspend to the last element in the CoroSuspends vector.
311     if (SwitchLowering.HasFinalSuspend &&
312         FinalSuspendIndex != CoroSuspends.size() - 1)
313       std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
314     break;
315   }
316   case Intrinsic::coro_id_async: {
317     ABI = coro::ABI::Async;
318     auto *AsyncId = getAsyncCoroId();
319     AsyncId->checkWellFormed();
320     AsyncLowering.Context = AsyncId->getStorage();
321     AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
322     AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
323     AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value();
324     AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
325     AsyncLowering.AsyncCC = F.getCallingConv();
326     break;
327   }
328   case Intrinsic::coro_id_retcon:
329   case Intrinsic::coro_id_retcon_once: {
330     ABI = IntrID == Intrinsic::coro_id_retcon ? coro::ABI::Retcon
331                                               : coro::ABI::RetconOnce;
332     auto ContinuationId = getRetconCoroId();
333     ContinuationId->checkWellFormed();
334     auto Prototype = ContinuationId->getPrototype();
335     RetconLowering.ResumePrototype = Prototype;
336     RetconLowering.Alloc = ContinuationId->getAllocFunction();
337     RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
338     RetconLowering.ReturnBlock = nullptr;
339     RetconLowering.IsFrameInlineInStorage = false;
340     break;
341   }
342   default:
343     llvm_unreachable("coro.begin is not dependent on a coro.id call");
344   }
345 }
346 
347 // If for some reason, we were not able to find coro.begin, bailout.
348 void coro::Shape::invalidateCoroutine(
349     Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames) {
350   assert(!CoroBegin);
351   {
352     // Replace coro.frame which are supposed to be lowered to the result of
353     // coro.begin with poison.
354     auto *Poison = PoisonValue::get(PointerType::get(F.getContext(), 0));
355     for (CoroFrameInst *CF : CoroFrames) {
356       CF->replaceAllUsesWith(Poison);
357       CF->eraseFromParent();
358     }
359     CoroFrames.clear();
360 
361     // Replace all coro.suspend with poison and remove related coro.saves if
362     // present.
363     for (AnyCoroSuspendInst *CS : CoroSuspends) {
364       CS->replaceAllUsesWith(PoisonValue::get(CS->getType()));
365       CS->eraseFromParent();
366       if (auto *CoroSave = CS->getCoroSave())
367         CoroSave->eraseFromParent();
368     }
369     CoroSuspends.clear();
370 
371     // Replace all coro.ends with unreachable instruction.
372     for (AnyCoroEndInst *CE : CoroEnds)
373       changeToUnreachable(CE);
374   }
375 }
376 
377 void coro::SwitchABI::init() {
378   assert(Shape.ABI == coro::ABI::Switch);
379   {
380     for (auto *AnySuspend : Shape.CoroSuspends) {
381       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
382       if (!Suspend) {
383 #ifndef NDEBUG
384         AnySuspend->dump();
385 #endif
386         report_fatal_error("coro.id must be paired with coro.suspend");
387       }
388 
389       if (!Suspend->getCoroSave())
390         createCoroSave(Shape.CoroBegin, Suspend);
391     }
392   }
393 }
394 
395 void coro::AsyncABI::init() { assert(Shape.ABI == coro::ABI::Async); }
396 
397 void coro::AnyRetconABI::init() {
398   assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
399   {
400     // Determine the result value types, and make sure they match up with
401     // the values passed to the suspends.
402     auto ResultTys = Shape.getRetconResultTypes();
403     auto ResumeTys = Shape.getRetconResumeTypes();
404 
405     for (auto *AnySuspend : Shape.CoroSuspends) {
406       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
407       if (!Suspend) {
408 #ifndef NDEBUG
409         AnySuspend->dump();
410 #endif
411         report_fatal_error("coro.id.retcon.* must be paired with "
412                            "coro.suspend.retcon");
413       }
414 
415       // Check that the argument types of the suspend match the results.
416       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
417       auto RI = ResultTys.begin(), RE = ResultTys.end();
418       for (; SI != SE && RI != RE; ++SI, ++RI) {
419         auto SrcTy = (*SI)->getType();
420         if (SrcTy != *RI) {
421           // The optimizer likes to eliminate bitcasts leading into variadic
422           // calls, but that messes with our invariants.  Re-insert the
423           // bitcast and ignore this type mismatch.
424           if (CastInst::isBitCastable(SrcTy, *RI)) {
425             auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
426             SI->set(BCI);
427             continue;
428           }
429 
430 #ifndef NDEBUG
431           Suspend->dump();
432           Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
433 #endif
434           report_fatal_error("argument to coro.suspend.retcon does not "
435                              "match corresponding prototype function result");
436         }
437       }
438       if (SI != SE || RI != RE) {
439 #ifndef NDEBUG
440         Suspend->dump();
441         Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
442 #endif
443         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
444       }
445 
446       // Check that the result type of the suspend matches the resume types.
447       Type *SResultTy = Suspend->getType();
448       ArrayRef<Type *> SuspendResultTys;
449       if (SResultTy->isVoidTy()) {
450         // leave as empty array
451       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
452         SuspendResultTys = SResultStructTy->elements();
453       } else {
454         // forms an ArrayRef using SResultTy, be careful
455         SuspendResultTys = SResultTy;
456       }
457       if (SuspendResultTys.size() != ResumeTys.size()) {
458 #ifndef NDEBUG
459         Suspend->dump();
460         Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
461 #endif
462         report_fatal_error("wrong number of results from coro.suspend.retcon");
463       }
464       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
465         if (SuspendResultTys[I] != ResumeTys[I]) {
466 #ifndef NDEBUG
467           Suspend->dump();
468           Shape.RetconLowering.ResumePrototype->getFunctionType()->dump();
469 #endif
470           report_fatal_error("result from coro.suspend.retcon does not "
471                              "match corresponding prototype function param");
472         }
473       }
474     }
475   }
476 }
477 
478 void coro::Shape::cleanCoroutine(
479     SmallVectorImpl<CoroFrameInst *> &CoroFrames,
480     SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
481   // The coro.frame intrinsic is always lowered to the result of coro.begin.
482   for (CoroFrameInst *CF : CoroFrames) {
483     CF->replaceAllUsesWith(CoroBegin);
484     CF->eraseFromParent();
485   }
486   CoroFrames.clear();
487 
488   // Remove orphaned coro.saves.
489   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
490     CoroSave->eraseFromParent();
491   UnusedCoroSaves.clear();
492 }
493 
494 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
495   Call->setCallingConv(Callee->getCallingConv());
496   // TODO: attributes?
497 }
498 
499 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
500   if (CG)
501     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
502 }
503 
504 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
505                               CallGraph *CG) const {
506   switch (ABI) {
507   case coro::ABI::Switch:
508     llvm_unreachable("can't allocate memory in coro switch-lowering");
509 
510   case coro::ABI::Retcon:
511   case coro::ABI::RetconOnce: {
512     auto Alloc = RetconLowering.Alloc;
513     Size = Builder.CreateIntCast(Size,
514                                  Alloc->getFunctionType()->getParamType(0),
515                                  /*is signed*/ false);
516     auto *Call = Builder.CreateCall(Alloc, Size);
517     propagateCallAttrsFromCallee(Call, Alloc);
518     addCallToCallGraph(CG, Call, Alloc);
519     return Call;
520   }
521   case coro::ABI::Async:
522     llvm_unreachable("can't allocate memory in coro async-lowering");
523   }
524   llvm_unreachable("Unknown coro::ABI enum");
525 }
526 
527 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
528                               CallGraph *CG) const {
529   switch (ABI) {
530   case coro::ABI::Switch:
531     llvm_unreachable("can't allocate memory in coro switch-lowering");
532 
533   case coro::ABI::Retcon:
534   case coro::ABI::RetconOnce: {
535     auto Dealloc = RetconLowering.Dealloc;
536     Ptr = Builder.CreateBitCast(Ptr,
537                                 Dealloc->getFunctionType()->getParamType(0));
538     auto *Call = Builder.CreateCall(Dealloc, Ptr);
539     propagateCallAttrsFromCallee(Call, Dealloc);
540     addCallToCallGraph(CG, Call, Dealloc);
541     return;
542   }
543   case coro::ABI::Async:
544     llvm_unreachable("can't allocate memory in coro async-lowering");
545   }
546   llvm_unreachable("Unknown coro::ABI enum");
547 }
548 
549 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
550                               Value *V) {
551 #ifndef NDEBUG
552   I->dump();
553   if (V) {
554     errs() << "  Value: ";
555     V->printAsOperand(llvm::errs());
556     errs() << '\n';
557   }
558 #endif
559   report_fatal_error(Reason);
560 }
561 
562 /// Check that the given value is a well-formed prototype for the
563 /// llvm.coro.id.retcon.* intrinsics.
564 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
565   auto F = dyn_cast<Function>(V->stripPointerCasts());
566   if (!F)
567     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
568 
569   auto FT = F->getFunctionType();
570 
571   if (isa<CoroIdRetconInst>(I)) {
572     bool ResultOkay;
573     if (FT->getReturnType()->isPointerTy()) {
574       ResultOkay = true;
575     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
576       ResultOkay = (!SRetTy->isOpaque() &&
577                     SRetTy->getNumElements() > 0 &&
578                     SRetTy->getElementType(0)->isPointerTy());
579     } else {
580       ResultOkay = false;
581     }
582     if (!ResultOkay)
583       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
584               "result", F);
585 
586     if (FT->getReturnType() !=
587           I->getFunction()->getFunctionType()->getReturnType())
588       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
589               "current function return type", F);
590   } else {
591     // No meaningful validation to do here for llvm.coro.id.unique.once.
592   }
593 
594   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
595     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
596             "its first parameter", F);
597 }
598 
599 /// Check that the given value is a well-formed allocator.
600 static void checkWFAlloc(const Instruction *I, Value *V) {
601   auto F = dyn_cast<Function>(V->stripPointerCasts());
602   if (!F)
603     fail(I, "llvm.coro.* allocator not a Function", V);
604 
605   auto FT = F->getFunctionType();
606   if (!FT->getReturnType()->isPointerTy())
607     fail(I, "llvm.coro.* allocator must return a pointer", F);
608 
609   if (FT->getNumParams() != 1 ||
610       !FT->getParamType(0)->isIntegerTy())
611     fail(I, "llvm.coro.* allocator must take integer as only param", F);
612 }
613 
614 /// Check that the given value is a well-formed deallocator.
615 static void checkWFDealloc(const Instruction *I, Value *V) {
616   auto F = dyn_cast<Function>(V->stripPointerCasts());
617   if (!F)
618     fail(I, "llvm.coro.* deallocator not a Function", V);
619 
620   auto FT = F->getFunctionType();
621   if (!FT->getReturnType()->isVoidTy())
622     fail(I, "llvm.coro.* deallocator must return void", F);
623 
624   if (FT->getNumParams() != 1 ||
625       !FT->getParamType(0)->isPointerTy())
626     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
627 }
628 
629 static void checkConstantInt(const Instruction *I, Value *V,
630                              const char *Reason) {
631   if (!isa<ConstantInt>(V)) {
632     fail(I, Reason, V);
633   }
634 }
635 
636 void AnyCoroIdRetconInst::checkWellFormed() const {
637   checkConstantInt(this, getArgOperand(SizeArg),
638                    "size argument to coro.id.retcon.* must be constant");
639   checkConstantInt(this, getArgOperand(AlignArg),
640                    "alignment argument to coro.id.retcon.* must be constant");
641   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
642   checkWFAlloc(this, getArgOperand(AllocArg));
643   checkWFDealloc(this, getArgOperand(DeallocArg));
644 }
645 
646 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
647   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
648   if (!AsyncFuncPtrAddr)
649     fail(I, "llvm.coro.id.async async function pointer not a global", V);
650 }
651 
652 void CoroIdAsyncInst::checkWellFormed() const {
653   checkConstantInt(this, getArgOperand(SizeArg),
654                    "size argument to coro.id.async must be constant");
655   checkConstantInt(this, getArgOperand(AlignArg),
656                    "alignment argument to coro.id.async must be constant");
657   checkConstantInt(this, getArgOperand(StorageArg),
658                    "storage argument offset to coro.id.async must be constant");
659   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
660 }
661 
662 static void checkAsyncContextProjectFunction(const Instruction *I,
663                                              Function *F) {
664   auto *FunTy = cast<FunctionType>(F->getValueType());
665   if (!FunTy->getReturnType()->isPointerTy())
666     fail(I,
667          "llvm.coro.suspend.async resume function projection function must "
668          "return a ptr type",
669          F);
670   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
671     fail(I,
672          "llvm.coro.suspend.async resume function projection function must "
673          "take one ptr type as parameter",
674          F);
675 }
676 
677 void CoroSuspendAsyncInst::checkWellFormed() const {
678   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
679 }
680 
681 void CoroAsyncEndInst::checkWellFormed() const {
682   auto *MustTailCallFunc = getMustTailCallFunction();
683   if (!MustTailCallFunc)
684     return;
685   auto *FnTy = MustTailCallFunc->getFunctionType();
686   if (FnTy->getNumParams() != (arg_size() - 3))
687     fail(this,
688          "llvm.coro.end.async must tail call function argument type must "
689          "match the tail arguments",
690          MustTailCallFunc);
691 }
692