xref: /llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp (revision d57525294a6f9318db6aeda46bf77d004b836b43)
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInstr.h"
14 #include "CoroInternal.h"
15 #include "CoroShape.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/IR/Attributes.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/DerivedTypes.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/InstIterator.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/IntrinsicInst.h"
26 #include "llvm/IR/Intrinsics.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include <cassert>
33 #include <cstddef>
34 #include <utility>
35 
36 using namespace llvm;
37 
38 // Construct the lowerer base class and initialize its members.
39 coro::LowererBase::LowererBase(Module &M)
40     : TheModule(M), Context(M.getContext()),
41       Int8Ptr(PointerType::get(Context, 0)),
42       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
43                                      /*isVarArg=*/false)),
44       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
45 
46 // Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
47 // It generates the following:
48 //
49 //    call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
50 
51 CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
52                                            Instruction *InsertPt) {
53   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
54   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
55 
56   assert(Index >= CoroSubFnInst::IndexFirst &&
57          Index < CoroSubFnInst::IndexLast &&
58          "makeSubFnCall: Index value out of range");
59   return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt->getIterator());
60 }
61 
62 // NOTE: Must be sorted!
63 static const char *const CoroIntrinsics[] = {
64     "llvm.coro.align",
65     "llvm.coro.alloc",
66     "llvm.coro.async.context.alloc",
67     "llvm.coro.async.context.dealloc",
68     "llvm.coro.async.resume",
69     "llvm.coro.async.size.replace",
70     "llvm.coro.async.store_resume",
71     "llvm.coro.await.suspend.bool",
72     "llvm.coro.await.suspend.handle",
73     "llvm.coro.await.suspend.void",
74     "llvm.coro.begin",
75     "llvm.coro.destroy",
76     "llvm.coro.done",
77     "llvm.coro.end",
78     "llvm.coro.end.async",
79     "llvm.coro.frame",
80     "llvm.coro.free",
81     "llvm.coro.id",
82     "llvm.coro.id.async",
83     "llvm.coro.id.retcon",
84     "llvm.coro.id.retcon.once",
85     "llvm.coro.noop",
86     "llvm.coro.prepare.async",
87     "llvm.coro.prepare.retcon",
88     "llvm.coro.promise",
89     "llvm.coro.resume",
90     "llvm.coro.save",
91     "llvm.coro.size",
92     "llvm.coro.subfn.addr",
93     "llvm.coro.suspend",
94     "llvm.coro.suspend.async",
95     "llvm.coro.suspend.retcon",
96 };
97 
98 #ifndef NDEBUG
99 static bool isCoroutineIntrinsicName(StringRef Name) {
100   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
101 }
102 #endif
103 
104 bool coro::isSuspendBlock(BasicBlock *BB) {
105   return isa<AnyCoroSuspendInst>(BB->front());
106 }
107 
108 bool coro::declaresAnyIntrinsic(const Module &M) {
109   for (StringRef Name : CoroIntrinsics) {
110     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
111     if (M.getNamedValue(Name))
112       return true;
113   }
114 
115   return false;
116 }
117 
118 // Verifies if a module has named values listed. Also, in debug mode verifies
119 // that names are intrinsic names.
120 bool coro::declaresIntrinsics(const Module &M,
121                               const std::initializer_list<StringRef> List) {
122   for (StringRef Name : List) {
123     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
124     if (M.getNamedValue(Name))
125       return true;
126   }
127 
128   return false;
129 }
130 
131 // Replace all coro.frees associated with the provided CoroId either with 'null'
132 // if Elide is true and with its frame parameter otherwise.
133 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
134   SmallVector<CoroFreeInst *, 4> CoroFrees;
135   for (User *U : CoroId->users())
136     if (auto CF = dyn_cast<CoroFreeInst>(U))
137       CoroFrees.push_back(CF);
138 
139   if (CoroFrees.empty())
140     return;
141 
142   Value *Replacement =
143       Elide
144           ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
145           : CoroFrees.front()->getFrame();
146 
147   for (CoroFreeInst *CF : CoroFrees) {
148     CF->replaceAllUsesWith(Replacement);
149     CF->eraseFromParent();
150   }
151 }
152 
153 void coro::suppressCoroAllocs(CoroIdInst *CoroId) {
154   SmallVector<CoroAllocInst *, 4> CoroAllocs;
155   for (User *U : CoroId->users())
156     if (auto *CA = dyn_cast<CoroAllocInst>(U))
157       CoroAllocs.push_back(CA);
158 
159   if (CoroAllocs.empty())
160     return;
161 
162   coro::suppressCoroAllocs(CoroId->getContext(), CoroAllocs);
163 }
164 
165 // Replacing llvm.coro.alloc with false will suppress dynamic
166 // allocation as it is expected for the frontend to generate the code that
167 // looks like:
168 //   id = coro.id(...)
169 //   mem = coro.alloc(id) ? malloc(coro.size()) : 0;
170 //   coro.begin(id, mem)
171 void coro::suppressCoroAllocs(LLVMContext &Context,
172                               ArrayRef<CoroAllocInst *> CoroAllocs) {
173   auto *False = ConstantInt::getFalse(Context);
174   for (auto *CA : CoroAllocs) {
175     CA->replaceAllUsesWith(False);
176     CA->eraseFromParent();
177   }
178 }
179 
180 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
181                                     CoroSuspendInst *SuspendInst) {
182   Module *M = SuspendInst->getModule();
183   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
184   auto *SaveInst = cast<CoroSaveInst>(
185       CallInst::Create(Fn, CoroBegin, "", SuspendInst->getIterator()));
186   assert(!SuspendInst->getCoroSave());
187   SuspendInst->setArgOperand(0, SaveInst);
188   return SaveInst;
189 }
190 
191 // Collect "interesting" coroutine intrinsics.
192 void coro::Shape::analyze(Function &F,
193                           SmallVectorImpl<CoroFrameInst *> &CoroFrames,
194                           SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
195   clear();
196 
197   bool HasFinalSuspend = false;
198   bool HasUnwindCoroEnd = false;
199   size_t FinalSuspendIndex = 0;
200 
201   for (Instruction &I : instructions(F)) {
202     // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
203     // because they might be invoked
204     if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(&I)) {
205       CoroAwaitSuspends.push_back(AWS);
206     } else if (auto II = dyn_cast<IntrinsicInst>(&I)) {
207       switch (II->getIntrinsicID()) {
208       default:
209         continue;
210       case Intrinsic::coro_size:
211         CoroSizes.push_back(cast<CoroSizeInst>(II));
212         break;
213       case Intrinsic::coro_align:
214         CoroAligns.push_back(cast<CoroAlignInst>(II));
215         break;
216       case Intrinsic::coro_frame:
217         CoroFrames.push_back(cast<CoroFrameInst>(II));
218         break;
219       case Intrinsic::coro_save:
220         // After optimizations, coro_suspends using this coro_save might have
221         // been removed, remember orphaned coro_saves to remove them later.
222         if (II->use_empty())
223           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
224         break;
225       case Intrinsic::coro_suspend_async: {
226         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
227         Suspend->checkWellFormed();
228         CoroSuspends.push_back(Suspend);
229         break;
230       }
231       case Intrinsic::coro_suspend_retcon: {
232         auto Suspend = cast<CoroSuspendRetconInst>(II);
233         CoroSuspends.push_back(Suspend);
234         break;
235       }
236       case Intrinsic::coro_suspend: {
237         auto Suspend = cast<CoroSuspendInst>(II);
238         CoroSuspends.push_back(Suspend);
239         if (Suspend->isFinal()) {
240           if (HasFinalSuspend)
241             report_fatal_error(
242               "Only one suspend point can be marked as final");
243           HasFinalSuspend = true;
244           FinalSuspendIndex = CoroSuspends.size() - 1;
245         }
246         break;
247       }
248       case Intrinsic::coro_begin: {
249         auto CB = cast<CoroBeginInst>(II);
250 
251         // Ignore coro id's that aren't pre-split.
252         auto Id = dyn_cast<CoroIdInst>(CB->getId());
253         if (Id && !Id->getInfo().isPreSplit())
254           break;
255 
256         if (CoroBegin)
257           report_fatal_error(
258                 "coroutine should have exactly one defining @llvm.coro.begin");
259         CB->addRetAttr(Attribute::NonNull);
260         CB->addRetAttr(Attribute::NoAlias);
261         CB->removeFnAttr(Attribute::NoDuplicate);
262         CoroBegin = CB;
263         break;
264       }
265       case Intrinsic::coro_end_async:
266       case Intrinsic::coro_end:
267         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
268         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
269           AsyncEnd->checkWellFormed();
270         }
271 
272         if (CoroEnds.back()->isUnwind())
273           HasUnwindCoroEnd = true;
274 
275         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
276           // Make sure that the fallthrough coro.end is the first element in the
277           // CoroEnds vector.
278           // Note: I don't think this is neccessary anymore.
279           if (CoroEnds.size() > 1) {
280             if (CoroEnds.front()->isFallthrough())
281               report_fatal_error(
282                   "Only one coro.end can be marked as fallthrough");
283             std::swap(CoroEnds.front(), CoroEnds.back());
284           }
285         }
286         break;
287       }
288     }
289   }
290 
291   // If there is no CoroBegin then this is not a coroutine.
292   if (!CoroBegin)
293     return;
294 
295   // Determination of ABI and initializing lowering info
296   auto Id = CoroBegin->getId();
297   switch (auto IntrID = Id->getIntrinsicID()) {
298   case Intrinsic::coro_id: {
299     ABI = coro::ABI::Switch;
300     SwitchLowering.HasFinalSuspend = HasFinalSuspend;
301     SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
302 
303     auto SwitchId = getSwitchCoroId();
304     SwitchLowering.ResumeSwitch = nullptr;
305     SwitchLowering.PromiseAlloca = SwitchId->getPromise();
306     SwitchLowering.ResumeEntryBlock = nullptr;
307 
308     // Move final suspend to the last element in the CoroSuspends vector.
309     if (SwitchLowering.HasFinalSuspend &&
310         FinalSuspendIndex != CoroSuspends.size() - 1)
311       std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
312     break;
313   }
314   case Intrinsic::coro_id_async: {
315     ABI = coro::ABI::Async;
316     auto *AsyncId = getAsyncCoroId();
317     AsyncId->checkWellFormed();
318     AsyncLowering.Context = AsyncId->getStorage();
319     AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
320     AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
321     AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value();
322     AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
323     AsyncLowering.AsyncCC = F.getCallingConv();
324     break;
325   }
326   case Intrinsic::coro_id_retcon:
327   case Intrinsic::coro_id_retcon_once: {
328     ABI = IntrID == Intrinsic::coro_id_retcon ? coro::ABI::Retcon
329                                               : coro::ABI::RetconOnce;
330     auto ContinuationId = getRetconCoroId();
331     ContinuationId->checkWellFormed();
332     auto Prototype = ContinuationId->getPrototype();
333     RetconLowering.ResumePrototype = Prototype;
334     RetconLowering.Alloc = ContinuationId->getAllocFunction();
335     RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
336     RetconLowering.ReturnBlock = nullptr;
337     RetconLowering.IsFrameInlineInStorage = false;
338     break;
339   }
340   default:
341     llvm_unreachable("coro.begin is not dependent on a coro.id call");
342   }
343 }
344 
345 // If for some reason, we were not able to find coro.begin, bailout.
346 void coro::Shape::invalidateCoroutine(
347     Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames) {
348   assert(!CoroBegin);
349   {
350     // Replace coro.frame which are supposed to be lowered to the result of
351     // coro.begin with undef.
352     auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
353     for (CoroFrameInst *CF : CoroFrames) {
354       CF->replaceAllUsesWith(Undef);
355       CF->eraseFromParent();
356     }
357     CoroFrames.clear();
358 
359     // Replace all coro.suspend with undef and remove related coro.saves if
360     // present.
361     for (AnyCoroSuspendInst *CS : CoroSuspends) {
362       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
363       CS->eraseFromParent();
364       if (auto *CoroSave = CS->getCoroSave())
365         CoroSave->eraseFromParent();
366     }
367     CoroSuspends.clear();
368 
369     // Replace all coro.ends with unreachable instruction.
370     for (AnyCoroEndInst *CE : CoroEnds)
371       changeToUnreachable(CE);
372   }
373 }
374 
375 // Perform semantic checking and initialization of the ABI
376 void coro::Shape::initABI() {
377   switch (ABI) {
378   case coro::ABI::Switch: {
379     for (auto *AnySuspend : CoroSuspends) {
380       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
381       if (!Suspend) {
382 #ifndef NDEBUG
383         AnySuspend->dump();
384 #endif
385         report_fatal_error("coro.id must be paired with coro.suspend");
386       }
387 
388       if (!Suspend->getCoroSave())
389         createCoroSave(CoroBegin, Suspend);
390     }
391     break;
392   }
393   case coro::ABI::Async: {
394     break;
395   };
396   case coro::ABI::Retcon:
397   case coro::ABI::RetconOnce: {
398     // Determine the result value types, and make sure they match up with
399     // the values passed to the suspends.
400     auto ResultTys = getRetconResultTypes();
401     auto ResumeTys = getRetconResumeTypes();
402 
403     for (auto *AnySuspend : CoroSuspends) {
404       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
405       if (!Suspend) {
406 #ifndef NDEBUG
407         AnySuspend->dump();
408 #endif
409         report_fatal_error("coro.id.retcon.* must be paired with "
410                            "coro.suspend.retcon");
411       }
412 
413       // Check that the argument types of the suspend match the results.
414       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
415       auto RI = ResultTys.begin(), RE = ResultTys.end();
416       for (; SI != SE && RI != RE; ++SI, ++RI) {
417         auto SrcTy = (*SI)->getType();
418         if (SrcTy != *RI) {
419           // The optimizer likes to eliminate bitcasts leading into variadic
420           // calls, but that messes with our invariants.  Re-insert the
421           // bitcast and ignore this type mismatch.
422           if (CastInst::isBitCastable(SrcTy, *RI)) {
423             auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
424             SI->set(BCI);
425             continue;
426           }
427 
428 #ifndef NDEBUG
429           Suspend->dump();
430           RetconLowering.ResumePrototype->getFunctionType()->dump();
431 #endif
432           report_fatal_error("argument to coro.suspend.retcon does not "
433                              "match corresponding prototype function result");
434         }
435       }
436       if (SI != SE || RI != RE) {
437 #ifndef NDEBUG
438         Suspend->dump();
439         RetconLowering.ResumePrototype->getFunctionType()->dump();
440 #endif
441         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
442       }
443 
444       // Check that the result type of the suspend matches the resume types.
445       Type *SResultTy = Suspend->getType();
446       ArrayRef<Type *> SuspendResultTys;
447       if (SResultTy->isVoidTy()) {
448         // leave as empty array
449       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
450         SuspendResultTys = SResultStructTy->elements();
451       } else {
452         // forms an ArrayRef using SResultTy, be careful
453         SuspendResultTys = SResultTy;
454       }
455       if (SuspendResultTys.size() != ResumeTys.size()) {
456 #ifndef NDEBUG
457         Suspend->dump();
458         RetconLowering.ResumePrototype->getFunctionType()->dump();
459 #endif
460         report_fatal_error("wrong number of results from coro.suspend.retcon");
461       }
462       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
463         if (SuspendResultTys[I] != ResumeTys[I]) {
464 #ifndef NDEBUG
465           Suspend->dump();
466           RetconLowering.ResumePrototype->getFunctionType()->dump();
467 #endif
468           report_fatal_error("result from coro.suspend.retcon does not "
469                              "match corresponding prototype function param");
470         }
471       }
472     }
473     break;
474   }
475   default:
476     llvm_unreachable("coro.begin is not dependent on a coro.id call");
477   }
478 }
479 
480 void coro::Shape::cleanCoroutine(
481     SmallVectorImpl<CoroFrameInst *> &CoroFrames,
482     SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
483   // The coro.frame intrinsic is always lowered to the result of coro.begin.
484   for (CoroFrameInst *CF : CoroFrames) {
485     CF->replaceAllUsesWith(CoroBegin);
486     CF->eraseFromParent();
487   }
488   CoroFrames.clear();
489 
490   // Remove orphaned coro.saves.
491   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
492     CoroSave->eraseFromParent();
493   UnusedCoroSaves.clear();
494 }
495 
496 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
497   Call->setCallingConv(Callee->getCallingConv());
498   // TODO: attributes?
499 }
500 
501 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
502   if (CG)
503     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
504 }
505 
506 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
507                               CallGraph *CG) const {
508   switch (ABI) {
509   case coro::ABI::Switch:
510     llvm_unreachable("can't allocate memory in coro switch-lowering");
511 
512   case coro::ABI::Retcon:
513   case coro::ABI::RetconOnce: {
514     auto Alloc = RetconLowering.Alloc;
515     Size = Builder.CreateIntCast(Size,
516                                  Alloc->getFunctionType()->getParamType(0),
517                                  /*is signed*/ false);
518     auto *Call = Builder.CreateCall(Alloc, Size);
519     propagateCallAttrsFromCallee(Call, Alloc);
520     addCallToCallGraph(CG, Call, Alloc);
521     return Call;
522   }
523   case coro::ABI::Async:
524     llvm_unreachable("can't allocate memory in coro async-lowering");
525   }
526   llvm_unreachable("Unknown coro::ABI enum");
527 }
528 
529 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
530                               CallGraph *CG) const {
531   switch (ABI) {
532   case coro::ABI::Switch:
533     llvm_unreachable("can't allocate memory in coro switch-lowering");
534 
535   case coro::ABI::Retcon:
536   case coro::ABI::RetconOnce: {
537     auto Dealloc = RetconLowering.Dealloc;
538     Ptr = Builder.CreateBitCast(Ptr,
539                                 Dealloc->getFunctionType()->getParamType(0));
540     auto *Call = Builder.CreateCall(Dealloc, Ptr);
541     propagateCallAttrsFromCallee(Call, Dealloc);
542     addCallToCallGraph(CG, Call, Dealloc);
543     return;
544   }
545   case coro::ABI::Async:
546     llvm_unreachable("can't allocate memory in coro async-lowering");
547   }
548   llvm_unreachable("Unknown coro::ABI enum");
549 }
550 
551 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
552                               Value *V) {
553 #ifndef NDEBUG
554   I->dump();
555   if (V) {
556     errs() << "  Value: ";
557     V->printAsOperand(llvm::errs());
558     errs() << '\n';
559   }
560 #endif
561   report_fatal_error(Reason);
562 }
563 
564 /// Check that the given value is a well-formed prototype for the
565 /// llvm.coro.id.retcon.* intrinsics.
566 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
567   auto F = dyn_cast<Function>(V->stripPointerCasts());
568   if (!F)
569     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
570 
571   auto FT = F->getFunctionType();
572 
573   if (isa<CoroIdRetconInst>(I)) {
574     bool ResultOkay;
575     if (FT->getReturnType()->isPointerTy()) {
576       ResultOkay = true;
577     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
578       ResultOkay = (!SRetTy->isOpaque() &&
579                     SRetTy->getNumElements() > 0 &&
580                     SRetTy->getElementType(0)->isPointerTy());
581     } else {
582       ResultOkay = false;
583     }
584     if (!ResultOkay)
585       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
586               "result", F);
587 
588     if (FT->getReturnType() !=
589           I->getFunction()->getFunctionType()->getReturnType())
590       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
591               "current function return type", F);
592   } else {
593     // No meaningful validation to do here for llvm.coro.id.unique.once.
594   }
595 
596   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
597     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
598             "its first parameter", F);
599 }
600 
601 /// Check that the given value is a well-formed allocator.
602 static void checkWFAlloc(const Instruction *I, Value *V) {
603   auto F = dyn_cast<Function>(V->stripPointerCasts());
604   if (!F)
605     fail(I, "llvm.coro.* allocator not a Function", V);
606 
607   auto FT = F->getFunctionType();
608   if (!FT->getReturnType()->isPointerTy())
609     fail(I, "llvm.coro.* allocator must return a pointer", F);
610 
611   if (FT->getNumParams() != 1 ||
612       !FT->getParamType(0)->isIntegerTy())
613     fail(I, "llvm.coro.* allocator must take integer as only param", F);
614 }
615 
616 /// Check that the given value is a well-formed deallocator.
617 static void checkWFDealloc(const Instruction *I, Value *V) {
618   auto F = dyn_cast<Function>(V->stripPointerCasts());
619   if (!F)
620     fail(I, "llvm.coro.* deallocator not a Function", V);
621 
622   auto FT = F->getFunctionType();
623   if (!FT->getReturnType()->isVoidTy())
624     fail(I, "llvm.coro.* deallocator must return void", F);
625 
626   if (FT->getNumParams() != 1 ||
627       !FT->getParamType(0)->isPointerTy())
628     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
629 }
630 
631 static void checkConstantInt(const Instruction *I, Value *V,
632                              const char *Reason) {
633   if (!isa<ConstantInt>(V)) {
634     fail(I, Reason, V);
635   }
636 }
637 
638 void AnyCoroIdRetconInst::checkWellFormed() const {
639   checkConstantInt(this, getArgOperand(SizeArg),
640                    "size argument to coro.id.retcon.* must be constant");
641   checkConstantInt(this, getArgOperand(AlignArg),
642                    "alignment argument to coro.id.retcon.* must be constant");
643   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
644   checkWFAlloc(this, getArgOperand(AllocArg));
645   checkWFDealloc(this, getArgOperand(DeallocArg));
646 }
647 
648 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
649   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
650   if (!AsyncFuncPtrAddr)
651     fail(I, "llvm.coro.id.async async function pointer not a global", V);
652 }
653 
654 void CoroIdAsyncInst::checkWellFormed() const {
655   checkConstantInt(this, getArgOperand(SizeArg),
656                    "size argument to coro.id.async must be constant");
657   checkConstantInt(this, getArgOperand(AlignArg),
658                    "alignment argument to coro.id.async must be constant");
659   checkConstantInt(this, getArgOperand(StorageArg),
660                    "storage argument offset to coro.id.async must be constant");
661   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
662 }
663 
664 static void checkAsyncContextProjectFunction(const Instruction *I,
665                                              Function *F) {
666   auto *FunTy = cast<FunctionType>(F->getValueType());
667   if (!FunTy->getReturnType()->isPointerTy())
668     fail(I,
669          "llvm.coro.suspend.async resume function projection function must "
670          "return a ptr type",
671          F);
672   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
673     fail(I,
674          "llvm.coro.suspend.async resume function projection function must "
675          "take one ptr type as parameter",
676          F);
677 }
678 
679 void CoroSuspendAsyncInst::checkWellFormed() const {
680   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
681 }
682 
683 void CoroAsyncEndInst::checkWellFormed() const {
684   auto *MustTailCallFunc = getMustTailCallFunction();
685   if (!MustTailCallFunc)
686     return;
687   auto *FnTy = MustTailCallFunc->getFunctionType();
688   if (FnTy->getNumParams() != (arg_size() - 3))
689     fail(this,
690          "llvm.coro.end.async must tail call function argument type must "
691          "match the tail arguments",
692          MustTailCallFunc);
693 }
694