136ce915aSLei Zhang //===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// 236ce915aSLei Zhang // 336ce915aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 436ce915aSLei Zhang // See https://llvm.org/LICENSE.txt for license information. 536ce915aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 636ce915aSLei Zhang // 736ce915aSLei Zhang //===----------------------------------------------------------------------===// 836ce915aSLei Zhang 936ce915aSLei Zhang #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 1036ce915aSLei Zhang 11c4769ef5SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 125a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 1375e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 142ca46421SMarkus Böck #include "mlir/Conversion/LLVMCommon/Pattern.h" 1575e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1736ce915aSLei Zhang #include "mlir/Dialect/Async/IR/Async.h" 1823aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1923aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 205acd6e05SBenjamin Kramer #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 2136ce915aSLei Zhang #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 2275a3f326SChris Lattner #include "mlir/IR/ImplicitLocOpBuilder.h" 2336ce915aSLei Zhang #include "mlir/IR/TypeUtilities.h" 2436ce915aSLei Zhang #include "mlir/Pass/Pass.h" 2536ce915aSLei Zhang #include "mlir/Transforms/DialectConversion.h" 26d8c84d2aSEugene Zhulenev #include "llvm/ADT/TypeSwitch.h" 2736ce915aSLei Zhang 2867d0d7acSMichele Scuttari namespace mlir { 29cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS 3067d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3167d0d7acSMichele Scuttari } // namespace mlir 3267d0d7acSMichele Scuttari 3336ce915aSLei Zhang #define DEBUG_TYPE "convert-async-to-llvm" 3436ce915aSLei Zhang 3536ce915aSLei Zhang using namespace mlir; 3636ce915aSLei Zhang using namespace mlir::async; 3736ce915aSLei Zhang 3836ce915aSLei Zhang //===----------------------------------------------------------------------===// 3936ce915aSLei Zhang // Async Runtime C API declaration. 4036ce915aSLei Zhang //===----------------------------------------------------------------------===// 4136ce915aSLei Zhang 42a86a9b5eSEugene Zhulenev static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; 43a86a9b5eSEugene Zhulenev static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; 4436ce915aSLei Zhang static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; 45621ad468SEugene Zhulenev static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; 46c30ab6c2SEugene Zhulenev static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; 4736ce915aSLei Zhang static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; 48621ad468SEugene Zhulenev static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; 4939957aa4SEugene Zhulenev static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; 5039957aa4SEugene Zhulenev static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; 5139957aa4SEugene Zhulenev static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; 5239957aa4SEugene Zhulenev static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; 53d8c84d2aSEugene Zhulenev static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; 5436ce915aSLei Zhang static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; 55621ad468SEugene Zhulenev static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; 56c30ab6c2SEugene Zhulenev static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; 5736ce915aSLei Zhang static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; 58621ad468SEugene Zhulenev static constexpr const char *kGetValueStorage = 59621ad468SEugene Zhulenev "mlirAsyncRuntimeGetValueStorage"; 60c30ab6c2SEugene Zhulenev static constexpr const char *kAddTokenToGroup = 61c30ab6c2SEugene Zhulenev "mlirAsyncRuntimeAddTokenToGroup"; 62621ad468SEugene Zhulenev static constexpr const char *kAwaitTokenAndExecute = 6336ce915aSLei Zhang "mlirAsyncRuntimeAwaitTokenAndExecute"; 64621ad468SEugene Zhulenev static constexpr const char *kAwaitValueAndExecute = 65621ad468SEugene Zhulenev "mlirAsyncRuntimeAwaitValueAndExecute"; 66c30ab6c2SEugene Zhulenev static constexpr const char *kAwaitAllAndExecute = 67c30ab6c2SEugene Zhulenev "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; 68149311b4Sbakhtiyar static constexpr const char *kGetNumWorkerThreads = 69149311b4Sbakhtiyar "mlirAsyncRuntimGetNumWorkerThreads"; 7036ce915aSLei Zhang 7136ce915aSLei Zhang namespace { 72621ad468SEugene Zhulenev /// Async Runtime API function types. 73621ad468SEugene Zhulenev /// 74621ad468SEugene Zhulenev /// Because we can't create API function signature for type parametrized 752ca46421SMarkus Böck /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After 76621ad468SEugene Zhulenev /// lowering all async data types become opaque pointers at runtime. 7736ce915aSLei Zhang struct AsyncAPI { 782ca46421SMarkus Böck // All async types are lowered to opaque LLVM pointers at runtime. 79749f3708SChristian Ulmann static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { 802ca46421SMarkus Böck return LLVM::LLVMPointerType::get(ctx); 81621ad468SEugene Zhulenev } 82621ad468SEugene Zhulenev 839c53b8e5SEugene Zhulenev static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { 849c53b8e5SEugene Zhulenev return LLVM::LLVMTokenType::get(ctx); 859c53b8e5SEugene Zhulenev } 869c53b8e5SEugene Zhulenev 87749f3708SChristian Ulmann static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { 88749f3708SChristian Ulmann auto ref = opaquePointerType(ctx); 8992db09cdSEugene Zhulenev auto count = IntegerType::get(ctx, 64); 901b97cdf8SRiver Riddle return FunctionType::get(ctx, {ref, count}, {}); 91a86a9b5eSEugene Zhulenev } 92a86a9b5eSEugene Zhulenev 9336ce915aSLei Zhang static FunctionType createTokenFunctionType(MLIRContext *ctx) { 941b97cdf8SRiver Riddle return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); 9536ce915aSLei Zhang } 9636ce915aSLei Zhang 97749f3708SChristian Ulmann static FunctionType createValueFunctionType(MLIRContext *ctx) { 9892db09cdSEugene Zhulenev auto i64 = IntegerType::get(ctx, 64); 99749f3708SChristian Ulmann auto value = opaquePointerType(ctx); 10092db09cdSEugene Zhulenev return FunctionType::get(ctx, {i64}, {value}); 101621ad468SEugene Zhulenev } 102621ad468SEugene Zhulenev 103c30ab6c2SEugene Zhulenev static FunctionType createGroupFunctionType(MLIRContext *ctx) { 104d43b2360SEugene Zhulenev auto i64 = IntegerType::get(ctx, 64); 105d43b2360SEugene Zhulenev return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); 106c30ab6c2SEugene Zhulenev } 107c30ab6c2SEugene Zhulenev 108749f3708SChristian Ulmann static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { 109749f3708SChristian Ulmann auto ptrType = opaquePointerType(ctx); 110749f3708SChristian Ulmann return FunctionType::get(ctx, {ptrType}, {ptrType}); 111621ad468SEugene Zhulenev } 112621ad468SEugene Zhulenev 11336ce915aSLei Zhang static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { 1141b97cdf8SRiver Riddle return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 11536ce915aSLei Zhang } 11636ce915aSLei Zhang 117749f3708SChristian Ulmann static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { 118749f3708SChristian Ulmann auto value = opaquePointerType(ctx); 119621ad468SEugene Zhulenev return FunctionType::get(ctx, {value}, {}); 120621ad468SEugene Zhulenev } 121621ad468SEugene Zhulenev 12239957aa4SEugene Zhulenev static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { 12339957aa4SEugene Zhulenev return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 12439957aa4SEugene Zhulenev } 12539957aa4SEugene Zhulenev 126749f3708SChristian Ulmann static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { 127749f3708SChristian Ulmann auto value = opaquePointerType(ctx); 12839957aa4SEugene Zhulenev return FunctionType::get(ctx, {value}, {}); 12939957aa4SEugene Zhulenev } 13039957aa4SEugene Zhulenev 13139957aa4SEugene Zhulenev static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { 13239957aa4SEugene Zhulenev auto i1 = IntegerType::get(ctx, 1); 13339957aa4SEugene Zhulenev return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); 13439957aa4SEugene Zhulenev } 13539957aa4SEugene Zhulenev 136749f3708SChristian Ulmann static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { 137749f3708SChristian Ulmann auto value = opaquePointerType(ctx); 13839957aa4SEugene Zhulenev auto i1 = IntegerType::get(ctx, 1); 13939957aa4SEugene Zhulenev return FunctionType::get(ctx, {value}, {i1}); 14039957aa4SEugene Zhulenev } 14139957aa4SEugene Zhulenev 142d8c84d2aSEugene Zhulenev static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { 143d8c84d2aSEugene Zhulenev auto i1 = IntegerType::get(ctx, 1); 144d8c84d2aSEugene Zhulenev return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); 145d8c84d2aSEugene Zhulenev } 146d8c84d2aSEugene Zhulenev 14736ce915aSLei Zhang static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { 1481b97cdf8SRiver Riddle return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); 14936ce915aSLei Zhang } 15036ce915aSLei Zhang 151749f3708SChristian Ulmann static FunctionType awaitValueFunctionType(MLIRContext *ctx) { 152749f3708SChristian Ulmann auto value = opaquePointerType(ctx); 153621ad468SEugene Zhulenev return FunctionType::get(ctx, {value}, {}); 154621ad468SEugene Zhulenev } 155621ad468SEugene Zhulenev 156c30ab6c2SEugene Zhulenev static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { 1571b97cdf8SRiver Riddle return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); 158c30ab6c2SEugene Zhulenev } 159c30ab6c2SEugene Zhulenev 160749f3708SChristian Ulmann static FunctionType executeFunctionType(MLIRContext *ctx) { 161749f3708SChristian Ulmann auto ptrType = opaquePointerType(ctx); 162749f3708SChristian Ulmann return FunctionType::get(ctx, {ptrType, ptrType}, {}); 16336ce915aSLei Zhang } 16436ce915aSLei Zhang 165c30ab6c2SEugene Zhulenev static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { 1661b97cdf8SRiver Riddle auto i64 = IntegerType::get(ctx, 64); 1671b97cdf8SRiver Riddle return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, 1681b97cdf8SRiver Riddle {i64}); 169c30ab6c2SEugene Zhulenev } 170c30ab6c2SEugene Zhulenev 171749f3708SChristian Ulmann static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { 172749f3708SChristian Ulmann auto ptrType = opaquePointerType(ctx); 173749f3708SChristian Ulmann return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {}); 17436ce915aSLei Zhang } 17536ce915aSLei Zhang 176749f3708SChristian Ulmann static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { 177749f3708SChristian Ulmann auto ptrType = opaquePointerType(ctx); 178749f3708SChristian Ulmann return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {}); 179621ad468SEugene Zhulenev } 180621ad468SEugene Zhulenev 181749f3708SChristian Ulmann static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { 182749f3708SChristian Ulmann auto ptrType = opaquePointerType(ctx); 183749f3708SChristian Ulmann return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {}); 184c30ab6c2SEugene Zhulenev } 185c30ab6c2SEugene Zhulenev 186149311b4Sbakhtiyar static FunctionType getNumWorkerThreads(MLIRContext *ctx) { 187149311b4Sbakhtiyar return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); 188149311b4Sbakhtiyar } 189149311b4Sbakhtiyar 19036ce915aSLei Zhang // Auxiliary coroutine resume intrinsic wrapper. 191749f3708SChristian Ulmann static Type resumeFunctionType(MLIRContext *ctx) { 1927ed9cfc7SAlex Zinenko auto voidTy = LLVM::LLVMVoidType::get(ctx); 193749f3708SChristian Ulmann auto ptrType = opaquePointerType(ctx); 1942ca46421SMarkus Böck return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false); 19536ce915aSLei Zhang } 19636ce915aSLei Zhang }; 19736ce915aSLei Zhang } // namespace 19836ce915aSLei Zhang 199621ad468SEugene Zhulenev /// Adds Async Runtime C API declarations to the module. 200749f3708SChristian Ulmann static void addAsyncRuntimeApiDeclarations(ModuleOp module) { 201973ddb7dSMehdi Amini auto builder = 202973ddb7dSMehdi Amini ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); 20336ce915aSLei Zhang 204d8437552SRahul Joshi auto addFuncDecl = [&](StringRef name, FunctionType type) { 205d8437552SRahul Joshi if (module.lookupSymbol(name)) 206d8437552SRahul Joshi return; 20758ceae95SRiver Riddle builder.create<func::FuncOp>(name, type).setPrivate(); 208d8437552SRahul Joshi }; 209d8437552SRahul Joshi 21036ce915aSLei Zhang MLIRContext *ctx = module.getContext(); 211749f3708SChristian Ulmann addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 212749f3708SChristian Ulmann addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); 213d8437552SRahul Joshi addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); 214749f3708SChristian Ulmann addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); 215d8437552SRahul Joshi addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); 216d8437552SRahul Joshi addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); 217749f3708SChristian Ulmann addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); 21839957aa4SEugene Zhulenev addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); 219749f3708SChristian Ulmann addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); 22039957aa4SEugene Zhulenev addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); 221749f3708SChristian Ulmann addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); 222d8c84d2aSEugene Zhulenev addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); 223d8437552SRahul Joshi addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); 224749f3708SChristian Ulmann addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); 225d8437552SRahul Joshi addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); 226749f3708SChristian Ulmann addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); 227749f3708SChristian Ulmann addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); 228d8437552SRahul Joshi addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); 229749f3708SChristian Ulmann addFuncDecl(kAwaitTokenAndExecute, 230749f3708SChristian Ulmann AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); 231749f3708SChristian Ulmann addFuncDecl(kAwaitValueAndExecute, 232749f3708SChristian Ulmann AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); 233749f3708SChristian Ulmann addFuncDecl(kAwaitAllAndExecute, 234749f3708SChristian Ulmann AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); 235149311b4Sbakhtiyar addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); 23636ce915aSLei Zhang } 23736ce915aSLei Zhang 23836ce915aSLei Zhang //===----------------------------------------------------------------------===// 23936ce915aSLei Zhang // Coroutine resume function wrapper. 24036ce915aSLei Zhang //===----------------------------------------------------------------------===// 24136ce915aSLei Zhang 24236ce915aSLei Zhang static constexpr const char *kResume = "__resume"; 24336ce915aSLei Zhang 244621ad468SEugene Zhulenev /// A function that takes a coroutine handle and calls a `llvm.coro.resume` 245621ad468SEugene Zhulenev /// intrinsics. We need this function to be able to pass it to the async 246621ad468SEugene Zhulenev /// runtime execute API. 247749f3708SChristian Ulmann static void addResumeFunction(ModuleOp module) { 2485b388169SChristian Sigg if (module.lookupSymbol(kResume)) 2495b388169SChristian Sigg return; 2505b388169SChristian Sigg 25136ce915aSLei Zhang MLIRContext *ctx = module.getContext(); 252973ddb7dSMehdi Amini auto loc = module.getLoc(); 253973ddb7dSMehdi Amini auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); 25436ce915aSLei Zhang 2557ed9cfc7SAlex Zinenko auto voidTy = LLVM::LLVMVoidType::get(ctx); 256749f3708SChristian Ulmann Type ptrType = AsyncAPI::opaquePointerType(ctx); 25736ce915aSLei Zhang 25836ce915aSLei Zhang auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>( 2592ca46421SMarkus Böck kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); 260d8437552SRahul Joshi resumeOp.setPrivate(); 26136ce915aSLei Zhang 26291d5653eSMatthias Springer auto *block = resumeOp.addEntryBlock(moduleBuilder); 26375a3f326SChris Lattner auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); 26436ce915aSLei Zhang 265d37b5393SEugene Zhulenev blockBuilder.create<LLVM::CoroResumeOp>(resumeOp.getArgument(0)); 26675a3f326SChris Lattner blockBuilder.create<LLVM::ReturnOp>(ValueRange()); 26736ce915aSLei Zhang } 26836ce915aSLei Zhang 26936ce915aSLei Zhang //===----------------------------------------------------------------------===// 27036ce915aSLei Zhang // Convert Async dialect types to LLVM types. 27136ce915aSLei Zhang //===----------------------------------------------------------------------===// 27236ce915aSLei Zhang 27336ce915aSLei Zhang namespace { 274621ad468SEugene Zhulenev /// AsyncRuntimeTypeConverter only converts types from the Async dialect to 275621ad468SEugene Zhulenev /// their runtime type (opaque pointers) and does not convert any other types. 27636ce915aSLei Zhang class AsyncRuntimeTypeConverter : public TypeConverter { 27736ce915aSLei Zhang public: 278749f3708SChristian Ulmann AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) { 279621ad468SEugene Zhulenev addConversion([](Type type) { return type; }); 280749f3708SChristian Ulmann addConversion([](Type type) { return convertAsyncTypes(type); }); 281d2a8a3afSEugene Zhulenev 282d2a8a3afSEugene Zhulenev // Use UnrealizedConversionCast as the bridge so that we don't need to pull 283d2a8a3afSEugene Zhulenev // in patterns for other dialects. 284d2a8a3afSEugene Zhulenev auto addUnrealizedCast = [](OpBuilder &builder, Type type, 285f18c3e4eSMatthias Springer ValueRange inputs, Location loc) -> Value { 286d2a8a3afSEugene Zhulenev auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs); 287f18c3e4eSMatthias Springer return cast.getResult(0); 288d2a8a3afSEugene Zhulenev }; 289d2a8a3afSEugene Zhulenev 290d2a8a3afSEugene Zhulenev addSourceMaterialization(addUnrealizedCast); 291d2a8a3afSEugene Zhulenev addTargetMaterialization(addUnrealizedCast); 292621ad468SEugene Zhulenev } 29336ce915aSLei Zhang 294749f3708SChristian Ulmann static std::optional<Type> convertAsyncTypes(Type type) { 2955550c821STres Popp if (isa<TokenType, GroupType, ValueType>(type)) 296749f3708SChristian Ulmann return AsyncAPI::opaquePointerType(type.getContext()); 2979c53b8e5SEugene Zhulenev 2985550c821STres Popp if (isa<CoroIdType, CoroStateType>(type)) 2999c53b8e5SEugene Zhulenev return AsyncAPI::tokenType(type.getContext()); 3005550c821STres Popp if (isa<CoroHandleType>(type)) 301749f3708SChristian Ulmann return AsyncAPI::opaquePointerType(type.getContext()); 3029c53b8e5SEugene Zhulenev 3031a36588eSKazu Hirata return std::nullopt; 30436ce915aSLei Zhang } 30536ce915aSLei Zhang }; 3062ca46421SMarkus Böck 3072ca46421SMarkus Böck /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter 3082ca46421SMarkus Böck /// as type converter. Allows access to it via the 'getTypeConverter' 3092ca46421SMarkus Böck /// convenience method. 3102ca46421SMarkus Böck template <typename SourceOp> 3112ca46421SMarkus Böck class AsyncOpConversionPattern : public OpConversionPattern<SourceOp> { 3122ca46421SMarkus Böck 3132ca46421SMarkus Böck using Base = OpConversionPattern<SourceOp>; 3142ca46421SMarkus Böck 3152ca46421SMarkus Böck public: 316ce254598SMatthias Springer AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter, 3172ca46421SMarkus Böck MLIRContext *context) 3182ca46421SMarkus Böck : Base(typeConverter, context) {} 3192ca46421SMarkus Böck 3202ca46421SMarkus Böck /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. 321ce254598SMatthias Springer const AsyncRuntimeTypeConverter *getTypeConverter() const { 322ce254598SMatthias Springer return static_cast<const AsyncRuntimeTypeConverter *>( 323ce254598SMatthias Springer Base::getTypeConverter()); 3242ca46421SMarkus Böck } 3252ca46421SMarkus Böck }; 3262ca46421SMarkus Böck 32736ce915aSLei Zhang } // namespace 32836ce915aSLei Zhang 32936ce915aSLei Zhang //===----------------------------------------------------------------------===// 3309c53b8e5SEugene Zhulenev // Convert async.coro.id to @llvm.coro.id intrinsic. 33136ce915aSLei Zhang //===----------------------------------------------------------------------===// 33236ce915aSLei Zhang 33336ce915aSLei Zhang namespace { 3342ca46421SMarkus Böck class CoroIdOpConversion : public AsyncOpConversionPattern<CoroIdOp> { 33536ce915aSLei Zhang public: 3362ca46421SMarkus Böck using AsyncOpConversionPattern::AsyncOpConversionPattern; 33736ce915aSLei Zhang 33836ce915aSLei Zhang LogicalResult 339b54c724bSRiver Riddle matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, 34036ce915aSLei Zhang ConversionPatternRewriter &rewriter) const override { 3419c53b8e5SEugene Zhulenev auto token = AsyncAPI::tokenType(op->getContext()); 342749f3708SChristian Ulmann auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); 3439c53b8e5SEugene Zhulenev auto loc = op->getLoc(); 3449c53b8e5SEugene Zhulenev 3459c53b8e5SEugene Zhulenev // Constants for initializing coroutine frame. 3460af643f3SJeff Niu auto constZero = 3470af643f3SJeff Niu rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 0); 34885175eddSTobias Gysi auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrType); 3499c53b8e5SEugene Zhulenev 3509c53b8e5SEugene Zhulenev // Get coroutine id: @llvm.coro.id. 351d37b5393SEugene Zhulenev rewriter.replaceOpWithNewOp<LLVM::CoroIdOp>( 352d37b5393SEugene Zhulenev op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); 3539c53b8e5SEugene Zhulenev 35436ce915aSLei Zhang return success(); 35536ce915aSLei Zhang } 35636ce915aSLei Zhang }; 35736ce915aSLei Zhang } // namespace 35836ce915aSLei Zhang 35936ce915aSLei Zhang //===----------------------------------------------------------------------===// 3609c53b8e5SEugene Zhulenev // Convert async.coro.begin to @llvm.coro.begin intrinsic. 3619c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 3629c53b8e5SEugene Zhulenev 3639c53b8e5SEugene Zhulenev namespace { 3642ca46421SMarkus Böck class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> { 3659c53b8e5SEugene Zhulenev public: 3662ca46421SMarkus Böck using AsyncOpConversionPattern::AsyncOpConversionPattern; 3679c53b8e5SEugene Zhulenev 3689c53b8e5SEugene Zhulenev LogicalResult 369b54c724bSRiver Riddle matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, 3709c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 371749f3708SChristian Ulmann auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); 3729c53b8e5SEugene Zhulenev auto loc = op->getLoc(); 3739c53b8e5SEugene Zhulenev 3749c53b8e5SEugene Zhulenev // Get coroutine frame size: @llvm.coro.size.i64. 375964dc368SBenjamin Kramer Value coroSize = 376d37b5393SEugene Zhulenev rewriter.create<LLVM::CoroSizeOp>(loc, rewriter.getI64Type()); 377dbbe0109SChuanqi Xu // Get coroutine frame alignment: @llvm.coro.align.i64. 378dbbe0109SChuanqi Xu Value coroAlign = 379dbbe0109SChuanqi Xu rewriter.create<LLVM::CoroAlignOp>(loc, rewriter.getI64Type()); 380dbbe0109SChuanqi Xu 381dbbe0109SChuanqi Xu // Round up the size to be multiple of the alignment. Since aligned_alloc 382dbbe0109SChuanqi Xu // requires the size parameter be an integral multiple of the alignment 383dbbe0109SChuanqi Xu // parameter. 384964dc368SBenjamin Kramer auto makeConstant = [&](uint64_t c) { 3855e0c3b43SJeff Niu return rewriter.create<LLVM::ConstantOp>(op->getLoc(), 3865e0c3b43SJeff Niu rewriter.getI64Type(), c); 387964dc368SBenjamin Kramer }; 388dbbe0109SChuanqi Xu coroSize = rewriter.create<LLVM::AddOp>(op->getLoc(), coroSize, coroAlign); 389dbbe0109SChuanqi Xu coroSize = 390dbbe0109SChuanqi Xu rewriter.create<LLVM::SubOp>(op->getLoc(), coroSize, makeConstant(1)); 391f65994c9SMehdi Amini Value negCoroAlign = 392dbbe0109SChuanqi Xu rewriter.create<LLVM::SubOp>(op->getLoc(), makeConstant(0), coroAlign); 393dbbe0109SChuanqi Xu coroSize = 394f65994c9SMehdi Amini rewriter.create<LLVM::AndOp>(op->getLoc(), coroSize, negCoroAlign); 3959c53b8e5SEugene Zhulenev 3969c53b8e5SEugene Zhulenev // Allocate memory for the coroutine frame. 3975acd6e05SBenjamin Kramer auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( 398749f3708SChristian Ulmann op->getParentOfType<ModuleOp>(), rewriter.getI64Type()); 399*e84f6b6aSLuohao Wang if (failed(allocFuncOp)) 400*e84f6b6aSLuohao Wang return failure(); 4019c53b8e5SEugene Zhulenev auto coroAlloc = rewriter.create<LLVM::CallOp>( 402*e84f6b6aSLuohao Wang loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); 4039c53b8e5SEugene Zhulenev 4049c53b8e5SEugene Zhulenev // Begin a coroutine: @llvm.coro.begin. 405a5aa7836SRiver Riddle auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); 406d37b5393SEugene Zhulenev rewriter.replaceOpWithNewOp<LLVM::CoroBeginOp>( 4072ca46421SMarkus Böck op, ptrType, ValueRange({coroId, coroAlloc.getResult()})); 4089c53b8e5SEugene Zhulenev 4099c53b8e5SEugene Zhulenev return success(); 4109c53b8e5SEugene Zhulenev } 4119c53b8e5SEugene Zhulenev }; 4129c53b8e5SEugene Zhulenev } // namespace 4139c53b8e5SEugene Zhulenev 4149c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4159c53b8e5SEugene Zhulenev // Convert async.coro.free to @llvm.coro.free intrinsic. 4169c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4179c53b8e5SEugene Zhulenev 4189c53b8e5SEugene Zhulenev namespace { 4192ca46421SMarkus Böck class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> { 4209c53b8e5SEugene Zhulenev public: 4212ca46421SMarkus Böck using AsyncOpConversionPattern::AsyncOpConversionPattern; 4229c53b8e5SEugene Zhulenev 4239c53b8e5SEugene Zhulenev LogicalResult 424b54c724bSRiver Riddle matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, 4259c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 426749f3708SChristian Ulmann auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); 4279c53b8e5SEugene Zhulenev auto loc = op->getLoc(); 4289c53b8e5SEugene Zhulenev 4299c53b8e5SEugene Zhulenev // Get a pointer to the coroutine frame memory: @llvm.coro.free. 430b54c724bSRiver Riddle auto coroMem = 4312ca46421SMarkus Böck rewriter.create<LLVM::CoroFreeOp>(loc, ptrType, adaptor.getOperands()); 4329c53b8e5SEugene Zhulenev 4339c53b8e5SEugene Zhulenev // Free the memory. 4342ca46421SMarkus Böck auto freeFuncOp = 435749f3708SChristian Ulmann LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); 436*e84f6b6aSLuohao Wang if (failed(freeFuncOp)) 437*e84f6b6aSLuohao Wang return failure(); 438*e84f6b6aSLuohao Wang rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(), 439d37b5393SEugene Zhulenev ValueRange(coroMem.getResult())); 4409c53b8e5SEugene Zhulenev 4419c53b8e5SEugene Zhulenev return success(); 4429c53b8e5SEugene Zhulenev } 4439c53b8e5SEugene Zhulenev }; 4449c53b8e5SEugene Zhulenev } // namespace 4459c53b8e5SEugene Zhulenev 4469c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4479c53b8e5SEugene Zhulenev // Convert async.coro.end to @llvm.coro.end intrinsic. 4489c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4499c53b8e5SEugene Zhulenev 4509c53b8e5SEugene Zhulenev namespace { 4519c53b8e5SEugene Zhulenev class CoroEndOpConversion : public OpConversionPattern<CoroEndOp> { 4529c53b8e5SEugene Zhulenev public: 4539c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 4549c53b8e5SEugene Zhulenev 4559c53b8e5SEugene Zhulenev LogicalResult 456b54c724bSRiver Riddle matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, 4579c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 4589c53b8e5SEugene Zhulenev // We are not in the block that is part of the unwind sequence. 4599c53b8e5SEugene Zhulenev auto constFalse = rewriter.create<LLVM::ConstantOp>( 4609c53b8e5SEugene Zhulenev op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); 46151d5d7bbSAnton Korobeynikov auto noneToken = rewriter.create<LLVM::NoneTokenOp>(op->getLoc()); 4629c53b8e5SEugene Zhulenev 4639c53b8e5SEugene Zhulenev // Mark the end of a coroutine: @llvm.coro.end. 464a5aa7836SRiver Riddle auto coroHdl = adaptor.getHandle(); 465749f3708SChristian Ulmann rewriter.create<LLVM::CoroEndOp>( 466749f3708SChristian Ulmann op->getLoc(), rewriter.getI1Type(), 46751d5d7bbSAnton Korobeynikov ValueRange({coroHdl, constFalse, noneToken})); 4689c53b8e5SEugene Zhulenev rewriter.eraseOp(op); 4699c53b8e5SEugene Zhulenev 4709c53b8e5SEugene Zhulenev return success(); 4719c53b8e5SEugene Zhulenev } 4729c53b8e5SEugene Zhulenev }; 4739c53b8e5SEugene Zhulenev } // namespace 4749c53b8e5SEugene Zhulenev 4759c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4769c53b8e5SEugene Zhulenev // Convert async.coro.save to @llvm.coro.save intrinsic. 4779c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4789c53b8e5SEugene Zhulenev 4799c53b8e5SEugene Zhulenev namespace { 4809c53b8e5SEugene Zhulenev class CoroSaveOpConversion : public OpConversionPattern<CoroSaveOp> { 4819c53b8e5SEugene Zhulenev public: 4829c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 4839c53b8e5SEugene Zhulenev 4849c53b8e5SEugene Zhulenev LogicalResult 485b54c724bSRiver Riddle matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, 4869c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 4879c53b8e5SEugene Zhulenev // Save the coroutine state: @llvm.coro.save 488d37b5393SEugene Zhulenev rewriter.replaceOpWithNewOp<LLVM::CoroSaveOp>( 489b54c724bSRiver Riddle op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); 4909c53b8e5SEugene Zhulenev 4919c53b8e5SEugene Zhulenev return success(); 4929c53b8e5SEugene Zhulenev } 4939c53b8e5SEugene Zhulenev }; 4949c53b8e5SEugene Zhulenev } // namespace 4959c53b8e5SEugene Zhulenev 4969c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 4979c53b8e5SEugene Zhulenev // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. 498a86a9b5eSEugene Zhulenev //===----------------------------------------------------------------------===// 499a86a9b5eSEugene Zhulenev 500a86a9b5eSEugene Zhulenev namespace { 501a86a9b5eSEugene Zhulenev 5029c53b8e5SEugene Zhulenev /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and 5039c53b8e5SEugene Zhulenev /// branch to the appropriate block based on the return code. 5049c53b8e5SEugene Zhulenev /// 5059c53b8e5SEugene Zhulenev /// Before: 5069c53b8e5SEugene Zhulenev /// 5079c53b8e5SEugene Zhulenev /// ^suspended: 5089c53b8e5SEugene Zhulenev /// "opBefore"(...) 5099c53b8e5SEugene Zhulenev /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup 5109c53b8e5SEugene Zhulenev /// ^resume: 5119c53b8e5SEugene Zhulenev /// "op"(...) 5129c53b8e5SEugene Zhulenev /// ^cleanup: ... 5139c53b8e5SEugene Zhulenev /// ^suspend: ... 5149c53b8e5SEugene Zhulenev /// 5159c53b8e5SEugene Zhulenev /// After: 5169c53b8e5SEugene Zhulenev /// 5179c53b8e5SEugene Zhulenev /// ^suspended: 5189c53b8e5SEugene Zhulenev /// "opBefore"(...) 519d37b5393SEugene Zhulenev /// %suspend = llmv.intr.coro.suspend ... 5209c53b8e5SEugene Zhulenev /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] 5219c53b8e5SEugene Zhulenev /// ^resume: 5229c53b8e5SEugene Zhulenev /// "op"(...) 5239c53b8e5SEugene Zhulenev /// ^cleanup: ... 5249c53b8e5SEugene Zhulenev /// ^suspend: ... 5259c53b8e5SEugene Zhulenev /// 5269c53b8e5SEugene Zhulenev class CoroSuspendOpConversion : public OpConversionPattern<CoroSuspendOp> { 5279c53b8e5SEugene Zhulenev public: 5289c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 5299c53b8e5SEugene Zhulenev 5309c53b8e5SEugene Zhulenev LogicalResult 531b54c724bSRiver Riddle matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, 5329c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 5339c53b8e5SEugene Zhulenev auto i8 = rewriter.getIntegerType(8); 5349c53b8e5SEugene Zhulenev auto i32 = rewriter.getI32Type(); 5359c53b8e5SEugene Zhulenev auto loc = op->getLoc(); 5369c53b8e5SEugene Zhulenev 5379c53b8e5SEugene Zhulenev // This is not a final suspension point. 5389c53b8e5SEugene Zhulenev auto constFalse = rewriter.create<LLVM::ConstantOp>( 5399c53b8e5SEugene Zhulenev loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); 5409c53b8e5SEugene Zhulenev 5419c53b8e5SEugene Zhulenev // Suspend a coroutine: @llvm.coro.suspend 542a5aa7836SRiver Riddle auto coroState = adaptor.getState(); 543d37b5393SEugene Zhulenev auto coroSuspend = rewriter.create<LLVM::CoroSuspendOp>( 544d37b5393SEugene Zhulenev loc, i8, ValueRange({coroState, constFalse})); 5459c53b8e5SEugene Zhulenev 5469c53b8e5SEugene Zhulenev // Cast return code to i32. 5479c53b8e5SEugene Zhulenev 5489c53b8e5SEugene Zhulenev // After a suspension point decide if we should branch into resume, cleanup 5499c53b8e5SEugene Zhulenev // or suspend block of the coroutine (see @llvm.coro.suspend return code 5509c53b8e5SEugene Zhulenev // documentation). 5519c53b8e5SEugene Zhulenev llvm::SmallVector<int32_t, 2> caseValues = {0, 1}; 552a5aa7836SRiver Riddle llvm::SmallVector<Block *, 2> caseDest = {op.getResumeDest(), 553a5aa7836SRiver Riddle op.getCleanupDest()}; 5549c53b8e5SEugene Zhulenev rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 555d37b5393SEugene Zhulenev op, rewriter.create<LLVM::SExtOp>(loc, i32, coroSuspend.getResult()), 556a5aa7836SRiver Riddle /*defaultDestination=*/op.getSuspendDest(), 5579c53b8e5SEugene Zhulenev /*defaultOperands=*/ValueRange(), 5589c53b8e5SEugene Zhulenev /*caseValues=*/caseValues, 5599c53b8e5SEugene Zhulenev /*caseDestinations=*/caseDest, 5604e103a12SRiver Riddle /*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}), 5619c53b8e5SEugene Zhulenev /*branchWeights=*/ArrayRef<int32_t>()); 5629c53b8e5SEugene Zhulenev 5639c53b8e5SEugene Zhulenev return success(); 5649c53b8e5SEugene Zhulenev } 5659c53b8e5SEugene Zhulenev }; 5669c53b8e5SEugene Zhulenev } // namespace 5679c53b8e5SEugene Zhulenev 5689c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 5699c53b8e5SEugene Zhulenev // Convert async.runtime.create to the corresponding runtime API call. 5709c53b8e5SEugene Zhulenev // 5719c53b8e5SEugene Zhulenev // To allocate storage for the async values we use getelementptr trick: 5729c53b8e5SEugene Zhulenev // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt 5739c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 5749c53b8e5SEugene Zhulenev 5759c53b8e5SEugene Zhulenev namespace { 5762ca46421SMarkus Böck class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern<RuntimeCreateOp> { 5779c53b8e5SEugene Zhulenev public: 5782ca46421SMarkus Böck using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 5799c53b8e5SEugene Zhulenev 5809c53b8e5SEugene Zhulenev LogicalResult 581b54c724bSRiver Riddle matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, 5829c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 583ce254598SMatthias Springer const TypeConverter *converter = getTypeConverter(); 5849c53b8e5SEugene Zhulenev Type resultType = op->getResultTypes()[0]; 5859c53b8e5SEugene Zhulenev 586d43b2360SEugene Zhulenev // Tokens creation maps to a simple function call. 5875550c821STres Popp if (isa<TokenType>(resultType)) { 58823aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>( 58923aa5a74SRiver Riddle op, kCreateToken, converter->convertType(resultType)); 5909c53b8e5SEugene Zhulenev return success(); 5919c53b8e5SEugene Zhulenev } 5929c53b8e5SEugene Zhulenev 5939c53b8e5SEugene Zhulenev // To create a value we need to compute the storage requirement. 5945550c821STres Popp if (auto value = dyn_cast<ValueType>(resultType)) { 5959c53b8e5SEugene Zhulenev // Returns the size requirements for the async value storage. 5969c53b8e5SEugene Zhulenev auto sizeOf = [&](ValueType valueType) -> Value { 5979c53b8e5SEugene Zhulenev auto loc = op->getLoc(); 59892db09cdSEugene Zhulenev auto i64 = rewriter.getI64Type(); 5999c53b8e5SEugene Zhulenev 6009c53b8e5SEugene Zhulenev auto storedType = converter->convertType(valueType.getValueType()); 601749f3708SChristian Ulmann auto storagePtrType = 602749f3708SChristian Ulmann AsyncAPI::opaquePointerType(rewriter.getContext()); 6039c53b8e5SEugene Zhulenev 6049c53b8e5SEugene Zhulenev // %Size = getelementptr %T* null, int 1 60592db09cdSEugene Zhulenev // %SizeI = ptrtoint %T* %Size to i64 60685175eddSTobias Gysi auto nullPtr = rewriter.create<LLVM::ZeroOp>(loc, storagePtrType); 6072ca46421SMarkus Böck auto gep = 6082ca46421SMarkus Böck rewriter.create<LLVM::GEPOp>(loc, storagePtrType, storedType, 6092ca46421SMarkus Böck nullPtr, ArrayRef<LLVM::GEPArg>{1}); 61092db09cdSEugene Zhulenev return rewriter.create<LLVM::PtrToIntOp>(loc, i64, gep); 6119c53b8e5SEugene Zhulenev }; 6129c53b8e5SEugene Zhulenev 61323aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, kCreateValue, resultType, 6149c53b8e5SEugene Zhulenev sizeOf(value)); 6159c53b8e5SEugene Zhulenev 6169c53b8e5SEugene Zhulenev return success(); 6179c53b8e5SEugene Zhulenev } 6189c53b8e5SEugene Zhulenev 6199c53b8e5SEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported async type"); 6209c53b8e5SEugene Zhulenev } 6219c53b8e5SEugene Zhulenev }; 6229c53b8e5SEugene Zhulenev } // namespace 6239c53b8e5SEugene Zhulenev 6249c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 625d43b2360SEugene Zhulenev // Convert async.runtime.create_group to the corresponding runtime API call. 626d43b2360SEugene Zhulenev //===----------------------------------------------------------------------===// 627d43b2360SEugene Zhulenev 628d43b2360SEugene Zhulenev namespace { 629d43b2360SEugene Zhulenev class RuntimeCreateGroupOpLowering 6302ca46421SMarkus Böck : public ConvertOpToLLVMPattern<RuntimeCreateGroupOp> { 631d43b2360SEugene Zhulenev public: 6322ca46421SMarkus Böck using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 633d43b2360SEugene Zhulenev 634d43b2360SEugene Zhulenev LogicalResult 635b54c724bSRiver Riddle matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, 636d43b2360SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 637ce254598SMatthias Springer const TypeConverter *converter = getTypeConverter(); 63834a164c9SEugene Zhulenev Type resultType = op.getResult().getType(); 639d43b2360SEugene Zhulenev 64023aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>( 64123aa5a74SRiver Riddle op, kCreateGroup, converter->convertType(resultType), 642b54c724bSRiver Riddle adaptor.getOperands()); 643d43b2360SEugene Zhulenev return success(); 644d43b2360SEugene Zhulenev } 645d43b2360SEugene Zhulenev }; 646d43b2360SEugene Zhulenev } // namespace 647d43b2360SEugene Zhulenev 648d43b2360SEugene Zhulenev //===----------------------------------------------------------------------===// 6499c53b8e5SEugene Zhulenev // Convert async.runtime.set_available to the corresponding runtime API call. 6509c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 6519c53b8e5SEugene Zhulenev 6529c53b8e5SEugene Zhulenev namespace { 6539c53b8e5SEugene Zhulenev class RuntimeSetAvailableOpLowering 6549c53b8e5SEugene Zhulenev : public OpConversionPattern<RuntimeSetAvailableOp> { 6559c53b8e5SEugene Zhulenev public: 6569c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 6579c53b8e5SEugene Zhulenev 6589c53b8e5SEugene Zhulenev LogicalResult 659b54c724bSRiver Riddle matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, 6609c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 661d8c84d2aSEugene Zhulenev StringRef apiFuncName = 662a5aa7836SRiver Riddle TypeSwitch<Type, StringRef>(op.getOperand().getType()) 663d8c84d2aSEugene Zhulenev .Case<TokenType>([](Type) { return kEmplaceToken; }) 664d8c84d2aSEugene Zhulenev .Case<ValueType>([](Type) { return kEmplaceValue; }); 665d8c84d2aSEugene Zhulenev 66623aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), 667b54c724bSRiver Riddle adaptor.getOperands()); 668d8c84d2aSEugene Zhulenev 6699c53b8e5SEugene Zhulenev return success(); 6709c53b8e5SEugene Zhulenev } 67139957aa4SEugene Zhulenev }; 67239957aa4SEugene Zhulenev } // namespace 6739c53b8e5SEugene Zhulenev 67439957aa4SEugene Zhulenev //===----------------------------------------------------------------------===// 67539957aa4SEugene Zhulenev // Convert async.runtime.set_error to the corresponding runtime API call. 67639957aa4SEugene Zhulenev //===----------------------------------------------------------------------===// 67739957aa4SEugene Zhulenev 67839957aa4SEugene Zhulenev namespace { 67939957aa4SEugene Zhulenev class RuntimeSetErrorOpLowering 68039957aa4SEugene Zhulenev : public OpConversionPattern<RuntimeSetErrorOp> { 68139957aa4SEugene Zhulenev public: 68239957aa4SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 68339957aa4SEugene Zhulenev 68439957aa4SEugene Zhulenev LogicalResult 685b54c724bSRiver Riddle matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, 68639957aa4SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 687d8c84d2aSEugene Zhulenev StringRef apiFuncName = 688a5aa7836SRiver Riddle TypeSwitch<Type, StringRef>(op.getOperand().getType()) 689d8c84d2aSEugene Zhulenev .Case<TokenType>([](Type) { return kSetTokenError; }) 690d8c84d2aSEugene Zhulenev .Case<ValueType>([](Type) { return kSetValueError; }); 691d8c84d2aSEugene Zhulenev 69223aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, apiFuncName, TypeRange(), 693b54c724bSRiver Riddle adaptor.getOperands()); 694d8c84d2aSEugene Zhulenev 69539957aa4SEugene Zhulenev return success(); 69639957aa4SEugene Zhulenev } 69739957aa4SEugene Zhulenev }; 69839957aa4SEugene Zhulenev } // namespace 69939957aa4SEugene Zhulenev 70039957aa4SEugene Zhulenev //===----------------------------------------------------------------------===// 70139957aa4SEugene Zhulenev // Convert async.runtime.is_error to the corresponding runtime API call. 70239957aa4SEugene Zhulenev //===----------------------------------------------------------------------===// 70339957aa4SEugene Zhulenev 70439957aa4SEugene Zhulenev namespace { 70539957aa4SEugene Zhulenev class RuntimeIsErrorOpLowering : public OpConversionPattern<RuntimeIsErrorOp> { 70639957aa4SEugene Zhulenev public: 70739957aa4SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 70839957aa4SEugene Zhulenev 70939957aa4SEugene Zhulenev LogicalResult 710b54c724bSRiver Riddle matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, 71139957aa4SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 712d8c84d2aSEugene Zhulenev StringRef apiFuncName = 713a5aa7836SRiver Riddle TypeSwitch<Type, StringRef>(op.getOperand().getType()) 714d8c84d2aSEugene Zhulenev .Case<TokenType>([](Type) { return kIsTokenError; }) 715d8c84d2aSEugene Zhulenev .Case<GroupType>([](Type) { return kIsGroupError; }) 716d8c84d2aSEugene Zhulenev .Case<ValueType>([](Type) { return kIsValueError; }); 717d8c84d2aSEugene Zhulenev 71823aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>( 71923aa5a74SRiver Riddle op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands()); 72039957aa4SEugene Zhulenev return success(); 7219c53b8e5SEugene Zhulenev } 7229c53b8e5SEugene Zhulenev }; 7239c53b8e5SEugene Zhulenev } // namespace 7249c53b8e5SEugene Zhulenev 7259c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 7269c53b8e5SEugene Zhulenev // Convert async.runtime.await to the corresponding runtime API call. 7279c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 7289c53b8e5SEugene Zhulenev 7299c53b8e5SEugene Zhulenev namespace { 7309c53b8e5SEugene Zhulenev class RuntimeAwaitOpLowering : public OpConversionPattern<RuntimeAwaitOp> { 7319c53b8e5SEugene Zhulenev public: 7329c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 7339c53b8e5SEugene Zhulenev 7349c53b8e5SEugene Zhulenev LogicalResult 735b54c724bSRiver Riddle matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, 7369c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 737d8c84d2aSEugene Zhulenev StringRef apiFuncName = 738a5aa7836SRiver Riddle TypeSwitch<Type, StringRef>(op.getOperand().getType()) 739d8c84d2aSEugene Zhulenev .Case<TokenType>([](Type) { return kAwaitToken; }) 740d8c84d2aSEugene Zhulenev .Case<ValueType>([](Type) { return kAwaitValue; }) 741d8c84d2aSEugene Zhulenev .Case<GroupType>([](Type) { return kAwaitGroup; }); 7429c53b8e5SEugene Zhulenev 74323aa5a74SRiver Riddle rewriter.create<func::CallOp>(op->getLoc(), apiFuncName, TypeRange(), 744b54c724bSRiver Riddle adaptor.getOperands()); 7459c53b8e5SEugene Zhulenev rewriter.eraseOp(op); 7469c53b8e5SEugene Zhulenev 7479c53b8e5SEugene Zhulenev return success(); 7489c53b8e5SEugene Zhulenev } 7499c53b8e5SEugene Zhulenev }; 7509c53b8e5SEugene Zhulenev } // namespace 7519c53b8e5SEugene Zhulenev 7529c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 7539c53b8e5SEugene Zhulenev // Convert async.runtime.await_and_resume to the corresponding runtime API call. 7549c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 7559c53b8e5SEugene Zhulenev 7569c53b8e5SEugene Zhulenev namespace { 7579c53b8e5SEugene Zhulenev class RuntimeAwaitAndResumeOpLowering 7582ca46421SMarkus Böck : public AsyncOpConversionPattern<RuntimeAwaitAndResumeOp> { 7599c53b8e5SEugene Zhulenev public: 7602ca46421SMarkus Böck using AsyncOpConversionPattern::AsyncOpConversionPattern; 7619c53b8e5SEugene Zhulenev 7629c53b8e5SEugene Zhulenev LogicalResult 763b54c724bSRiver Riddle matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, 7649c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 765d8c84d2aSEugene Zhulenev StringRef apiFuncName = 766a5aa7836SRiver Riddle TypeSwitch<Type, StringRef>(op.getOperand().getType()) 767d8c84d2aSEugene Zhulenev .Case<TokenType>([](Type) { return kAwaitTokenAndExecute; }) 768d8c84d2aSEugene Zhulenev .Case<ValueType>([](Type) { return kAwaitValueAndExecute; }) 769d8c84d2aSEugene Zhulenev .Case<GroupType>([](Type) { return kAwaitAllAndExecute; }); 7709c53b8e5SEugene Zhulenev 771a5aa7836SRiver Riddle Value operand = adaptor.getOperand(); 772a5aa7836SRiver Riddle Value handle = adaptor.getHandle(); 7739c53b8e5SEugene Zhulenev 7749c53b8e5SEugene Zhulenev // A pointer to coroutine resume intrinsic wrapper. 775749f3708SChristian Ulmann addResumeFunction(op->getParentOfType<ModuleOp>()); 7769c53b8e5SEugene Zhulenev auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 777749f3708SChristian Ulmann op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), 778749f3708SChristian Ulmann kResume); 7799c53b8e5SEugene Zhulenev 78023aa5a74SRiver Riddle rewriter.create<func::CallOp>( 78123aa5a74SRiver Riddle op->getLoc(), apiFuncName, TypeRange(), 782cfb72fd3SJacques Pienaar ValueRange({operand, handle, resumePtr.getRes()})); 7839c53b8e5SEugene Zhulenev rewriter.eraseOp(op); 7849c53b8e5SEugene Zhulenev 7859c53b8e5SEugene Zhulenev return success(); 7869c53b8e5SEugene Zhulenev } 7879c53b8e5SEugene Zhulenev }; 7889c53b8e5SEugene Zhulenev } // namespace 7899c53b8e5SEugene Zhulenev 7909c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 7919c53b8e5SEugene Zhulenev // Convert async.runtime.resume to the corresponding runtime API call. 7929c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 7939c53b8e5SEugene Zhulenev 7949c53b8e5SEugene Zhulenev namespace { 7952ca46421SMarkus Böck class RuntimeResumeOpLowering 7962ca46421SMarkus Böck : public AsyncOpConversionPattern<RuntimeResumeOp> { 7979c53b8e5SEugene Zhulenev public: 7982ca46421SMarkus Böck using AsyncOpConversionPattern::AsyncOpConversionPattern; 7999c53b8e5SEugene Zhulenev 8009c53b8e5SEugene Zhulenev LogicalResult 801b54c724bSRiver Riddle matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, 8029c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 8039c53b8e5SEugene Zhulenev // A pointer to coroutine resume intrinsic wrapper. 804749f3708SChristian Ulmann addResumeFunction(op->getParentOfType<ModuleOp>()); 8059c53b8e5SEugene Zhulenev auto resumePtr = rewriter.create<LLVM::AddressOfOp>( 806749f3708SChristian Ulmann op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), 807749f3708SChristian Ulmann kResume); 8089c53b8e5SEugene Zhulenev 8099c53b8e5SEugene Zhulenev // Call async runtime API to execute a coroutine in the managed thread. 810a5aa7836SRiver Riddle auto coroHdl = adaptor.getHandle(); 81123aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>( 812cfb72fd3SJacques Pienaar op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); 8139c53b8e5SEugene Zhulenev 8149c53b8e5SEugene Zhulenev return success(); 8159c53b8e5SEugene Zhulenev } 8169c53b8e5SEugene Zhulenev }; 8179c53b8e5SEugene Zhulenev } // namespace 8189c53b8e5SEugene Zhulenev 8199c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 8209c53b8e5SEugene Zhulenev // Convert async.runtime.store to the corresponding runtime API call. 8219c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 8229c53b8e5SEugene Zhulenev 8239c53b8e5SEugene Zhulenev namespace { 8242ca46421SMarkus Böck class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern<RuntimeStoreOp> { 8259c53b8e5SEugene Zhulenev public: 8262ca46421SMarkus Böck using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 8279c53b8e5SEugene Zhulenev 8289c53b8e5SEugene Zhulenev LogicalResult 829b54c724bSRiver Riddle matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, 8309c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 8319c53b8e5SEugene Zhulenev Location loc = op->getLoc(); 8329c53b8e5SEugene Zhulenev 8339c53b8e5SEugene Zhulenev // Get a pointer to the async value storage from the runtime. 834749f3708SChristian Ulmann auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); 835a5aa7836SRiver Riddle auto storage = adaptor.getStorage(); 8362ca46421SMarkus Böck auto storagePtr = rewriter.create<func::CallOp>( 8372ca46421SMarkus Böck loc, kGetValueStorage, TypeRange(ptrType), storage); 8389c53b8e5SEugene Zhulenev 8399c53b8e5SEugene Zhulenev // Cast from i8* to the LLVM pointer type. 840a5aa7836SRiver Riddle auto valueType = op.getValue().getType(); 8419c53b8e5SEugene Zhulenev auto llvmValueType = getTypeConverter()->convertType(valueType); 84225f80e16SEugene Zhulenev if (!llvmValueType) 84325f80e16SEugene Zhulenev return rewriter.notifyMatchFailure( 84425f80e16SEugene Zhulenev op, "failed to convert stored value type to LLVM type"); 84525f80e16SEugene Zhulenev 8462ca46421SMarkus Böck Value castedStoragePtr = storagePtr.getResult(0); 8479c53b8e5SEugene Zhulenev // Store the yielded value into the async value storage. 848a5aa7836SRiver Riddle auto value = adaptor.getValue(); 8492ca46421SMarkus Böck rewriter.create<LLVM::StoreOp>(loc, value, castedStoragePtr); 8509c53b8e5SEugene Zhulenev 8519c53b8e5SEugene Zhulenev // Erase the original runtime store operation. 8529c53b8e5SEugene Zhulenev rewriter.eraseOp(op); 8539c53b8e5SEugene Zhulenev 8549c53b8e5SEugene Zhulenev return success(); 8559c53b8e5SEugene Zhulenev } 8569c53b8e5SEugene Zhulenev }; 8579c53b8e5SEugene Zhulenev } // namespace 8589c53b8e5SEugene Zhulenev 8599c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 8609c53b8e5SEugene Zhulenev // Convert async.runtime.load to the corresponding runtime API call. 8619c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 8629c53b8e5SEugene Zhulenev 8639c53b8e5SEugene Zhulenev namespace { 8642ca46421SMarkus Böck class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern<RuntimeLoadOp> { 8659c53b8e5SEugene Zhulenev public: 8662ca46421SMarkus Böck using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 8679c53b8e5SEugene Zhulenev 8689c53b8e5SEugene Zhulenev LogicalResult 869b54c724bSRiver Riddle matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, 8709c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 8719c53b8e5SEugene Zhulenev Location loc = op->getLoc(); 8729c53b8e5SEugene Zhulenev 8739c53b8e5SEugene Zhulenev // Get a pointer to the async value storage from the runtime. 874749f3708SChristian Ulmann auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); 875a5aa7836SRiver Riddle auto storage = adaptor.getStorage(); 8762ca46421SMarkus Böck auto storagePtr = rewriter.create<func::CallOp>( 8772ca46421SMarkus Böck loc, kGetValueStorage, TypeRange(ptrType), storage); 8789c53b8e5SEugene Zhulenev 8799c53b8e5SEugene Zhulenev // Cast from i8* to the LLVM pointer type. 880a5aa7836SRiver Riddle auto valueType = op.getResult().getType(); 8819c53b8e5SEugene Zhulenev auto llvmValueType = getTypeConverter()->convertType(valueType); 88225f80e16SEugene Zhulenev if (!llvmValueType) 88325f80e16SEugene Zhulenev return rewriter.notifyMatchFailure( 88425f80e16SEugene Zhulenev op, "failed to convert loaded value type to LLVM type"); 88525f80e16SEugene Zhulenev 8862ca46421SMarkus Böck Value castedStoragePtr = storagePtr.getResult(0); 8879c53b8e5SEugene Zhulenev 8889c53b8e5SEugene Zhulenev // Load from the casted pointer. 8892ca46421SMarkus Böck rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, llvmValueType, 8902ca46421SMarkus Böck castedStoragePtr); 8919c53b8e5SEugene Zhulenev 8929c53b8e5SEugene Zhulenev return success(); 8939c53b8e5SEugene Zhulenev } 8949c53b8e5SEugene Zhulenev }; 8959c53b8e5SEugene Zhulenev } // namespace 8969c53b8e5SEugene Zhulenev 8979c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 8989c53b8e5SEugene Zhulenev // Convert async.runtime.add_to_group to the corresponding runtime API call. 8999c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 9009c53b8e5SEugene Zhulenev 9019c53b8e5SEugene Zhulenev namespace { 9029c53b8e5SEugene Zhulenev class RuntimeAddToGroupOpLowering 9039c53b8e5SEugene Zhulenev : public OpConversionPattern<RuntimeAddToGroupOp> { 9049c53b8e5SEugene Zhulenev public: 9059c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 9069c53b8e5SEugene Zhulenev 9079c53b8e5SEugene Zhulenev LogicalResult 908b54c724bSRiver Riddle matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, 9099c53b8e5SEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 9109c53b8e5SEugene Zhulenev // Currently we can only add tokens to the group. 9115550c821STres Popp if (!isa<TokenType>(op.getOperand().getType())) 9129c53b8e5SEugene Zhulenev return rewriter.notifyMatchFailure(op, "only token type is supported"); 9139c53b8e5SEugene Zhulenev 9149c53b8e5SEugene Zhulenev // Replace with a runtime API function call. 91523aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>( 916b54c724bSRiver Riddle op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); 9179c53b8e5SEugene Zhulenev 9189c53b8e5SEugene Zhulenev return success(); 9199c53b8e5SEugene Zhulenev } 9209c53b8e5SEugene Zhulenev }; 9219c53b8e5SEugene Zhulenev } // namespace 9229c53b8e5SEugene Zhulenev 9239c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 924149311b4Sbakhtiyar // Convert async.runtime.num_worker_threads to the corresponding runtime API 925149311b4Sbakhtiyar // call. 926149311b4Sbakhtiyar //===----------------------------------------------------------------------===// 927149311b4Sbakhtiyar 928149311b4Sbakhtiyar namespace { 929149311b4Sbakhtiyar class RuntimeNumWorkerThreadsOpLowering 930149311b4Sbakhtiyar : public OpConversionPattern<RuntimeNumWorkerThreadsOp> { 931149311b4Sbakhtiyar public: 932149311b4Sbakhtiyar using OpConversionPattern::OpConversionPattern; 933149311b4Sbakhtiyar 934149311b4Sbakhtiyar LogicalResult 935149311b4Sbakhtiyar matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, 936149311b4Sbakhtiyar ConversionPatternRewriter &rewriter) const override { 937149311b4Sbakhtiyar 938149311b4Sbakhtiyar // Replace with a runtime API function call. 93923aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, kGetNumWorkerThreads, 940149311b4Sbakhtiyar rewriter.getIndexType()); 941149311b4Sbakhtiyar 942149311b4Sbakhtiyar return success(); 943149311b4Sbakhtiyar } 944149311b4Sbakhtiyar }; 945149311b4Sbakhtiyar } // namespace 946149311b4Sbakhtiyar 947149311b4Sbakhtiyar //===----------------------------------------------------------------------===// 9489c53b8e5SEugene Zhulenev // Async reference counting ops lowering (`async.runtime.add_ref` and 9499c53b8e5SEugene Zhulenev // `async.runtime.drop_ref` to the corresponding API calls). 9509c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 9519c53b8e5SEugene Zhulenev 9529c53b8e5SEugene Zhulenev namespace { 953a86a9b5eSEugene Zhulenev template <typename RefCountingOp> 9549c53b8e5SEugene Zhulenev class RefCountingOpLowering : public OpConversionPattern<RefCountingOp> { 955a86a9b5eSEugene Zhulenev public: 956ce254598SMatthias Springer explicit RefCountingOpLowering(const TypeConverter &converter, 957ce254598SMatthias Springer MLIRContext *ctx, StringRef apiFunctionName) 9589c53b8e5SEugene Zhulenev : OpConversionPattern<RefCountingOp>(converter, ctx), 959a86a9b5eSEugene Zhulenev apiFunctionName(apiFunctionName) {} 960a86a9b5eSEugene Zhulenev 961a86a9b5eSEugene Zhulenev LogicalResult 962b54c724bSRiver Riddle matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, 963a86a9b5eSEugene Zhulenev ConversionPatternRewriter &rewriter) const override { 964a54f4eaeSMogball auto count = rewriter.create<arith::ConstantOp>( 965a54f4eaeSMogball op->getLoc(), rewriter.getI64Type(), 966a5aa7836SRiver Riddle rewriter.getI64IntegerAttr(op.getCount())); 967a86a9b5eSEugene Zhulenev 968a5aa7836SRiver Riddle auto operand = adaptor.getOperand(); 96923aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), apiFunctionName, 9709c53b8e5SEugene Zhulenev ValueRange({operand, count})); 971a86a9b5eSEugene Zhulenev 972a86a9b5eSEugene Zhulenev return success(); 973a86a9b5eSEugene Zhulenev } 974a86a9b5eSEugene Zhulenev 975a86a9b5eSEugene Zhulenev private: 976a86a9b5eSEugene Zhulenev StringRef apiFunctionName; 977a86a9b5eSEugene Zhulenev }; 978a86a9b5eSEugene Zhulenev 9799c53b8e5SEugene Zhulenev class RuntimeAddRefOpLowering : public RefCountingOpLowering<RuntimeAddRefOp> { 980a86a9b5eSEugene Zhulenev public: 981ce254598SMatthias Springer explicit RuntimeAddRefOpLowering(const TypeConverter &converter, 982ce254598SMatthias Springer MLIRContext *ctx) 983621ad468SEugene Zhulenev : RefCountingOpLowering(converter, ctx, kAddRef) {} 984a86a9b5eSEugene Zhulenev }; 985a86a9b5eSEugene Zhulenev 9869c53b8e5SEugene Zhulenev class RuntimeDropRefOpLowering 9879c53b8e5SEugene Zhulenev : public RefCountingOpLowering<RuntimeDropRefOp> { 988a86a9b5eSEugene Zhulenev public: 989ce254598SMatthias Springer explicit RuntimeDropRefOpLowering(const TypeConverter &converter, 990ce254598SMatthias Springer MLIRContext *ctx) 991621ad468SEugene Zhulenev : RefCountingOpLowering(converter, ctx, kDropRef) {} 992a86a9b5eSEugene Zhulenev }; 993a86a9b5eSEugene Zhulenev } // namespace 994a86a9b5eSEugene Zhulenev 995a86a9b5eSEugene Zhulenev //===----------------------------------------------------------------------===// 9969c53b8e5SEugene Zhulenev // Convert return operations that return async values from async regions. 99736ce915aSLei Zhang //===----------------------------------------------------------------------===// 99836ce915aSLei Zhang 99936ce915aSLei Zhang namespace { 100023aa5a74SRiver Riddle class ReturnOpOpConversion : public OpConversionPattern<func::ReturnOp> { 100136ce915aSLei Zhang public: 10029c53b8e5SEugene Zhulenev using OpConversionPattern::OpConversionPattern; 100336ce915aSLei Zhang 100436ce915aSLei Zhang LogicalResult 100523aa5a74SRiver Riddle matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, 100636ce915aSLei Zhang ConversionPatternRewriter &rewriter) const override { 100723aa5a74SRiver Riddle rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getOperands()); 1008c30ab6c2SEugene Zhulenev return success(); 1009c30ab6c2SEugene Zhulenev } 1010c30ab6c2SEugene Zhulenev }; 1011c30ab6c2SEugene Zhulenev } // namespace 1012c30ab6c2SEugene Zhulenev 1013c30ab6c2SEugene Zhulenev //===----------------------------------------------------------------------===// 101436ce915aSLei Zhang 101536ce915aSLei Zhang namespace { 101636ce915aSLei Zhang struct ConvertAsyncToLLVMPass 1017cd4ca2d7SMarkus Böck : public impl::ConvertAsyncToLLVMPassBase<ConvertAsyncToLLVMPass> { 1018cd4ca2d7SMarkus Böck using Base::Base; 1019cd4ca2d7SMarkus Böck 102036ce915aSLei Zhang void runOnOperation() override; 102136ce915aSLei Zhang }; 10229c53b8e5SEugene Zhulenev } // namespace 102336ce915aSLei Zhang 102436ce915aSLei Zhang void ConvertAsyncToLLVMPass::runOnOperation() { 102536ce915aSLei Zhang ModuleOp module = getOperation(); 102625f80e16SEugene Zhulenev MLIRContext *ctx = module->getContext(); 102736ce915aSLei Zhang 10282ca46421SMarkus Böck LowerToLLVMOptions options(ctx); 10292ca46421SMarkus Böck 10305b388169SChristian Sigg // Add declarations for most functions required by the coroutines lowering. 10315b388169SChristian Sigg // We delay adding the resume function until it's needed because it currently 10325b388169SChristian Sigg // fails to compile unless '-O0' is specified. 1033749f3708SChristian Ulmann addAsyncRuntimeApiDeclarations(module); 103436ce915aSLei Zhang 10359c53b8e5SEugene Zhulenev // Lower async.runtime and async.coro operations to Async Runtime API and 10369c53b8e5SEugene Zhulenev // LLVM coroutine intrinsics. 10379c53b8e5SEugene Zhulenev 103836ce915aSLei Zhang // Convert async dialect types and operations to LLVM dialect. 10392ca46421SMarkus Böck AsyncRuntimeTypeConverter converter(options); 1040dc4e913bSChris Lattner RewritePatternSet patterns(ctx); 104136ce915aSLei Zhang 104225f80e16SEugene Zhulenev // We use conversion to LLVM type to lower async.runtime load and store 104325f80e16SEugene Zhulenev // operations. 10442ca46421SMarkus Böck LLVMTypeConverter llvmConverter(ctx, options); 10452ca46421SMarkus Böck llvmConverter.addConversion([&](Type type) { 1046749f3708SChristian Ulmann return AsyncRuntimeTypeConverter::convertAsyncTypes(type); 10472ca46421SMarkus Böck }); 104825f80e16SEugene Zhulenev 1049621ad468SEugene Zhulenev // Convert async types in function signatures and function calls. 105058ceae95SRiver Riddle populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 105158ceae95SRiver Riddle converter); 10523a506b31SChris Lattner populateCallOpTypeConversionPattern(patterns, converter); 1053621ad468SEugene Zhulenev 1054621ad468SEugene Zhulenev // Convert return operations inside async.execute regions. 1055dc4e913bSChris Lattner patterns.add<ReturnOpOpConversion>(converter, ctx); 1056621ad468SEugene Zhulenev 10579c53b8e5SEugene Zhulenev // Lower async.runtime operations to the async runtime API calls. 105839957aa4SEugene Zhulenev patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering, 105939957aa4SEugene Zhulenev RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering, 10609c53b8e5SEugene Zhulenev RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering, 1061149311b4Sbakhtiyar RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering, 1062149311b4Sbakhtiyar RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter, 1063149311b4Sbakhtiyar ctx); 1064621ad468SEugene Zhulenev 10659c53b8e5SEugene Zhulenev // Lower async.runtime operations that rely on LLVM type converter to convert 10669c53b8e5SEugene Zhulenev // from async value payload type to the LLVM type. 1067d43b2360SEugene Zhulenev patterns.add<RuntimeCreateOpLowering, RuntimeCreateGroupOpLowering, 10682ca46421SMarkus Böck RuntimeStoreOpLowering, RuntimeLoadOpLowering>(llvmConverter); 10699c53b8e5SEugene Zhulenev 10709c53b8e5SEugene Zhulenev // Lower async coroutine operations to LLVM coroutine intrinsics. 1071dc4e913bSChris Lattner patterns 1072dc4e913bSChris Lattner .add<CoroIdOpConversion, CoroBeginOpConversion, CoroFreeOpConversion, 1073dc4e913bSChris Lattner CoroEndOpConversion, CoroSaveOpConversion, CoroSuspendOpConversion>( 1074dc4e913bSChris Lattner converter, ctx); 107536ce915aSLei Zhang 107636ce915aSLei Zhang ConversionTarget target(*ctx); 107723aa5a74SRiver Riddle target.addLegalOp<arith::ConstantOp, func::ConstantOp, 107823aa5a74SRiver Riddle UnrealizedConversionCastOp>(); 107936ce915aSLei Zhang target.addLegalDialect<LLVM::LLVMDialect>(); 1080621ad468SEugene Zhulenev 10819c53b8e5SEugene Zhulenev // All operations from Async dialect must be lowered to the runtime API and 10829c53b8e5SEugene Zhulenev // LLVM intrinsics calls. 108336ce915aSLei Zhang target.addIllegalDialect<AsyncDialect>(); 1084621ad468SEugene Zhulenev 1085621ad468SEugene Zhulenev // Add dynamic legality constraints to apply conversions defined above. 108658ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 10874a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType()); 10884a3460a7SRiver Riddle }); 108923aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { 109023aa5a74SRiver Riddle return converter.isLegal(op.getOperandTypes()); 109123aa5a74SRiver Riddle }); 109223aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) { 1093621ad468SEugene Zhulenev return converter.isSignatureLegal(op.getCalleeType()); 1094621ad468SEugene Zhulenev }); 109536ce915aSLei Zhang 10963fffffa8SRiver Riddle if (failed(applyPartialConversion(module, target, std::move(patterns)))) 109736ce915aSLei Zhang signalPassFailure(); 109836ce915aSLei Zhang } 10999c53b8e5SEugene Zhulenev 11009c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 11019c53b8e5SEugene Zhulenev // Patterns for structural type conversions for the Async dialect operations. 11029c53b8e5SEugene Zhulenev //===----------------------------------------------------------------------===// 110336ce915aSLei Zhang 1104195728c7SChristian Sigg namespace { 1105195728c7SChristian Sigg class ConvertExecuteOpTypes : public OpConversionPattern<ExecuteOp> { 1106195728c7SChristian Sigg public: 1107195728c7SChristian Sigg using OpConversionPattern::OpConversionPattern; 1108195728c7SChristian Sigg LogicalResult 1109b54c724bSRiver Riddle matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, 1110195728c7SChristian Sigg ConversionPatternRewriter &rewriter) const override { 1111195728c7SChristian Sigg ExecuteOp newOp = 1112195728c7SChristian Sigg cast<ExecuteOp>(rewriter.cloneWithoutRegions(*op.getOperation())); 1113195728c7SChristian Sigg rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), 1114195728c7SChristian Sigg newOp.getRegion().end()); 1115195728c7SChristian Sigg 1116195728c7SChristian Sigg // Set operands and update block argument and result types. 1117b54c724bSRiver Riddle newOp->setOperands(adaptor.getOperands()); 1118195728c7SChristian Sigg if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) 1119195728c7SChristian Sigg return failure(); 1120195728c7SChristian Sigg for (auto result : newOp.getResults()) 1121195728c7SChristian Sigg result.setType(typeConverter->convertType(result.getType())); 1122195728c7SChristian Sigg 1123195728c7SChristian Sigg rewriter.replaceOp(op, newOp.getResults()); 1124195728c7SChristian Sigg return success(); 1125195728c7SChristian Sigg } 1126195728c7SChristian Sigg }; 1127195728c7SChristian Sigg 1128195728c7SChristian Sigg // Dummy pattern to trigger the appropriate type conversion / materialization. 1129195728c7SChristian Sigg class ConvertAwaitOpTypes : public OpConversionPattern<AwaitOp> { 1130195728c7SChristian Sigg public: 1131195728c7SChristian Sigg using OpConversionPattern::OpConversionPattern; 1132195728c7SChristian Sigg LogicalResult 1133b54c724bSRiver Riddle matchAndRewrite(AwaitOp op, OpAdaptor adaptor, 1134195728c7SChristian Sigg ConversionPatternRewriter &rewriter) const override { 1135b54c724bSRiver Riddle rewriter.replaceOpWithNewOp<AwaitOp>(op, adaptor.getOperands().front()); 1136195728c7SChristian Sigg return success(); 1137195728c7SChristian Sigg } 1138195728c7SChristian Sigg }; 1139195728c7SChristian Sigg 1140195728c7SChristian Sigg // Dummy pattern to trigger the appropriate type conversion / materialization. 1141195728c7SChristian Sigg class ConvertYieldOpTypes : public OpConversionPattern<async::YieldOp> { 1142195728c7SChristian Sigg public: 1143195728c7SChristian Sigg using OpConversionPattern::OpConversionPattern; 1144195728c7SChristian Sigg LogicalResult 1145b54c724bSRiver Riddle matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, 1146195728c7SChristian Sigg ConversionPatternRewriter &rewriter) const override { 1147b54c724bSRiver Riddle rewriter.replaceOpWithNewOp<async::YieldOp>(op, adaptor.getOperands()); 1148195728c7SChristian Sigg return success(); 1149195728c7SChristian Sigg } 1150195728c7SChristian Sigg }; 1151195728c7SChristian Sigg } // namespace 1152195728c7SChristian Sigg 1153195728c7SChristian Sigg void mlir::populateAsyncStructuralTypeConversionsAndLegality( 1154dc4e913bSChris Lattner TypeConverter &typeConverter, RewritePatternSet &patterns, 11553a506b31SChris Lattner ConversionTarget &target) { 1156195728c7SChristian Sigg typeConverter.addConversion([&](TokenType type) { return type; }); 1157195728c7SChristian Sigg typeConverter.addConversion([&](ValueType type) { 1158fd3f2518SChristian Sigg Type converted = typeConverter.convertType(type.getValueType()); 1159fd3f2518SChristian Sigg return converted ? ValueType::get(converted) : converted; 1160195728c7SChristian Sigg }); 1161195728c7SChristian Sigg 1162dc4e913bSChris Lattner patterns.add<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>( 11633a506b31SChris Lattner typeConverter, patterns.getContext()); 1164195728c7SChristian Sigg 1165195728c7SChristian Sigg target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>( 1166195728c7SChristian Sigg [&](Operation *op) { return typeConverter.isLegal(op); }); 1167195728c7SChristian Sigg } 1168