125f80e16SEugene Zhulenev //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
225f80e16SEugene Zhulenev //
325f80e16SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
425f80e16SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
525f80e16SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
625f80e16SEugene Zhulenev //
725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
825f80e16SEugene Zhulenev //
925f80e16SEugene Zhulenev // This file implements lowering from high level async operations to async.coro
1025f80e16SEugene Zhulenev // and async.runtime operations.
1125f80e16SEugene Zhulenev //
1225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
1325f80e16SEugene Zhulenev
14190cdf44SMehdi Amini #include <utility>
15190cdf44SMehdi Amini
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h"
1767d0d7acSMichele Scuttari
1825f80e16SEugene Zhulenev #include "PassDetail.h"
19ace01605SRiver Riddle #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
20abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
2125f80e16SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
22ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
248b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
254d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
2625f80e16SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
2725f80e16SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
2825f80e16SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h"
2925f80e16SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h"
3025f80e16SEugene Zhulenev #include "llvm/ADT/SetVector.h"
31297a5b7cSNico Weber #include "llvm/Support/Debug.h"
32f3b7b300SKazu Hirata #include <optional>
3325f80e16SEugene Zhulenev
3467d0d7acSMichele Scuttari namespace mlir {
3567d0d7acSMichele Scuttari #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
366cca6b9aSyijiagu #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
3767d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h.inc"
3867d0d7acSMichele Scuttari } // namespace mlir
3967d0d7acSMichele Scuttari
4025f80e16SEugene Zhulenev using namespace mlir;
4125f80e16SEugene Zhulenev using namespace mlir::async;
4225f80e16SEugene Zhulenev
4325f80e16SEugene Zhulenev #define DEBUG_TYPE "async-to-async-runtime"
4425f80e16SEugene Zhulenev // Prefix for functions outlined from `async.execute` op regions.
4525f80e16SEugene Zhulenev static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
4625f80e16SEugene Zhulenev
4725f80e16SEugene Zhulenev namespace {
4825f80e16SEugene Zhulenev
4925f80e16SEugene Zhulenev class AsyncToAsyncRuntimePass
5067d0d7acSMichele Scuttari : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
5125f80e16SEugene Zhulenev public:
5225f80e16SEugene Zhulenev AsyncToAsyncRuntimePass() = default;
5325f80e16SEugene Zhulenev void runOnOperation() override;
5425f80e16SEugene Zhulenev };
5525f80e16SEugene Zhulenev
5625f80e16SEugene Zhulenev } // namespace
5725f80e16SEugene Zhulenev
586cca6b9aSyijiagu namespace {
596cca6b9aSyijiagu
606cca6b9aSyijiagu class AsyncFuncToAsyncRuntimePass
616cca6b9aSyijiagu : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
626cca6b9aSyijiagu public:
636cca6b9aSyijiagu AsyncFuncToAsyncRuntimePass() = default;
646cca6b9aSyijiagu void runOnOperation() override;
656cca6b9aSyijiagu };
666cca6b9aSyijiagu
676cca6b9aSyijiagu } // namespace
686cca6b9aSyijiagu
6925f80e16SEugene Zhulenev /// Function targeted for coroutine transformation has two additional blocks at
7025f80e16SEugene Zhulenev /// the end: coroutine cleanup and coroutine suspension.
7125f80e16SEugene Zhulenev ///
7225f80e16SEugene Zhulenev /// async.await op lowering additionaly creates a resume block for each
7325f80e16SEugene Zhulenev /// operation to enable non-blocking waiting via coroutine suspension.
7425f80e16SEugene Zhulenev namespace {
7525f80e16SEugene Zhulenev struct CoroMachinery {
7658ceae95SRiver Riddle func::FuncOp func;
7739957aa4SEugene Zhulenev
78f81f8808Syijiagu // Async function returns an optional token, followed by some async values
79f81f8808Syijiagu //
80f81f8808Syijiagu // async.func @foo() -> !async.value<T> {
81f81f8808Syijiagu // %cst = arith.constant 42.0 : T
82f81f8808Syijiagu // return %cst: T
83f81f8808Syijiagu // }
8425f80e16SEugene Zhulenev // Async execute region returns a completion token, and an async value for
8525f80e16SEugene Zhulenev // each yielded value.
8625f80e16SEugene Zhulenev //
8725f80e16SEugene Zhulenev // %token, %result = async.execute -> !async.value<T> {
88cb3aa49eSMogball // %0 = arith.constant ... : T
8925f80e16SEugene Zhulenev // async.yield %0 : T
9025f80e16SEugene Zhulenev // }
910a81ace0SKazu Hirata std::optional<Value> asyncToken; // returned completion token
9225f80e16SEugene Zhulenev llvm::SmallVector<Value, 4> returnValues; // returned async values
9325f80e16SEugene Zhulenev
94a5aa7836SRiver Riddle Value coroHandle; // coroutine handle (!async.coro.getHandle value)
951c144410Sbakhtiyar Block *entry; // coroutine entry block
96f3b7b300SKazu Hirata std::optional<Block *> setError; // set returned values to error state
9725f80e16SEugene Zhulenev Block *cleanup; // coroutine cleanup block
98af562fd2SYunlong Liu
99af562fd2SYunlong Liu // Coroutine cleanup block for destroy after the coroutine is resumed,
100af562fd2SYunlong Liu // e.g. async.coro.suspend state, [suspend], [resume], [destroy]
101af562fd2SYunlong Liu //
102af562fd2SYunlong Liu // This cleanup block is a duplicate of the cleanup block followed by the
103af562fd2SYunlong Liu // resume block. The purpose of having a duplicate cleanup block for destroy
104af562fd2SYunlong Liu // is to make the CFG clear so that the control flow analysis won't confuse.
105af562fd2SYunlong Liu //
106af562fd2SYunlong Liu // The overall structure of the lowered CFG can be the following,
107af562fd2SYunlong Liu //
108af562fd2SYunlong Liu // Entry (calling async.coro.suspend)
109af562fd2SYunlong Liu // | \
110af562fd2SYunlong Liu // Resume Destroy (duplicate of Cleanup)
111af562fd2SYunlong Liu // | |
112af562fd2SYunlong Liu // Cleanup |
113af562fd2SYunlong Liu // | /
114af562fd2SYunlong Liu // End (ends the corontine)
115af562fd2SYunlong Liu //
116af562fd2SYunlong Liu // If there is resume-specific cleanup logic, it can go into the Cleanup
117af562fd2SYunlong Liu // block but not the destroy block. Otherwise, it can fail block dominance
118af562fd2SYunlong Liu // check.
119af562fd2SYunlong Liu Block *cleanupForDestroy;
12025f80e16SEugene Zhulenev Block *suspend; // coroutine suspension block
12125f80e16SEugene Zhulenev };
12225f80e16SEugene Zhulenev } // namespace
12325f80e16SEugene Zhulenev
1246cca6b9aSyijiagu using FuncCoroMapPtr =
1256cca6b9aSyijiagu std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
1266cca6b9aSyijiagu
1276ea22d46Sbakhtiyar /// Utility to partially update the regular function CFG to the coroutine CFG
1286ea22d46Sbakhtiyar /// compatible with LLVM coroutines switched-resume lowering using
1291c144410Sbakhtiyar /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
1301c144410Sbakhtiyar /// that branches into preexisting entry block. Also inserts trailing blocks.
1316ea22d46Sbakhtiyar ///
132f81f8808Syijiagu /// The result types of the passed `func` start with an optional `async.token`
1336ea22d46Sbakhtiyar /// and be continued with some number of `async.value`s.
1346ea22d46Sbakhtiyar ///
13525f80e16SEugene Zhulenev /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
13625f80e16SEugene Zhulenev ///
13725f80e16SEugene Zhulenev /// - `entry` block sets up the coroutine.
13839957aa4SEugene Zhulenev /// - `set_error` block sets completion token and async values state to error.
13925f80e16SEugene Zhulenev /// - `cleanup` block cleans up the coroutine state.
14025f80e16SEugene Zhulenev /// - `suspend block after the @llvm.coro.end() defines what value will be
14125f80e16SEugene Zhulenev /// returned to the initial caller of a coroutine. Everything before the
14225f80e16SEugene Zhulenev /// @llvm.coro.end() will be executed at every suspension point.
14325f80e16SEugene Zhulenev ///
14425f80e16SEugene Zhulenev /// Coroutine structure (only the important bits):
14525f80e16SEugene Zhulenev ///
1466ea22d46Sbakhtiyar /// func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
14725f80e16SEugene Zhulenev /// {
14825f80e16SEugene Zhulenev /// ^entry(<function-arguments>):
14925f80e16SEugene Zhulenev /// %token = <async token> : !async.token // create async runtime token
15025f80e16SEugene Zhulenev /// %value = <async value> : !async.value<T> // create async value
151a5aa7836SRiver Riddle /// %id = async.coro.getId // create a coroutine id
15225f80e16SEugene Zhulenev /// %hdl = async.coro.begin %id // create a coroutine handle
153ace01605SRiver Riddle /// cf.br ^preexisting_entry_block
1546ea22d46Sbakhtiyar ///
1551c144410Sbakhtiyar /// /* preexisting blocks modified to branch to the cleanup block */
15625f80e16SEugene Zhulenev ///
15739957aa4SEugene Zhulenev /// ^set_error: // this block created lazily only if needed (see code below)
15839957aa4SEugene Zhulenev /// async.runtime.set_error %token : !async.token
15939957aa4SEugene Zhulenev /// async.runtime.set_error %value : !async.value<T>
160ace01605SRiver Riddle /// cf.br ^cleanup
16139957aa4SEugene Zhulenev ///
16225f80e16SEugene Zhulenev /// ^cleanup:
16325f80e16SEugene Zhulenev /// async.coro.free %hdl // delete the coroutine state
164ace01605SRiver Riddle /// cf.br ^suspend
16525f80e16SEugene Zhulenev ///
16625f80e16SEugene Zhulenev /// ^suspend:
16725f80e16SEugene Zhulenev /// async.coro.end %hdl // marks the end of a coroutine
16825f80e16SEugene Zhulenev /// return %token, %value : !async.token, !async.value<T>
16925f80e16SEugene Zhulenev /// }
17025f80e16SEugene Zhulenev ///
setupCoroMachinery(func::FuncOp func)17158ceae95SRiver Riddle static CoroMachinery setupCoroMachinery(func::FuncOp func) {
1726ea22d46Sbakhtiyar assert(!func.getBlocks().empty() && "Function must have an entry block");
17325f80e16SEugene Zhulenev
17425f80e16SEugene Zhulenev MLIRContext *ctx = func.getContext();
1756ea22d46Sbakhtiyar Block *entryBlock = &func.getBlocks().front();
1761c144410Sbakhtiyar Block *originalEntryBlock =
1771c144410Sbakhtiyar entryBlock->splitBlock(entryBlock->getOperations().begin());
17825f80e16SEugene Zhulenev auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
17925f80e16SEugene Zhulenev
18025f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
18125f80e16SEugene Zhulenev // Allocate async token/values that we will return from a ramp function.
18225f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
183f81f8808Syijiagu
184f81f8808Syijiagu // We treat TokenType as state update marker to represent side-effects of
185f81f8808Syijiagu // async computations
18634a35a8bSMartin Erhart bool isStateful = isa<TokenType>(func.getResultTypes().front());
187f81f8808Syijiagu
1880a81ace0SKazu Hirata std::optional<Value> retToken;
189f81f8808Syijiagu if (isStateful)
190f81f8808Syijiagu retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
19125f80e16SEugene Zhulenev
19225f80e16SEugene Zhulenev llvm::SmallVector<Value, 4> retValues;
19334a35a8bSMartin Erhart ArrayRef<Type> resValueTypes =
19434a35a8bSMartin Erhart isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195f81f8808Syijiagu for (auto resType : resValueTypes)
196a5aa7836SRiver Riddle retValues.emplace_back(
197a5aa7836SRiver Riddle builder.create<RuntimeCreateOp>(resType).getResult());
19825f80e16SEugene Zhulenev
19925f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
20025f80e16SEugene Zhulenev // Initialize coroutine: get coroutine id and coroutine handle.
20125f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
20225f80e16SEugene Zhulenev auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
20325f80e16SEugene Zhulenev auto coroHdlOp =
204a5aa7836SRiver Riddle builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
205ace01605SRiver Riddle builder.create<cf::BranchOp>(originalEntryBlock);
20625f80e16SEugene Zhulenev
20725f80e16SEugene Zhulenev Block *cleanupBlock = func.addBlock();
208af562fd2SYunlong Liu Block *cleanupBlockForDestroy = func.addBlock();
20925f80e16SEugene Zhulenev Block *suspendBlock = func.addBlock();
21025f80e16SEugene Zhulenev
21125f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
212af562fd2SYunlong Liu // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
21325f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
214af562fd2SYunlong Liu auto buildCleanupBlock = [&](Block *cb) {
215af562fd2SYunlong Liu builder.setInsertionPointToStart(cb);
216a5aa7836SRiver Riddle builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
21725f80e16SEugene Zhulenev
21825f80e16SEugene Zhulenev // Branch into the suspend block.
219ace01605SRiver Riddle builder.create<cf::BranchOp>(suspendBlock);
220af562fd2SYunlong Liu };
221af562fd2SYunlong Liu buildCleanupBlock(cleanupBlock);
222af562fd2SYunlong Liu buildCleanupBlock(cleanupBlockForDestroy);
22325f80e16SEugene Zhulenev
22425f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
22525f80e16SEugene Zhulenev // Coroutine suspend block: mark the end of a coroutine and return allocated
22625f80e16SEugene Zhulenev // async token.
22725f80e16SEugene Zhulenev // ------------------------------------------------------------------------ //
22825f80e16SEugene Zhulenev builder.setInsertionPointToStart(suspendBlock);
22925f80e16SEugene Zhulenev
23025f80e16SEugene Zhulenev // Mark the end of a coroutine: async.coro.end
231a5aa7836SRiver Riddle builder.create<CoroEndOp>(coroHdlOp.getHandle());
23225f80e16SEugene Zhulenev
233f81f8808Syijiagu // Return created optional `async.token` and `async.values` from the suspend
234f81f8808Syijiagu // block. This will be the return value of a coroutine ramp function.
235f81f8808Syijiagu SmallVector<Value, 4> ret;
236f81f8808Syijiagu if (retToken)
237f81f8808Syijiagu ret.push_back(*retToken);
23825f80e16SEugene Zhulenev ret.insert(ret.end(), retValues.begin(), retValues.end());
23923aa5a74SRiver Riddle builder.create<func::ReturnOp>(ret);
24025f80e16SEugene Zhulenev
24125f80e16SEugene Zhulenev // `async.await` op lowering will create resume blocks for async
24225f80e16SEugene Zhulenev // continuations, and will conditionally branch to cleanup or suspend blocks.
24325f80e16SEugene Zhulenev
244c75cedc2SChuanqi Xu // The switch-resumed API based coroutine should be marked with
245be690ea3Syonillasky // presplitcoroutine attribute to mark the function as a coroutine.
246735e6c40SChuanqi Xu func->setAttr("passthrough", builder.getArrayAttr(
247735e6c40SChuanqi Xu StringAttr::get(ctx, "presplitcoroutine")));
248c75cedc2SChuanqi Xu
24925f80e16SEugene Zhulenev CoroMachinery machinery;
25039957aa4SEugene Zhulenev machinery.func = func;
25125f80e16SEugene Zhulenev machinery.asyncToken = retToken;
25225f80e16SEugene Zhulenev machinery.returnValues = retValues;
253a5aa7836SRiver Riddle machinery.coroHandle = coroHdlOp.getHandle();
2541c144410Sbakhtiyar machinery.entry = entryBlock;
2551a36588eSKazu Hirata machinery.setError = std::nullopt; // created lazily only if needed
25625f80e16SEugene Zhulenev machinery.cleanup = cleanupBlock;
257af562fd2SYunlong Liu machinery.cleanupForDestroy = cleanupBlockForDestroy;
25825f80e16SEugene Zhulenev machinery.suspend = suspendBlock;
25925f80e16SEugene Zhulenev return machinery;
26025f80e16SEugene Zhulenev }
26125f80e16SEugene Zhulenev
26239957aa4SEugene Zhulenev // Lazily creates `set_error` block only if it is required for lowering to the
26339957aa4SEugene Zhulenev // runtime operations (see for example lowering of assert operation).
setupSetErrorBlock(CoroMachinery & coro)26439957aa4SEugene Zhulenev static Block *setupSetErrorBlock(CoroMachinery &coro) {
26539957aa4SEugene Zhulenev if (coro.setError)
266f81f8808Syijiagu return *coro.setError;
26739957aa4SEugene Zhulenev
26839957aa4SEugene Zhulenev coro.setError = coro.func.addBlock();
269f81f8808Syijiagu (*coro.setError)->moveBefore(coro.cleanup);
27039957aa4SEugene Zhulenev
27139957aa4SEugene Zhulenev auto builder =
272f81f8808Syijiagu ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
27339957aa4SEugene Zhulenev
27439957aa4SEugene Zhulenev // Coroutine set_error block: set error on token and all returned values.
275f81f8808Syijiagu if (coro.asyncToken)
276f81f8808Syijiagu builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
277f81f8808Syijiagu
27839957aa4SEugene Zhulenev for (Value retValue : coro.returnValues)
27939957aa4SEugene Zhulenev builder.create<RuntimeSetErrorOp>(retValue);
28039957aa4SEugene Zhulenev
28139957aa4SEugene Zhulenev // Branch into the cleanup block.
282ace01605SRiver Riddle builder.create<cf::BranchOp>(coro.cleanup);
28339957aa4SEugene Zhulenev
284f81f8808Syijiagu return *coro.setError;
28539957aa4SEugene Zhulenev }
28639957aa4SEugene Zhulenev
287f81f8808Syijiagu //===----------------------------------------------------------------------===//
288f81f8808Syijiagu // async.execute op outlining to the coroutine functions.
289f81f8808Syijiagu //===----------------------------------------------------------------------===//
290f81f8808Syijiagu
29125f80e16SEugene Zhulenev /// Outline the body region attached to the `async.execute` op into a standalone
29225f80e16SEugene Zhulenev /// function.
29325f80e16SEugene Zhulenev ///
29425f80e16SEugene Zhulenev /// Note that this is not reversible transformation.
29558ceae95SRiver Riddle static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)29625f80e16SEugene Zhulenev outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
29725f80e16SEugene Zhulenev ModuleOp module = execute->getParentOfType<ModuleOp>();
29825f80e16SEugene Zhulenev
29925f80e16SEugene Zhulenev MLIRContext *ctx = module.getContext();
30025f80e16SEugene Zhulenev Location loc = execute.getLoc();
30125f80e16SEugene Zhulenev
302b537c5b4SEugene Zhulenev // Make sure that all constants will be inside the outlined async function to
303b537c5b4SEugene Zhulenev // reduce the number of function arguments.
304a5aa7836SRiver Riddle cloneConstantsIntoTheRegion(execute.getBodyRegion());
305b537c5b4SEugene Zhulenev
30625f80e16SEugene Zhulenev // Collect all outlined function inputs.
307a5aa7836SRiver Riddle SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
308a5aa7836SRiver Riddle execute.getDependencies().end());
309a5aa7836SRiver Riddle functionInputs.insert(execute.getBodyOperands().begin(),
310a5aa7836SRiver Riddle execute.getBodyOperands().end());
311a5aa7836SRiver Riddle getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
31225f80e16SEugene Zhulenev
31325f80e16SEugene Zhulenev // Collect types for the outlined function inputs and outputs.
31425f80e16SEugene Zhulenev auto typesRange = llvm::map_range(
31525f80e16SEugene Zhulenev functionInputs, [](Value value) { return value.getType(); });
31625f80e16SEugene Zhulenev SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
31725f80e16SEugene Zhulenev auto outputTypes = execute.getResultTypes();
31825f80e16SEugene Zhulenev
31925f80e16SEugene Zhulenev auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
32025f80e16SEugene Zhulenev auto funcAttrs = ArrayRef<NamedAttribute>();
32125f80e16SEugene Zhulenev
32225f80e16SEugene Zhulenev // TODO: Derive outlined function name from the parent FuncOp (support
32325f80e16SEugene Zhulenev // multiple nested async.execute operations).
32458ceae95SRiver Riddle func::FuncOp func =
32558ceae95SRiver Riddle func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
326973ddb7dSMehdi Amini symbolTable.insert(func);
32725f80e16SEugene Zhulenev
32825f80e16SEugene Zhulenev SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
3291c144410Sbakhtiyar auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
33025f80e16SEugene Zhulenev
3311c144410Sbakhtiyar // Prepare for coroutine conversion by creating the body of the function.
3321c144410Sbakhtiyar {
333a5aa7836SRiver Riddle size_t numDependencies = execute.getDependencies().size();
334a5aa7836SRiver Riddle size_t numOperands = execute.getBodyOperands().size();
33525f80e16SEugene Zhulenev
33625f80e16SEugene Zhulenev // Await on all dependencies before starting to execute the body region.
33725f80e16SEugene Zhulenev for (size_t i = 0; i < numDependencies; ++i)
33825f80e16SEugene Zhulenev builder.create<AwaitOp>(func.getArgument(i));
33925f80e16SEugene Zhulenev
34025f80e16SEugene Zhulenev // Await on all async value operands and unwrap the payload.
34125f80e16SEugene Zhulenev SmallVector<Value, 4> unwrappedOperands(numOperands);
34225f80e16SEugene Zhulenev for (size_t i = 0; i < numOperands; ++i) {
34325f80e16SEugene Zhulenev Value operand = func.getArgument(numDependencies + i);
344a5aa7836SRiver Riddle unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
34525f80e16SEugene Zhulenev }
34625f80e16SEugene Zhulenev
34725f80e16SEugene Zhulenev // Map from function inputs defined above the execute op to the function
34825f80e16SEugene Zhulenev // arguments.
3494d67b278SJeff Niu IRMapping valueMapping;
35025f80e16SEugene Zhulenev valueMapping.map(functionInputs, func.getArguments());
351a5aa7836SRiver Riddle valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
35225f80e16SEugene Zhulenev
35325f80e16SEugene Zhulenev // Clone all operations from the execute operation body into the outlined
35425f80e16SEugene Zhulenev // function body.
355a5aa7836SRiver Riddle for (Operation &op : execute.getBodyRegion().getOps())
35625f80e16SEugene Zhulenev builder.clone(op, valueMapping);
3571c144410Sbakhtiyar }
3581c144410Sbakhtiyar
3591c144410Sbakhtiyar // Adding entry/cleanup/suspend blocks.
3601c144410Sbakhtiyar CoroMachinery coro = setupCoroMachinery(func);
3611c144410Sbakhtiyar
3621c144410Sbakhtiyar // Suspend async function at the end of an entry block, and resume it using
3631c144410Sbakhtiyar // Async resume operation (execution will be resumed in a thread managed by
3641c144410Sbakhtiyar // the async runtime).
3651c144410Sbakhtiyar {
366ace01605SRiver Riddle cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
3671c144410Sbakhtiyar builder.setInsertionPointToEnd(coro.entry);
3681c144410Sbakhtiyar
3691c144410Sbakhtiyar // Save the coroutine state: async.coro.save
3701c144410Sbakhtiyar auto coroSaveOp =
3711c144410Sbakhtiyar builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
3721c144410Sbakhtiyar
3731c144410Sbakhtiyar // Pass coroutine to the runtime to be resumed on a runtime managed
3741c144410Sbakhtiyar // thread.
3751c144410Sbakhtiyar builder.create<RuntimeResumeOp>(coro.coroHandle);
3761c144410Sbakhtiyar
3771c144410Sbakhtiyar // Add async.coro.suspend as a suspended block terminator.
378a5aa7836SRiver Riddle builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
379af562fd2SYunlong Liu branch.getDest(), coro.cleanupForDestroy);
3801c144410Sbakhtiyar
3811c144410Sbakhtiyar branch.erase();
3821c144410Sbakhtiyar }
38325f80e16SEugene Zhulenev
38425f80e16SEugene Zhulenev // Replace the original `async.execute` with a call to outlined function.
3851c144410Sbakhtiyar {
38625f80e16SEugene Zhulenev ImplicitLocOpBuilder callBuilder(loc, execute);
38723aa5a74SRiver Riddle auto callOutlinedFunc = callBuilder.create<func::CallOp>(
38825f80e16SEugene Zhulenev func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
38925f80e16SEugene Zhulenev execute.replaceAllUsesWith(callOutlinedFunc.getResults());
39025f80e16SEugene Zhulenev execute.erase();
3911c144410Sbakhtiyar }
39225f80e16SEugene Zhulenev
39325f80e16SEugene Zhulenev return {func, coro};
39425f80e16SEugene Zhulenev }
39525f80e16SEugene Zhulenev
39625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
397d43b2360SEugene Zhulenev // Convert async.create_group operation to async.runtime.create_group
39825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
39925f80e16SEugene Zhulenev
40025f80e16SEugene Zhulenev namespace {
40125f80e16SEugene Zhulenev class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
40225f80e16SEugene Zhulenev public:
40325f80e16SEugene Zhulenev using OpConversionPattern::OpConversionPattern;
40425f80e16SEugene Zhulenev
40525f80e16SEugene Zhulenev LogicalResult
matchAndRewrite(CreateGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const406b54c724bSRiver Riddle matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
40725f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override {
408d43b2360SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
409b54c724bSRiver Riddle op, GroupType::get(op->getContext()), adaptor.getOperands());
41025f80e16SEugene Zhulenev return success();
41125f80e16SEugene Zhulenev }
41225f80e16SEugene Zhulenev };
41325f80e16SEugene Zhulenev } // namespace
41425f80e16SEugene Zhulenev
41525f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
41625f80e16SEugene Zhulenev // Convert async.add_to_group operation to async.runtime.add_to_group.
41725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
41825f80e16SEugene Zhulenev
41925f80e16SEugene Zhulenev namespace {
42025f80e16SEugene Zhulenev class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
42125f80e16SEugene Zhulenev public:
42225f80e16SEugene Zhulenev using OpConversionPattern::OpConversionPattern;
42325f80e16SEugene Zhulenev
42425f80e16SEugene Zhulenev LogicalResult
matchAndRewrite(AddToGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const425b54c724bSRiver Riddle matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
42625f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override {
42725f80e16SEugene Zhulenev rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
428b54c724bSRiver Riddle op, rewriter.getIndexType(), adaptor.getOperands());
42925f80e16SEugene Zhulenev return success();
43025f80e16SEugene Zhulenev }
43125f80e16SEugene Zhulenev };
43225f80e16SEugene Zhulenev } // namespace
43325f80e16SEugene Zhulenev
43425f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
435f81f8808Syijiagu // Convert async.func, async.return and async.call operations to non-blocking
436f81f8808Syijiagu // operations based on llvm coroutine
437f81f8808Syijiagu //===----------------------------------------------------------------------===//
438f81f8808Syijiagu
439f81f8808Syijiagu namespace {
440f81f8808Syijiagu
441f81f8808Syijiagu //===----------------------------------------------------------------------===//
442f81f8808Syijiagu // Convert async.func operation to func.func
443f81f8808Syijiagu //===----------------------------------------------------------------------===//
444f81f8808Syijiagu
445f81f8808Syijiagu class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
446f81f8808Syijiagu public:
AsyncFuncOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)4476cca6b9aSyijiagu AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
44898b13979SMehdi Amini : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
449f81f8808Syijiagu
450f81f8808Syijiagu LogicalResult
matchAndRewrite(async::FuncOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const451f81f8808Syijiagu matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
452f81f8808Syijiagu ConversionPatternRewriter &rewriter) const override {
453f81f8808Syijiagu Location loc = op->getLoc();
454f81f8808Syijiagu
455f81f8808Syijiagu auto newFuncOp =
456f81f8808Syijiagu rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
457f81f8808Syijiagu
458f81f8808Syijiagu SymbolTable::setSymbolVisibility(newFuncOp,
459f81f8808Syijiagu SymbolTable::getSymbolVisibility(op));
460f81f8808Syijiagu // Copy over all attributes other than the name.
461f81f8808Syijiagu for (const auto &namedAttr : op->getAttrs()) {
462f81f8808Syijiagu if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
463f81f8808Syijiagu newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
464f81f8808Syijiagu }
465f81f8808Syijiagu
466f81f8808Syijiagu rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
467f81f8808Syijiagu newFuncOp.end());
468f81f8808Syijiagu
469f81f8808Syijiagu CoroMachinery coro = setupCoroMachinery(newFuncOp);
47098b13979SMehdi Amini (*coros)[newFuncOp] = coro;
471f81f8808Syijiagu // no initial suspend, we should hot-start
472f81f8808Syijiagu
473f81f8808Syijiagu rewriter.eraseOp(op);
474f81f8808Syijiagu return success();
475f81f8808Syijiagu }
476f81f8808Syijiagu
477f81f8808Syijiagu private:
47898b13979SMehdi Amini FuncCoroMapPtr coros;
479f81f8808Syijiagu };
480f81f8808Syijiagu
481f81f8808Syijiagu //===----------------------------------------------------------------------===//
482f81f8808Syijiagu // Convert async.call operation to func.call
483f81f8808Syijiagu //===----------------------------------------------------------------------===//
484f81f8808Syijiagu
485f81f8808Syijiagu class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
486f81f8808Syijiagu public:
AsyncCallOpLowering(MLIRContext * ctx)487f81f8808Syijiagu AsyncCallOpLowering(MLIRContext *ctx)
488f81f8808Syijiagu : OpConversionPattern<async::CallOp>(ctx) {}
489f81f8808Syijiagu
490f81f8808Syijiagu LogicalResult
matchAndRewrite(async::CallOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const491f81f8808Syijiagu matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
492f81f8808Syijiagu ConversionPatternRewriter &rewriter) const override {
493f81f8808Syijiagu rewriter.replaceOpWithNewOp<func::CallOp>(
494f81f8808Syijiagu op, op.getCallee(), op.getResultTypes(), op.getOperands());
495f81f8808Syijiagu return success();
496f81f8808Syijiagu }
497f81f8808Syijiagu };
498f81f8808Syijiagu
499f81f8808Syijiagu //===----------------------------------------------------------------------===//
500f81f8808Syijiagu // Convert async.return operation to async.runtime operations.
501f81f8808Syijiagu //===----------------------------------------------------------------------===//
502f81f8808Syijiagu
503f81f8808Syijiagu class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
504f81f8808Syijiagu public:
AsyncReturnOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)5056cca6b9aSyijiagu AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
50698b13979SMehdi Amini : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
507f81f8808Syijiagu
508f81f8808Syijiagu LogicalResult
matchAndRewrite(async::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const509f81f8808Syijiagu matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
510f81f8808Syijiagu ConversionPatternRewriter &rewriter) const override {
511f81f8808Syijiagu auto func = op->template getParentOfType<func::FuncOp>();
51298b13979SMehdi Amini auto funcCoro = coros->find(func);
51398b13979SMehdi Amini if (funcCoro == coros->end())
514f81f8808Syijiagu return rewriter.notifyMatchFailure(
515f81f8808Syijiagu op, "operation is not inside the async coroutine function");
516f81f8808Syijiagu
517f81f8808Syijiagu Location loc = op->getLoc();
518f81f8808Syijiagu const CoroMachinery &coro = funcCoro->getSecond();
519f81f8808Syijiagu rewriter.setInsertionPointAfter(op);
520f81f8808Syijiagu
521f81f8808Syijiagu // Store return values into the async values storage and switch async
522f81f8808Syijiagu // values state to available.
523f81f8808Syijiagu for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
524f81f8808Syijiagu Value returnValue = std::get<0>(tuple);
525f81f8808Syijiagu Value asyncValue = std::get<1>(tuple);
526f81f8808Syijiagu rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
527f81f8808Syijiagu rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
528f81f8808Syijiagu }
529f81f8808Syijiagu
530f81f8808Syijiagu if (coro.asyncToken)
531f81f8808Syijiagu // Switch the coroutine completion token to available state.
532f81f8808Syijiagu rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
533f81f8808Syijiagu
534f81f8808Syijiagu rewriter.eraseOp(op);
535f81f8808Syijiagu rewriter.create<cf::BranchOp>(loc, coro.cleanup);
536f81f8808Syijiagu return success();
537f81f8808Syijiagu }
538f81f8808Syijiagu
539f81f8808Syijiagu private:
54098b13979SMehdi Amini FuncCoroMapPtr coros;
541f81f8808Syijiagu };
542f81f8808Syijiagu } // namespace
543f81f8808Syijiagu
544f81f8808Syijiagu //===----------------------------------------------------------------------===//
54525f80e16SEugene Zhulenev // Convert async.await and async.await_all operations to the async.runtime.await
54625f80e16SEugene Zhulenev // or async.runtime.await_and_resume operations.
54725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
54825f80e16SEugene Zhulenev
54925f80e16SEugene Zhulenev namespace {
55025f80e16SEugene Zhulenev template <typename AwaitType, typename AwaitableType>
55125f80e16SEugene Zhulenev class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
55225f80e16SEugene Zhulenev using AwaitAdaptor = typename AwaitType::Adaptor;
55325f80e16SEugene Zhulenev
55425f80e16SEugene Zhulenev public:
AwaitOpLoweringBase(MLIRContext * ctx,FuncCoroMapPtr coros,bool shouldLowerBlockingWait)5556cca6b9aSyijiagu AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
55698b13979SMehdi Amini bool shouldLowerBlockingWait)
55798b13979SMehdi Amini : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
55898b13979SMehdi Amini shouldLowerBlockingWait(shouldLowerBlockingWait) {}
55925f80e16SEugene Zhulenev
56025f80e16SEugene Zhulenev LogicalResult
matchAndRewrite(AwaitType op,typename AwaitType::Adaptor adaptor,ConversionPatternRewriter & rewriter) const561b54c724bSRiver Riddle matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
56225f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override {
56325f80e16SEugene Zhulenev // We can only await on one the `AwaitableType` (for `await` it can be
56425f80e16SEugene Zhulenev // a `token` or a `value`, for `await_all` it must be a `group`).
5655550c821STres Popp if (!isa<AwaitableType>(op.getOperand().getType()))
56625f80e16SEugene Zhulenev return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
56725f80e16SEugene Zhulenev
5686cca6b9aSyijiagu // Check if await operation is inside the coroutine function.
56958ceae95SRiver Riddle auto func = op->template getParentOfType<func::FuncOp>();
57098b13979SMehdi Amini auto funcCoro = coros->find(func);
57198b13979SMehdi Amini const bool isInCoroutine = funcCoro != coros->end();
57225f80e16SEugene Zhulenev
57325f80e16SEugene Zhulenev Location loc = op->getLoc();
574a5aa7836SRiver Riddle Value operand = adaptor.getOperand();
57525f80e16SEugene Zhulenev
576fd52b435SEugene Zhulenev Type i1 = rewriter.getI1Type();
577fd52b435SEugene Zhulenev
5786cca6b9aSyijiagu // Delay lowering to block wait in case await op is inside async.execute
57998b13979SMehdi Amini if (!isInCoroutine && !shouldLowerBlockingWait)
5806cca6b9aSyijiagu return failure();
5816cca6b9aSyijiagu
58225f80e16SEugene Zhulenev // Inside regular functions we use the blocking wait operation to wait for
58325f80e16SEugene Zhulenev // the async object (token, value or group) to become available.
584fd52b435SEugene Zhulenev if (!isInCoroutine) {
585*ea2d9383SMatthias Springer ImplicitLocOpBuilder builder(loc, rewriter);
586fd52b435SEugene Zhulenev builder.create<RuntimeAwaitOp>(loc, operand);
587fd52b435SEugene Zhulenev
588fd52b435SEugene Zhulenev // Assert that the awaited operands is not in the error state.
589fd52b435SEugene Zhulenev Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
590a54f4eaeSMogball Value notError = builder.create<arith::XOrIOp>(
591a54f4eaeSMogball isError, builder.create<arith::ConstantOp>(
592a54f4eaeSMogball loc, i1, builder.getIntegerAttr(i1, 1)));
593fd52b435SEugene Zhulenev
594ace01605SRiver Riddle builder.create<cf::AssertOp>(notError,
595fd52b435SEugene Zhulenev "Awaited async operand is in error state");
596fd52b435SEugene Zhulenev }
59725f80e16SEugene Zhulenev
59825f80e16SEugene Zhulenev // Inside the coroutine we convert await operation into coroutine suspension
59925f80e16SEugene Zhulenev // point, and resume execution asynchronously.
60025f80e16SEugene Zhulenev if (isInCoroutine) {
601f81f8808Syijiagu CoroMachinery &coro = funcCoro->getSecond();
60225f80e16SEugene Zhulenev Block *suspended = op->getBlock();
60325f80e16SEugene Zhulenev
604*ea2d9383SMatthias Springer ImplicitLocOpBuilder builder(loc, rewriter);
60525f80e16SEugene Zhulenev MLIRContext *ctx = op->getContext();
60625f80e16SEugene Zhulenev
60725f80e16SEugene Zhulenev // Save the coroutine state and resume on a runtime managed thread when
60825f80e16SEugene Zhulenev // the operand becomes available.
60925f80e16SEugene Zhulenev auto coroSaveOp =
61025f80e16SEugene Zhulenev builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
61125f80e16SEugene Zhulenev builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
61225f80e16SEugene Zhulenev
61325f80e16SEugene Zhulenev // Split the entry block before the await operation.
61425f80e16SEugene Zhulenev Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
61525f80e16SEugene Zhulenev
61625f80e16SEugene Zhulenev // Add async.coro.suspend as a suspended block terminator.
61725f80e16SEugene Zhulenev builder.setInsertionPointToEnd(suspended);
618a5aa7836SRiver Riddle builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
619af562fd2SYunlong Liu coro.cleanupForDestroy);
62025f80e16SEugene Zhulenev
62139957aa4SEugene Zhulenev // Split the resume block into error checking and continuation.
62239957aa4SEugene Zhulenev Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
62339957aa4SEugene Zhulenev
62439957aa4SEugene Zhulenev // Check if the awaited value is in the error state.
62539957aa4SEugene Zhulenev builder.setInsertionPointToStart(resume);
626fd52b435SEugene Zhulenev auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
627ace01605SRiver Riddle builder.create<cf::CondBranchOp>(isError,
62839957aa4SEugene Zhulenev /*trueDest=*/setupSetErrorBlock(coro),
62939957aa4SEugene Zhulenev /*trueArgs=*/ArrayRef<Value>(),
63039957aa4SEugene Zhulenev /*falseDest=*/continuation,
63139957aa4SEugene Zhulenev /*falseArgs=*/ArrayRef<Value>());
63239957aa4SEugene Zhulenev
63339957aa4SEugene Zhulenev // Make sure that replacement value will be constructed in the
63439957aa4SEugene Zhulenev // continuation block.
63539957aa4SEugene Zhulenev rewriter.setInsertionPointToStart(continuation);
63639957aa4SEugene Zhulenev }
63725f80e16SEugene Zhulenev
63825f80e16SEugene Zhulenev // Erase or replace the await operation with the new value.
63925f80e16SEugene Zhulenev if (Value replaceWith = getReplacementValue(op, operand, rewriter))
64025f80e16SEugene Zhulenev rewriter.replaceOp(op, replaceWith);
64125f80e16SEugene Zhulenev else
64225f80e16SEugene Zhulenev rewriter.eraseOp(op);
64325f80e16SEugene Zhulenev
64425f80e16SEugene Zhulenev return success();
64525f80e16SEugene Zhulenev }
64625f80e16SEugene Zhulenev
getReplacementValue(AwaitType op,Value operand,ConversionPatternRewriter & rewriter) const64725f80e16SEugene Zhulenev virtual Value getReplacementValue(AwaitType op, Value operand,
64825f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const {
64925f80e16SEugene Zhulenev return Value();
65025f80e16SEugene Zhulenev }
65125f80e16SEugene Zhulenev
65225f80e16SEugene Zhulenev private:
65398b13979SMehdi Amini FuncCoroMapPtr coros;
65498b13979SMehdi Amini bool shouldLowerBlockingWait;
65525f80e16SEugene Zhulenev };
65625f80e16SEugene Zhulenev
65725f80e16SEugene Zhulenev /// Lowering for `async.await` with a token operand.
65825f80e16SEugene Zhulenev class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
65925f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
66025f80e16SEugene Zhulenev
66125f80e16SEugene Zhulenev public:
66225f80e16SEugene Zhulenev using Base::Base;
66325f80e16SEugene Zhulenev };
66425f80e16SEugene Zhulenev
66525f80e16SEugene Zhulenev /// Lowering for `async.await` with a value operand.
66625f80e16SEugene Zhulenev class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
66725f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
66825f80e16SEugene Zhulenev
66925f80e16SEugene Zhulenev public:
67025f80e16SEugene Zhulenev using Base::Base;
67125f80e16SEugene Zhulenev
67225f80e16SEugene Zhulenev Value
getReplacementValue(AwaitOp op,Value operand,ConversionPatternRewriter & rewriter) const67325f80e16SEugene Zhulenev getReplacementValue(AwaitOp op, Value operand,
67425f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override {
67525f80e16SEugene Zhulenev // Load from the async value storage.
6765550c821STres Popp auto valueType = cast<ValueType>(operand.getType()).getValueType();
67725f80e16SEugene Zhulenev return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
67825f80e16SEugene Zhulenev }
67925f80e16SEugene Zhulenev };
68025f80e16SEugene Zhulenev
68125f80e16SEugene Zhulenev /// Lowering for `async.await_all` operation.
68225f80e16SEugene Zhulenev class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
68325f80e16SEugene Zhulenev using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
68425f80e16SEugene Zhulenev
68525f80e16SEugene Zhulenev public:
68625f80e16SEugene Zhulenev using Base::Base;
68725f80e16SEugene Zhulenev };
68825f80e16SEugene Zhulenev
68925f80e16SEugene Zhulenev } // namespace
69025f80e16SEugene Zhulenev
69125f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
69225f80e16SEugene Zhulenev // Convert async.yield operation to async.runtime operations.
69325f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
69425f80e16SEugene Zhulenev
69525f80e16SEugene Zhulenev class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
69625f80e16SEugene Zhulenev public:
YieldOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)6976cca6b9aSyijiagu YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
69898b13979SMehdi Amini : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
69925f80e16SEugene Zhulenev
70025f80e16SEugene Zhulenev LogicalResult
matchAndRewrite(async::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const701b54c724bSRiver Riddle matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
70225f80e16SEugene Zhulenev ConversionPatternRewriter &rewriter) const override {
70339957aa4SEugene Zhulenev // Check if yield operation is inside the async coroutine function.
70458ceae95SRiver Riddle auto func = op->template getParentOfType<func::FuncOp>();
70598b13979SMehdi Amini auto funcCoro = coros->find(func);
70698b13979SMehdi Amini if (funcCoro == coros->end())
70725f80e16SEugene Zhulenev return rewriter.notifyMatchFailure(
70839957aa4SEugene Zhulenev op, "operation is not inside the async coroutine function");
70925f80e16SEugene Zhulenev
71025f80e16SEugene Zhulenev Location loc = op->getLoc();
711f81f8808Syijiagu const CoroMachinery &coro = funcCoro->getSecond();
71225f80e16SEugene Zhulenev
71325f80e16SEugene Zhulenev // Store yielded values into the async values storage and switch async
71425f80e16SEugene Zhulenev // values state to available.
715b54c724bSRiver Riddle for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
71625f80e16SEugene Zhulenev Value yieldValue = std::get<0>(tuple);
71725f80e16SEugene Zhulenev Value asyncValue = std::get<1>(tuple);
71825f80e16SEugene Zhulenev rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
71925f80e16SEugene Zhulenev rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
72025f80e16SEugene Zhulenev }
72125f80e16SEugene Zhulenev
722f81f8808Syijiagu if (coro.asyncToken)
72325f80e16SEugene Zhulenev // Switch the coroutine completion token to available state.
724f81f8808Syijiagu rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
725f81f8808Syijiagu
726f81f8808Syijiagu rewriter.eraseOp(op);
727f81f8808Syijiagu rewriter.create<cf::BranchOp>(loc, coro.cleanup);
72825f80e16SEugene Zhulenev
72925f80e16SEugene Zhulenev return success();
73025f80e16SEugene Zhulenev }
73125f80e16SEugene Zhulenev
73225f80e16SEugene Zhulenev private:
73398b13979SMehdi Amini FuncCoroMapPtr coros;
73425f80e16SEugene Zhulenev };
73525f80e16SEugene Zhulenev
73625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
73723aa5a74SRiver Riddle // Convert cf.assert operation to cf.cond_br into `set_error` block.
73839957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
73939957aa4SEugene Zhulenev
740ace01605SRiver Riddle class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
74139957aa4SEugene Zhulenev public:
AssertOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)7426cca6b9aSyijiagu AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
74398b13979SMehdi Amini : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
74439957aa4SEugene Zhulenev
74539957aa4SEugene Zhulenev LogicalResult
matchAndRewrite(cf::AssertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const746ace01605SRiver Riddle matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
74739957aa4SEugene Zhulenev ConversionPatternRewriter &rewriter) const override {
74839957aa4SEugene Zhulenev // Check if assert operation is inside the async coroutine function.
74958ceae95SRiver Riddle auto func = op->template getParentOfType<func::FuncOp>();
75098b13979SMehdi Amini auto funcCoro = coros->find(func);
75198b13979SMehdi Amini if (funcCoro == coros->end())
75239957aa4SEugene Zhulenev return rewriter.notifyMatchFailure(
75339957aa4SEugene Zhulenev op, "operation is not inside the async coroutine function");
75439957aa4SEugene Zhulenev
75539957aa4SEugene Zhulenev Location loc = op->getLoc();
756f81f8808Syijiagu CoroMachinery &coro = funcCoro->getSecond();
75739957aa4SEugene Zhulenev
75839957aa4SEugene Zhulenev Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
75939957aa4SEugene Zhulenev rewriter.setInsertionPointToEnd(cont->getPrevNode());
760ace01605SRiver Riddle rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
76139957aa4SEugene Zhulenev /*trueDest=*/cont,
76239957aa4SEugene Zhulenev /*trueArgs=*/ArrayRef<Value>(),
76339957aa4SEugene Zhulenev /*falseDest=*/setupSetErrorBlock(coro),
76439957aa4SEugene Zhulenev /*falseArgs=*/ArrayRef<Value>());
76539957aa4SEugene Zhulenev rewriter.eraseOp(op);
76639957aa4SEugene Zhulenev
76739957aa4SEugene Zhulenev return success();
76839957aa4SEugene Zhulenev }
76939957aa4SEugene Zhulenev
77039957aa4SEugene Zhulenev private:
77198b13979SMehdi Amini FuncCoroMapPtr coros;
77239957aa4SEugene Zhulenev };
77339957aa4SEugene Zhulenev
77439957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
runOnOperation()77525f80e16SEugene Zhulenev void AsyncToAsyncRuntimePass::runOnOperation() {
77625f80e16SEugene Zhulenev ModuleOp module = getOperation();
77725f80e16SEugene Zhulenev SymbolTable symbolTable(module);
77825f80e16SEugene Zhulenev
779f81f8808Syijiagu // Functions with coroutine CFG setups, which are results of outlining
7806cca6b9aSyijiagu // `async.execute` body regions
7816cca6b9aSyijiagu FuncCoroMapPtr coros =
7826cca6b9aSyijiagu std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
78325f80e16SEugene Zhulenev
78425f80e16SEugene Zhulenev module.walk([&](ExecuteOp execute) {
7856cca6b9aSyijiagu coros->insert(outlineExecuteOp(symbolTable, execute));
78625f80e16SEugene Zhulenev });
78725f80e16SEugene Zhulenev
78825f80e16SEugene Zhulenev LLVM_DEBUG({
7896cca6b9aSyijiagu llvm::dbgs() << "Outlined " << coros->size()
79025f80e16SEugene Zhulenev << " functions built from async.execute operations\n";
79125f80e16SEugene Zhulenev });
79225f80e16SEugene Zhulenev
793de7a4e53SEugene Zhulenev // Returns true if operation is inside the coroutine.
794de7a4e53SEugene Zhulenev auto isInCoroutine = [&](Operation *op) -> bool {
79558ceae95SRiver Riddle auto parentFunc = op->getParentOfType<func::FuncOp>();
7966cca6b9aSyijiagu return coros->find(parentFunc) != coros->end();
797de7a4e53SEugene Zhulenev };
798de7a4e53SEugene Zhulenev
79925f80e16SEugene Zhulenev // Lower async operations to async.runtime operations.
80025f80e16SEugene Zhulenev MLIRContext *ctx = module->getContext();
801dc4e913bSChris Lattner RewritePatternSet asyncPatterns(ctx);
80225f80e16SEugene Zhulenev
803de7a4e53SEugene Zhulenev // Conversion to async runtime augments original CFG with the coroutine CFG,
804de7a4e53SEugene Zhulenev // and we have to make sure that structured control flow operations with async
805de7a4e53SEugene Zhulenev // operations in nested regions will be converted to branch-based control flow
806de7a4e53SEugene Zhulenev // before we add the coroutine basic blocks.
807ace01605SRiver Riddle populateSCFToControlFlowConversionPatterns(asyncPatterns);
808de7a4e53SEugene Zhulenev
80925f80e16SEugene Zhulenev // Async lowering does not use type converter because it must preserve all
81025f80e16SEugene Zhulenev // types for async.runtime operations.
811dc4e913bSChris Lattner asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812f81f8808Syijiagu
8136cca6b9aSyijiagu asyncPatterns
8146cca6b9aSyijiagu .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
8156cca6b9aSyijiagu ctx, coros, /*should_lower_blocking_wait=*/true);
81625f80e16SEugene Zhulenev
81739957aa4SEugene Zhulenev // Lower assertions to conditional branches into error blocks.
8186cca6b9aSyijiagu asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
81939957aa4SEugene Zhulenev
82025f80e16SEugene Zhulenev // All high level async operations must be lowered to the runtime operations.
82125f80e16SEugene Zhulenev ConversionTarget runtimeTarget(*ctx);
822f81f8808Syijiagu runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
82325f80e16SEugene Zhulenev runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
8246cca6b9aSyijiagu runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
82525f80e16SEugene Zhulenev
826de7a4e53SEugene Zhulenev // Decide if structured control flow has to be lowered to branch-based CFG.
827de7a4e53SEugene Zhulenev runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828de7a4e53SEugene Zhulenev auto walkResult = op->walk([&](Operation *nested) {
829de7a4e53SEugene Zhulenev bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
830de7a4e53SEugene Zhulenev return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
831de7a4e53SEugene Zhulenev : WalkResult::advance();
832de7a4e53SEugene Zhulenev });
833de7a4e53SEugene Zhulenev return !walkResult.wasInterrupted();
834de7a4e53SEugene Zhulenev });
835ace01605SRiver Riddle runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
83623aa5a74SRiver Riddle func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837de7a4e53SEugene Zhulenev
8388f23fac4SEugene Zhulenev // Assertions must be converted to runtime errors inside async functions.
839ace01605SRiver Riddle runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840ace01605SRiver Riddle [&](cf::AssertOp op) -> bool {
84158ceae95SRiver Riddle auto func = op->getParentOfType<func::FuncOp>();
84269ffd49cSKazu Hirata return !coros->contains(func);
8438f23fac4SEugene Zhulenev });
84439957aa4SEugene Zhulenev
84525f80e16SEugene Zhulenev if (failed(applyPartialConversion(module, runtimeTarget,
84625f80e16SEugene Zhulenev std::move(asyncPatterns)))) {
84725f80e16SEugene Zhulenev signalPassFailure();
84825f80e16SEugene Zhulenev return;
84925f80e16SEugene Zhulenev }
85025f80e16SEugene Zhulenev }
85125f80e16SEugene Zhulenev
8526cca6b9aSyijiagu //===----------------------------------------------------------------------===//
populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet & patterns,ConversionTarget & target)8536cca6b9aSyijiagu void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
8546cca6b9aSyijiagu RewritePatternSet &patterns, ConversionTarget &target) {
8556cca6b9aSyijiagu // Functions with coroutine CFG setups, which are results of converting
8566cca6b9aSyijiagu // async.func.
8576cca6b9aSyijiagu FuncCoroMapPtr coros =
8586cca6b9aSyijiagu std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
8596cca6b9aSyijiagu MLIRContext *ctx = patterns.getContext();
8606cca6b9aSyijiagu // Lower async.func to func.func with coroutine cfg.
8616cca6b9aSyijiagu patterns.add<AsyncCallOpLowering>(ctx);
8626cca6b9aSyijiagu patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
8636cca6b9aSyijiagu
8646cca6b9aSyijiagu patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
8656cca6b9aSyijiagu ctx, coros, /*should_lower_blocking_wait=*/false);
8666cca6b9aSyijiagu patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
8676cca6b9aSyijiagu
8686cca6b9aSyijiagu target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
8696cca6b9aSyijiagu [coros](Operation *op) {
870a5ddd920Syijiagu auto exec = op->getParentOfType<ExecuteOp>();
8716cca6b9aSyijiagu auto func = op->getParentOfType<func::FuncOp>();
87269ffd49cSKazu Hirata return exec || !coros->contains(func);
8736cca6b9aSyijiagu });
8746cca6b9aSyijiagu }
8756cca6b9aSyijiagu
runOnOperation()8766cca6b9aSyijiagu void AsyncFuncToAsyncRuntimePass::runOnOperation() {
8776cca6b9aSyijiagu ModuleOp module = getOperation();
8786cca6b9aSyijiagu
8796cca6b9aSyijiagu // Lower async operations to async.runtime operations.
8806cca6b9aSyijiagu MLIRContext *ctx = module->getContext();
8816cca6b9aSyijiagu RewritePatternSet asyncPatterns(ctx);
8826cca6b9aSyijiagu ConversionTarget runtimeTarget(*ctx);
8836cca6b9aSyijiagu
8846cca6b9aSyijiagu // Lower async.func to func.func with coroutine cfg.
8856cca6b9aSyijiagu populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
8866cca6b9aSyijiagu runtimeTarget);
8876cca6b9aSyijiagu
8886cca6b9aSyijiagu runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
8896cca6b9aSyijiagu runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
8906cca6b9aSyijiagu
8916cca6b9aSyijiagu runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
8926cca6b9aSyijiagu cf::BranchOp, cf::CondBranchOp>();
8936cca6b9aSyijiagu
8946cca6b9aSyijiagu if (failed(applyPartialConversion(module, runtimeTarget,
8956cca6b9aSyijiagu std::move(asyncPatterns)))) {
8966cca6b9aSyijiagu signalPassFailure();
8976cca6b9aSyijiagu return;
8986cca6b9aSyijiagu }
8996cca6b9aSyijiagu }
9006cca6b9aSyijiagu
createAsyncToAsyncRuntimePass()90125f80e16SEugene Zhulenev std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
90225f80e16SEugene Zhulenev return std::make_unique<AsyncToAsyncRuntimePass>();
90325f80e16SEugene Zhulenev }
9046cca6b9aSyijiagu
9056cca6b9aSyijiagu std::unique_ptr<OperationPass<ModuleOp>>
createAsyncFuncToAsyncRuntimePass()9066cca6b9aSyijiagu mlir::createAsyncFuncToAsyncRuntimePass() {
9076cca6b9aSyijiagu return std::make_unique<AsyncFuncToAsyncRuntimePass>();
9086cca6b9aSyijiagu }
909