//===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering from high level async operations to async.coro // and async.runtime operations. // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/Async/Passes.h" #include "PassDetail.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #include namespace mlir { #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME #include "mlir/Dialect/Async/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::async; #define DEBUG_TYPE "async-to-async-runtime" // Prefix for functions outlined from `async.execute` op regions. static constexpr const char kAsyncFnPrefix[] = "async_execute_fn"; namespace { class AsyncToAsyncRuntimePass : public impl::AsyncToAsyncRuntimeBase { public: AsyncToAsyncRuntimePass() = default; void runOnOperation() override; }; } // namespace namespace { class AsyncFuncToAsyncRuntimePass : public impl::AsyncFuncToAsyncRuntimeBase { public: AsyncFuncToAsyncRuntimePass() = default; void runOnOperation() override; }; } // namespace /// Function targeted for coroutine transformation has two additional blocks at /// the end: coroutine cleanup and coroutine suspension. /// /// async.await op lowering additionaly creates a resume block for each /// operation to enable non-blocking waiting via coroutine suspension. namespace { struct CoroMachinery { func::FuncOp func; // Async function returns an optional token, followed by some async values // // async.func @foo() -> !async.value { // %cst = arith.constant 42.0 : T // return %cst: T // } // Async execute region returns a completion token, and an async value for // each yielded value. // // %token, %result = async.execute -> !async.value { // %0 = arith.constant ... : T // async.yield %0 : T // } std::optional asyncToken; // returned completion token llvm::SmallVector returnValues; // returned async values Value coroHandle; // coroutine handle (!async.coro.getHandle value) Block *entry; // coroutine entry block std::optional setError; // set returned values to error state Block *cleanup; // coroutine cleanup block // Coroutine cleanup block for destroy after the coroutine is resumed, // e.g. async.coro.suspend state, [suspend], [resume], [destroy] // // This cleanup block is a duplicate of the cleanup block followed by the // resume block. The purpose of having a duplicate cleanup block for destroy // is to make the CFG clear so that the control flow analysis won't confuse. // // The overall structure of the lowered CFG can be the following, // // Entry (calling async.coro.suspend) // | \ // Resume Destroy (duplicate of Cleanup) // | | // Cleanup | // | / // End (ends the corontine) // // If there is resume-specific cleanup logic, it can go into the Cleanup // block but not the destroy block. Otherwise, it can fail block dominance // check. Block *cleanupForDestroy; Block *suspend; // coroutine suspension block }; } // namespace using FuncCoroMapPtr = std::shared_ptr>; /// Utility to partially update the regular function CFG to the coroutine CFG /// compatible with LLVM coroutines switched-resume lowering using /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block /// that branches into preexisting entry block. Also inserts trailing blocks. /// /// The result types of the passed `func` start with an optional `async.token` /// and be continued with some number of `async.value`s. /// /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// /// - `entry` block sets up the coroutine. /// - `set_error` block sets completion token and async values state to error. /// - `cleanup` block cleans up the coroutine state. /// - `suspend block after the @llvm.coro.end() defines what value will be /// returned to the initial caller of a coroutine. Everything before the /// @llvm.coro.end() will be executed at every suspension point. /// /// Coroutine structure (only the important bits): /// /// func @some_fn() -> (!async.token, !async.value) /// { /// ^entry(): /// %token = : !async.token // create async runtime token /// %value = : !async.value // create async value /// %id = async.coro.getId // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle /// cf.br ^preexisting_entry_block /// /// /* preexisting blocks modified to branch to the cleanup block */ /// /// ^set_error: // this block created lazily only if needed (see code below) /// async.runtime.set_error %token : !async.token /// async.runtime.set_error %value : !async.value /// cf.br ^cleanup /// /// ^cleanup: /// async.coro.free %hdl // delete the coroutine state /// cf.br ^suspend /// /// ^suspend: /// async.coro.end %hdl // marks the end of a coroutine /// return %token, %value : !async.token, !async.value /// } /// static CoroMachinery setupCoroMachinery(func::FuncOp func) { assert(!func.getBlocks().empty() && "Function must have an entry block"); MLIRContext *ctx = func.getContext(); Block *entryBlock = &func.getBlocks().front(); Block *originalEntryBlock = entryBlock->splitBlock(entryBlock->getOperations().begin()); auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock); // ------------------------------------------------------------------------ // // Allocate async token/values that we will return from a ramp function. // ------------------------------------------------------------------------ // // We treat TokenType as state update marker to represent side-effects of // async computations bool isStateful = isa(func.getResultTypes().front()); std::optional retToken; if (isStateful) retToken.emplace(builder.create(TokenType::get(ctx))); llvm::SmallVector retValues; ArrayRef resValueTypes = isStateful ? func.getResultTypes().drop_front() : func.getResultTypes(); for (auto resType : resValueTypes) retValues.emplace_back( builder.create(resType).getResult()); // ------------------------------------------------------------------------ // // Initialize coroutine: get coroutine id and coroutine handle. // ------------------------------------------------------------------------ // auto coroIdOp = builder.create(CoroIdType::get(ctx)); auto coroHdlOp = builder.create(CoroHandleType::get(ctx), coroIdOp.getId()); builder.create(originalEntryBlock); Block *cleanupBlock = func.addBlock(); Block *cleanupBlockForDestroy = func.addBlock(); Block *suspendBlock = func.addBlock(); // ------------------------------------------------------------------------ // // Coroutine cleanup blocks: deallocate coroutine frame, free the memory. // ------------------------------------------------------------------------ // auto buildCleanupBlock = [&](Block *cb) { builder.setInsertionPointToStart(cb); builder.create(coroIdOp.getId(), coroHdlOp.getHandle()); // Branch into the suspend block. builder.create(suspendBlock); }; buildCleanupBlock(cleanupBlock); buildCleanupBlock(cleanupBlockForDestroy); // ------------------------------------------------------------------------ // // Coroutine suspend block: mark the end of a coroutine and return allocated // async token. // ------------------------------------------------------------------------ // builder.setInsertionPointToStart(suspendBlock); // Mark the end of a coroutine: async.coro.end builder.create(coroHdlOp.getHandle()); // Return created optional `async.token` and `async.values` from the suspend // block. This will be the return value of a coroutine ramp function. SmallVector ret; if (retToken) ret.push_back(*retToken); ret.insert(ret.end(), retValues.begin(), retValues.end()); builder.create(ret); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. // The switch-resumed API based coroutine should be marked with // presplitcoroutine attribute to mark the function as a coroutine. func->setAttr("passthrough", builder.getArrayAttr( StringAttr::get(ctx, "presplitcoroutine"))); CoroMachinery machinery; machinery.func = func; machinery.asyncToken = retToken; machinery.returnValues = retValues; machinery.coroHandle = coroHdlOp.getHandle(); machinery.entry = entryBlock; machinery.setError = std::nullopt; // created lazily only if needed machinery.cleanup = cleanupBlock; machinery.cleanupForDestroy = cleanupBlockForDestroy; machinery.suspend = suspendBlock; return machinery; } // Lazily creates `set_error` block only if it is required for lowering to the // runtime operations (see for example lowering of assert operation). static Block *setupSetErrorBlock(CoroMachinery &coro) { if (coro.setError) return *coro.setError; coro.setError = coro.func.addBlock(); (*coro.setError)->moveBefore(coro.cleanup); auto builder = ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError); // Coroutine set_error block: set error on token and all returned values. if (coro.asyncToken) builder.create(*coro.asyncToken); for (Value retValue : coro.returnValues) builder.create(retValue); // Branch into the cleanup block. builder.create(coro.cleanup); return *coro.setError; } //===----------------------------------------------------------------------===// // async.execute op outlining to the coroutine functions. //===----------------------------------------------------------------------===// /// Outline the body region attached to the `async.execute` op into a standalone /// function. /// /// Note that this is not reversible transformation. static std::pair outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { ModuleOp module = execute->getParentOfType(); MLIRContext *ctx = module.getContext(); Location loc = execute.getLoc(); // Make sure that all constants will be inside the outlined async function to // reduce the number of function arguments. cloneConstantsIntoTheRegion(execute.getBodyRegion()); // Collect all outlined function inputs. SetVector functionInputs(execute.getDependencies().begin(), execute.getDependencies().end()); functionInputs.insert(execute.getBodyOperands().begin(), execute.getBodyOperands().end()); getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs); // Collect types for the outlined function inputs and outputs. auto typesRange = llvm::map_range( functionInputs, [](Value value) { return value.getType(); }); SmallVector inputTypes(typesRange.begin(), typesRange.end()); auto outputTypes = execute.getResultTypes(); auto funcType = FunctionType::get(ctx, inputTypes, outputTypes); auto funcAttrs = ArrayRef(); // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). func::FuncOp func = func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); symbolTable.insert(func); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock()); // Prepare for coroutine conversion by creating the body of the function. { size_t numDependencies = execute.getDependencies().size(); size_t numOperands = execute.getBodyOperands().size(); // Await on all dependencies before starting to execute the body region. for (size_t i = 0; i < numDependencies; ++i) builder.create(func.getArgument(i)); // Await on all async value operands and unwrap the payload. SmallVector unwrappedOperands(numOperands); for (size_t i = 0; i < numOperands; ++i) { Value operand = func.getArgument(numDependencies + i); unwrappedOperands[i] = builder.create(loc, operand).getResult(); } // Map from function inputs defined above the execute op to the function // arguments. IRMapping valueMapping; valueMapping.map(functionInputs, func.getArguments()); valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands); // Clone all operations from the execute operation body into the outlined // function body. for (Operation &op : execute.getBodyRegion().getOps()) builder.clone(op, valueMapping); } // Adding entry/cleanup/suspend blocks. CoroMachinery coro = setupCoroMachinery(func); // Suspend async function at the end of an entry block, and resume it using // Async resume operation (execution will be resumed in a thread managed by // the async runtime). { cf::BranchOp branch = cast(coro.entry->getTerminator()); builder.setInsertionPointToEnd(coro.entry); // Save the coroutine state: async.coro.save auto coroSaveOp = builder.create(CoroStateType::get(ctx), coro.coroHandle); // Pass coroutine to the runtime to be resumed on a runtime managed // thread. builder.create(coro.coroHandle); // Add async.coro.suspend as a suspended block terminator. builder.create(coroSaveOp.getState(), coro.suspend, branch.getDest(), coro.cleanupForDestroy); branch.erase(); } // Replace the original `async.execute` with a call to outlined function. { ImplicitLocOpBuilder callBuilder(loc, execute); auto callOutlinedFunc = callBuilder.create( func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); } return {func, coro}; } //===----------------------------------------------------------------------===// // Convert async.create_group operation to async.runtime.create_group //===----------------------------------------------------------------------===// namespace { class CreateGroupOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, GroupType::get(op->getContext()), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.add_to_group operation to async.runtime.add_to_group. //===----------------------------------------------------------------------===// namespace { class AddToGroupOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, rewriter.getIndexType(), adaptor.getOperands()); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Convert async.func, async.return and async.call operations to non-blocking // operations based on llvm coroutine //===----------------------------------------------------------------------===// namespace { //===----------------------------------------------------------------------===// // Convert async.func operation to func.func //===----------------------------------------------------------------------===// class AsyncFuncOpLowering : public OpConversionPattern { public: AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) : OpConversionPattern(ctx), coros(std::move(coros)) {} LogicalResult matchAndRewrite(async::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto newFuncOp = rewriter.create(loc, op.getName(), op.getFunctionType()); SymbolTable::setSymbolVisibility(newFuncOp, SymbolTable::getSymbolVisibility(op)); // Copy over all attributes other than the name. for (const auto &namedAttr : op->getAttrs()) { if (namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(), newFuncOp.end()); CoroMachinery coro = setupCoroMachinery(newFuncOp); (*coros)[newFuncOp] = coro; // no initial suspend, we should hot-start rewriter.eraseOp(op); return success(); } private: FuncCoroMapPtr coros; }; //===----------------------------------------------------------------------===// // Convert async.call operation to func.call //===----------------------------------------------------------------------===// class AsyncCallOpLowering : public OpConversionPattern { public: AsyncCallOpLowering(MLIRContext *ctx) : OpConversionPattern(ctx) {} LogicalResult matchAndRewrite(async::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op.getCallee(), op.getResultTypes(), op.getOperands()); return success(); } }; //===----------------------------------------------------------------------===// // Convert async.return operation to async.runtime operations. //===----------------------------------------------------------------------===// class AsyncReturnOpLowering : public OpConversionPattern { public: AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) : OpConversionPattern(ctx), coros(std::move(coros)) {} LogicalResult matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto func = op->template getParentOfType(); auto funcCoro = coros->find(func); if (funcCoro == coros->end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); Location loc = op->getLoc(); const CoroMachinery &coro = funcCoro->getSecond(); rewriter.setInsertionPointAfter(op); // Store return values into the async values storage and switch async // values state to available. for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value returnValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); rewriter.create(loc, returnValue, asyncValue); rewriter.create(loc, asyncValue); } if (coro.asyncToken) // Switch the coroutine completion token to available state. rewriter.create(loc, *coro.asyncToken); rewriter.eraseOp(op); rewriter.create(loc, coro.cleanup); return success(); } private: FuncCoroMapPtr coros; }; } // namespace //===----------------------------------------------------------------------===// // Convert async.await and async.await_all operations to the async.runtime.await // or async.runtime.await_and_resume operations. //===----------------------------------------------------------------------===// namespace { template class AwaitOpLoweringBase : public OpConversionPattern { using AwaitAdaptor = typename AwaitType::Adaptor; public: AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros, bool shouldLowerBlockingWait) : OpConversionPattern(ctx), coros(std::move(coros)), shouldLowerBlockingWait(shouldLowerBlockingWait) {} LogicalResult matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). if (!isa(op.getOperand().getType())) return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); // Check if await operation is inside the coroutine function. auto func = op->template getParentOfType(); auto funcCoro = coros->find(func); const bool isInCoroutine = funcCoro != coros->end(); Location loc = op->getLoc(); Value operand = adaptor.getOperand(); Type i1 = rewriter.getI1Type(); // Delay lowering to block wait in case await op is inside async.execute if (!isInCoroutine && !shouldLowerBlockingWait) return failure(); // Inside regular functions we use the blocking wait operation to wait for // the async object (token, value or group) to become available. if (!isInCoroutine) { ImplicitLocOpBuilder builder(loc, rewriter); builder.create(loc, operand); // Assert that the awaited operands is not in the error state. Value isError = builder.create(i1, operand); Value notError = builder.create( isError, builder.create( loc, i1, builder.getIntegerAttr(i1, 1))); builder.create(notError, "Awaited async operand is in error state"); } // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. if (isInCoroutine) { CoroMachinery &coro = funcCoro->getSecond(); Block *suspended = op->getBlock(); ImplicitLocOpBuilder builder(loc, rewriter); MLIRContext *ctx = op->getContext(); // Save the coroutine state and resume on a runtime managed thread when // the operand becomes available. auto coroSaveOp = builder.create(CoroStateType::get(ctx), coro.coroHandle); builder.create(operand, coro.coroHandle); // Split the entry block before the await operation. Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); // Add async.coro.suspend as a suspended block terminator. builder.setInsertionPointToEnd(suspended); builder.create(coroSaveOp.getState(), coro.suspend, resume, coro.cleanupForDestroy); // Split the resume block into error checking and continuation. Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); // Check if the awaited value is in the error state. builder.setInsertionPointToStart(resume); auto isError = builder.create(loc, i1, operand); builder.create(isError, /*trueDest=*/setupSetErrorBlock(coro), /*trueArgs=*/ArrayRef(), /*falseDest=*/continuation, /*falseArgs=*/ArrayRef()); // Make sure that replacement value will be constructed in the // continuation block. rewriter.setInsertionPointToStart(continuation); } // Erase or replace the await operation with the new value. if (Value replaceWith = getReplacementValue(op, operand, rewriter)) rewriter.replaceOp(op, replaceWith); else rewriter.eraseOp(op); return success(); } virtual Value getReplacementValue(AwaitType op, Value operand, ConversionPatternRewriter &rewriter) const { return Value(); } private: FuncCoroMapPtr coros; bool shouldLowerBlockingWait; }; /// Lowering for `async.await` with a token operand. class AwaitTokenOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: using Base::Base; }; /// Lowering for `async.await` with a value operand. class AwaitValueOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: using Base::Base; Value getReplacementValue(AwaitOp op, Value operand, ConversionPatternRewriter &rewriter) const override { // Load from the async value storage. auto valueType = cast(operand.getType()).getValueType(); return rewriter.create(op->getLoc(), valueType, operand); } }; /// Lowering for `async.await_all` operation. class AwaitAllOpLowering : public AwaitOpLoweringBase { using Base = AwaitOpLoweringBase; public: using Base::Base; }; } // namespace //===----------------------------------------------------------------------===// // Convert async.yield operation to async.runtime operations. //===----------------------------------------------------------------------===// class YieldOpLowering : public OpConversionPattern { public: YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) : OpConversionPattern(ctx), coros(std::move(coros)) {} LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the async coroutine function. auto func = op->template getParentOfType(); auto funcCoro = coros->find(func); if (funcCoro == coros->end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); Location loc = op->getLoc(); const CoroMachinery &coro = funcCoro->getSecond(); // Store yielded values into the async values storage and switch async // values state to available. for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); rewriter.create(loc, yieldValue, asyncValue); rewriter.create(loc, asyncValue); } if (coro.asyncToken) // Switch the coroutine completion token to available state. rewriter.create(loc, *coro.asyncToken); rewriter.eraseOp(op); rewriter.create(loc, coro.cleanup); return success(); } private: FuncCoroMapPtr coros; }; //===----------------------------------------------------------------------===// // Convert cf.assert operation to cf.cond_br into `set_error` block. //===----------------------------------------------------------------------===// class AssertOpLowering : public OpConversionPattern { public: AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) : OpConversionPattern(ctx), coros(std::move(coros)) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if assert operation is inside the async coroutine function. auto func = op->template getParentOfType(); auto funcCoro = coros->find(func); if (funcCoro == coros->end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); Location loc = op->getLoc(); CoroMachinery &coro = funcCoro->getSecond(); Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); rewriter.create(loc, adaptor.getArg(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), /*falseArgs=*/ArrayRef()); rewriter.eraseOp(op); return success(); } private: FuncCoroMapPtr coros; }; //===----------------------------------------------------------------------===// void AsyncToAsyncRuntimePass::runOnOperation() { ModuleOp module = getOperation(); SymbolTable symbolTable(module); // Functions with coroutine CFG setups, which are results of outlining // `async.execute` body regions FuncCoroMapPtr coros = std::make_shared>(); module.walk([&](ExecuteOp execute) { coros->insert(outlineExecuteOp(symbolTable, execute)); }); LLVM_DEBUG({ llvm::dbgs() << "Outlined " << coros->size() << " functions built from async.execute operations\n"; }); // Returns true if operation is inside the coroutine. auto isInCoroutine = [&](Operation *op) -> bool { auto parentFunc = op->getParentOfType(); return coros->find(parentFunc) != coros->end(); }; // Lower async operations to async.runtime operations. MLIRContext *ctx = module->getContext(); RewritePatternSet asyncPatterns(ctx); // Conversion to async runtime augments original CFG with the coroutine CFG, // and we have to make sure that structured control flow operations with async // operations in nested regions will be converted to branch-based control flow // before we add the coroutine basic blocks. populateSCFToControlFlowConversionPatterns(asyncPatterns); // Async lowering does not use type converter because it must preserve all // types for async.runtime operations. asyncPatterns.add(ctx); asyncPatterns .add( ctx, coros, /*should_lower_blocking_wait=*/true); // Lower assertions to conditional branches into error blocks. asyncPatterns.add(ctx, coros); // All high level async operations must be lowered to the runtime operations. ConversionTarget runtimeTarget(*ctx); runtimeTarget.addLegalDialect(); runtimeTarget.addIllegalOp(); runtimeTarget.addIllegalOp(); // Decide if structured control flow has to be lowered to branch-based CFG. runtimeTarget.addDynamicallyLegalDialect([&](Operation *op) { auto walkResult = op->walk([&](Operation *nested) { bool isAsync = isa(nested->getDialect()); return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() : WalkResult::advance(); }); return !walkResult.wasInterrupted(); }); runtimeTarget.addLegalOp(); // Assertions must be converted to runtime errors inside async functions. runtimeTarget.addDynamicallyLegalOp( [&](cf::AssertOp op) -> bool { auto func = op->getParentOfType(); return !coros->contains(func); }); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { signalPassFailure(); return; } } //===----------------------------------------------------------------------===// void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns( RewritePatternSet &patterns, ConversionTarget &target) { // Functions with coroutine CFG setups, which are results of converting // async.func. FuncCoroMapPtr coros = std::make_shared>(); MLIRContext *ctx = patterns.getContext(); // Lower async.func to func.func with coroutine cfg. patterns.add(ctx); patterns.add(ctx, coros); patterns.add( ctx, coros, /*should_lower_blocking_wait=*/false); patterns.add(ctx, coros); target.addDynamicallyLegalOp( [coros](Operation *op) { auto exec = op->getParentOfType(); auto func = op->getParentOfType(); return exec || !coros->contains(func); }); } void AsyncFuncToAsyncRuntimePass::runOnOperation() { ModuleOp module = getOperation(); // Lower async operations to async.runtime operations. MLIRContext *ctx = module->getContext(); RewritePatternSet asyncPatterns(ctx); ConversionTarget runtimeTarget(*ctx); // Lower async.func to func.func with coroutine cfg. populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns, runtimeTarget); runtimeTarget.addLegalDialect(); runtimeTarget.addIllegalOp(); runtimeTarget.addLegalOp(); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { signalPassFailure(); return; } } std::unique_ptr> mlir::createAsyncToAsyncRuntimePass() { return std::make_unique(); } std::unique_ptr> mlir::createAsyncFuncToAsyncRuntimePass() { return std::make_unique(); }