xref: /llvm-project/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
136ce915aSLei Zhang //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===//
236ce915aSLei Zhang //
336ce915aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
436ce915aSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
536ce915aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
636ce915aSLei Zhang //
736ce915aSLei Zhang //===----------------------------------------------------------------------===//
836ce915aSLei Zhang 
936ce915aSLei Zhang #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
1036ce915aSLei Zhang 
11c4769ef5SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
125a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
142ca46421SMarkus Böck #include "mlir/Conversion/LLVMCommon/Pattern.h"
1575e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1736ce915aSLei Zhang #include "mlir/Dialect/Async/IR/Async.h"
1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1923aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
205acd6e05SBenjamin Kramer #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
2136ce915aSLei Zhang #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2275a3f326SChris Lattner #include "mlir/IR/ImplicitLocOpBuilder.h"
2336ce915aSLei Zhang #include "mlir/IR/TypeUtilities.h"
2436ce915aSLei Zhang #include "mlir/Pass/Pass.h"
2536ce915aSLei Zhang #include "mlir/Transforms/DialectConversion.h"
26d8c84d2aSEugene Zhulenev #include "llvm/ADT/TypeSwitch.h"
2736ce915aSLei Zhang 
2867d0d7acSMichele Scuttari namespace mlir {
29cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS
3067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3167d0d7acSMichele Scuttari } // namespace mlir
3267d0d7acSMichele Scuttari 
3336ce915aSLei Zhang #define DEBUG_TYPE "convert-async-to-llvm"
3436ce915aSLei Zhang 
3536ce915aSLei Zhang using namespace mlir;
3636ce915aSLei Zhang using namespace mlir::async;
3736ce915aSLei Zhang 
3836ce915aSLei Zhang //===----------------------------------------------------------------------===//
3936ce915aSLei Zhang // Async Runtime C API declaration.
4036ce915aSLei Zhang //===----------------------------------------------------------------------===//
4136ce915aSLei Zhang 
42a86a9b5eSEugene Zhulenev static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef";
43a86a9b5eSEugene Zhulenev static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef";
4436ce915aSLei Zhang static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken";
45621ad468SEugene Zhulenev static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue";
46c30ab6c2SEugene Zhulenev static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup";
4736ce915aSLei Zhang static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken";
48621ad468SEugene Zhulenev static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue";
4939957aa4SEugene Zhulenev static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError";
5039957aa4SEugene Zhulenev static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError";
5139957aa4SEugene Zhulenev static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError";
5239957aa4SEugene Zhulenev static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError";
53d8c84d2aSEugene Zhulenev static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError";
5436ce915aSLei Zhang static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken";
55621ad468SEugene Zhulenev static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue";
56c30ab6c2SEugene Zhulenev static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup";
5736ce915aSLei Zhang static constexpr const char *kExecute = "mlirAsyncRuntimeExecute";
58621ad468SEugene Zhulenev static constexpr const char *kGetValueStorage =
59621ad468SEugene Zhulenev     "mlirAsyncRuntimeGetValueStorage";
60c30ab6c2SEugene Zhulenev static constexpr const char *kAddTokenToGroup =
61c30ab6c2SEugene Zhulenev     "mlirAsyncRuntimeAddTokenToGroup";
62621ad468SEugene Zhulenev static constexpr const char *kAwaitTokenAndExecute =
6336ce915aSLei Zhang     "mlirAsyncRuntimeAwaitTokenAndExecute";
64621ad468SEugene Zhulenev static constexpr const char *kAwaitValueAndExecute =
65621ad468SEugene Zhulenev     "mlirAsyncRuntimeAwaitValueAndExecute";
66c30ab6c2SEugene Zhulenev static constexpr const char *kAwaitAllAndExecute =
67c30ab6c2SEugene Zhulenev     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
68149311b4Sbakhtiyar static constexpr const char *kGetNumWorkerThreads =
69149311b4Sbakhtiyar     "mlirAsyncRuntimGetNumWorkerThreads";
7036ce915aSLei Zhang 
7136ce915aSLei Zhang namespace {
72621ad468SEugene Zhulenev /// Async Runtime API function types.
73621ad468SEugene Zhulenev ///
74621ad468SEugene Zhulenev /// Because we can't create API function signature for type parametrized
752ca46421SMarkus Böck /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After
76621ad468SEugene Zhulenev /// lowering all async data types become opaque pointers at runtime.
7736ce915aSLei Zhang struct AsyncAPI {
782ca46421SMarkus Böck   // All async types are lowered to opaque LLVM pointers at runtime.
79749f3708SChristian Ulmann   static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) {
802ca46421SMarkus Böck     return LLVM::LLVMPointerType::get(ctx);
81621ad468SEugene Zhulenev   }
82621ad468SEugene Zhulenev 
839c53b8e5SEugene Zhulenev   static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) {
849c53b8e5SEugene Zhulenev     return LLVM::LLVMTokenType::get(ctx);
859c53b8e5SEugene Zhulenev   }
869c53b8e5SEugene Zhulenev 
87749f3708SChristian Ulmann   static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
88749f3708SChristian Ulmann     auto ref = opaquePointerType(ctx);
8992db09cdSEugene Zhulenev     auto count = IntegerType::get(ctx, 64);
901b97cdf8SRiver Riddle     return FunctionType::get(ctx, {ref, count}, {});
91a86a9b5eSEugene Zhulenev   }
92a86a9b5eSEugene Zhulenev 
9336ce915aSLei Zhang   static FunctionType createTokenFunctionType(MLIRContext *ctx) {
941b97cdf8SRiver Riddle     return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
9536ce915aSLei Zhang   }
9636ce915aSLei Zhang 
97749f3708SChristian Ulmann   static FunctionType createValueFunctionType(MLIRContext *ctx) {
9892db09cdSEugene Zhulenev     auto i64 = IntegerType::get(ctx, 64);
99749f3708SChristian Ulmann     auto value = opaquePointerType(ctx);
10092db09cdSEugene Zhulenev     return FunctionType::get(ctx, {i64}, {value});
101621ad468SEugene Zhulenev   }
102621ad468SEugene Zhulenev 
103c30ab6c2SEugene Zhulenev   static FunctionType createGroupFunctionType(MLIRContext *ctx) {
104d43b2360SEugene Zhulenev     auto i64 = IntegerType::get(ctx, 64);
105d43b2360SEugene Zhulenev     return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)});
106c30ab6c2SEugene Zhulenev   }
107c30ab6c2SEugene Zhulenev 
108749f3708SChristian Ulmann   static FunctionType getValueStorageFunctionType(MLIRContext *ctx) {
109749f3708SChristian Ulmann     auto ptrType = opaquePointerType(ctx);
110749f3708SChristian Ulmann     return FunctionType::get(ctx, {ptrType}, {ptrType});
111621ad468SEugene Zhulenev   }
112621ad468SEugene Zhulenev 
11336ce915aSLei Zhang   static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
1141b97cdf8SRiver Riddle     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
11536ce915aSLei Zhang   }
11636ce915aSLei Zhang 
117749f3708SChristian Ulmann   static FunctionType emplaceValueFunctionType(MLIRContext *ctx) {
118749f3708SChristian Ulmann     auto value = opaquePointerType(ctx);
119621ad468SEugene Zhulenev     return FunctionType::get(ctx, {value}, {});
120621ad468SEugene Zhulenev   }
121621ad468SEugene Zhulenev 
12239957aa4SEugene Zhulenev   static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) {
12339957aa4SEugene Zhulenev     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
12439957aa4SEugene Zhulenev   }
12539957aa4SEugene Zhulenev 
126749f3708SChristian Ulmann   static FunctionType setValueErrorFunctionType(MLIRContext *ctx) {
127749f3708SChristian Ulmann     auto value = opaquePointerType(ctx);
12839957aa4SEugene Zhulenev     return FunctionType::get(ctx, {value}, {});
12939957aa4SEugene Zhulenev   }
13039957aa4SEugene Zhulenev 
13139957aa4SEugene Zhulenev   static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) {
13239957aa4SEugene Zhulenev     auto i1 = IntegerType::get(ctx, 1);
13339957aa4SEugene Zhulenev     return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1});
13439957aa4SEugene Zhulenev   }
13539957aa4SEugene Zhulenev 
136749f3708SChristian Ulmann   static FunctionType isValueErrorFunctionType(MLIRContext *ctx) {
137749f3708SChristian Ulmann     auto value = opaquePointerType(ctx);
13839957aa4SEugene Zhulenev     auto i1 = IntegerType::get(ctx, 1);
13939957aa4SEugene Zhulenev     return FunctionType::get(ctx, {value}, {i1});
14039957aa4SEugene Zhulenev   }
14139957aa4SEugene Zhulenev 
142d8c84d2aSEugene Zhulenev   static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) {
143d8c84d2aSEugene Zhulenev     auto i1 = IntegerType::get(ctx, 1);
144d8c84d2aSEugene Zhulenev     return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1});
145d8c84d2aSEugene Zhulenev   }
146d8c84d2aSEugene Zhulenev 
14736ce915aSLei Zhang   static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
1481b97cdf8SRiver Riddle     return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
14936ce915aSLei Zhang   }
15036ce915aSLei Zhang 
151749f3708SChristian Ulmann   static FunctionType awaitValueFunctionType(MLIRContext *ctx) {
152749f3708SChristian Ulmann     auto value = opaquePointerType(ctx);
153621ad468SEugene Zhulenev     return FunctionType::get(ctx, {value}, {});
154621ad468SEugene Zhulenev   }
155621ad468SEugene Zhulenev 
156c30ab6c2SEugene Zhulenev   static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
1571b97cdf8SRiver Riddle     return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
158c30ab6c2SEugene Zhulenev   }
159c30ab6c2SEugene Zhulenev 
160749f3708SChristian Ulmann   static FunctionType executeFunctionType(MLIRContext *ctx) {
161749f3708SChristian Ulmann     auto ptrType = opaquePointerType(ctx);
162749f3708SChristian Ulmann     return FunctionType::get(ctx, {ptrType, ptrType}, {});
16336ce915aSLei Zhang   }
16436ce915aSLei Zhang 
165c30ab6c2SEugene Zhulenev   static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
1661b97cdf8SRiver Riddle     auto i64 = IntegerType::get(ctx, 64);
1671b97cdf8SRiver Riddle     return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
1681b97cdf8SRiver Riddle                              {i64});
169c30ab6c2SEugene Zhulenev   }
170c30ab6c2SEugene Zhulenev 
171749f3708SChristian Ulmann   static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) {
172749f3708SChristian Ulmann     auto ptrType = opaquePointerType(ctx);
173749f3708SChristian Ulmann     return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {});
17436ce915aSLei Zhang   }
17536ce915aSLei Zhang 
176749f3708SChristian Ulmann   static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) {
177749f3708SChristian Ulmann     auto ptrType = opaquePointerType(ctx);
178749f3708SChristian Ulmann     return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {});
179621ad468SEugene Zhulenev   }
180621ad468SEugene Zhulenev 
181749f3708SChristian Ulmann   static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
182749f3708SChristian Ulmann     auto ptrType = opaquePointerType(ctx);
183749f3708SChristian Ulmann     return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {});
184c30ab6c2SEugene Zhulenev   }
185c30ab6c2SEugene Zhulenev 
186149311b4Sbakhtiyar   static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
187149311b4Sbakhtiyar     return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
188149311b4Sbakhtiyar   }
189149311b4Sbakhtiyar 
19036ce915aSLei Zhang   // Auxiliary coroutine resume intrinsic wrapper.
191749f3708SChristian Ulmann   static Type resumeFunctionType(MLIRContext *ctx) {
1927ed9cfc7SAlex Zinenko     auto voidTy = LLVM::LLVMVoidType::get(ctx);
193749f3708SChristian Ulmann     auto ptrType = opaquePointerType(ctx);
1942ca46421SMarkus Böck     return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false);
19536ce915aSLei Zhang   }
19636ce915aSLei Zhang };
19736ce915aSLei Zhang } // namespace
19836ce915aSLei Zhang 
199621ad468SEugene Zhulenev /// Adds Async Runtime C API declarations to the module.
200749f3708SChristian Ulmann static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
201973ddb7dSMehdi Amini   auto builder =
202973ddb7dSMehdi Amini       ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
20336ce915aSLei Zhang 
204d8437552SRahul Joshi   auto addFuncDecl = [&](StringRef name, FunctionType type) {
205d8437552SRahul Joshi     if (module.lookupSymbol(name))
206d8437552SRahul Joshi       return;
20758ceae95SRiver Riddle     builder.create<func::FuncOp>(name, type).setPrivate();
208d8437552SRahul Joshi   };
209d8437552SRahul Joshi 
21036ce915aSLei Zhang   MLIRContext *ctx = module.getContext();
211749f3708SChristian Ulmann   addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx));
212749f3708SChristian Ulmann   addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx));
213d8437552SRahul Joshi   addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
214749f3708SChristian Ulmann   addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx));
215d8437552SRahul Joshi   addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
216d8437552SRahul Joshi   addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
217749f3708SChristian Ulmann   addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx));
21839957aa4SEugene Zhulenev   addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx));
219749f3708SChristian Ulmann   addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx));
22039957aa4SEugene Zhulenev   addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx));
221749f3708SChristian Ulmann   addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx));
222d8c84d2aSEugene Zhulenev   addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx));
223d8437552SRahul Joshi   addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
224749f3708SChristian Ulmann   addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx));
225d8437552SRahul Joshi   addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
226749f3708SChristian Ulmann   addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
227749f3708SChristian Ulmann   addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx));
228d8437552SRahul Joshi   addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
229749f3708SChristian Ulmann   addFuncDecl(kAwaitTokenAndExecute,
230749f3708SChristian Ulmann               AsyncAPI::awaitTokenAndExecuteFunctionType(ctx));
231749f3708SChristian Ulmann   addFuncDecl(kAwaitValueAndExecute,
232749f3708SChristian Ulmann               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
233749f3708SChristian Ulmann   addFuncDecl(kAwaitAllAndExecute,
234749f3708SChristian Ulmann               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
235149311b4Sbakhtiyar   addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
23636ce915aSLei Zhang }
23736ce915aSLei Zhang 
23836ce915aSLei Zhang //===----------------------------------------------------------------------===//
23936ce915aSLei Zhang // Coroutine resume function wrapper.
24036ce915aSLei Zhang //===----------------------------------------------------------------------===//
24136ce915aSLei Zhang 
24236ce915aSLei Zhang static constexpr const char *kResume = "__resume";
24336ce915aSLei Zhang 
244621ad468SEugene Zhulenev /// A function that takes a coroutine handle and calls a `llvm.coro.resume`
245621ad468SEugene Zhulenev /// intrinsics. We need this function to be able to pass it to the async
246621ad468SEugene Zhulenev /// runtime execute API.
247749f3708SChristian Ulmann static void addResumeFunction(ModuleOp module) {
2485b388169SChristian Sigg   if (module.lookupSymbol(kResume))
2495b388169SChristian Sigg     return;
2505b388169SChristian Sigg 
25136ce915aSLei Zhang   MLIRContext *ctx = module.getContext();
252973ddb7dSMehdi Amini   auto loc = module.getLoc();
253973ddb7dSMehdi Amini   auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
25436ce915aSLei Zhang 
2557ed9cfc7SAlex Zinenko   auto voidTy = LLVM::LLVMVoidType::get(ctx);
256749f3708SChristian Ulmann   Type ptrType = AsyncAPI::opaquePointerType(ctx);
25736ce915aSLei Zhang 
25836ce915aSLei Zhang   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
2592ca46421SMarkus Böck       kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType}));
260d8437552SRahul Joshi   resumeOp.setPrivate();
26136ce915aSLei Zhang 
26291d5653eSMatthias Springer   auto *block = resumeOp.addEntryBlock(moduleBuilder);
26375a3f326SChris Lattner   auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block);
26436ce915aSLei Zhang 
265d37b5393SEugene Zhulenev   blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0));
26675a3f326SChris Lattner   blockBuilder.create<LLVM::ReturnOp>(ValueRange());
26736ce915aSLei Zhang }
26836ce915aSLei Zhang 
26936ce915aSLei Zhang //===----------------------------------------------------------------------===//
27036ce915aSLei Zhang // Convert Async dialect types to LLVM types.
27136ce915aSLei Zhang //===----------------------------------------------------------------------===//
27236ce915aSLei Zhang 
27336ce915aSLei Zhang namespace {
274621ad468SEugene Zhulenev /// AsyncRuntimeTypeConverter only converts types from the Async dialect to
275621ad468SEugene Zhulenev /// their runtime type (opaque pointers) and does not convert any other types.
27636ce915aSLei Zhang class AsyncRuntimeTypeConverter : public TypeConverter {
27736ce915aSLei Zhang public:
278749f3708SChristian Ulmann   AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) {
279621ad468SEugene Zhulenev     addConversion([](Type type) { return type; });
280749f3708SChristian Ulmann     addConversion([](Type type) { return convertAsyncTypes(type); });
281d2a8a3afSEugene Zhulenev 
282d2a8a3afSEugene Zhulenev     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
283d2a8a3afSEugene Zhulenev     // in patterns for other dialects.
284d2a8a3afSEugene Zhulenev     auto addUnrealizedCast = [](OpBuilder &builder, Type type,
285f18c3e4eSMatthias Springer                                 ValueRange inputs, Location loc) -> Value {
286d2a8a3afSEugene Zhulenev       auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
287f18c3e4eSMatthias Springer       return cast.getResult(0);
288d2a8a3afSEugene Zhulenev     };
289d2a8a3afSEugene Zhulenev 
290d2a8a3afSEugene Zhulenev     addSourceMaterialization(addUnrealizedCast);
291d2a8a3afSEugene Zhulenev     addTargetMaterialization(addUnrealizedCast);
292621ad468SEugene Zhulenev   }
29336ce915aSLei Zhang 
294749f3708SChristian Ulmann   static std::optional<Type> convertAsyncTypes(Type type) {
2955550c821STres Popp     if (isa<TokenType, GroupType, ValueType>(type))
296749f3708SChristian Ulmann       return AsyncAPI::opaquePointerType(type.getContext());
2979c53b8e5SEugene Zhulenev 
2985550c821STres Popp     if (isa<CoroIdType, CoroStateType>(type))
2999c53b8e5SEugene Zhulenev       return AsyncAPI::tokenType(type.getContext());
3005550c821STres Popp     if (isa<CoroHandleType>(type))
301749f3708SChristian Ulmann       return AsyncAPI::opaquePointerType(type.getContext());
3029c53b8e5SEugene Zhulenev 
3031a36588eSKazu Hirata     return std::nullopt;
30436ce915aSLei Zhang   }
30536ce915aSLei Zhang };
3062ca46421SMarkus Böck 
3072ca46421SMarkus Böck /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter
3082ca46421SMarkus Böck /// as type converter. Allows access to it via the 'getTypeConverter'
3092ca46421SMarkus Böck /// convenience method.
3102ca46421SMarkus Böck template <typename SourceOp>
3112ca46421SMarkus Böck class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> {
3122ca46421SMarkus Böck 
3132ca46421SMarkus Böck   using Base = OpConversionPattern<SourceOp>;
3142ca46421SMarkus Böck 
3152ca46421SMarkus Böck public:
316ce254598SMatthias Springer   AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter,
3172ca46421SMarkus Böck                            MLIRContext *context)
3182ca46421SMarkus Böck       : Base(typeConverter, context) {}
3192ca46421SMarkus Böck 
3202ca46421SMarkus Böck   /// Returns the 'AsyncRuntimeTypeConverter' of the pattern.
321ce254598SMatthias Springer   const AsyncRuntimeTypeConverter *getTypeConverter() const {
322ce254598SMatthias Springer     return static_cast<const AsyncRuntimeTypeConverter *>(
323ce254598SMatthias Springer         Base::getTypeConverter());
3242ca46421SMarkus Böck   }
3252ca46421SMarkus Böck };
3262ca46421SMarkus Böck 
32736ce915aSLei Zhang } // namespace
32836ce915aSLei Zhang 
32936ce915aSLei Zhang //===----------------------------------------------------------------------===//
3309c53b8e5SEugene Zhulenev // Convert async.coro.id to @llvm.coro.id intrinsic.
33136ce915aSLei Zhang //===----------------------------------------------------------------------===//
33236ce915aSLei Zhang 
33336ce915aSLei Zhang namespace {
3342ca46421SMarkus Böck class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> {
33536ce915aSLei Zhang public:
3362ca46421SMarkus Böck   using AsyncOpConversionPattern::AsyncOpConversionPattern;
33736ce915aSLei Zhang 
33836ce915aSLei Zhang   LogicalResult
339b54c724bSRiver Riddle   matchAndRewrite(CoroIdOp op, OpAdaptor adaptor,
34036ce915aSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
3419c53b8e5SEugene Zhulenev     auto token = AsyncAPI::tokenType(op->getContext());
342749f3708SChristian Ulmann     auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
3439c53b8e5SEugene Zhulenev     auto loc = op->getLoc();
3449c53b8e5SEugene Zhulenev 
3459c53b8e5SEugene Zhulenev     // Constants for initializing coroutine frame.
3460af643f3SJeff Niu     auto constZero =
3470af643f3SJeff Niu         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0);
34885175eddSTobias Gysi     auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
3499c53b8e5SEugene Zhulenev 
3509c53b8e5SEugene Zhulenev     // Get coroutine id: @llvm.coro.id.
351d37b5393SEugene Zhulenev     rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>(
352d37b5393SEugene Zhulenev         op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr}));
3539c53b8e5SEugene Zhulenev 
35436ce915aSLei Zhang     return success();
35536ce915aSLei Zhang   }
35636ce915aSLei Zhang };
35736ce915aSLei Zhang } // namespace
35836ce915aSLei Zhang 
35936ce915aSLei Zhang //===----------------------------------------------------------------------===//
3609c53b8e5SEugene Zhulenev // Convert async.coro.begin to @llvm.coro.begin intrinsic.
3619c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
3629c53b8e5SEugene Zhulenev 
3639c53b8e5SEugene Zhulenev namespace {
3642ca46421SMarkus Böck class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
3659c53b8e5SEugene Zhulenev public:
3662ca46421SMarkus Böck   using AsyncOpConversionPattern::AsyncOpConversionPattern;
3679c53b8e5SEugene Zhulenev 
3689c53b8e5SEugene Zhulenev   LogicalResult
369b54c724bSRiver Riddle   matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor,
3709c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
371749f3708SChristian Ulmann     auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
3729c53b8e5SEugene Zhulenev     auto loc = op->getLoc();
3739c53b8e5SEugene Zhulenev 
3749c53b8e5SEugene Zhulenev     // Get coroutine frame size: @llvm.coro.size.i64.
375964dc368SBenjamin Kramer     Value coroSize =
376d37b5393SEugene Zhulenev         rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type());
377dbbe0109SChuanqi Xu     // Get coroutine frame alignment: @llvm.coro.align.i64.
378dbbe0109SChuanqi Xu     Value coroAlign =
379dbbe0109SChuanqi Xu         rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type());
380dbbe0109SChuanqi Xu 
381dbbe0109SChuanqi Xu     // Round up the size to be multiple of the alignment. Since aligned_alloc
382dbbe0109SChuanqi Xu     // requires the size parameter be an integral multiple of the alignment
383dbbe0109SChuanqi Xu     // parameter.
384964dc368SBenjamin Kramer     auto makeConstant = [&](uint64_t c) {
3855e0c3b43SJeff Niu       return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
3865e0c3b43SJeff Niu                                                rewriter.getI64Type(), c);
387964dc368SBenjamin Kramer     };
388dbbe0109SChuanqi Xu     coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign);
389dbbe0109SChuanqi Xu     coroSize =
390dbbe0109SChuanqi Xu         rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1));
391f65994c9SMehdi Amini     Value negCoroAlign =
392dbbe0109SChuanqi Xu         rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign);
393dbbe0109SChuanqi Xu     coroSize =
394f65994c9SMehdi Amini         rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign);
3959c53b8e5SEugene Zhulenev 
3969c53b8e5SEugene Zhulenev     // Allocate memory for the coroutine frame.
3975acd6e05SBenjamin Kramer     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398749f3708SChristian Ulmann         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399*e84f6b6aSLuohao Wang     if (failed(allocFuncOp))
400*e84f6b6aSLuohao Wang       return failure();
4019c53b8e5SEugene Zhulenev     auto coroAlloc = rewriter.create<LLVM::CallOp>(
402*e84f6b6aSLuohao Wang         loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
4039c53b8e5SEugene Zhulenev 
4049c53b8e5SEugene Zhulenev     // Begin a coroutine: @llvm.coro.begin.
405a5aa7836SRiver Riddle     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
406d37b5393SEugene Zhulenev     rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>(
4072ca46421SMarkus Böck         op, ptrType, ValueRange({coroId, coroAlloc.getResult()}));
4089c53b8e5SEugene Zhulenev 
4099c53b8e5SEugene Zhulenev     return success();
4109c53b8e5SEugene Zhulenev   }
4119c53b8e5SEugene Zhulenev };
4129c53b8e5SEugene Zhulenev } // namespace
4139c53b8e5SEugene Zhulenev 
4149c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4159c53b8e5SEugene Zhulenev // Convert async.coro.free to @llvm.coro.free intrinsic.
4169c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4179c53b8e5SEugene Zhulenev 
4189c53b8e5SEugene Zhulenev namespace {
4192ca46421SMarkus Böck class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
4209c53b8e5SEugene Zhulenev public:
4212ca46421SMarkus Böck   using AsyncOpConversionPattern::AsyncOpConversionPattern;
4229c53b8e5SEugene Zhulenev 
4239c53b8e5SEugene Zhulenev   LogicalResult
424b54c724bSRiver Riddle   matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor,
4259c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
426749f3708SChristian Ulmann     auto ptrType = AsyncAPI::opaquePointerType(op->getContext());
4279c53b8e5SEugene Zhulenev     auto loc = op->getLoc();
4289c53b8e5SEugene Zhulenev 
4299c53b8e5SEugene Zhulenev     // Get a pointer to the coroutine frame memory: @llvm.coro.free.
430b54c724bSRiver Riddle     auto coroMem =
4312ca46421SMarkus Böck         rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands());
4329c53b8e5SEugene Zhulenev 
4339c53b8e5SEugene Zhulenev     // Free the memory.
4342ca46421SMarkus Böck     auto freeFuncOp =
435749f3708SChristian Ulmann         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
436*e84f6b6aSLuohao Wang     if (failed(freeFuncOp))
437*e84f6b6aSLuohao Wang       return failure();
438*e84f6b6aSLuohao Wang     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
439d37b5393SEugene Zhulenev                                               ValueRange(coroMem.getResult()));
4409c53b8e5SEugene Zhulenev 
4419c53b8e5SEugene Zhulenev     return success();
4429c53b8e5SEugene Zhulenev   }
4439c53b8e5SEugene Zhulenev };
4449c53b8e5SEugene Zhulenev } // namespace
4459c53b8e5SEugene Zhulenev 
4469c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4479c53b8e5SEugene Zhulenev // Convert async.coro.end to @llvm.coro.end intrinsic.
4489c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4499c53b8e5SEugene Zhulenev 
4509c53b8e5SEugene Zhulenev namespace {
4519c53b8e5SEugene Zhulenev class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> {
4529c53b8e5SEugene Zhulenev public:
4539c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
4549c53b8e5SEugene Zhulenev 
4559c53b8e5SEugene Zhulenev   LogicalResult
456b54c724bSRiver Riddle   matchAndRewrite(CoroEndOp op, OpAdaptor adaptor,
4579c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
4589c53b8e5SEugene Zhulenev     // We are not in the block that is part of the unwind sequence.
4599c53b8e5SEugene Zhulenev     auto constFalse = rewriter.create<LLVM::ConstantOp>(
4609c53b8e5SEugene Zhulenev         op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false));
46151d5d7bbSAnton Korobeynikov     auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc());
4629c53b8e5SEugene Zhulenev 
4639c53b8e5SEugene Zhulenev     // Mark the end of a coroutine: @llvm.coro.end.
464a5aa7836SRiver Riddle     auto coroHdl = adaptor.getHandle();
465749f3708SChristian Ulmann     rewriter.create<LLVM::CoroEndOp>(
466749f3708SChristian Ulmann         op->getLoc(), rewriter.getI1Type(),
46751d5d7bbSAnton Korobeynikov         ValueRange({coroHdl, constFalse, noneToken}));
4689c53b8e5SEugene Zhulenev     rewriter.eraseOp(op);
4699c53b8e5SEugene Zhulenev 
4709c53b8e5SEugene Zhulenev     return success();
4719c53b8e5SEugene Zhulenev   }
4729c53b8e5SEugene Zhulenev };
4739c53b8e5SEugene Zhulenev } // namespace
4749c53b8e5SEugene Zhulenev 
4759c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4769c53b8e5SEugene Zhulenev // Convert async.coro.save to @llvm.coro.save intrinsic.
4779c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4789c53b8e5SEugene Zhulenev 
4799c53b8e5SEugene Zhulenev namespace {
4809c53b8e5SEugene Zhulenev class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> {
4819c53b8e5SEugene Zhulenev public:
4829c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
4839c53b8e5SEugene Zhulenev 
4849c53b8e5SEugene Zhulenev   LogicalResult
485b54c724bSRiver Riddle   matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor,
4869c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
4879c53b8e5SEugene Zhulenev     // Save the coroutine state: @llvm.coro.save
488d37b5393SEugene Zhulenev     rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>(
489b54c724bSRiver Riddle         op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands());
4909c53b8e5SEugene Zhulenev 
4919c53b8e5SEugene Zhulenev     return success();
4929c53b8e5SEugene Zhulenev   }
4939c53b8e5SEugene Zhulenev };
4949c53b8e5SEugene Zhulenev } // namespace
4959c53b8e5SEugene Zhulenev 
4969c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
4979c53b8e5SEugene Zhulenev // Convert async.coro.suspend to @llvm.coro.suspend intrinsic.
498a86a9b5eSEugene Zhulenev //===----------------------------------------------------------------------===//
499a86a9b5eSEugene Zhulenev 
500a86a9b5eSEugene Zhulenev namespace {
501a86a9b5eSEugene Zhulenev 
5029c53b8e5SEugene Zhulenev /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and
5039c53b8e5SEugene Zhulenev /// branch to the appropriate block based on the return code.
5049c53b8e5SEugene Zhulenev ///
5059c53b8e5SEugene Zhulenev /// Before:
5069c53b8e5SEugene Zhulenev ///
5079c53b8e5SEugene Zhulenev ///   ^suspended:
5089c53b8e5SEugene Zhulenev ///     "opBefore"(...)
5099c53b8e5SEugene Zhulenev ///     async.coro.suspend %state, ^suspend, ^resume, ^cleanup
5109c53b8e5SEugene Zhulenev ///   ^resume:
5119c53b8e5SEugene Zhulenev ///     "op"(...)
5129c53b8e5SEugene Zhulenev ///   ^cleanup: ...
5139c53b8e5SEugene Zhulenev ///   ^suspend: ...
5149c53b8e5SEugene Zhulenev ///
5159c53b8e5SEugene Zhulenev /// After:
5169c53b8e5SEugene Zhulenev ///
5179c53b8e5SEugene Zhulenev ///   ^suspended:
5189c53b8e5SEugene Zhulenev ///     "opBefore"(...)
519d37b5393SEugene Zhulenev ///     %suspend = llmv.intr.coro.suspend ...
5209c53b8e5SEugene Zhulenev ///     switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup]
5219c53b8e5SEugene Zhulenev ///   ^resume:
5229c53b8e5SEugene Zhulenev ///     "op"(...)
5239c53b8e5SEugene Zhulenev ///   ^cleanup: ...
5249c53b8e5SEugene Zhulenev ///   ^suspend: ...
5259c53b8e5SEugene Zhulenev ///
5269c53b8e5SEugene Zhulenev class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> {
5279c53b8e5SEugene Zhulenev public:
5289c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
5299c53b8e5SEugene Zhulenev 
5309c53b8e5SEugene Zhulenev   LogicalResult
531b54c724bSRiver Riddle   matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor,
5329c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
5339c53b8e5SEugene Zhulenev     auto i8 = rewriter.getIntegerType(8);
5349c53b8e5SEugene Zhulenev     auto i32 = rewriter.getI32Type();
5359c53b8e5SEugene Zhulenev     auto loc = op->getLoc();
5369c53b8e5SEugene Zhulenev 
5379c53b8e5SEugene Zhulenev     // This is not a final suspension point.
5389c53b8e5SEugene Zhulenev     auto constFalse = rewriter.create<LLVM::ConstantOp>(
5399c53b8e5SEugene Zhulenev         loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
5409c53b8e5SEugene Zhulenev 
5419c53b8e5SEugene Zhulenev     // Suspend a coroutine: @llvm.coro.suspend
542a5aa7836SRiver Riddle     auto coroState = adaptor.getState();
543d37b5393SEugene Zhulenev     auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>(
544d37b5393SEugene Zhulenev         loc, i8, ValueRange({coroState, constFalse}));
5459c53b8e5SEugene Zhulenev 
5469c53b8e5SEugene Zhulenev     // Cast return code to i32.
5479c53b8e5SEugene Zhulenev 
5489c53b8e5SEugene Zhulenev     // After a suspension point decide if we should branch into resume, cleanup
5499c53b8e5SEugene Zhulenev     // or suspend block of the coroutine (see @llvm.coro.suspend return code
5509c53b8e5SEugene Zhulenev     // documentation).
5519c53b8e5SEugene Zhulenev     llvm::SmallVector<int32_t, 2> caseValues = {0, 1};
552a5aa7836SRiver Riddle     llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(),
553a5aa7836SRiver Riddle                                               op.getCleanupDest()};
5549c53b8e5SEugene Zhulenev     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
555d37b5393SEugene Zhulenev         op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()),
556a5aa7836SRiver Riddle         /*defaultDestination=*/op.getSuspendDest(),
5579c53b8e5SEugene Zhulenev         /*defaultOperands=*/ValueRange(),
5589c53b8e5SEugene Zhulenev         /*caseValues=*/caseValues,
5599c53b8e5SEugene Zhulenev         /*caseDestinations=*/caseDest,
5604e103a12SRiver Riddle         /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
5619c53b8e5SEugene Zhulenev         /*branchWeights=*/ArrayRef<int32_t>());
5629c53b8e5SEugene Zhulenev 
5639c53b8e5SEugene Zhulenev     return success();
5649c53b8e5SEugene Zhulenev   }
5659c53b8e5SEugene Zhulenev };
5669c53b8e5SEugene Zhulenev } // namespace
5679c53b8e5SEugene Zhulenev 
5689c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
5699c53b8e5SEugene Zhulenev // Convert async.runtime.create to the corresponding runtime API call.
5709c53b8e5SEugene Zhulenev //
5719c53b8e5SEugene Zhulenev // To allocate storage for the async values we use getelementptr trick:
5729c53b8e5SEugene Zhulenev // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
5739c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
5749c53b8e5SEugene Zhulenev 
5759c53b8e5SEugene Zhulenev namespace {
5762ca46421SMarkus Böck class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> {
5779c53b8e5SEugene Zhulenev public:
5782ca46421SMarkus Böck   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
5799c53b8e5SEugene Zhulenev 
5809c53b8e5SEugene Zhulenev   LogicalResult
581b54c724bSRiver Riddle   matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor,
5829c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
583ce254598SMatthias Springer     const TypeConverter *converter = getTypeConverter();
5849c53b8e5SEugene Zhulenev     Type resultType = op->getResultTypes()[0];
5859c53b8e5SEugene Zhulenev 
586d43b2360SEugene Zhulenev     // Tokens creation maps to a simple function call.
5875550c821STres Popp     if (isa<TokenType>(resultType)) {
58823aa5a74SRiver Riddle       rewriter.replaceOpWithNewOp<func::CallOp>(
58923aa5a74SRiver Riddle           op, kCreateToken, converter->convertType(resultType));
5909c53b8e5SEugene Zhulenev       return success();
5919c53b8e5SEugene Zhulenev     }
5929c53b8e5SEugene Zhulenev 
5939c53b8e5SEugene Zhulenev     // To create a value we need to compute the storage requirement.
5945550c821STres Popp     if (auto value = dyn_cast<ValueType>(resultType)) {
5959c53b8e5SEugene Zhulenev       // Returns the size requirements for the async value storage.
5969c53b8e5SEugene Zhulenev       auto sizeOf = [&](ValueType valueType) -> Value {
5979c53b8e5SEugene Zhulenev         auto loc = op->getLoc();
59892db09cdSEugene Zhulenev         auto i64 = rewriter.getI64Type();
5999c53b8e5SEugene Zhulenev 
6009c53b8e5SEugene Zhulenev         auto storedType = converter->convertType(valueType.getValueType());
601749f3708SChristian Ulmann         auto storagePtrType =
602749f3708SChristian Ulmann             AsyncAPI::opaquePointerType(rewriter.getContext());
6039c53b8e5SEugene Zhulenev 
6049c53b8e5SEugene Zhulenev         // %Size = getelementptr %T* null, int 1
60592db09cdSEugene Zhulenev         // %SizeI = ptrtoint %T* %Size to i64
60685175eddSTobias Gysi         auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType);
6072ca46421SMarkus Böck         auto gep =
6082ca46421SMarkus Böck             rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType,
6092ca46421SMarkus Böck                                          nullPtr, ArrayRef<LLVM::GEPArg>{1});
61092db09cdSEugene Zhulenev         return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep);
6119c53b8e5SEugene Zhulenev       };
6129c53b8e5SEugene Zhulenev 
61323aa5a74SRiver Riddle       rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType,
6149c53b8e5SEugene Zhulenev                                                 sizeOf(value));
6159c53b8e5SEugene Zhulenev 
6169c53b8e5SEugene Zhulenev       return success();
6179c53b8e5SEugene Zhulenev     }
6189c53b8e5SEugene Zhulenev 
6199c53b8e5SEugene Zhulenev     return rewriter.notifyMatchFailure(op, "unsupported async type");
6209c53b8e5SEugene Zhulenev   }
6219c53b8e5SEugene Zhulenev };
6229c53b8e5SEugene Zhulenev } // namespace
6239c53b8e5SEugene Zhulenev 
6249c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
625d43b2360SEugene Zhulenev // Convert async.runtime.create_group to the corresponding runtime API call.
626d43b2360SEugene Zhulenev //===----------------------------------------------------------------------===//
627d43b2360SEugene Zhulenev 
628d43b2360SEugene Zhulenev namespace {
629d43b2360SEugene Zhulenev class RuntimeCreateGroupOpLowering
6302ca46421SMarkus Böck     : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> {
631d43b2360SEugene Zhulenev public:
6322ca46421SMarkus Böck   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
633d43b2360SEugene Zhulenev 
634d43b2360SEugene Zhulenev   LogicalResult
635b54c724bSRiver Riddle   matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor,
636d43b2360SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
637ce254598SMatthias Springer     const TypeConverter *converter = getTypeConverter();
63834a164c9SEugene Zhulenev     Type resultType = op.getResult().getType();
639d43b2360SEugene Zhulenev 
64023aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(
64123aa5a74SRiver Riddle         op, kCreateGroup, converter->convertType(resultType),
642b54c724bSRiver Riddle         adaptor.getOperands());
643d43b2360SEugene Zhulenev     return success();
644d43b2360SEugene Zhulenev   }
645d43b2360SEugene Zhulenev };
646d43b2360SEugene Zhulenev } // namespace
647d43b2360SEugene Zhulenev 
648d43b2360SEugene Zhulenev //===----------------------------------------------------------------------===//
6499c53b8e5SEugene Zhulenev // Convert async.runtime.set_available to the corresponding runtime API call.
6509c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
6519c53b8e5SEugene Zhulenev 
6529c53b8e5SEugene Zhulenev namespace {
6539c53b8e5SEugene Zhulenev class RuntimeSetAvailableOpLowering
6549c53b8e5SEugene Zhulenev     : public OpConversionPattern<RuntimeSetAvailableOp> {
6559c53b8e5SEugene Zhulenev public:
6569c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
6579c53b8e5SEugene Zhulenev 
6589c53b8e5SEugene Zhulenev   LogicalResult
659b54c724bSRiver Riddle   matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor,
6609c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
661d8c84d2aSEugene Zhulenev     StringRef apiFuncName =
662a5aa7836SRiver Riddle         TypeSwitch<Type, StringRef>(op.getOperand().getType())
663d8c84d2aSEugene Zhulenev             .Case<TokenType>([](Type) { return kEmplaceToken; })
664d8c84d2aSEugene Zhulenev             .Case<ValueType>([](Type) { return kEmplaceValue; });
665d8c84d2aSEugene Zhulenev 
66623aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
667b54c724bSRiver Riddle                                               adaptor.getOperands());
668d8c84d2aSEugene Zhulenev 
6699c53b8e5SEugene Zhulenev     return success();
6709c53b8e5SEugene Zhulenev   }
67139957aa4SEugene Zhulenev };
67239957aa4SEugene Zhulenev } // namespace
6739c53b8e5SEugene Zhulenev 
67439957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
67539957aa4SEugene Zhulenev // Convert async.runtime.set_error to the corresponding runtime API call.
67639957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
67739957aa4SEugene Zhulenev 
67839957aa4SEugene Zhulenev namespace {
67939957aa4SEugene Zhulenev class RuntimeSetErrorOpLowering
68039957aa4SEugene Zhulenev     : public OpConversionPattern<RuntimeSetErrorOp> {
68139957aa4SEugene Zhulenev public:
68239957aa4SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
68339957aa4SEugene Zhulenev 
68439957aa4SEugene Zhulenev   LogicalResult
685b54c724bSRiver Riddle   matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor,
68639957aa4SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
687d8c84d2aSEugene Zhulenev     StringRef apiFuncName =
688a5aa7836SRiver Riddle         TypeSwitch<Type, StringRef>(op.getOperand().getType())
689d8c84d2aSEugene Zhulenev             .Case<TokenType>([](Type) { return kSetTokenError; })
690d8c84d2aSEugene Zhulenev             .Case<ValueType>([](Type) { return kSetValueError; });
691d8c84d2aSEugene Zhulenev 
69223aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(),
693b54c724bSRiver Riddle                                               adaptor.getOperands());
694d8c84d2aSEugene Zhulenev 
69539957aa4SEugene Zhulenev     return success();
69639957aa4SEugene Zhulenev   }
69739957aa4SEugene Zhulenev };
69839957aa4SEugene Zhulenev } // namespace
69939957aa4SEugene Zhulenev 
70039957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
70139957aa4SEugene Zhulenev // Convert async.runtime.is_error to the corresponding runtime API call.
70239957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
70339957aa4SEugene Zhulenev 
70439957aa4SEugene Zhulenev namespace {
70539957aa4SEugene Zhulenev class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> {
70639957aa4SEugene Zhulenev public:
70739957aa4SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
70839957aa4SEugene Zhulenev 
70939957aa4SEugene Zhulenev   LogicalResult
710b54c724bSRiver Riddle   matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor,
71139957aa4SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
712d8c84d2aSEugene Zhulenev     StringRef apiFuncName =
713a5aa7836SRiver Riddle         TypeSwitch<Type, StringRef>(op.getOperand().getType())
714d8c84d2aSEugene Zhulenev             .Case<TokenType>([](Type) { return kIsTokenError; })
715d8c84d2aSEugene Zhulenev             .Case<GroupType>([](Type) { return kIsGroupError; })
716d8c84d2aSEugene Zhulenev             .Case<ValueType>([](Type) { return kIsValueError; });
717d8c84d2aSEugene Zhulenev 
71823aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(
71923aa5a74SRiver Riddle         op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands());
72039957aa4SEugene Zhulenev     return success();
7219c53b8e5SEugene Zhulenev   }
7229c53b8e5SEugene Zhulenev };
7239c53b8e5SEugene Zhulenev } // namespace
7249c53b8e5SEugene Zhulenev 
7259c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
7269c53b8e5SEugene Zhulenev // Convert async.runtime.await to the corresponding runtime API call.
7279c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
7289c53b8e5SEugene Zhulenev 
7299c53b8e5SEugene Zhulenev namespace {
7309c53b8e5SEugene Zhulenev class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> {
7319c53b8e5SEugene Zhulenev public:
7329c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
7339c53b8e5SEugene Zhulenev 
7349c53b8e5SEugene Zhulenev   LogicalResult
735b54c724bSRiver Riddle   matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor,
7369c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
737d8c84d2aSEugene Zhulenev     StringRef apiFuncName =
738a5aa7836SRiver Riddle         TypeSwitch<Type, StringRef>(op.getOperand().getType())
739d8c84d2aSEugene Zhulenev             .Case<TokenType>([](Type) { return kAwaitToken; })
740d8c84d2aSEugene Zhulenev             .Case<ValueType>([](Type) { return kAwaitValue; })
741d8c84d2aSEugene Zhulenev             .Case<GroupType>([](Type) { return kAwaitGroup; });
7429c53b8e5SEugene Zhulenev 
74323aa5a74SRiver Riddle     rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(),
744b54c724bSRiver Riddle                                   adaptor.getOperands());
7459c53b8e5SEugene Zhulenev     rewriter.eraseOp(op);
7469c53b8e5SEugene Zhulenev 
7479c53b8e5SEugene Zhulenev     return success();
7489c53b8e5SEugene Zhulenev   }
7499c53b8e5SEugene Zhulenev };
7509c53b8e5SEugene Zhulenev } // namespace
7519c53b8e5SEugene Zhulenev 
7529c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
7539c53b8e5SEugene Zhulenev // Convert async.runtime.await_and_resume to the corresponding runtime API call.
7549c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
7559c53b8e5SEugene Zhulenev 
7569c53b8e5SEugene Zhulenev namespace {
7579c53b8e5SEugene Zhulenev class RuntimeAwaitAndResumeOpLowering
7582ca46421SMarkus Böck     : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> {
7599c53b8e5SEugene Zhulenev public:
7602ca46421SMarkus Böck   using AsyncOpConversionPattern::AsyncOpConversionPattern;
7619c53b8e5SEugene Zhulenev 
7629c53b8e5SEugene Zhulenev   LogicalResult
763b54c724bSRiver Riddle   matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor,
7649c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
765d8c84d2aSEugene Zhulenev     StringRef apiFuncName =
766a5aa7836SRiver Riddle         TypeSwitch<Type, StringRef>(op.getOperand().getType())
767d8c84d2aSEugene Zhulenev             .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; })
768d8c84d2aSEugene Zhulenev             .Case<ValueType>([](Type) { return kAwaitValueAndExecute; })
769d8c84d2aSEugene Zhulenev             .Case<GroupType>([](Type) { return kAwaitAllAndExecute; });
7709c53b8e5SEugene Zhulenev 
771a5aa7836SRiver Riddle     Value operand = adaptor.getOperand();
772a5aa7836SRiver Riddle     Value handle = adaptor.getHandle();
7739c53b8e5SEugene Zhulenev 
7749c53b8e5SEugene Zhulenev     // A pointer to coroutine resume intrinsic wrapper.
775749f3708SChristian Ulmann     addResumeFunction(op->getParentOfType<ModuleOp>());
7769c53b8e5SEugene Zhulenev     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
777749f3708SChristian Ulmann         op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
778749f3708SChristian Ulmann         kResume);
7799c53b8e5SEugene Zhulenev 
78023aa5a74SRiver Riddle     rewriter.create<func::CallOp>(
78123aa5a74SRiver Riddle         op->getLoc(), apiFuncName, TypeRange(),
782cfb72fd3SJacques Pienaar         ValueRange({operand, handle, resumePtr.getRes()}));
7839c53b8e5SEugene Zhulenev     rewriter.eraseOp(op);
7849c53b8e5SEugene Zhulenev 
7859c53b8e5SEugene Zhulenev     return success();
7869c53b8e5SEugene Zhulenev   }
7879c53b8e5SEugene Zhulenev };
7889c53b8e5SEugene Zhulenev } // namespace
7899c53b8e5SEugene Zhulenev 
7909c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
7919c53b8e5SEugene Zhulenev // Convert async.runtime.resume to the corresponding runtime API call.
7929c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
7939c53b8e5SEugene Zhulenev 
7949c53b8e5SEugene Zhulenev namespace {
7952ca46421SMarkus Böck class RuntimeResumeOpLowering
7962ca46421SMarkus Böck     : public AsyncOpConversionPattern<RuntimeResumeOp> {
7979c53b8e5SEugene Zhulenev public:
7982ca46421SMarkus Böck   using AsyncOpConversionPattern::AsyncOpConversionPattern;
7999c53b8e5SEugene Zhulenev 
8009c53b8e5SEugene Zhulenev   LogicalResult
801b54c724bSRiver Riddle   matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor,
8029c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
8039c53b8e5SEugene Zhulenev     // A pointer to coroutine resume intrinsic wrapper.
804749f3708SChristian Ulmann     addResumeFunction(op->getParentOfType<ModuleOp>());
8059c53b8e5SEugene Zhulenev     auto resumePtr = rewriter.create<LLVM::AddressOfOp>(
806749f3708SChristian Ulmann         op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()),
807749f3708SChristian Ulmann         kResume);
8089c53b8e5SEugene Zhulenev 
8099c53b8e5SEugene Zhulenev     // Call async runtime API to execute a coroutine in the managed thread.
810a5aa7836SRiver Riddle     auto coroHdl = adaptor.getHandle();
81123aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(
812cfb72fd3SJacques Pienaar         op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()}));
8139c53b8e5SEugene Zhulenev 
8149c53b8e5SEugene Zhulenev     return success();
8159c53b8e5SEugene Zhulenev   }
8169c53b8e5SEugene Zhulenev };
8179c53b8e5SEugene Zhulenev } // namespace
8189c53b8e5SEugene Zhulenev 
8199c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
8209c53b8e5SEugene Zhulenev // Convert async.runtime.store to the corresponding runtime API call.
8219c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
8229c53b8e5SEugene Zhulenev 
8239c53b8e5SEugene Zhulenev namespace {
8242ca46421SMarkus Böck class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> {
8259c53b8e5SEugene Zhulenev public:
8262ca46421SMarkus Böck   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
8279c53b8e5SEugene Zhulenev 
8289c53b8e5SEugene Zhulenev   LogicalResult
829b54c724bSRiver Riddle   matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor,
8309c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
8319c53b8e5SEugene Zhulenev     Location loc = op->getLoc();
8329c53b8e5SEugene Zhulenev 
8339c53b8e5SEugene Zhulenev     // Get a pointer to the async value storage from the runtime.
834749f3708SChristian Ulmann     auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
835a5aa7836SRiver Riddle     auto storage = adaptor.getStorage();
8362ca46421SMarkus Böck     auto storagePtr = rewriter.create<func::CallOp>(
8372ca46421SMarkus Böck         loc, kGetValueStorage, TypeRange(ptrType), storage);
8389c53b8e5SEugene Zhulenev 
8399c53b8e5SEugene Zhulenev     // Cast from i8* to the LLVM pointer type.
840a5aa7836SRiver Riddle     auto valueType = op.getValue().getType();
8419c53b8e5SEugene Zhulenev     auto llvmValueType = getTypeConverter()->convertType(valueType);
84225f80e16SEugene Zhulenev     if (!llvmValueType)
84325f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(
84425f80e16SEugene Zhulenev           op, "failed to convert stored value type to LLVM type");
84525f80e16SEugene Zhulenev 
8462ca46421SMarkus Böck     Value castedStoragePtr = storagePtr.getResult(0);
8479c53b8e5SEugene Zhulenev     // Store the yielded value into the async value storage.
848a5aa7836SRiver Riddle     auto value = adaptor.getValue();
8492ca46421SMarkus Böck     rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr);
8509c53b8e5SEugene Zhulenev 
8519c53b8e5SEugene Zhulenev     // Erase the original runtime store operation.
8529c53b8e5SEugene Zhulenev     rewriter.eraseOp(op);
8539c53b8e5SEugene Zhulenev 
8549c53b8e5SEugene Zhulenev     return success();
8559c53b8e5SEugene Zhulenev   }
8569c53b8e5SEugene Zhulenev };
8579c53b8e5SEugene Zhulenev } // namespace
8589c53b8e5SEugene Zhulenev 
8599c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
8609c53b8e5SEugene Zhulenev // Convert async.runtime.load to the corresponding runtime API call.
8619c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
8629c53b8e5SEugene Zhulenev 
8639c53b8e5SEugene Zhulenev namespace {
8642ca46421SMarkus Böck class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> {
8659c53b8e5SEugene Zhulenev public:
8662ca46421SMarkus Böck   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
8679c53b8e5SEugene Zhulenev 
8689c53b8e5SEugene Zhulenev   LogicalResult
869b54c724bSRiver Riddle   matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor,
8709c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
8719c53b8e5SEugene Zhulenev     Location loc = op->getLoc();
8729c53b8e5SEugene Zhulenev 
8739c53b8e5SEugene Zhulenev     // Get a pointer to the async value storage from the runtime.
874749f3708SChristian Ulmann     auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext());
875a5aa7836SRiver Riddle     auto storage = adaptor.getStorage();
8762ca46421SMarkus Böck     auto storagePtr = rewriter.create<func::CallOp>(
8772ca46421SMarkus Böck         loc, kGetValueStorage, TypeRange(ptrType), storage);
8789c53b8e5SEugene Zhulenev 
8799c53b8e5SEugene Zhulenev     // Cast from i8* to the LLVM pointer type.
880a5aa7836SRiver Riddle     auto valueType = op.getResult().getType();
8819c53b8e5SEugene Zhulenev     auto llvmValueType = getTypeConverter()->convertType(valueType);
88225f80e16SEugene Zhulenev     if (!llvmValueType)
88325f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(
88425f80e16SEugene Zhulenev           op, "failed to convert loaded value type to LLVM type");
88525f80e16SEugene Zhulenev 
8862ca46421SMarkus Böck     Value castedStoragePtr = storagePtr.getResult(0);
8879c53b8e5SEugene Zhulenev 
8889c53b8e5SEugene Zhulenev     // Load from the casted pointer.
8892ca46421SMarkus Böck     rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType,
8902ca46421SMarkus Böck                                               castedStoragePtr);
8919c53b8e5SEugene Zhulenev 
8929c53b8e5SEugene Zhulenev     return success();
8939c53b8e5SEugene Zhulenev   }
8949c53b8e5SEugene Zhulenev };
8959c53b8e5SEugene Zhulenev } // namespace
8969c53b8e5SEugene Zhulenev 
8979c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
8989c53b8e5SEugene Zhulenev // Convert async.runtime.add_to_group to the corresponding runtime API call.
8999c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
9009c53b8e5SEugene Zhulenev 
9019c53b8e5SEugene Zhulenev namespace {
9029c53b8e5SEugene Zhulenev class RuntimeAddToGroupOpLowering
9039c53b8e5SEugene Zhulenev     : public OpConversionPattern<RuntimeAddToGroupOp> {
9049c53b8e5SEugene Zhulenev public:
9059c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
9069c53b8e5SEugene Zhulenev 
9079c53b8e5SEugene Zhulenev   LogicalResult
908b54c724bSRiver Riddle   matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor,
9099c53b8e5SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
9109c53b8e5SEugene Zhulenev     // Currently we can only add tokens to the group.
9115550c821STres Popp     if (!isa<TokenType>(op.getOperand().getType()))
9129c53b8e5SEugene Zhulenev       return rewriter.notifyMatchFailure(op, "only token type is supported");
9139c53b8e5SEugene Zhulenev 
9149c53b8e5SEugene Zhulenev     // Replace with a runtime API function call.
91523aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(
916b54c724bSRiver Riddle         op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands());
9179c53b8e5SEugene Zhulenev 
9189c53b8e5SEugene Zhulenev     return success();
9199c53b8e5SEugene Zhulenev   }
9209c53b8e5SEugene Zhulenev };
9219c53b8e5SEugene Zhulenev } // namespace
9229c53b8e5SEugene Zhulenev 
9239c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
924149311b4Sbakhtiyar // Convert async.runtime.num_worker_threads to the corresponding runtime API
925149311b4Sbakhtiyar // call.
926149311b4Sbakhtiyar //===----------------------------------------------------------------------===//
927149311b4Sbakhtiyar 
928149311b4Sbakhtiyar namespace {
929149311b4Sbakhtiyar class RuntimeNumWorkerThreadsOpLowering
930149311b4Sbakhtiyar     : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
931149311b4Sbakhtiyar public:
932149311b4Sbakhtiyar   using OpConversionPattern::OpConversionPattern;
933149311b4Sbakhtiyar 
934149311b4Sbakhtiyar   LogicalResult
935149311b4Sbakhtiyar   matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
936149311b4Sbakhtiyar                   ConversionPatternRewriter &rewriter) const override {
937149311b4Sbakhtiyar 
938149311b4Sbakhtiyar     // Replace with a runtime API function call.
93923aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads,
940149311b4Sbakhtiyar                                               rewriter.getIndexType());
941149311b4Sbakhtiyar 
942149311b4Sbakhtiyar     return success();
943149311b4Sbakhtiyar   }
944149311b4Sbakhtiyar };
945149311b4Sbakhtiyar } // namespace
946149311b4Sbakhtiyar 
947149311b4Sbakhtiyar //===----------------------------------------------------------------------===//
9489c53b8e5SEugene Zhulenev // Async reference counting ops lowering (`async.runtime.add_ref` and
9499c53b8e5SEugene Zhulenev // `async.runtime.drop_ref` to the corresponding API calls).
9509c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
9519c53b8e5SEugene Zhulenev 
9529c53b8e5SEugene Zhulenev namespace {
953a86a9b5eSEugene Zhulenev template <typename RefCountingOp>
9549c53b8e5SEugene Zhulenev class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> {
955a86a9b5eSEugene Zhulenev public:
956ce254598SMatthias Springer   explicit RefCountingOpLowering(const TypeConverter &converter,
957ce254598SMatthias Springer                                  MLIRContext *ctx, StringRef apiFunctionName)
9589c53b8e5SEugene Zhulenev       : OpConversionPattern<RefCountingOp>(converter, ctx),
959a86a9b5eSEugene Zhulenev         apiFunctionName(apiFunctionName) {}
960a86a9b5eSEugene Zhulenev 
961a86a9b5eSEugene Zhulenev   LogicalResult
962b54c724bSRiver Riddle   matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor,
963a86a9b5eSEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
964a54f4eaeSMogball     auto count = rewriter.create<arith::ConstantOp>(
965a54f4eaeSMogball         op->getLoc(), rewriter.getI64Type(),
966a5aa7836SRiver Riddle         rewriter.getI64IntegerAttr(op.getCount()));
967a86a9b5eSEugene Zhulenev 
968a5aa7836SRiver Riddle     auto operand = adaptor.getOperand();
96923aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName,
9709c53b8e5SEugene Zhulenev                                               ValueRange({operand, count}));
971a86a9b5eSEugene Zhulenev 
972a86a9b5eSEugene Zhulenev     return success();
973a86a9b5eSEugene Zhulenev   }
974a86a9b5eSEugene Zhulenev 
975a86a9b5eSEugene Zhulenev private:
976a86a9b5eSEugene Zhulenev   StringRef apiFunctionName;
977a86a9b5eSEugene Zhulenev };
978a86a9b5eSEugene Zhulenev 
9799c53b8e5SEugene Zhulenev class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> {
980a86a9b5eSEugene Zhulenev public:
981ce254598SMatthias Springer   explicit RuntimeAddRefOpLowering(const TypeConverter &converter,
982ce254598SMatthias Springer                                    MLIRContext *ctx)
983621ad468SEugene Zhulenev       : RefCountingOpLowering(converter, ctx, kAddRef) {}
984a86a9b5eSEugene Zhulenev };
985a86a9b5eSEugene Zhulenev 
9869c53b8e5SEugene Zhulenev class RuntimeDropRefOpLowering
9879c53b8e5SEugene Zhulenev     : public RefCountingOpLowering<RuntimeDropRefOp> {
988a86a9b5eSEugene Zhulenev public:
989ce254598SMatthias Springer   explicit RuntimeDropRefOpLowering(const TypeConverter &converter,
990ce254598SMatthias Springer                                     MLIRContext *ctx)
991621ad468SEugene Zhulenev       : RefCountingOpLowering(converter, ctx, kDropRef) {}
992a86a9b5eSEugene Zhulenev };
993a86a9b5eSEugene Zhulenev } // namespace
994a86a9b5eSEugene Zhulenev 
995a86a9b5eSEugene Zhulenev //===----------------------------------------------------------------------===//
9969c53b8e5SEugene Zhulenev // Convert return operations that return async values from async regions.
99736ce915aSLei Zhang //===----------------------------------------------------------------------===//
99836ce915aSLei Zhang 
99936ce915aSLei Zhang namespace {
100023aa5a74SRiver Riddle class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> {
100136ce915aSLei Zhang public:
10029c53b8e5SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
100336ce915aSLei Zhang 
100436ce915aSLei Zhang   LogicalResult
100523aa5a74SRiver Riddle   matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
100636ce915aSLei Zhang                   ConversionPatternRewriter &rewriter) const override {
100723aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands());
1008c30ab6c2SEugene Zhulenev     return success();
1009c30ab6c2SEugene Zhulenev   }
1010c30ab6c2SEugene Zhulenev };
1011c30ab6c2SEugene Zhulenev } // namespace
1012c30ab6c2SEugene Zhulenev 
1013c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===//
101436ce915aSLei Zhang 
101536ce915aSLei Zhang namespace {
101636ce915aSLei Zhang struct ConvertAsyncToLLVMPass
1017cd4ca2d7SMarkus Böck     : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> {
1018cd4ca2d7SMarkus Böck   using Base::Base;
1019cd4ca2d7SMarkus Böck 
102036ce915aSLei Zhang   void runOnOperation() override;
102136ce915aSLei Zhang };
10229c53b8e5SEugene Zhulenev } // namespace
102336ce915aSLei Zhang 
102436ce915aSLei Zhang void ConvertAsyncToLLVMPass::runOnOperation() {
102536ce915aSLei Zhang   ModuleOp module = getOperation();
102625f80e16SEugene Zhulenev   MLIRContext *ctx = module->getContext();
102736ce915aSLei Zhang 
10282ca46421SMarkus Böck   LowerToLLVMOptions options(ctx);
10292ca46421SMarkus Böck 
10305b388169SChristian Sigg   // Add declarations for most functions required by the coroutines lowering.
10315b388169SChristian Sigg   // We delay adding the resume function until it's needed because it currently
10325b388169SChristian Sigg   // fails to compile unless '-O0' is specified.
1033749f3708SChristian Ulmann   addAsyncRuntimeApiDeclarations(module);
103436ce915aSLei Zhang 
10359c53b8e5SEugene Zhulenev   // Lower async.runtime and async.coro operations to Async Runtime API and
10369c53b8e5SEugene Zhulenev   // LLVM coroutine intrinsics.
10379c53b8e5SEugene Zhulenev 
103836ce915aSLei Zhang   // Convert async dialect types and operations to LLVM dialect.
10392ca46421SMarkus Böck   AsyncRuntimeTypeConverter converter(options);
1040dc4e913bSChris Lattner   RewritePatternSet patterns(ctx);
104136ce915aSLei Zhang 
104225f80e16SEugene Zhulenev   // We use conversion to LLVM type to lower async.runtime load and store
104325f80e16SEugene Zhulenev   // operations.
10442ca46421SMarkus Böck   LLVMTypeConverter llvmConverter(ctx, options);
10452ca46421SMarkus Böck   llvmConverter.addConversion([&](Type type) {
1046749f3708SChristian Ulmann     return AsyncRuntimeTypeConverter::convertAsyncTypes(type);
10472ca46421SMarkus Böck   });
104825f80e16SEugene Zhulenev 
1049621ad468SEugene Zhulenev   // Convert async types in function signatures and function calls.
105058ceae95SRiver Riddle   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
105158ceae95SRiver Riddle                                                                  converter);
10523a506b31SChris Lattner   populateCallOpTypeConversionPattern(patterns, converter);
1053621ad468SEugene Zhulenev 
1054621ad468SEugene Zhulenev   // Convert return operations inside async.execute regions.
1055dc4e913bSChris Lattner   patterns.add<ReturnOpOpConversion>(converter, ctx);
1056621ad468SEugene Zhulenev 
10579c53b8e5SEugene Zhulenev   // Lower async.runtime operations to the async runtime API calls.
105839957aa4SEugene Zhulenev   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
105939957aa4SEugene Zhulenev                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
10609c53b8e5SEugene Zhulenev                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
1061149311b4Sbakhtiyar                RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
1062149311b4Sbakhtiyar                RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
1063149311b4Sbakhtiyar                                                                   ctx);
1064621ad468SEugene Zhulenev 
10659c53b8e5SEugene Zhulenev   // Lower async.runtime operations that rely on LLVM type converter to convert
10669c53b8e5SEugene Zhulenev   // from async value payload type to the LLVM type.
1067d43b2360SEugene Zhulenev   patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering,
10682ca46421SMarkus Böck                RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter);
10699c53b8e5SEugene Zhulenev 
10709c53b8e5SEugene Zhulenev   // Lower async coroutine operations to LLVM coroutine intrinsics.
1071dc4e913bSChris Lattner   patterns
1072dc4e913bSChris Lattner       .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion,
1073dc4e913bSChris Lattner            CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>(
1074dc4e913bSChris Lattner           converter, ctx);
107536ce915aSLei Zhang 
107636ce915aSLei Zhang   ConversionTarget target(*ctx);
107723aa5a74SRiver Riddle   target.addLegalOp<arith::ConstantOp, func::ConstantOp,
107823aa5a74SRiver Riddle                     UnrealizedConversionCastOp>();
107936ce915aSLei Zhang   target.addLegalDialect<LLVM::LLVMDialect>();
1080621ad468SEugene Zhulenev 
10819c53b8e5SEugene Zhulenev   // All operations from Async dialect must be lowered to the runtime API and
10829c53b8e5SEugene Zhulenev   // LLVM intrinsics calls.
108336ce915aSLei Zhang   target.addIllegalDialect<AsyncDialect>();
1084621ad468SEugene Zhulenev 
1085621ad468SEugene Zhulenev   // Add dynamic legality constraints to apply conversions defined above.
108658ceae95SRiver Riddle   target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
10874a3460a7SRiver Riddle     return converter.isSignatureLegal(op.getFunctionType());
10884a3460a7SRiver Riddle   });
108923aa5a74SRiver Riddle   target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
109023aa5a74SRiver Riddle     return converter.isLegal(op.getOperandTypes());
109123aa5a74SRiver Riddle   });
109223aa5a74SRiver Riddle   target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
1093621ad468SEugene Zhulenev     return converter.isSignatureLegal(op.getCalleeType());
1094621ad468SEugene Zhulenev   });
109536ce915aSLei Zhang 
10963fffffa8SRiver Riddle   if (failed(applyPartialConversion(module, target, std::move(patterns))))
109736ce915aSLei Zhang     signalPassFailure();
109836ce915aSLei Zhang }
10999c53b8e5SEugene Zhulenev 
11009c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
11019c53b8e5SEugene Zhulenev // Patterns for structural type conversions for the Async dialect operations.
11029c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===//
110336ce915aSLei Zhang 
1104195728c7SChristian Sigg namespace {
1105195728c7SChristian Sigg class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> {
1106195728c7SChristian Sigg public:
1107195728c7SChristian Sigg   using OpConversionPattern::OpConversionPattern;
1108195728c7SChristian Sigg   LogicalResult
1109b54c724bSRiver Riddle   matchAndRewrite(ExecuteOp op, OpAdaptor adaptor,
1110195728c7SChristian Sigg                   ConversionPatternRewriter &rewriter) const override {
1111195728c7SChristian Sigg     ExecuteOp newOp =
1112195728c7SChristian Sigg         cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
1113195728c7SChristian Sigg     rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
1114195728c7SChristian Sigg                                 newOp.getRegion().end());
1115195728c7SChristian Sigg 
1116195728c7SChristian Sigg     // Set operands and update block argument and result types.
1117b54c724bSRiver Riddle     newOp->setOperands(adaptor.getOperands());
1118195728c7SChristian Sigg     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
1119195728c7SChristian Sigg       return failure();
1120195728c7SChristian Sigg     for (auto result : newOp.getResults())
1121195728c7SChristian Sigg       result.setType(typeConverter->convertType(result.getType()));
1122195728c7SChristian Sigg 
1123195728c7SChristian Sigg     rewriter.replaceOp(op, newOp.getResults());
1124195728c7SChristian Sigg     return success();
1125195728c7SChristian Sigg   }
1126195728c7SChristian Sigg };
1127195728c7SChristian Sigg 
1128195728c7SChristian Sigg // Dummy pattern to trigger the appropriate type conversion / materialization.
1129195728c7SChristian Sigg class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> {
1130195728c7SChristian Sigg public:
1131195728c7SChristian Sigg   using OpConversionPattern::OpConversionPattern;
1132195728c7SChristian Sigg   LogicalResult
1133b54c724bSRiver Riddle   matchAndRewrite(AwaitOp op, OpAdaptor adaptor,
1134195728c7SChristian Sigg                   ConversionPatternRewriter &rewriter) const override {
1135b54c724bSRiver Riddle     rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front());
1136195728c7SChristian Sigg     return success();
1137195728c7SChristian Sigg   }
1138195728c7SChristian Sigg };
1139195728c7SChristian Sigg 
1140195728c7SChristian Sigg // Dummy pattern to trigger the appropriate type conversion / materialization.
1141195728c7SChristian Sigg class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> {
1142195728c7SChristian Sigg public:
1143195728c7SChristian Sigg   using OpConversionPattern::OpConversionPattern;
1144195728c7SChristian Sigg   LogicalResult
1145b54c724bSRiver Riddle   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
1146195728c7SChristian Sigg                   ConversionPatternRewriter &rewriter) const override {
1147b54c724bSRiver Riddle     rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands());
1148195728c7SChristian Sigg     return success();
1149195728c7SChristian Sigg   }
1150195728c7SChristian Sigg };
1151195728c7SChristian Sigg } // namespace
1152195728c7SChristian Sigg 
1153195728c7SChristian Sigg void mlir::populateAsyncStructuralTypeConversionsAndLegality(
1154dc4e913bSChris Lattner     TypeConverter &typeConverter, RewritePatternSet &patterns,
11553a506b31SChris Lattner     ConversionTarget &target) {
1156195728c7SChristian Sigg   typeConverter.addConversion([&](TokenType type) { return type; });
1157195728c7SChristian Sigg   typeConverter.addConversion([&](ValueType type) {
1158fd3f2518SChristian Sigg     Type converted = typeConverter.convertType(type.getValueType());
1159fd3f2518SChristian Sigg     return converted ? ValueType::get(converted) : converted;
1160195728c7SChristian Sigg   });
1161195728c7SChristian Sigg 
1162dc4e913bSChris Lattner   patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
11633a506b31SChris Lattner       typeConverter, patterns.getContext());
1164195728c7SChristian Sigg 
1165195728c7SChristian Sigg   target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
1166195728c7SChristian Sigg       [&](Operation *op) { return typeConverter.isLegal(op); });
1167195728c7SChristian Sigg }
1168