xref: /llvm-project/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
10 
11 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/Async/IR/Async.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "convert-async-to-llvm"
34 
35 using namespace mlir;
36 using namespace mlir::async;
37 
38 //===----------------------------------------------------------------------===//
39 // Async Runtime C API declaration.
40 //===----------------------------------------------------------------------===//
41 
42 static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
43 static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
44 static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
45 static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
46 static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
47 static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
48 static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
49 static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
50 static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
51 static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
52 static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
53 static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
54 static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
55 static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
56 static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
57 static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
58 static constexpr const char *kGetValueStorage =
59     "mlirAsyncRuntimeGetValueStorage";
60 static constexpr const char *kAddTokenToGroup =
61     "mlirAsyncRuntimeAddTokenToGroup";
62 static constexpr const char *kAwaitTokenAndExecute =
63     "mlirAsyncRuntimeAwaitTokenAndExecute";
64 static constexpr const char *kAwaitValueAndExecute =
65     "mlirAsyncRuntimeAwaitValueAndExecute";
66 static constexpr const char *kAwaitAllAndExecute =
67     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
68 static constexpr const char *kGetNumWorkerThreads =
69     "mlirAsyncRuntimGetNumWorkerThreads";
70 
71 namespace {
72 /// Async Runtime API function types.
73 ///
74 /// Because we can't create API function signature for type parametrized
75 /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
76 /// lowering all async data types become opaque pointers at runtime.
77 struct AsyncAPI {
78   // All async types are lowered to opaque LLVM pointers at runtime.
79   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
80     return LLVM::LLVMPointerType::get(ctx);
81   }
82 
83   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
84     return LLVM::LLVMTokenType::get(ctx);
85   }
86 
87   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
88     auto ref = opaquePointerType(ctx);
89     auto count = IntegerType::get(ctx, 64);
90     return FunctionType::get(ctx, {ref, count}, {});
91   }
92 
93   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
94     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
95   }
96 
97   static FunctionType createValueFunctionType(MLIRContext *ctx) {
98     auto i64 = IntegerType::get(ctx, 64);
99     auto value = opaquePointerType(ctx);
100     return FunctionType::get(ctx, {i64}, {value});
101   }
102 
103   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
104     auto i64 = IntegerType::get(ctx, 64);
105     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106   }
107 
108   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
109     auto ptrType = opaquePointerType(ctx);
110     return FunctionType::get(ctx, {ptrType}, {ptrType});
111   }
112 
113   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
114     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
115   }
116 
117   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
118     auto value = opaquePointerType(ctx);
119     return FunctionType::get(ctx, {value}, {});
120   }
121 
122   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
123     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
124   }
125 
126   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
127     auto value = opaquePointerType(ctx);
128     return FunctionType::get(ctx, {value}, {});
129   }
130 
131   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
132     auto i1 = IntegerType::get(ctx, 1);
133     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
134   }
135 
136   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
137     auto value = opaquePointerType(ctx);
138     auto i1 = IntegerType::get(ctx, 1);
139     return FunctionType::get(ctx, {value}, {i1});
140   }
141 
142   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
143     auto i1 = IntegerType::get(ctx, 1);
144     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145   }
146 
147   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
148     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
149   }
150 
151   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
152     auto value = opaquePointerType(ctx);
153     return FunctionType::get(ctx, {value}, {});
154   }
155 
156   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
157     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158   }
159 
160   static FunctionType executeFunctionType(MLIRContext *ctx) {
161     auto ptrType = opaquePointerType(ctx);
162     return FunctionType::get(ctx, {ptrType, ptrType}, {});
163   }
164 
165   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
166     auto i64 = IntegerType::get(ctx, 64);
167     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
168                              {i64});
169   }
170 
171   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172     auto ptrType = opaquePointerType(ctx);
173     return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
174   }
175 
176   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
177     auto ptrType = opaquePointerType(ctx);
178     return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179   }
180 
181   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
182     auto ptrType = opaquePointerType(ctx);
183     return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184   }
185 
186   static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187     return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188   }
189 
190   // Auxiliary coroutine resume intrinsic wrapper.
191   static Type resumeFunctionType(MLIRContext *ctx) {
192     auto voidTy = LLVM::LLVMVoidType::get(ctx);
193     auto ptrType = opaquePointerType(ctx);
194     return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
195   }
196 };
197 } // namespace
198 
199 /// Adds Async Runtime C API declarations to the module.
200 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201   auto builder =
202       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
203 
204   auto addFuncDecl = [&](StringRef name, FunctionType type) {
205     if (module.lookupSymbol(name))
206       return;
207     builder.create<func::FuncOp>(name, type).setPrivate();
208   };
209 
210   MLIRContext *ctx = module.getContext();
211   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
218   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
220   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229   addFuncDecl(kAwaitTokenAndExecute,
230               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231   addFuncDecl(kAwaitValueAndExecute,
232               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233   addFuncDecl(kAwaitAllAndExecute,
234               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235   addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // Coroutine resume function wrapper.
240 //===----------------------------------------------------------------------===//
241 
242 static constexpr const char *kResume = "__resume";
243 
244 /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245 /// intrinsics. We need this function to be able to pass it to the async
246 /// runtime execute API.
247 static void addResumeFunction(ModuleOp module) {
248   if (module.lookupSymbol(kResume))
249     return;
250 
251   MLIRContext *ctx = module.getContext();
252   auto loc = module.getLoc();
253   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
254 
255   auto voidTy = LLVM::LLVMVoidType::get(ctx);
256   Type ptrType = AsyncAPI::opaquePointerType(ctx);
257 
258   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
259       kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
260   resumeOp.setPrivate();
261 
262   auto *block = resumeOp.addEntryBlock(moduleBuilder);
263   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
264 
265   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
266   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
267 }
268 
269 //===----------------------------------------------------------------------===//
270 // Convert Async dialect types to LLVM types.
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275 /// their runtime type (opaque pointers) and does not convert any other types.
276 class AsyncRuntimeTypeConverter : public TypeConverter {
277 public:
278   AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
279     addConversion([](Type type) { return type; });
280     addConversion([](Type type) { return convertAsyncTypes(type); });
281 
282     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
283     // in patterns for other dialects.
284     auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285                                 ValueRange inputs, Location loc) -> Value {
286       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287       return cast.getResult(0);
288     };
289 
290     addSourceMaterialization(addUnrealizedCast);
291     addTargetMaterialization(addUnrealizedCast);
292   }
293 
294   static std::optional<Type> convertAsyncTypes(Type type) {
295     if (isa<TokenType, GroupType, ValueType>(type))
296       return AsyncAPI::opaquePointerType(type.getContext());
297 
298     if (isa<CoroIdType, CoroStateType>(type))
299       return AsyncAPI::tokenType(type.getContext());
300     if (isa<CoroHandleType>(type))
301       return AsyncAPI::opaquePointerType(type.getContext());
302 
303     return std::nullopt;
304   }
305 };
306 
307 /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
308 /// as type converter. Allows access to it via the 'getTypeConverter'
309 /// convenience method.
310 template <typename SourceOp>
311 class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
312 
313   using Base = OpConversionPattern<SourceOp>;
314 
315 public:
316   AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
317                            MLIRContext *context)
318       : Base(typeConverter, context) {}
319 
320   /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
321   const AsyncRuntimeTypeConverter *getTypeConverter() const {
322     return static_cast<const AsyncRuntimeTypeConverter *>(
323         Base::getTypeConverter());
324   }
325 };
326 
327 } // namespace
328 
329 //===----------------------------------------------------------------------===//
330 // Convert async.coro.id to @llvm.coro.id intrinsic.
331 //===----------------------------------------------------------------------===//
332 
333 namespace {
334 class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
335 public:
336   using AsyncOpConversionPattern::AsyncOpConversionPattern;
337 
338   LogicalResult
339   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
340                   ConversionPatternRewriter &rewriter) const override {
341     auto token = AsyncAPI::tokenType(op->getContext());
342     auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
343     auto loc = op->getLoc();
344 
345     // Constants for initializing coroutine frame.
346     auto constZero =
347         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
348     auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
349 
350     // Get coroutine id: @llvm.coro.id.
351     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
353 
354     return success();
355   }
356 };
357 } // namespace
358 
359 //===----------------------------------------------------------------------===//
360 // Convert async.coro.begin to @llvm.coro.begin intrinsic.
361 //===----------------------------------------------------------------------===//
362 
363 namespace {
364 class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
365 public:
366   using AsyncOpConversionPattern::AsyncOpConversionPattern;
367 
368   LogicalResult
369   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
370                   ConversionPatternRewriter &rewriter) const override {
371     auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
372     auto loc = op->getLoc();
373 
374     // Get coroutine frame size: @llvm.coro.size.i64.
375     Value coroSize =
376         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
377     // Get coroutine frame alignment: @llvm.coro.align.i64.
378     Value coroAlign =
379         rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
380 
381     // Round up the size to be multiple of the alignment. Since aligned_alloc
382     // requires the size parameter be an integral multiple of the alignment
383     // parameter.
384     auto makeConstant = [&](uint64_t c) {
385       return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
386                                                rewriter.getI64Type(), c);
387     };
388     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
389     coroSize =
390         rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
391     Value negCoroAlign =
392         rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
393     coroSize =
394         rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
395 
396     // Allocate memory for the coroutine frame.
397     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399     if (failed(allocFuncOp))
400       return failure();
401     auto coroAlloc = rewriter.create<LLVM::CallOp>(
402         loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
403 
404     // Begin a coroutine: @llvm.coro.begin.
405     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
406     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
407         op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
408 
409     return success();
410   }
411 };
412 } // namespace
413 
414 //===----------------------------------------------------------------------===//
415 // Convert async.coro.free to @llvm.coro.free intrinsic.
416 //===----------------------------------------------------------------------===//
417 
418 namespace {
419 class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
420 public:
421   using AsyncOpConversionPattern::AsyncOpConversionPattern;
422 
423   LogicalResult
424   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
425                   ConversionPatternRewriter &rewriter) const override {
426     auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
427     auto loc = op->getLoc();
428 
429     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
430     auto coroMem =
431         rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
432 
433     // Free the memory.
434     auto freeFuncOp =
435         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
436     if (failed(freeFuncOp))
437       return failure();
438     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
439                                               ValueRange(coroMem.getResult()));
440 
441     return success();
442   }
443 };
444 } // namespace
445 
446 //===----------------------------------------------------------------------===//
447 // Convert async.coro.end to @llvm.coro.end intrinsic.
448 //===----------------------------------------------------------------------===//
449 
450 namespace {
451 class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
452 public:
453   using OpConversionPattern::OpConversionPattern;
454 
455   LogicalResult
456   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
457                   ConversionPatternRewriter &rewriter) const override {
458     // We are not in the block that is part of the unwind sequence.
459     auto constFalse = rewriter.create<LLVM::ConstantOp>(
460         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
461     auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
462 
463     // Mark the end of a coroutine: @llvm.coro.end.
464     auto coroHdl = adaptor.getHandle();
465     rewriter.create<LLVM::CoroEndOp>(
466         op->getLoc(), rewriter.getI1Type(),
467         ValueRange({coroHdl, constFalse, noneToken}));
468     rewriter.eraseOp(op);
469 
470     return success();
471   }
472 };
473 } // namespace
474 
475 //===----------------------------------------------------------------------===//
476 // Convert async.coro.save to @llvm.coro.save intrinsic.
477 //===----------------------------------------------------------------------===//
478 
479 namespace {
480 class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
481 public:
482   using OpConversionPattern::OpConversionPattern;
483 
484   LogicalResult
485   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
486                   ConversionPatternRewriter &rewriter) const override {
487     // Save the coroutine state: @llvm.coro.save
488     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
489         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
490 
491     return success();
492   }
493 };
494 } // namespace
495 
496 //===----------------------------------------------------------------------===//
497 // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
498 //===----------------------------------------------------------------------===//
499 
500 namespace {
501 
502 /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
503 /// branch to the appropriate block based on the return code.
504 ///
505 /// Before:
506 ///
507 ///   ^suspended:
508 ///     "opBefore"(...)
509 ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
510 ///   ^resume:
511 ///     "op"(...)
512 ///   ^cleanup: ...
513 ///   ^suspend: ...
514 ///
515 /// After:
516 ///
517 ///   ^suspended:
518 ///     "opBefore"(...)
519 ///     %suspend = llmv.intr.coro.suspend ...
520 ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
521 ///   ^resume:
522 ///     "op"(...)
523 ///   ^cleanup: ...
524 ///   ^suspend: ...
525 ///
526 class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
527 public:
528   using OpConversionPattern::OpConversionPattern;
529 
530   LogicalResult
531   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
532                   ConversionPatternRewriter &rewriter) const override {
533     auto i8 = rewriter.getIntegerType(8);
534     auto i32 = rewriter.getI32Type();
535     auto loc = op->getLoc();
536 
537     // This is not a final suspension point.
538     auto constFalse = rewriter.create<LLVM::ConstantOp>(
539         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
540 
541     // Suspend a coroutine: @llvm.coro.suspend
542     auto coroState = adaptor.getState();
543     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
544         loc, i8, ValueRange({coroState, constFalse}));
545 
546     // Cast return code to i32.
547 
548     // After a suspension point decide if we should branch into resume, cleanup
549     // or suspend block of the coroutine (see @llvm.coro.suspend return code
550     // documentation).
551     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
552     llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
553                                               op.getCleanupDest()};
554     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
555         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
556         /*defaultDestination=*/op.getSuspendDest(),
557         /*defaultOperands=*/ValueRange(),
558         /*caseValues=*/caseValues,
559         /*caseDestinations=*/caseDest,
560         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
561         /*branchWeights=*/ArrayRef<int32_t>());
562 
563     return success();
564   }
565 };
566 } // namespace
567 
568 //===----------------------------------------------------------------------===//
569 // Convert async.runtime.create to the corresponding runtime API call.
570 //
571 // To allocate storage for the async values we use getelementptr trick:
572 // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
573 //===----------------------------------------------------------------------===//
574 
575 namespace {
576 class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
577 public:
578   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
579 
580   LogicalResult
581   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
582                   ConversionPatternRewriter &rewriter) const override {
583     const TypeConverter *converter = getTypeConverter();
584     Type resultType = op->getResultTypes()[0];
585 
586     // Tokens creation maps to a simple function call.
587     if (isa<TokenType>(resultType)) {
588       rewriter.replaceOpWithNewOp<func::CallOp>(
589           op, kCreateToken, converter->convertType(resultType));
590       return success();
591     }
592 
593     // To create a value we need to compute the storage requirement.
594     if (auto value = dyn_cast<ValueType>(resultType)) {
595       // Returns the size requirements for the async value storage.
596       auto sizeOf = [&](ValueType valueType) -> Value {
597         auto loc = op->getLoc();
598         auto i64 = rewriter.getI64Type();
599 
600         auto storedType = converter->convertType(valueType.getValueType());
601         auto storagePtrType =
602             AsyncAPI::opaquePointerType(rewriter.getContext());
603 
604         // %Size = getelementptr %T* null, int 1
605         // %SizeI = ptrtoint %T* %Size to i64
606         auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
607         auto gep =
608             rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
609                                          nullPtr, ArrayRef<LLVM::GEPArg>{1});
610         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
611       };
612 
613       rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
614                                                 sizeOf(value));
615 
616       return success();
617     }
618 
619     return rewriter.notifyMatchFailure(op, "unsupported async type");
620   }
621 };
622 } // namespace
623 
624 //===----------------------------------------------------------------------===//
625 // Convert async.runtime.create_group to the corresponding runtime API call.
626 //===----------------------------------------------------------------------===//
627 
628 namespace {
629 class RuntimeCreateGroupOpLowering
630     : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
631 public:
632   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
633 
634   LogicalResult
635   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
636                   ConversionPatternRewriter &rewriter) const override {
637     const TypeConverter *converter = getTypeConverter();
638     Type resultType = op.getResult().getType();
639 
640     rewriter.replaceOpWithNewOp<func::CallOp>(
641         op, kCreateGroup, converter->convertType(resultType),
642         adaptor.getOperands());
643     return success();
644   }
645 };
646 } // namespace
647 
648 //===----------------------------------------------------------------------===//
649 // Convert async.runtime.set_available to the corresponding runtime API call.
650 //===----------------------------------------------------------------------===//
651 
652 namespace {
653 class RuntimeSetAvailableOpLowering
654     : public OpConversionPattern<RuntimeSetAvailableOp> {
655 public:
656   using OpConversionPattern::OpConversionPattern;
657 
658   LogicalResult
659   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
660                   ConversionPatternRewriter &rewriter) const override {
661     StringRef apiFuncName =
662         TypeSwitch<Type, StringRef>(op.getOperand().getType())
663             .Case<TokenType>([](Type) { return kEmplaceToken; })
664             .Case<ValueType>([](Type) { return kEmplaceValue; });
665 
666     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
667                                               adaptor.getOperands());
668 
669     return success();
670   }
671 };
672 } // namespace
673 
674 //===----------------------------------------------------------------------===//
675 // Convert async.runtime.set_error to the corresponding runtime API call.
676 //===----------------------------------------------------------------------===//
677 
678 namespace {
679 class RuntimeSetErrorOpLowering
680     : public OpConversionPattern<RuntimeSetErrorOp> {
681 public:
682   using OpConversionPattern::OpConversionPattern;
683 
684   LogicalResult
685   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
686                   ConversionPatternRewriter &rewriter) const override {
687     StringRef apiFuncName =
688         TypeSwitch<Type, StringRef>(op.getOperand().getType())
689             .Case<TokenType>([](Type) { return kSetTokenError; })
690             .Case<ValueType>([](Type) { return kSetValueError; });
691 
692     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
693                                               adaptor.getOperands());
694 
695     return success();
696   }
697 };
698 } // namespace
699 
700 //===----------------------------------------------------------------------===//
701 // Convert async.runtime.is_error to the corresponding runtime API call.
702 //===----------------------------------------------------------------------===//
703 
704 namespace {
705 class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
706 public:
707   using OpConversionPattern::OpConversionPattern;
708 
709   LogicalResult
710   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
711                   ConversionPatternRewriter &rewriter) const override {
712     StringRef apiFuncName =
713         TypeSwitch<Type, StringRef>(op.getOperand().getType())
714             .Case<TokenType>([](Type) { return kIsTokenError; })
715             .Case<GroupType>([](Type) { return kIsGroupError; })
716             .Case<ValueType>([](Type) { return kIsValueError; });
717 
718     rewriter.replaceOpWithNewOp<func::CallOp>(
719         op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
720     return success();
721   }
722 };
723 } // namespace
724 
725 //===----------------------------------------------------------------------===//
726 // Convert async.runtime.await to the corresponding runtime API call.
727 //===----------------------------------------------------------------------===//
728 
729 namespace {
730 class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
731 public:
732   using OpConversionPattern::OpConversionPattern;
733 
734   LogicalResult
735   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
736                   ConversionPatternRewriter &rewriter) const override {
737     StringRef apiFuncName =
738         TypeSwitch<Type, StringRef>(op.getOperand().getType())
739             .Case<TokenType>([](Type) { return kAwaitToken; })
740             .Case<ValueType>([](Type) { return kAwaitValue; })
741             .Case<GroupType>([](Type) { return kAwaitGroup; });
742 
743     rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
744                                   adaptor.getOperands());
745     rewriter.eraseOp(op);
746 
747     return success();
748   }
749 };
750 } // namespace
751 
752 //===----------------------------------------------------------------------===//
753 // Convert async.runtime.await_and_resume to the corresponding runtime API call.
754 //===----------------------------------------------------------------------===//
755 
756 namespace {
757 class RuntimeAwaitAndResumeOpLowering
758     : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
759 public:
760   using AsyncOpConversionPattern::AsyncOpConversionPattern;
761 
762   LogicalResult
763   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
764                   ConversionPatternRewriter &rewriter) const override {
765     StringRef apiFuncName =
766         TypeSwitch<Type, StringRef>(op.getOperand().getType())
767             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
768             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
769             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
770 
771     Value operand = adaptor.getOperand();
772     Value handle = adaptor.getHandle();
773 
774     // A pointer to coroutine resume intrinsic wrapper.
775     addResumeFunction(op->getParentOfType<ModuleOp>());
776     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
777         op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
778         kResume);
779 
780     rewriter.create<func::CallOp>(
781         op->getLoc(), apiFuncName, TypeRange(),
782         ValueRange({operand, handle, resumePtr.getRes()}));
783     rewriter.eraseOp(op);
784 
785     return success();
786   }
787 };
788 } // namespace
789 
790 //===----------------------------------------------------------------------===//
791 // Convert async.runtime.resume to the corresponding runtime API call.
792 //===----------------------------------------------------------------------===//
793 
794 namespace {
795 class RuntimeResumeOpLowering
796     : public AsyncOpConversionPattern<RuntimeResumeOp> {
797 public:
798   using AsyncOpConversionPattern::AsyncOpConversionPattern;
799 
800   LogicalResult
801   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
802                   ConversionPatternRewriter &rewriter) const override {
803     // A pointer to coroutine resume intrinsic wrapper.
804     addResumeFunction(op->getParentOfType<ModuleOp>());
805     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
806         op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
807         kResume);
808 
809     // Call async runtime API to execute a coroutine in the managed thread.
810     auto coroHdl = adaptor.getHandle();
811     rewriter.replaceOpWithNewOp<func::CallOp>(
812         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
813 
814     return success();
815   }
816 };
817 } // namespace
818 
819 //===----------------------------------------------------------------------===//
820 // Convert async.runtime.store to the corresponding runtime API call.
821 //===----------------------------------------------------------------------===//
822 
823 namespace {
824 class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
825 public:
826   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
827 
828   LogicalResult
829   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
830                   ConversionPatternRewriter &rewriter) const override {
831     Location loc = op->getLoc();
832 
833     // Get a pointer to the async value storage from the runtime.
834     auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
835     auto storage = adaptor.getStorage();
836     auto storagePtr = rewriter.create<func::CallOp>(
837         loc, kGetValueStorage, TypeRange(ptrType), storage);
838 
839     // Cast from i8* to the LLVM pointer type.
840     auto valueType = op.getValue().getType();
841     auto llvmValueType = getTypeConverter()->convertType(valueType);
842     if (!llvmValueType)
843       return rewriter.notifyMatchFailure(
844           op, "failed to convert stored value type to LLVM type");
845 
846     Value castedStoragePtr = storagePtr.getResult(0);
847     // Store the yielded value into the async value storage.
848     auto value = adaptor.getValue();
849     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
850 
851     // Erase the original runtime store operation.
852     rewriter.eraseOp(op);
853 
854     return success();
855   }
856 };
857 } // namespace
858 
859 //===----------------------------------------------------------------------===//
860 // Convert async.runtime.load to the corresponding runtime API call.
861 //===----------------------------------------------------------------------===//
862 
863 namespace {
864 class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
865 public:
866   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
867 
868   LogicalResult
869   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
870                   ConversionPatternRewriter &rewriter) const override {
871     Location loc = op->getLoc();
872 
873     // Get a pointer to the async value storage from the runtime.
874     auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
875     auto storage = adaptor.getStorage();
876     auto storagePtr = rewriter.create<func::CallOp>(
877         loc, kGetValueStorage, TypeRange(ptrType), storage);
878 
879     // Cast from i8* to the LLVM pointer type.
880     auto valueType = op.getResult().getType();
881     auto llvmValueType = getTypeConverter()->convertType(valueType);
882     if (!llvmValueType)
883       return rewriter.notifyMatchFailure(
884           op, "failed to convert loaded value type to LLVM type");
885 
886     Value castedStoragePtr = storagePtr.getResult(0);
887 
888     // Load from the casted pointer.
889     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
890                                               castedStoragePtr);
891 
892     return success();
893   }
894 };
895 } // namespace
896 
897 //===----------------------------------------------------------------------===//
898 // Convert async.runtime.add_to_group to the corresponding runtime API call.
899 //===----------------------------------------------------------------------===//
900 
901 namespace {
902 class RuntimeAddToGroupOpLowering
903     : public OpConversionPattern<RuntimeAddToGroupOp> {
904 public:
905   using OpConversionPattern::OpConversionPattern;
906 
907   LogicalResult
908   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
909                   ConversionPatternRewriter &rewriter) const override {
910     // Currently we can only add tokens to the group.
911     if (!isa<TokenType>(op.getOperand().getType()))
912       return rewriter.notifyMatchFailure(op, "only token type is supported");
913 
914     // Replace with a runtime API function call.
915     rewriter.replaceOpWithNewOp<func::CallOp>(
916         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
917 
918     return success();
919   }
920 };
921 } // namespace
922 
923 //===----------------------------------------------------------------------===//
924 // Convert async.runtime.num_worker_threads to the corresponding runtime API
925 // call.
926 //===----------------------------------------------------------------------===//
927 
928 namespace {
929 class RuntimeNumWorkerThreadsOpLowering
930     : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
931 public:
932   using OpConversionPattern::OpConversionPattern;
933 
934   LogicalResult
935   matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
936                   ConversionPatternRewriter &rewriter) const override {
937 
938     // Replace with a runtime API function call.
939     rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
940                                               rewriter.getIndexType());
941 
942     return success();
943   }
944 };
945 } // namespace
946 
947 //===----------------------------------------------------------------------===//
948 // Async reference counting ops lowering (`async.runtime.add_ref` and
949 // `async.runtime.drop_ref` to the corresponding API calls).
950 //===----------------------------------------------------------------------===//
951 
952 namespace {
953 template <typename RefCountingOp>
954 class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
955 public:
956   explicit RefCountingOpLowering(const TypeConverter &converter,
957                                  MLIRContext *ctx, StringRef apiFunctionName)
958       : OpConversionPattern<RefCountingOp>(converter, ctx),
959         apiFunctionName(apiFunctionName) {}
960 
961   LogicalResult
962   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
963                   ConversionPatternRewriter &rewriter) const override {
964     auto count = rewriter.create<arith::ConstantOp>(
965         op->getLoc(), rewriter.getI64Type(),
966         rewriter.getI64IntegerAttr(op.getCount()));
967 
968     auto operand = adaptor.getOperand();
969     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
970                                               ValueRange({operand, count}));
971 
972     return success();
973   }
974 
975 private:
976   StringRef apiFunctionName;
977 };
978 
979 class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
980 public:
981   explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
982                                    MLIRContext *ctx)
983       : RefCountingOpLowering(converter, ctx, kAddRef) {}
984 };
985 
986 class RuntimeDropRefOpLowering
987     : public RefCountingOpLowering<RuntimeDropRefOp> {
988 public:
989   explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
990                                     MLIRContext *ctx)
991       : RefCountingOpLowering(converter, ctx, kDropRef) {}
992 };
993 } // namespace
994 
995 //===----------------------------------------------------------------------===//
996 // Convert return operations that return async values from async regions.
997 //===----------------------------------------------------------------------===//
998 
999 namespace {
1000 class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
1001 public:
1002   using OpConversionPattern::OpConversionPattern;
1003 
1004   LogicalResult
1005   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
1006                   ConversionPatternRewriter &rewriter) const override {
1007     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1008     return success();
1009   }
1010 };
1011 } // namespace
1012 
1013 //===----------------------------------------------------------------------===//
1014 
1015 namespace {
1016 struct ConvertAsyncToLLVMPass
1017     : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1018   using Base::Base;
1019 
1020   void runOnOperation() override;
1021 };
1022 } // namespace
1023 
1024 void ConvertAsyncToLLVMPass::runOnOperation() {
1025   ModuleOp module = getOperation();
1026   MLIRContext *ctx = module->getContext();
1027 
1028   LowerToLLVMOptions options(ctx);
1029 
1030   // Add declarations for most functions required by the coroutines lowering.
1031   // We delay adding the resume function until it's needed because it currently
1032   // fails to compile unless '-O0' is specified.
1033   addAsyncRuntimeApiDeclarations(module);
1034 
1035   // Lower async.runtime and async.coro operations to Async Runtime API and
1036   // LLVM coroutine intrinsics.
1037 
1038   // Convert async dialect types and operations to LLVM dialect.
1039   AsyncRuntimeTypeConverter converter(options);
1040   RewritePatternSet patterns(ctx);
1041 
1042   // We use conversion to LLVM type to lower async.runtime load and store
1043   // operations.
1044   LLVMTypeConverter llvmConverter(ctx, options);
1045   llvmConverter.addConversion([&](Type type) {
1046     return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
1047   });
1048 
1049   // Convert async types in function signatures and function calls.
1050   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1051                                                                  converter);
1052   populateCallOpTypeConversionPattern(patterns, converter);
1053 
1054   // Convert return operations inside async.execute regions.
1055   patterns.add<ReturnOpOpConversion>(converter, ctx);
1056 
1057   // Lower async.runtime operations to the async runtime API calls.
1058   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
1059                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
1060                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1061                RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1062                RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1063                                                                   ctx);
1064 
1065   // Lower async.runtime operations that rely on LLVM type converter to convert
1066   // from async value payload type to the LLVM type.
1067   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
1068                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
1069 
1070   // Lower async coroutine operations to LLVM coroutine intrinsics.
1071   patterns
1072       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1073            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1074           converter, ctx);
1075 
1076   ConversionTarget target(*ctx);
1077   target.addLegalOp<arith::ConstantOp, func::ConstantOp,
1078                     UnrealizedConversionCastOp>();
1079   target.addLegalDialect<LLVM::LLVMDialect>();
1080 
1081   // All operations from Async dialect must be lowered to the runtime API and
1082   // LLVM intrinsics calls.
1083   target.addIllegalDialect<AsyncDialect>();
1084 
1085   // Add dynamic legality constraints to apply conversions defined above.
1086   target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1087     return converter.isSignatureLegal(op.getFunctionType());
1088   });
1089   target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
1090     return converter.isLegal(op.getOperandTypes());
1091   });
1092   target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1093     return converter.isSignatureLegal(op.getCalleeType());
1094   });
1095 
1096   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1097     signalPassFailure();
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // Patterns for structural type conversions for the Async dialect operations.
1102 //===----------------------------------------------------------------------===//
1103 
1104 namespace {
1105 class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1106 public:
1107   using OpConversionPattern::OpConversionPattern;
1108   LogicalResult
1109   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1110                   ConversionPatternRewriter &rewriter) const override {
1111     ExecuteOp newOp =
1112         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1113     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1114                                 newOp.getRegion().end());
1115 
1116     // Set operands and update block argument and result types.
1117     newOp->setOperands(adaptor.getOperands());
1118     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1119       return failure();
1120     for (auto result : newOp.getResults())
1121       result.setType(typeConverter->convertType(result.getType()));
1122 
1123     rewriter.replaceOp(op, newOp.getResults());
1124     return success();
1125   }
1126 };
1127 
1128 // Dummy pattern to trigger the appropriate type conversion / materialization.
1129 class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1130 public:
1131   using OpConversionPattern::OpConversionPattern;
1132   LogicalResult
1133   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1134                   ConversionPatternRewriter &rewriter) const override {
1135     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1136     return success();
1137   }
1138 };
1139 
1140 // Dummy pattern to trigger the appropriate type conversion / materialization.
1141 class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1142 public:
1143   using OpConversionPattern::OpConversionPattern;
1144   LogicalResult
1145   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1146                   ConversionPatternRewriter &rewriter) const override {
1147     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1148     return success();
1149   }
1150 };
1151 } // namespace
1152 
1153 void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1154     TypeConverter &typeConverter, RewritePatternSet &patterns,
1155     ConversionTarget &target) {
1156   typeConverter.addConversion([&](TokenType type) { return type; });
1157   typeConverter.addConversion([&](ValueType type) {
1158     Type converted = typeConverter.convertType(type.getValueType());
1159     return converted ? ValueType::get(converted) : converted;
1160   });
1161 
1162   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
1163       typeConverter, patterns.getContext());
1164 
1165   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1166       [&](Operation *op) { return typeConverter.isLegal(op); });
1167 }
1168