//===- AsyncToLLVM.cpp - Convert Async to LLVM dialect --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_CONVERTASYNCTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "convert-async-to-llvm" using namespace mlir; using namespace mlir::async; //===----------------------------------------------------------------------===// // Async Runtime C API declaration. //===----------------------------------------------------------------------===// static constexpr const char *kAddRef = "mlirAsyncRuntimeAddRef"; static constexpr const char *kDropRef = "mlirAsyncRuntimeDropRef"; static constexpr const char *kCreateToken = "mlirAsyncRuntimeCreateToken"; static constexpr const char *kCreateValue = "mlirAsyncRuntimeCreateValue"; static constexpr const char *kCreateGroup = "mlirAsyncRuntimeCreateGroup"; static constexpr const char *kEmplaceToken = "mlirAsyncRuntimeEmplaceToken"; static constexpr const char *kEmplaceValue = "mlirAsyncRuntimeEmplaceValue"; static constexpr const char *kSetTokenError = "mlirAsyncRuntimeSetTokenError"; static constexpr const char *kSetValueError = "mlirAsyncRuntimeSetValueError"; static constexpr const char *kIsTokenError = "mlirAsyncRuntimeIsTokenError"; static constexpr const char *kIsValueError = "mlirAsyncRuntimeIsValueError"; static constexpr const char *kIsGroupError = "mlirAsyncRuntimeIsGroupError"; static constexpr const char *kAwaitToken = "mlirAsyncRuntimeAwaitToken"; static constexpr const char *kAwaitValue = "mlirAsyncRuntimeAwaitValue"; static constexpr const char *kAwaitGroup = "mlirAsyncRuntimeAwaitAllInGroup"; static constexpr const char *kExecute = "mlirAsyncRuntimeExecute"; static constexpr const char *kGetValueStorage = "mlirAsyncRuntimeGetValueStorage"; static constexpr const char *kAddTokenToGroup = "mlirAsyncRuntimeAddTokenToGroup"; static constexpr const char *kAwaitTokenAndExecute = "mlirAsyncRuntimeAwaitTokenAndExecute"; static constexpr const char *kAwaitValueAndExecute = "mlirAsyncRuntimeAwaitValueAndExecute"; static constexpr const char *kAwaitAllAndExecute = "mlirAsyncRuntimeAwaitAllInGroupAndExecute"; static constexpr const char *kGetNumWorkerThreads = "mlirAsyncRuntimGetNumWorkerThreads"; namespace { /// Async Runtime API function types. /// /// Because we can't create API function signature for type parametrized /// async.getValue type, we use opaque pointers (!llvm.ptr) instead. After /// lowering all async data types become opaque pointers at runtime. struct AsyncAPI { // All async types are lowered to opaque LLVM pointers at runtime. static LLVM::LLVMPointerType opaquePointerType(MLIRContext *ctx) { return LLVM::LLVMPointerType::get(ctx); } static LLVM::LLVMTokenType tokenType(MLIRContext *ctx) { return LLVM::LLVMTokenType::get(ctx); } static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { auto ref = opaquePointerType(ctx); auto count = IntegerType::get(ctx, 64); return FunctionType::get(ctx, {ref, count}, {}); } static FunctionType createTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); } static FunctionType createValueFunctionType(MLIRContext *ctx) { auto i64 = IntegerType::get(ctx, 64); auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {i64}, {value}); } static FunctionType createGroupFunctionType(MLIRContext *ctx) { auto i64 = IntegerType::get(ctx, 64); return FunctionType::get(ctx, {i64}, {GroupType::get(ctx)}); } static FunctionType getValueStorageFunctionType(MLIRContext *ctx) { auto ptrType = opaquePointerType(ctx); return FunctionType::get(ctx, {ptrType}, {ptrType}); } static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType emplaceValueFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {value}, {}); } static FunctionType setTokenErrorFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType setValueErrorFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {value}, {}); } static FunctionType isTokenErrorFunctionType(MLIRContext *ctx) { auto i1 = IntegerType::get(ctx, 1); return FunctionType::get(ctx, {TokenType::get(ctx)}, {i1}); } static FunctionType isValueErrorFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); auto i1 = IntegerType::get(ctx, 1); return FunctionType::get(ctx, {value}, {i1}); } static FunctionType isGroupErrorFunctionType(MLIRContext *ctx) { auto i1 = IntegerType::get(ctx, 1); return FunctionType::get(ctx, {GroupType::get(ctx)}, {i1}); } static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType awaitValueFunctionType(MLIRContext *ctx) { auto value = opaquePointerType(ctx); return FunctionType::get(ctx, {value}, {}); } static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); } static FunctionType executeFunctionType(MLIRContext *ctx) { auto ptrType = opaquePointerType(ctx); return FunctionType::get(ctx, {ptrType, ptrType}, {}); } static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { auto i64 = IntegerType::get(ctx, 64); return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, {i64}); } static FunctionType awaitTokenAndExecuteFunctionType(MLIRContext *ctx) { auto ptrType = opaquePointerType(ctx); return FunctionType::get(ctx, {TokenType::get(ctx), ptrType, ptrType}, {}); } static FunctionType awaitValueAndExecuteFunctionType(MLIRContext *ctx) { auto ptrType = opaquePointerType(ctx); return FunctionType::get(ctx, {ptrType, ptrType, ptrType}, {}); } static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { auto ptrType = opaquePointerType(ctx); return FunctionType::get(ctx, {GroupType::get(ctx), ptrType, ptrType}, {}); } static FunctionType getNumWorkerThreads(MLIRContext *ctx) { return FunctionType::get(ctx, {}, {IndexType::get(ctx)}); } // Auxiliary coroutine resume intrinsic wrapper. static Type resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto ptrType = opaquePointerType(ctx); return LLVM::LLVMFunctionType::get(voidTy, {ptrType}, false); } }; } // namespace /// Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto builder = ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; builder.create(name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); addFuncDecl(kAddRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kDropRef, AsyncAPI::addOrDropRefFunctionType(ctx)); addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx)); addFuncDecl(kCreateValue, AsyncAPI::createValueFunctionType(ctx)); addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx)); addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx)); addFuncDecl(kEmplaceValue, AsyncAPI::emplaceValueFunctionType(ctx)); addFuncDecl(kSetTokenError, AsyncAPI::setTokenErrorFunctionType(ctx)); addFuncDecl(kSetValueError, AsyncAPI::setValueErrorFunctionType(ctx)); addFuncDecl(kIsTokenError, AsyncAPI::isTokenErrorFunctionType(ctx)); addFuncDecl(kIsValueError, AsyncAPI::isValueErrorFunctionType(ctx)); addFuncDecl(kIsGroupError, AsyncAPI::isGroupErrorFunctionType(ctx)); addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx)); addFuncDecl(kAwaitValue, AsyncAPI::awaitValueFunctionType(ctx)); addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx)); addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx)); addFuncDecl(kGetValueStorage, AsyncAPI::getValueStorageFunctionType(ctx)); addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx)); addFuncDecl(kAwaitTokenAndExecute, AsyncAPI::awaitTokenAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitValueAndExecute, AsyncAPI::awaitValueAndExecuteFunctionType(ctx)); addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx)); addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx)); } //===----------------------------------------------------------------------===// // Coroutine resume function wrapper. //===----------------------------------------------------------------------===// static constexpr const char *kResume = "__resume"; /// A function that takes a coroutine handle and calls a `llvm.coro.resume` /// intrinsics. We need this function to be able to pass it to the async /// runtime execute API. static void addResumeFunction(ModuleOp module) { if (module.lookupSymbol(kResume)) return; MLIRContext *ctx = module.getContext(); auto loc = module.getLoc(); auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); auto voidTy = LLVM::LLVMVoidType::get(ctx); Type ptrType = AsyncAPI::opaquePointerType(ctx); auto resumeOp = moduleBuilder.create( kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); blockBuilder.create(resumeOp.getArgument(0)); blockBuilder.create(ValueRange()); } //===----------------------------------------------------------------------===// // Convert Async dialect types to LLVM types. //===----------------------------------------------------------------------===// namespace { /// AsyncRuntimeTypeConverter only converts types from the Async dialect to /// their runtime type (opaque pointers) and does not convert any other types. class AsyncRuntimeTypeConverter : public TypeConverter { public: AsyncRuntimeTypeConverter(const LowerToLLVMOptions &options) { addConversion([](Type type) { return type; }); addConversion([](Type type) { return convertAsyncTypes(type); }); // Use UnrealizedConversionCast as the bridge so that we don't need to pull // in patterns for other dialects. auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { auto cast = builder.create(loc, type, inputs); return cast.getResult(0); }; addSourceMaterialization(addUnrealizedCast); addTargetMaterialization(addUnrealizedCast); } static std::optional convertAsyncTypes(Type type) { if (isa(type)) return AsyncAPI::opaquePointerType(type.getContext()); if (isa(type)) return AsyncAPI::tokenType(type.getContext()); if (isa(type)) return AsyncAPI::opaquePointerType(type.getContext()); return std::nullopt; } }; /// Base class for conversion patterns requiring AsyncRuntimeTypeConverter /// as type converter. Allows access to it via the 'getTypeConverter' /// convenience method. template class AsyncOpConversionPattern : public OpConversionPattern { using Base = OpConversionPattern; public: AsyncOpConversionPattern(const AsyncRuntimeTypeConverter &typeConverter, MLIRContext *context) : Base(typeConverter, context) {} /// Returns the 'AsyncRuntimeTypeConverter' of the pattern. const AsyncRuntimeTypeConverter *getTypeConverter() const { return static_cast( Base::getTypeConverter()); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.id to @llvm.coro.id intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroIdOpConversion : public AsyncOpConversionPattern { public: using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto token = AsyncAPI::tokenType(op->getContext()); auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Constants for initializing coroutine frame. auto constZero = rewriter.create(loc, rewriter.getI32Type(), 0); auto nullPtr = rewriter.create(loc, ptrType); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp( op, token, ValueRange({constZero, nullPtr, nullPtr, nullPtr})); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.begin to @llvm.coro.begin intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroBeginOpConversion : public AsyncOpConversionPattern { public: using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Get coroutine frame size: @llvm.coro.size.i64. Value coroSize = rewriter.create(loc, rewriter.getI64Type()); // Get coroutine frame alignment: @llvm.coro.align.i64. Value coroAlign = rewriter.create(loc, rewriter.getI64Type()); // Round up the size to be multiple of the alignment. Since aligned_alloc // requires the size parameter be an integral multiple of the alignment // parameter. auto makeConstant = [&](uint64_t c) { return rewriter.create(op->getLoc(), rewriter.getI64Type(), c); }; coroSize = rewriter.create(op->getLoc(), coroSize, coroAlign); coroSize = rewriter.create(op->getLoc(), coroSize, makeConstant(1)); Value negCoroAlign = rewriter.create(op->getLoc(), makeConstant(0), coroAlign); coroSize = rewriter.create(op->getLoc(), coroSize, negCoroAlign); // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( op->getParentOfType(), rewriter.getI64Type()); if (failed(allocFuncOp)) return failure(); auto coroAlloc = rewriter.create( loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); rewriter.replaceOpWithNewOp( op, ptrType, ValueRange({coroId, coroAlloc.getResult()})); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.free to @llvm.coro.free intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroFreeOpConversion : public AsyncOpConversionPattern { public: using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto ptrType = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = rewriter.create(loc, ptrType, adaptor.getOperands()); // Free the memory. auto freeFuncOp = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); if (failed(freeFuncOp)) return failure(); rewriter.replaceOpWithNewOp(op, freeFuncOp.value(), ValueRange(coroMem.getResult())); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.end to @llvm.coro.end intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroEndOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. auto constFalse = rewriter.create( op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); auto noneToken = rewriter.create(op->getLoc()); // Mark the end of a coroutine: @llvm.coro.end. auto coroHdl = adaptor.getHandle(); rewriter.create( op->getLoc(), rewriter.getI1Type(), ValueRange({coroHdl, constFalse, noneToken})); rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.save to @llvm.coro.save intrinsic. //===----------------------------------------------------------------------===// namespace { class CoroSaveOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Save the coroutine state: @llvm.coro.save rewriter.replaceOpWithNewOp( op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.coro.suspend to @llvm.coro.suspend intrinsic. //===----------------------------------------------------------------------===// namespace { /// Convert async.coro.suspend to the @llvm.coro.suspend intrinsic call, and /// branch to the appropriate block based on the return code. /// /// Before: /// /// ^suspended: /// "opBefore"(...) /// async.coro.suspend %state, ^suspend, ^resume, ^cleanup /// ^resume: /// "op"(...) /// ^cleanup: ... /// ^suspend: ... /// /// After: /// /// ^suspended: /// "opBefore"(...) /// %suspend = llmv.intr.coro.suspend ... /// switch %suspend [-1: ^suspend, 0: ^resume, 1: ^cleanup] /// ^resume: /// "op"(...) /// ^cleanup: ... /// ^suspend: ... /// class CoroSuspendOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8 = rewriter.getIntegerType(8); auto i32 = rewriter.getI32Type(); auto loc = op->getLoc(); // This is not a final suspension point. auto constFalse = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroState = adaptor.getState(); auto coroSuspend = rewriter.create( loc, i8, ValueRange({coroState, constFalse})); // Cast return code to i32. // After a suspension point decide if we should branch into resume, cleanup // or suspend block of the coroutine (see @llvm.coro.suspend return code // documentation). llvm::SmallVector caseValues = {0, 1}; llvm::SmallVector caseDest = {op.getResumeDest(), op.getCleanupDest()}; rewriter.replaceOpWithNewOp( op, rewriter.create(loc, i32, coroSuspend.getResult()), /*defaultDestination=*/op.getSuspendDest(), /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, /*caseDestinations=*/caseDest, /*caseOperands=*/ArrayRef({ValueRange(), ValueRange()}), /*branchWeights=*/ArrayRef()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.create to the corresponding runtime API call. // // To allocate storage for the async values we use getelementptr trick: // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt //===----------------------------------------------------------------------===// namespace { class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; // Tokens creation maps to a simple function call. if (isa(resultType)) { rewriter.replaceOpWithNewOp( op, kCreateToken, converter->convertType(resultType)); return success(); } // To create a value we need to compute the storage requirement. if (auto value = dyn_cast(resultType)) { // Returns the size requirements for the async value storage. auto sizeOf = [&](ValueType valueType) -> Value { auto loc = op->getLoc(); auto i64 = rewriter.getI64Type(); auto storedType = converter->convertType(valueType.getValueType()); auto storagePtrType = AsyncAPI::opaquePointerType(rewriter.getContext()); // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i64 auto nullPtr = rewriter.create(loc, storagePtrType); auto gep = rewriter.create(loc, storagePtrType, storedType, nullPtr, ArrayRef{1}); return rewriter.create(loc, i64, gep); }; rewriter.replaceOpWithNewOp(op, kCreateValue, resultType, sizeOf(value)); return success(); } return rewriter.notifyMatchFailure(op, "unsupported async type"); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.create_group to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeCreateGroupOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { const TypeConverter *converter = getTypeConverter(); Type resultType = op.getResult().getType(); rewriter.replaceOpWithNewOp( op, kCreateGroup, converter->convertType(resultType), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.set_available to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeSetAvailableOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.getOperand().getType()) .Case([](Type) { return kEmplaceToken; }) .Case([](Type) { return kEmplaceValue; }); rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.set_error to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeSetErrorOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.getOperand().getType()) .Case([](Type) { return kSetTokenError; }) .Case([](Type) { return kSetValueError; }); rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.is_error to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeIsErrorOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.getOperand().getType()) .Case([](Type) { return kIsTokenError; }) .Case([](Type) { return kIsGroupError; }) .Case([](Type) { return kIsValueError; }); rewriter.replaceOpWithNewOp( op, apiFuncName, rewriter.getI1Type(), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.await to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAwaitOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.getOperand().getType()) .Case([](Type) { return kAwaitToken; }) .Case([](Type) { return kAwaitValue; }) .Case([](Type) { return kAwaitGroup; }); rewriter.create(op->getLoc(), apiFuncName, TypeRange(), adaptor.getOperands()); rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.await_and_resume to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAwaitAndResumeOpLowering : public AsyncOpConversionPattern { public: using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.getOperand().getType()) .Case([](Type) { return kAwaitTokenAndExecute; }) .Case([](Type) { return kAwaitValueAndExecute; }) .Case([](Type) { return kAwaitAllAndExecute; }); Value operand = adaptor.getOperand(); Value handle = adaptor.getHandle(); // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); auto resumePtr = rewriter.create( op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); rewriter.create( op->getLoc(), apiFuncName, TypeRange(), ValueRange({operand, handle, resumePtr.getRes()})); rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.resume to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeResumeOpLowering : public AsyncOpConversionPattern { public: using AsyncOpConversionPattern::AsyncOpConversionPattern; LogicalResult matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); auto resumePtr = rewriter.create( op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = adaptor.getHandle(); rewriter.replaceOpWithNewOp( op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.getRes()})); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.store to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); auto storagePtr = rewriter.create( loc, kGetValueStorage, TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getValue().getType(); auto llvmValueType = getTypeConverter()->convertType(valueType); if (!llvmValueType) return rewriter.notifyMatchFailure( op, "failed to convert stored value type to LLVM type"); Value castedStoragePtr = storagePtr.getResult(0); // Store the yielded value into the async value storage. auto value = adaptor.getValue(); rewriter.create(loc, value, castedStoragePtr); // Erase the original runtime store operation. rewriter.eraseOp(op); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.load to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); auto storagePtr = rewriter.create( loc, kGetValueStorage, TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getResult().getType(); auto llvmValueType = getTypeConverter()->convertType(valueType); if (!llvmValueType) return rewriter.notifyMatchFailure( op, "failed to convert loaded value type to LLVM type"); Value castedStoragePtr = storagePtr.getResult(0); // Load from the casted pointer. rewriter.replaceOpWithNewOp(op, llvmValueType, castedStoragePtr); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.add_to_group to the corresponding runtime API call. //===----------------------------------------------------------------------===// namespace { class RuntimeAddToGroupOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. if (!isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. rewriter.replaceOpWithNewOp( op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.runtime.num_worker_threads to the corresponding runtime API // call. //===----------------------------------------------------------------------===// namespace { class RuntimeNumWorkerThreadsOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Replace with a runtime API function call. rewriter.replaceOpWithNewOp(op, kGetNumWorkerThreads, rewriter.getIndexType()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Async reference counting ops lowering (`async.runtime.add_ref` and // `async.runtime.drop_ref` to the corresponding API calls). //===----------------------------------------------------------------------===// namespace { template class RefCountingOpLowering : public OpConversionPattern { public: explicit RefCountingOpLowering(const TypeConverter &converter, MLIRContext *ctx, StringRef apiFunctionName) : OpConversionPattern(converter, ctx), apiFunctionName(apiFunctionName) {} LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto count = rewriter.create( op->getLoc(), rewriter.getI64Type(), rewriter.getI64IntegerAttr(op.getCount())); auto operand = adaptor.getOperand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, ValueRange({operand, count})); return success(); } private: StringRef apiFunctionName; }; class RuntimeAddRefOpLowering : public RefCountingOpLowering { public: explicit RuntimeAddRefOpLowering(const TypeConverter &converter, MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kAddRef) {} }; class RuntimeDropRefOpLowering : public RefCountingOpLowering { public: explicit RuntimeDropRefOpLowering(const TypeConverter &converter, MLIRContext *ctx) : RefCountingOpLowering(converter, ctx, kDropRef) {} }; } // namespace //===----------------------------------------------------------------------===// // Convert return operations that return async values from async regions. //===----------------------------------------------------------------------===// namespace { class ReturnOpOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// namespace { struct ConvertAsyncToLLVMPass : public impl::ConvertAsyncToLLVMPassBase { using Base::Base; void runOnOperation() override; }; } // namespace void ConvertAsyncToLLVMPass::runOnOperation() { ModuleOp module = getOperation(); MLIRContext *ctx = module->getContext(); LowerToLLVMOptions options(ctx); // Add declarations for most functions required by the coroutines lowering. // We delay adding the resume function until it's needed because it currently // fails to compile unless '-O0' is specified. addAsyncRuntimeApiDeclarations(module); // Lower async.runtime and async.coro operations to Async Runtime API and // LLVM coroutine intrinsics. // Convert async dialect types and operations to LLVM dialect. AsyncRuntimeTypeConverter converter(options); RewritePatternSet patterns(ctx); // We use conversion to LLVM type to lower async.runtime load and store // operations. LLVMTypeConverter llvmConverter(ctx, options); llvmConverter.addConversion([&](Type type) { return AsyncRuntimeTypeConverter::convertAsyncTypes(type); }); // Convert async types in function signatures and function calls. populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); // Convert return operations inside async.execute regions. patterns.add(converter, ctx); // Lower async.runtime operations to the async runtime API calls. patterns.add(converter, ctx); // Lower async.runtime operations that rely on LLVM type converter to convert // from async value payload type to the LLVM type. patterns.add(llvmConverter); // Lower async coroutine operations to LLVM coroutine intrinsics. patterns .add( converter, ctx); ConversionTarget target(*ctx); target.addLegalOp(); target.addLegalDialect(); // All operations from Async dialect must be lowered to the runtime API and // LLVM intrinsics calls. target.addIllegalDialect(); // Add dynamic legality constraints to apply conversions defined above. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); target.addDynamicallyLegalOp([&](func::CallOp op) { return converter.isSignatureLegal(op.getCalleeType()); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } //===----------------------------------------------------------------------===// // Patterns for structural type conversions for the Async dialect operations. //===----------------------------------------------------------------------===// namespace { class ConvertExecuteOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ExecuteOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), newOp.getRegion().end()); // Set operands and update block argument and result types. newOp->setOperands(adaptor.getOperands()); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) return failure(); for (auto result : newOp.getResults()) result.setType(typeConverter->convertType(result.getType())); rewriter.replaceOp(op, newOp.getResults()); return success(); } }; // Dummy pattern to trigger the appropriate type conversion / materialization. class ConvertAwaitOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AwaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands().front()); return success(); } }; // Dummy pattern to trigger the appropriate type conversion / materialization. class ConvertYieldOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; } // namespace void mlir::populateAsyncStructuralTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { typeConverter.addConversion([&](TokenType type) { return type; }); typeConverter.addConversion([&](ValueType type) { Type converted = typeConverter.convertType(type.getValueType()); return converted ? ValueType::get(converted) : converted; }); patterns.add( typeConverter, patterns.getContext()); target.addDynamicallyLegalOp( [&](Operation *op) { return typeConverter.isLegal(op); }); }