xref: /llvm-project/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp (revision ea2d9383a23ca17b9240ad64c2adc5f2b5a73dc0)
125f80e16SEugene Zhulenev //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
225f80e16SEugene Zhulenev //
325f80e16SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
425f80e16SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
525f80e16SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
625f80e16SEugene Zhulenev //
725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
825f80e16SEugene Zhulenev //
925f80e16SEugene Zhulenev // This file implements lowering from high level async operations to async.coro
1025f80e16SEugene Zhulenev // and async.runtime operations.
1125f80e16SEugene Zhulenev //
1225f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
1325f80e16SEugene Zhulenev 
14190cdf44SMehdi Amini #include <utility>
15190cdf44SMehdi Amini 
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h"
1767d0d7acSMichele Scuttari 
1825f80e16SEugene Zhulenev #include "PassDetail.h"
19ace01605SRiver Riddle #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
20abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
2125f80e16SEugene Zhulenev #include "mlir/Dialect/Async/IR/Async.h"
22ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
248b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
254d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
2625f80e16SEugene Zhulenev #include "mlir/IR/ImplicitLocOpBuilder.h"
2725f80e16SEugene Zhulenev #include "mlir/IR/PatternMatch.h"
2825f80e16SEugene Zhulenev #include "mlir/Transforms/DialectConversion.h"
2925f80e16SEugene Zhulenev #include "mlir/Transforms/RegionUtils.h"
3025f80e16SEugene Zhulenev #include "llvm/ADT/SetVector.h"
31297a5b7cSNico Weber #include "llvm/Support/Debug.h"
32f3b7b300SKazu Hirata #include <optional>
3325f80e16SEugene Zhulenev 
3467d0d7acSMichele Scuttari namespace mlir {
3567d0d7acSMichele Scuttari #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
366cca6b9aSyijiagu #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
3767d0d7acSMichele Scuttari #include "mlir/Dialect/Async/Passes.h.inc"
3867d0d7acSMichele Scuttari } // namespace mlir
3967d0d7acSMichele Scuttari 
4025f80e16SEugene Zhulenev using namespace mlir;
4125f80e16SEugene Zhulenev using namespace mlir::async;
4225f80e16SEugene Zhulenev 
4325f80e16SEugene Zhulenev #define DEBUG_TYPE "async-to-async-runtime"
4425f80e16SEugene Zhulenev // Prefix for functions outlined from `async.execute` op regions.
4525f80e16SEugene Zhulenev static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
4625f80e16SEugene Zhulenev 
4725f80e16SEugene Zhulenev namespace {
4825f80e16SEugene Zhulenev 
4925f80e16SEugene Zhulenev class AsyncToAsyncRuntimePass
5067d0d7acSMichele Scuttari     : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
5125f80e16SEugene Zhulenev public:
5225f80e16SEugene Zhulenev   AsyncToAsyncRuntimePass() = default;
5325f80e16SEugene Zhulenev   void runOnOperation() override;
5425f80e16SEugene Zhulenev };
5525f80e16SEugene Zhulenev 
5625f80e16SEugene Zhulenev } // namespace
5725f80e16SEugene Zhulenev 
586cca6b9aSyijiagu namespace {
596cca6b9aSyijiagu 
606cca6b9aSyijiagu class AsyncFuncToAsyncRuntimePass
616cca6b9aSyijiagu     : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
626cca6b9aSyijiagu public:
636cca6b9aSyijiagu   AsyncFuncToAsyncRuntimePass() = default;
646cca6b9aSyijiagu   void runOnOperation() override;
656cca6b9aSyijiagu };
666cca6b9aSyijiagu 
676cca6b9aSyijiagu } // namespace
686cca6b9aSyijiagu 
6925f80e16SEugene Zhulenev /// Function targeted for coroutine transformation has two additional blocks at
7025f80e16SEugene Zhulenev /// the end: coroutine cleanup and coroutine suspension.
7125f80e16SEugene Zhulenev ///
7225f80e16SEugene Zhulenev /// async.await op lowering additionaly creates a resume block for each
7325f80e16SEugene Zhulenev /// operation to enable non-blocking waiting via coroutine suspension.
7425f80e16SEugene Zhulenev namespace {
7525f80e16SEugene Zhulenev struct CoroMachinery {
7658ceae95SRiver Riddle   func::FuncOp func;
7739957aa4SEugene Zhulenev 
78f81f8808Syijiagu   // Async function returns an optional token, followed by some async values
79f81f8808Syijiagu   //
80f81f8808Syijiagu   //  async.func @foo() -> !async.value<T> {
81f81f8808Syijiagu   //    %cst = arith.constant 42.0 : T
82f81f8808Syijiagu   //    return %cst: T
83f81f8808Syijiagu   //  }
8425f80e16SEugene Zhulenev   // Async execute region returns a completion token, and an async value for
8525f80e16SEugene Zhulenev   // each yielded value.
8625f80e16SEugene Zhulenev   //
8725f80e16SEugene Zhulenev   //   %token, %result = async.execute -> !async.value<T> {
88cb3aa49eSMogball   //     %0 = arith.constant ... : T
8925f80e16SEugene Zhulenev   //     async.yield %0 : T
9025f80e16SEugene Zhulenev   //   }
910a81ace0SKazu Hirata   std::optional<Value> asyncToken;          // returned completion token
9225f80e16SEugene Zhulenev   llvm::SmallVector<Value, 4> returnValues; // returned async values
9325f80e16SEugene Zhulenev 
94a5aa7836SRiver Riddle   Value coroHandle; // coroutine handle (!async.coro.getHandle value)
951c144410Sbakhtiyar   Block *entry;     // coroutine entry block
96f3b7b300SKazu Hirata   std::optional<Block *> setError; // set returned values to error state
9725f80e16SEugene Zhulenev   Block *cleanup;                  // coroutine cleanup block
98af562fd2SYunlong Liu 
99af562fd2SYunlong Liu   // Coroutine cleanup block for destroy after the coroutine is resumed,
100af562fd2SYunlong Liu   //   e.g. async.coro.suspend state, [suspend], [resume], [destroy]
101af562fd2SYunlong Liu   //
102af562fd2SYunlong Liu   // This cleanup block is a duplicate of the cleanup block followed by the
103af562fd2SYunlong Liu   // resume block. The purpose of having a duplicate cleanup block for destroy
104af562fd2SYunlong Liu   // is to make the CFG clear so that the control flow analysis won't confuse.
105af562fd2SYunlong Liu   //
106af562fd2SYunlong Liu   // The overall structure of the lowered CFG can be the following,
107af562fd2SYunlong Liu   //
108af562fd2SYunlong Liu   //     Entry (calling async.coro.suspend)
109af562fd2SYunlong Liu   //       |                \
110af562fd2SYunlong Liu   //     Resume           Destroy (duplicate of Cleanup)
111af562fd2SYunlong Liu   //       |                 |
112af562fd2SYunlong Liu   //     Cleanup             |
113af562fd2SYunlong Liu   //       |                 /
114af562fd2SYunlong Liu   //      End (ends the corontine)
115af562fd2SYunlong Liu   //
116af562fd2SYunlong Liu   // If there is resume-specific cleanup logic, it can go into the Cleanup
117af562fd2SYunlong Liu   // block but not the destroy block. Otherwise, it can fail block dominance
118af562fd2SYunlong Liu   // check.
119af562fd2SYunlong Liu   Block *cleanupForDestroy;
12025f80e16SEugene Zhulenev   Block *suspend; // coroutine suspension block
12125f80e16SEugene Zhulenev };
12225f80e16SEugene Zhulenev } // namespace
12325f80e16SEugene Zhulenev 
1246cca6b9aSyijiagu using FuncCoroMapPtr =
1256cca6b9aSyijiagu     std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
1266cca6b9aSyijiagu 
1276ea22d46Sbakhtiyar /// Utility to partially update the regular function CFG to the coroutine CFG
1286ea22d46Sbakhtiyar /// compatible with LLVM coroutines switched-resume lowering using
1291c144410Sbakhtiyar /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
1301c144410Sbakhtiyar /// that branches into preexisting entry block. Also inserts trailing blocks.
1316ea22d46Sbakhtiyar ///
132f81f8808Syijiagu /// The result types of the passed `func` start with an optional `async.token`
1336ea22d46Sbakhtiyar /// and be continued with some number of `async.value`s.
1346ea22d46Sbakhtiyar ///
13525f80e16SEugene Zhulenev /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
13625f80e16SEugene Zhulenev ///
13725f80e16SEugene Zhulenev ///  - `entry` block sets up the coroutine.
13839957aa4SEugene Zhulenev ///  - `set_error` block sets completion token and async values state to error.
13925f80e16SEugene Zhulenev ///  - `cleanup` block cleans up the coroutine state.
14025f80e16SEugene Zhulenev ///  - `suspend block after the @llvm.coro.end() defines what value will be
14125f80e16SEugene Zhulenev ///    returned to the initial caller of a coroutine. Everything before the
14225f80e16SEugene Zhulenev ///    @llvm.coro.end() will be executed at every suspension point.
14325f80e16SEugene Zhulenev ///
14425f80e16SEugene Zhulenev /// Coroutine structure (only the important bits):
14525f80e16SEugene Zhulenev ///
1466ea22d46Sbakhtiyar ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
14725f80e16SEugene Zhulenev ///   {
14825f80e16SEugene Zhulenev ///     ^entry(<function-arguments>):
14925f80e16SEugene Zhulenev ///       %token = <async token> : !async.token    // create async runtime token
15025f80e16SEugene Zhulenev ///       %value = <async value> : !async.value<T> // create async value
151a5aa7836SRiver Riddle ///       %id = async.coro.getId                   // create a coroutine id
15225f80e16SEugene Zhulenev ///       %hdl = async.coro.begin %id              // create a coroutine handle
153ace01605SRiver Riddle ///       cf.br ^preexisting_entry_block
1546ea22d46Sbakhtiyar ///
1551c144410Sbakhtiyar ///     /*  preexisting blocks modified to branch to the cleanup block */
15625f80e16SEugene Zhulenev ///
15739957aa4SEugene Zhulenev ///     ^set_error: // this block created lazily only if needed (see code below)
15839957aa4SEugene Zhulenev ///       async.runtime.set_error %token : !async.token
15939957aa4SEugene Zhulenev ///       async.runtime.set_error %value : !async.value<T>
160ace01605SRiver Riddle ///       cf.br ^cleanup
16139957aa4SEugene Zhulenev ///
16225f80e16SEugene Zhulenev ///     ^cleanup:
16325f80e16SEugene Zhulenev ///       async.coro.free %hdl // delete the coroutine state
164ace01605SRiver Riddle ///       cf.br ^suspend
16525f80e16SEugene Zhulenev ///
16625f80e16SEugene Zhulenev ///     ^suspend:
16725f80e16SEugene Zhulenev ///       async.coro.end %hdl // marks the end of a coroutine
16825f80e16SEugene Zhulenev ///       return %token, %value : !async.token, !async.value<T>
16925f80e16SEugene Zhulenev ///   }
17025f80e16SEugene Zhulenev ///
setupCoroMachinery(func::FuncOp func)17158ceae95SRiver Riddle static CoroMachinery setupCoroMachinery(func::FuncOp func) {
1726ea22d46Sbakhtiyar   assert(!func.getBlocks().empty() && "Function must have an entry block");
17325f80e16SEugene Zhulenev 
17425f80e16SEugene Zhulenev   MLIRContext *ctx = func.getContext();
1756ea22d46Sbakhtiyar   Block *entryBlock = &func.getBlocks().front();
1761c144410Sbakhtiyar   Block *originalEntryBlock =
1771c144410Sbakhtiyar       entryBlock->splitBlock(entryBlock->getOperations().begin());
17825f80e16SEugene Zhulenev   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
17925f80e16SEugene Zhulenev 
18025f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
18125f80e16SEugene Zhulenev   // Allocate async token/values that we will return from a ramp function.
18225f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
183f81f8808Syijiagu 
184f81f8808Syijiagu   // We treat TokenType as state update marker to represent side-effects of
185f81f8808Syijiagu   // async computations
18634a35a8bSMartin Erhart   bool isStateful = isa<TokenType>(func.getResultTypes().front());
187f81f8808Syijiagu 
1880a81ace0SKazu Hirata   std::optional<Value> retToken;
189f81f8808Syijiagu   if (isStateful)
190f81f8808Syijiagu     retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
19125f80e16SEugene Zhulenev 
19225f80e16SEugene Zhulenev   llvm::SmallVector<Value, 4> retValues;
19334a35a8bSMartin Erhart   ArrayRef<Type> resValueTypes =
19434a35a8bSMartin Erhart       isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195f81f8808Syijiagu   for (auto resType : resValueTypes)
196a5aa7836SRiver Riddle     retValues.emplace_back(
197a5aa7836SRiver Riddle         builder.create<RuntimeCreateOp>(resType).getResult());
19825f80e16SEugene Zhulenev 
19925f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
20025f80e16SEugene Zhulenev   // Initialize coroutine: get coroutine id and coroutine handle.
20125f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
20225f80e16SEugene Zhulenev   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
20325f80e16SEugene Zhulenev   auto coroHdlOp =
204a5aa7836SRiver Riddle       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
205ace01605SRiver Riddle   builder.create<cf::BranchOp>(originalEntryBlock);
20625f80e16SEugene Zhulenev 
20725f80e16SEugene Zhulenev   Block *cleanupBlock = func.addBlock();
208af562fd2SYunlong Liu   Block *cleanupBlockForDestroy = func.addBlock();
20925f80e16SEugene Zhulenev   Block *suspendBlock = func.addBlock();
21025f80e16SEugene Zhulenev 
21125f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
212af562fd2SYunlong Liu   // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
21325f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
214af562fd2SYunlong Liu   auto buildCleanupBlock = [&](Block *cb) {
215af562fd2SYunlong Liu     builder.setInsertionPointToStart(cb);
216a5aa7836SRiver Riddle     builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
21725f80e16SEugene Zhulenev 
21825f80e16SEugene Zhulenev     // Branch into the suspend block.
219ace01605SRiver Riddle     builder.create<cf::BranchOp>(suspendBlock);
220af562fd2SYunlong Liu   };
221af562fd2SYunlong Liu   buildCleanupBlock(cleanupBlock);
222af562fd2SYunlong Liu   buildCleanupBlock(cleanupBlockForDestroy);
22325f80e16SEugene Zhulenev 
22425f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
22525f80e16SEugene Zhulenev   // Coroutine suspend block: mark the end of a coroutine and return allocated
22625f80e16SEugene Zhulenev   // async token.
22725f80e16SEugene Zhulenev   // ------------------------------------------------------------------------ //
22825f80e16SEugene Zhulenev   builder.setInsertionPointToStart(suspendBlock);
22925f80e16SEugene Zhulenev 
23025f80e16SEugene Zhulenev   // Mark the end of a coroutine: async.coro.end
231a5aa7836SRiver Riddle   builder.create<CoroEndOp>(coroHdlOp.getHandle());
23225f80e16SEugene Zhulenev 
233f81f8808Syijiagu   // Return created optional `async.token` and `async.values` from the suspend
234f81f8808Syijiagu   // block. This will be the return value of a coroutine ramp function.
235f81f8808Syijiagu   SmallVector<Value, 4> ret;
236f81f8808Syijiagu   if (retToken)
237f81f8808Syijiagu     ret.push_back(*retToken);
23825f80e16SEugene Zhulenev   ret.insert(ret.end(), retValues.begin(), retValues.end());
23923aa5a74SRiver Riddle   builder.create<func::ReturnOp>(ret);
24025f80e16SEugene Zhulenev 
24125f80e16SEugene Zhulenev   // `async.await` op lowering will create resume blocks for async
24225f80e16SEugene Zhulenev   // continuations, and will conditionally branch to cleanup or suspend blocks.
24325f80e16SEugene Zhulenev 
244c75cedc2SChuanqi Xu   // The switch-resumed API based coroutine should be marked with
245be690ea3Syonillasky   // presplitcoroutine attribute to mark the function as a coroutine.
246735e6c40SChuanqi Xu   func->setAttr("passthrough", builder.getArrayAttr(
247735e6c40SChuanqi Xu                                    StringAttr::get(ctx, "presplitcoroutine")));
248c75cedc2SChuanqi Xu 
24925f80e16SEugene Zhulenev   CoroMachinery machinery;
25039957aa4SEugene Zhulenev   machinery.func = func;
25125f80e16SEugene Zhulenev   machinery.asyncToken = retToken;
25225f80e16SEugene Zhulenev   machinery.returnValues = retValues;
253a5aa7836SRiver Riddle   machinery.coroHandle = coroHdlOp.getHandle();
2541c144410Sbakhtiyar   machinery.entry = entryBlock;
2551a36588eSKazu Hirata   machinery.setError = std::nullopt; // created lazily only if needed
25625f80e16SEugene Zhulenev   machinery.cleanup = cleanupBlock;
257af562fd2SYunlong Liu   machinery.cleanupForDestroy = cleanupBlockForDestroy;
25825f80e16SEugene Zhulenev   machinery.suspend = suspendBlock;
25925f80e16SEugene Zhulenev   return machinery;
26025f80e16SEugene Zhulenev }
26125f80e16SEugene Zhulenev 
26239957aa4SEugene Zhulenev // Lazily creates `set_error` block only if it is required for lowering to the
26339957aa4SEugene Zhulenev // runtime operations (see for example lowering of assert operation).
setupSetErrorBlock(CoroMachinery & coro)26439957aa4SEugene Zhulenev static Block *setupSetErrorBlock(CoroMachinery &coro) {
26539957aa4SEugene Zhulenev   if (coro.setError)
266f81f8808Syijiagu     return *coro.setError;
26739957aa4SEugene Zhulenev 
26839957aa4SEugene Zhulenev   coro.setError = coro.func.addBlock();
269f81f8808Syijiagu   (*coro.setError)->moveBefore(coro.cleanup);
27039957aa4SEugene Zhulenev 
27139957aa4SEugene Zhulenev   auto builder =
272f81f8808Syijiagu       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
27339957aa4SEugene Zhulenev 
27439957aa4SEugene Zhulenev   // Coroutine set_error block: set error on token and all returned values.
275f81f8808Syijiagu   if (coro.asyncToken)
276f81f8808Syijiagu     builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
277f81f8808Syijiagu 
27839957aa4SEugene Zhulenev   for (Value retValue : coro.returnValues)
27939957aa4SEugene Zhulenev     builder.create<RuntimeSetErrorOp>(retValue);
28039957aa4SEugene Zhulenev 
28139957aa4SEugene Zhulenev   // Branch into the cleanup block.
282ace01605SRiver Riddle   builder.create<cf::BranchOp>(coro.cleanup);
28339957aa4SEugene Zhulenev 
284f81f8808Syijiagu   return *coro.setError;
28539957aa4SEugene Zhulenev }
28639957aa4SEugene Zhulenev 
287f81f8808Syijiagu //===----------------------------------------------------------------------===//
288f81f8808Syijiagu // async.execute op outlining to the coroutine functions.
289f81f8808Syijiagu //===----------------------------------------------------------------------===//
290f81f8808Syijiagu 
29125f80e16SEugene Zhulenev /// Outline the body region attached to the `async.execute` op into a standalone
29225f80e16SEugene Zhulenev /// function.
29325f80e16SEugene Zhulenev ///
29425f80e16SEugene Zhulenev /// Note that this is not reversible transformation.
29558ceae95SRiver Riddle static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)29625f80e16SEugene Zhulenev outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
29725f80e16SEugene Zhulenev   ModuleOp module = execute->getParentOfType<ModuleOp>();
29825f80e16SEugene Zhulenev 
29925f80e16SEugene Zhulenev   MLIRContext *ctx = module.getContext();
30025f80e16SEugene Zhulenev   Location loc = execute.getLoc();
30125f80e16SEugene Zhulenev 
302b537c5b4SEugene Zhulenev   // Make sure that all constants will be inside the outlined async function to
303b537c5b4SEugene Zhulenev   // reduce the number of function arguments.
304a5aa7836SRiver Riddle   cloneConstantsIntoTheRegion(execute.getBodyRegion());
305b537c5b4SEugene Zhulenev 
30625f80e16SEugene Zhulenev   // Collect all outlined function inputs.
307a5aa7836SRiver Riddle   SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
308a5aa7836SRiver Riddle                                         execute.getDependencies().end());
309a5aa7836SRiver Riddle   functionInputs.insert(execute.getBodyOperands().begin(),
310a5aa7836SRiver Riddle                         execute.getBodyOperands().end());
311a5aa7836SRiver Riddle   getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
31225f80e16SEugene Zhulenev 
31325f80e16SEugene Zhulenev   // Collect types for the outlined function inputs and outputs.
31425f80e16SEugene Zhulenev   auto typesRange = llvm::map_range(
31525f80e16SEugene Zhulenev       functionInputs, [](Value value) { return value.getType(); });
31625f80e16SEugene Zhulenev   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
31725f80e16SEugene Zhulenev   auto outputTypes = execute.getResultTypes();
31825f80e16SEugene Zhulenev 
31925f80e16SEugene Zhulenev   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
32025f80e16SEugene Zhulenev   auto funcAttrs = ArrayRef<NamedAttribute>();
32125f80e16SEugene Zhulenev 
32225f80e16SEugene Zhulenev   // TODO: Derive outlined function name from the parent FuncOp (support
32325f80e16SEugene Zhulenev   // multiple nested async.execute operations).
32458ceae95SRiver Riddle   func::FuncOp func =
32558ceae95SRiver Riddle       func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
326973ddb7dSMehdi Amini   symbolTable.insert(func);
32725f80e16SEugene Zhulenev 
32825f80e16SEugene Zhulenev   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
3291c144410Sbakhtiyar   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
33025f80e16SEugene Zhulenev 
3311c144410Sbakhtiyar   // Prepare for coroutine conversion by creating the body of the function.
3321c144410Sbakhtiyar   {
333a5aa7836SRiver Riddle     size_t numDependencies = execute.getDependencies().size();
334a5aa7836SRiver Riddle     size_t numOperands = execute.getBodyOperands().size();
33525f80e16SEugene Zhulenev 
33625f80e16SEugene Zhulenev     // Await on all dependencies before starting to execute the body region.
33725f80e16SEugene Zhulenev     for (size_t i = 0; i < numDependencies; ++i)
33825f80e16SEugene Zhulenev       builder.create<AwaitOp>(func.getArgument(i));
33925f80e16SEugene Zhulenev 
34025f80e16SEugene Zhulenev     // Await on all async value operands and unwrap the payload.
34125f80e16SEugene Zhulenev     SmallVector<Value, 4> unwrappedOperands(numOperands);
34225f80e16SEugene Zhulenev     for (size_t i = 0; i < numOperands; ++i) {
34325f80e16SEugene Zhulenev       Value operand = func.getArgument(numDependencies + i);
344a5aa7836SRiver Riddle       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
34525f80e16SEugene Zhulenev     }
34625f80e16SEugene Zhulenev 
34725f80e16SEugene Zhulenev     // Map from function inputs defined above the execute op to the function
34825f80e16SEugene Zhulenev     // arguments.
3494d67b278SJeff Niu     IRMapping valueMapping;
35025f80e16SEugene Zhulenev     valueMapping.map(functionInputs, func.getArguments());
351a5aa7836SRiver Riddle     valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
35225f80e16SEugene Zhulenev 
35325f80e16SEugene Zhulenev     // Clone all operations from the execute operation body into the outlined
35425f80e16SEugene Zhulenev     // function body.
355a5aa7836SRiver Riddle     for (Operation &op : execute.getBodyRegion().getOps())
35625f80e16SEugene Zhulenev       builder.clone(op, valueMapping);
3571c144410Sbakhtiyar   }
3581c144410Sbakhtiyar 
3591c144410Sbakhtiyar   // Adding entry/cleanup/suspend blocks.
3601c144410Sbakhtiyar   CoroMachinery coro = setupCoroMachinery(func);
3611c144410Sbakhtiyar 
3621c144410Sbakhtiyar   // Suspend async function at the end of an entry block, and resume it using
3631c144410Sbakhtiyar   // Async resume operation (execution will be resumed in a thread managed by
3641c144410Sbakhtiyar   // the async runtime).
3651c144410Sbakhtiyar   {
366ace01605SRiver Riddle     cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
3671c144410Sbakhtiyar     builder.setInsertionPointToEnd(coro.entry);
3681c144410Sbakhtiyar 
3691c144410Sbakhtiyar     // Save the coroutine state: async.coro.save
3701c144410Sbakhtiyar     auto coroSaveOp =
3711c144410Sbakhtiyar         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
3721c144410Sbakhtiyar 
3731c144410Sbakhtiyar     // Pass coroutine to the runtime to be resumed on a runtime managed
3741c144410Sbakhtiyar     // thread.
3751c144410Sbakhtiyar     builder.create<RuntimeResumeOp>(coro.coroHandle);
3761c144410Sbakhtiyar 
3771c144410Sbakhtiyar     // Add async.coro.suspend as a suspended block terminator.
378a5aa7836SRiver Riddle     builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
379af562fd2SYunlong Liu                                   branch.getDest(), coro.cleanupForDestroy);
3801c144410Sbakhtiyar 
3811c144410Sbakhtiyar     branch.erase();
3821c144410Sbakhtiyar   }
38325f80e16SEugene Zhulenev 
38425f80e16SEugene Zhulenev   // Replace the original `async.execute` with a call to outlined function.
3851c144410Sbakhtiyar   {
38625f80e16SEugene Zhulenev     ImplicitLocOpBuilder callBuilder(loc, execute);
38723aa5a74SRiver Riddle     auto callOutlinedFunc = callBuilder.create<func::CallOp>(
38825f80e16SEugene Zhulenev         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
38925f80e16SEugene Zhulenev     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
39025f80e16SEugene Zhulenev     execute.erase();
3911c144410Sbakhtiyar   }
39225f80e16SEugene Zhulenev 
39325f80e16SEugene Zhulenev   return {func, coro};
39425f80e16SEugene Zhulenev }
39525f80e16SEugene Zhulenev 
39625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
397d43b2360SEugene Zhulenev // Convert async.create_group operation to async.runtime.create_group
39825f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
39925f80e16SEugene Zhulenev 
40025f80e16SEugene Zhulenev namespace {
40125f80e16SEugene Zhulenev class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
40225f80e16SEugene Zhulenev public:
40325f80e16SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
40425f80e16SEugene Zhulenev 
40525f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(CreateGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const406b54c724bSRiver Riddle   matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
40725f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
408d43b2360SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
409b54c724bSRiver Riddle         op, GroupType::get(op->getContext()), adaptor.getOperands());
41025f80e16SEugene Zhulenev     return success();
41125f80e16SEugene Zhulenev   }
41225f80e16SEugene Zhulenev };
41325f80e16SEugene Zhulenev } // namespace
41425f80e16SEugene Zhulenev 
41525f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
41625f80e16SEugene Zhulenev // Convert async.add_to_group operation to async.runtime.add_to_group.
41725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
41825f80e16SEugene Zhulenev 
41925f80e16SEugene Zhulenev namespace {
42025f80e16SEugene Zhulenev class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
42125f80e16SEugene Zhulenev public:
42225f80e16SEugene Zhulenev   using OpConversionPattern::OpConversionPattern;
42325f80e16SEugene Zhulenev 
42425f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(AddToGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const425b54c724bSRiver Riddle   matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
42625f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
42725f80e16SEugene Zhulenev     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
428b54c724bSRiver Riddle         op, rewriter.getIndexType(), adaptor.getOperands());
42925f80e16SEugene Zhulenev     return success();
43025f80e16SEugene Zhulenev   }
43125f80e16SEugene Zhulenev };
43225f80e16SEugene Zhulenev } // namespace
43325f80e16SEugene Zhulenev 
43425f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
435f81f8808Syijiagu // Convert async.func, async.return and async.call operations to non-blocking
436f81f8808Syijiagu // operations based on llvm coroutine
437f81f8808Syijiagu //===----------------------------------------------------------------------===//
438f81f8808Syijiagu 
439f81f8808Syijiagu namespace {
440f81f8808Syijiagu 
441f81f8808Syijiagu //===----------------------------------------------------------------------===//
442f81f8808Syijiagu // Convert async.func operation to func.func
443f81f8808Syijiagu //===----------------------------------------------------------------------===//
444f81f8808Syijiagu 
445f81f8808Syijiagu class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
446f81f8808Syijiagu public:
AsyncFuncOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)4476cca6b9aSyijiagu   AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
44898b13979SMehdi Amini       : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
449f81f8808Syijiagu 
450f81f8808Syijiagu   LogicalResult
matchAndRewrite(async::FuncOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const451f81f8808Syijiagu   matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
452f81f8808Syijiagu                   ConversionPatternRewriter &rewriter) const override {
453f81f8808Syijiagu     Location loc = op->getLoc();
454f81f8808Syijiagu 
455f81f8808Syijiagu     auto newFuncOp =
456f81f8808Syijiagu         rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
457f81f8808Syijiagu 
458f81f8808Syijiagu     SymbolTable::setSymbolVisibility(newFuncOp,
459f81f8808Syijiagu                                      SymbolTable::getSymbolVisibility(op));
460f81f8808Syijiagu     // Copy over all attributes other than the name.
461f81f8808Syijiagu     for (const auto &namedAttr : op->getAttrs()) {
462f81f8808Syijiagu       if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
463f81f8808Syijiagu         newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
464f81f8808Syijiagu     }
465f81f8808Syijiagu 
466f81f8808Syijiagu     rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
467f81f8808Syijiagu                                 newFuncOp.end());
468f81f8808Syijiagu 
469f81f8808Syijiagu     CoroMachinery coro = setupCoroMachinery(newFuncOp);
47098b13979SMehdi Amini     (*coros)[newFuncOp] = coro;
471f81f8808Syijiagu     // no initial suspend, we should hot-start
472f81f8808Syijiagu 
473f81f8808Syijiagu     rewriter.eraseOp(op);
474f81f8808Syijiagu     return success();
475f81f8808Syijiagu   }
476f81f8808Syijiagu 
477f81f8808Syijiagu private:
47898b13979SMehdi Amini   FuncCoroMapPtr coros;
479f81f8808Syijiagu };
480f81f8808Syijiagu 
481f81f8808Syijiagu //===----------------------------------------------------------------------===//
482f81f8808Syijiagu // Convert async.call operation to func.call
483f81f8808Syijiagu //===----------------------------------------------------------------------===//
484f81f8808Syijiagu 
485f81f8808Syijiagu class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
486f81f8808Syijiagu public:
AsyncCallOpLowering(MLIRContext * ctx)487f81f8808Syijiagu   AsyncCallOpLowering(MLIRContext *ctx)
488f81f8808Syijiagu       : OpConversionPattern<async::CallOp>(ctx) {}
489f81f8808Syijiagu 
490f81f8808Syijiagu   LogicalResult
matchAndRewrite(async::CallOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const491f81f8808Syijiagu   matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
492f81f8808Syijiagu                   ConversionPatternRewriter &rewriter) const override {
493f81f8808Syijiagu     rewriter.replaceOpWithNewOp<func::CallOp>(
494f81f8808Syijiagu         op, op.getCallee(), op.getResultTypes(), op.getOperands());
495f81f8808Syijiagu     return success();
496f81f8808Syijiagu   }
497f81f8808Syijiagu };
498f81f8808Syijiagu 
499f81f8808Syijiagu //===----------------------------------------------------------------------===//
500f81f8808Syijiagu // Convert async.return operation to async.runtime operations.
501f81f8808Syijiagu //===----------------------------------------------------------------------===//
502f81f8808Syijiagu 
503f81f8808Syijiagu class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
504f81f8808Syijiagu public:
AsyncReturnOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)5056cca6b9aSyijiagu   AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
50698b13979SMehdi Amini       : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
507f81f8808Syijiagu 
508f81f8808Syijiagu   LogicalResult
matchAndRewrite(async::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const509f81f8808Syijiagu   matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
510f81f8808Syijiagu                   ConversionPatternRewriter &rewriter) const override {
511f81f8808Syijiagu     auto func = op->template getParentOfType<func::FuncOp>();
51298b13979SMehdi Amini     auto funcCoro = coros->find(func);
51398b13979SMehdi Amini     if (funcCoro == coros->end())
514f81f8808Syijiagu       return rewriter.notifyMatchFailure(
515f81f8808Syijiagu           op, "operation is not inside the async coroutine function");
516f81f8808Syijiagu 
517f81f8808Syijiagu     Location loc = op->getLoc();
518f81f8808Syijiagu     const CoroMachinery &coro = funcCoro->getSecond();
519f81f8808Syijiagu     rewriter.setInsertionPointAfter(op);
520f81f8808Syijiagu 
521f81f8808Syijiagu     // Store return values into the async values storage and switch async
522f81f8808Syijiagu     // values state to available.
523f81f8808Syijiagu     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
524f81f8808Syijiagu       Value returnValue = std::get<0>(tuple);
525f81f8808Syijiagu       Value asyncValue = std::get<1>(tuple);
526f81f8808Syijiagu       rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
527f81f8808Syijiagu       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
528f81f8808Syijiagu     }
529f81f8808Syijiagu 
530f81f8808Syijiagu     if (coro.asyncToken)
531f81f8808Syijiagu       // Switch the coroutine completion token to available state.
532f81f8808Syijiagu       rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
533f81f8808Syijiagu 
534f81f8808Syijiagu     rewriter.eraseOp(op);
535f81f8808Syijiagu     rewriter.create<cf::BranchOp>(loc, coro.cleanup);
536f81f8808Syijiagu     return success();
537f81f8808Syijiagu   }
538f81f8808Syijiagu 
539f81f8808Syijiagu private:
54098b13979SMehdi Amini   FuncCoroMapPtr coros;
541f81f8808Syijiagu };
542f81f8808Syijiagu } // namespace
543f81f8808Syijiagu 
544f81f8808Syijiagu //===----------------------------------------------------------------------===//
54525f80e16SEugene Zhulenev // Convert async.await and async.await_all operations to the async.runtime.await
54625f80e16SEugene Zhulenev // or async.runtime.await_and_resume operations.
54725f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
54825f80e16SEugene Zhulenev 
54925f80e16SEugene Zhulenev namespace {
55025f80e16SEugene Zhulenev template <typename AwaitType, typename AwaitableType>
55125f80e16SEugene Zhulenev class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
55225f80e16SEugene Zhulenev   using AwaitAdaptor = typename AwaitType::Adaptor;
55325f80e16SEugene Zhulenev 
55425f80e16SEugene Zhulenev public:
AwaitOpLoweringBase(MLIRContext * ctx,FuncCoroMapPtr coros,bool shouldLowerBlockingWait)5556cca6b9aSyijiagu   AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
55698b13979SMehdi Amini                       bool shouldLowerBlockingWait)
55798b13979SMehdi Amini       : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
55898b13979SMehdi Amini         shouldLowerBlockingWait(shouldLowerBlockingWait) {}
55925f80e16SEugene Zhulenev 
56025f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(AwaitType op,typename AwaitType::Adaptor adaptor,ConversionPatternRewriter & rewriter) const561b54c724bSRiver Riddle   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
56225f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
56325f80e16SEugene Zhulenev     // We can only await on one the `AwaitableType` (for `await` it can be
56425f80e16SEugene Zhulenev     // a `token` or a `value`, for `await_all` it must be a `group`).
5655550c821STres Popp     if (!isa<AwaitableType>(op.getOperand().getType()))
56625f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
56725f80e16SEugene Zhulenev 
5686cca6b9aSyijiagu     // Check if await operation is inside the coroutine function.
56958ceae95SRiver Riddle     auto func = op->template getParentOfType<func::FuncOp>();
57098b13979SMehdi Amini     auto funcCoro = coros->find(func);
57198b13979SMehdi Amini     const bool isInCoroutine = funcCoro != coros->end();
57225f80e16SEugene Zhulenev 
57325f80e16SEugene Zhulenev     Location loc = op->getLoc();
574a5aa7836SRiver Riddle     Value operand = adaptor.getOperand();
57525f80e16SEugene Zhulenev 
576fd52b435SEugene Zhulenev     Type i1 = rewriter.getI1Type();
577fd52b435SEugene Zhulenev 
5786cca6b9aSyijiagu     // Delay lowering to block wait in case await op is inside async.execute
57998b13979SMehdi Amini     if (!isInCoroutine && !shouldLowerBlockingWait)
5806cca6b9aSyijiagu       return failure();
5816cca6b9aSyijiagu 
58225f80e16SEugene Zhulenev     // Inside regular functions we use the blocking wait operation to wait for
58325f80e16SEugene Zhulenev     // the async object (token, value or group) to become available.
584fd52b435SEugene Zhulenev     if (!isInCoroutine) {
585*ea2d9383SMatthias Springer       ImplicitLocOpBuilder builder(loc, rewriter);
586fd52b435SEugene Zhulenev       builder.create<RuntimeAwaitOp>(loc, operand);
587fd52b435SEugene Zhulenev 
588fd52b435SEugene Zhulenev       // Assert that the awaited operands is not in the error state.
589fd52b435SEugene Zhulenev       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
590a54f4eaeSMogball       Value notError = builder.create<arith::XOrIOp>(
591a54f4eaeSMogball           isError, builder.create<arith::ConstantOp>(
592a54f4eaeSMogball                        loc, i1, builder.getIntegerAttr(i1, 1)));
593fd52b435SEugene Zhulenev 
594ace01605SRiver Riddle       builder.create<cf::AssertOp>(notError,
595fd52b435SEugene Zhulenev                                    "Awaited async operand is in error state");
596fd52b435SEugene Zhulenev     }
59725f80e16SEugene Zhulenev 
59825f80e16SEugene Zhulenev     // Inside the coroutine we convert await operation into coroutine suspension
59925f80e16SEugene Zhulenev     // point, and resume execution asynchronously.
60025f80e16SEugene Zhulenev     if (isInCoroutine) {
601f81f8808Syijiagu       CoroMachinery &coro = funcCoro->getSecond();
60225f80e16SEugene Zhulenev       Block *suspended = op->getBlock();
60325f80e16SEugene Zhulenev 
604*ea2d9383SMatthias Springer       ImplicitLocOpBuilder builder(loc, rewriter);
60525f80e16SEugene Zhulenev       MLIRContext *ctx = op->getContext();
60625f80e16SEugene Zhulenev 
60725f80e16SEugene Zhulenev       // Save the coroutine state and resume on a runtime managed thread when
60825f80e16SEugene Zhulenev       // the operand becomes available.
60925f80e16SEugene Zhulenev       auto coroSaveOp =
61025f80e16SEugene Zhulenev           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
61125f80e16SEugene Zhulenev       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
61225f80e16SEugene Zhulenev 
61325f80e16SEugene Zhulenev       // Split the entry block before the await operation.
61425f80e16SEugene Zhulenev       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
61525f80e16SEugene Zhulenev 
61625f80e16SEugene Zhulenev       // Add async.coro.suspend as a suspended block terminator.
61725f80e16SEugene Zhulenev       builder.setInsertionPointToEnd(suspended);
618a5aa7836SRiver Riddle       builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
619af562fd2SYunlong Liu                                     coro.cleanupForDestroy);
62025f80e16SEugene Zhulenev 
62139957aa4SEugene Zhulenev       // Split the resume block into error checking and continuation.
62239957aa4SEugene Zhulenev       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
62339957aa4SEugene Zhulenev 
62439957aa4SEugene Zhulenev       // Check if the awaited value is in the error state.
62539957aa4SEugene Zhulenev       builder.setInsertionPointToStart(resume);
626fd52b435SEugene Zhulenev       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
627ace01605SRiver Riddle       builder.create<cf::CondBranchOp>(isError,
62839957aa4SEugene Zhulenev                                        /*trueDest=*/setupSetErrorBlock(coro),
62939957aa4SEugene Zhulenev                                        /*trueArgs=*/ArrayRef<Value>(),
63039957aa4SEugene Zhulenev                                        /*falseDest=*/continuation,
63139957aa4SEugene Zhulenev                                        /*falseArgs=*/ArrayRef<Value>());
63239957aa4SEugene Zhulenev 
63339957aa4SEugene Zhulenev       // Make sure that replacement value will be constructed in the
63439957aa4SEugene Zhulenev       // continuation block.
63539957aa4SEugene Zhulenev       rewriter.setInsertionPointToStart(continuation);
63639957aa4SEugene Zhulenev     }
63725f80e16SEugene Zhulenev 
63825f80e16SEugene Zhulenev     // Erase or replace the await operation with the new value.
63925f80e16SEugene Zhulenev     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
64025f80e16SEugene Zhulenev       rewriter.replaceOp(op, replaceWith);
64125f80e16SEugene Zhulenev     else
64225f80e16SEugene Zhulenev       rewriter.eraseOp(op);
64325f80e16SEugene Zhulenev 
64425f80e16SEugene Zhulenev     return success();
64525f80e16SEugene Zhulenev   }
64625f80e16SEugene Zhulenev 
getReplacementValue(AwaitType op,Value operand,ConversionPatternRewriter & rewriter) const64725f80e16SEugene Zhulenev   virtual Value getReplacementValue(AwaitType op, Value operand,
64825f80e16SEugene Zhulenev                                     ConversionPatternRewriter &rewriter) const {
64925f80e16SEugene Zhulenev     return Value();
65025f80e16SEugene Zhulenev   }
65125f80e16SEugene Zhulenev 
65225f80e16SEugene Zhulenev private:
65398b13979SMehdi Amini   FuncCoroMapPtr coros;
65498b13979SMehdi Amini   bool shouldLowerBlockingWait;
65525f80e16SEugene Zhulenev };
65625f80e16SEugene Zhulenev 
65725f80e16SEugene Zhulenev /// Lowering for `async.await` with a token operand.
65825f80e16SEugene Zhulenev class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
65925f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
66025f80e16SEugene Zhulenev 
66125f80e16SEugene Zhulenev public:
66225f80e16SEugene Zhulenev   using Base::Base;
66325f80e16SEugene Zhulenev };
66425f80e16SEugene Zhulenev 
66525f80e16SEugene Zhulenev /// Lowering for `async.await` with a value operand.
66625f80e16SEugene Zhulenev class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
66725f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
66825f80e16SEugene Zhulenev 
66925f80e16SEugene Zhulenev public:
67025f80e16SEugene Zhulenev   using Base::Base;
67125f80e16SEugene Zhulenev 
67225f80e16SEugene Zhulenev   Value
getReplacementValue(AwaitOp op,Value operand,ConversionPatternRewriter & rewriter) const67325f80e16SEugene Zhulenev   getReplacementValue(AwaitOp op, Value operand,
67425f80e16SEugene Zhulenev                       ConversionPatternRewriter &rewriter) const override {
67525f80e16SEugene Zhulenev     // Load from the async value storage.
6765550c821STres Popp     auto valueType = cast<ValueType>(operand.getType()).getValueType();
67725f80e16SEugene Zhulenev     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
67825f80e16SEugene Zhulenev   }
67925f80e16SEugene Zhulenev };
68025f80e16SEugene Zhulenev 
68125f80e16SEugene Zhulenev /// Lowering for `async.await_all` operation.
68225f80e16SEugene Zhulenev class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
68325f80e16SEugene Zhulenev   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
68425f80e16SEugene Zhulenev 
68525f80e16SEugene Zhulenev public:
68625f80e16SEugene Zhulenev   using Base::Base;
68725f80e16SEugene Zhulenev };
68825f80e16SEugene Zhulenev 
68925f80e16SEugene Zhulenev } // namespace
69025f80e16SEugene Zhulenev 
69125f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
69225f80e16SEugene Zhulenev // Convert async.yield operation to async.runtime operations.
69325f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
69425f80e16SEugene Zhulenev 
69525f80e16SEugene Zhulenev class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
69625f80e16SEugene Zhulenev public:
YieldOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)6976cca6b9aSyijiagu   YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
69898b13979SMehdi Amini       : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
69925f80e16SEugene Zhulenev 
70025f80e16SEugene Zhulenev   LogicalResult
matchAndRewrite(async::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const701b54c724bSRiver Riddle   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
70225f80e16SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
70339957aa4SEugene Zhulenev     // Check if yield operation is inside the async coroutine function.
70458ceae95SRiver Riddle     auto func = op->template getParentOfType<func::FuncOp>();
70598b13979SMehdi Amini     auto funcCoro = coros->find(func);
70698b13979SMehdi Amini     if (funcCoro == coros->end())
70725f80e16SEugene Zhulenev       return rewriter.notifyMatchFailure(
70839957aa4SEugene Zhulenev           op, "operation is not inside the async coroutine function");
70925f80e16SEugene Zhulenev 
71025f80e16SEugene Zhulenev     Location loc = op->getLoc();
711f81f8808Syijiagu     const CoroMachinery &coro = funcCoro->getSecond();
71225f80e16SEugene Zhulenev 
71325f80e16SEugene Zhulenev     // Store yielded values into the async values storage and switch async
71425f80e16SEugene Zhulenev     // values state to available.
715b54c724bSRiver Riddle     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
71625f80e16SEugene Zhulenev       Value yieldValue = std::get<0>(tuple);
71725f80e16SEugene Zhulenev       Value asyncValue = std::get<1>(tuple);
71825f80e16SEugene Zhulenev       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
71925f80e16SEugene Zhulenev       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
72025f80e16SEugene Zhulenev     }
72125f80e16SEugene Zhulenev 
722f81f8808Syijiagu     if (coro.asyncToken)
72325f80e16SEugene Zhulenev       // Switch the coroutine completion token to available state.
724f81f8808Syijiagu       rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
725f81f8808Syijiagu 
726f81f8808Syijiagu     rewriter.eraseOp(op);
727f81f8808Syijiagu     rewriter.create<cf::BranchOp>(loc, coro.cleanup);
72825f80e16SEugene Zhulenev 
72925f80e16SEugene Zhulenev     return success();
73025f80e16SEugene Zhulenev   }
73125f80e16SEugene Zhulenev 
73225f80e16SEugene Zhulenev private:
73398b13979SMehdi Amini   FuncCoroMapPtr coros;
73425f80e16SEugene Zhulenev };
73525f80e16SEugene Zhulenev 
73625f80e16SEugene Zhulenev //===----------------------------------------------------------------------===//
73723aa5a74SRiver Riddle // Convert cf.assert operation to cf.cond_br into `set_error` block.
73839957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
73939957aa4SEugene Zhulenev 
740ace01605SRiver Riddle class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
74139957aa4SEugene Zhulenev public:
AssertOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)7426cca6b9aSyijiagu   AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
74398b13979SMehdi Amini       : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
74439957aa4SEugene Zhulenev 
74539957aa4SEugene Zhulenev   LogicalResult
matchAndRewrite(cf::AssertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const746ace01605SRiver Riddle   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
74739957aa4SEugene Zhulenev                   ConversionPatternRewriter &rewriter) const override {
74839957aa4SEugene Zhulenev     // Check if assert operation is inside the async coroutine function.
74958ceae95SRiver Riddle     auto func = op->template getParentOfType<func::FuncOp>();
75098b13979SMehdi Amini     auto funcCoro = coros->find(func);
75198b13979SMehdi Amini     if (funcCoro == coros->end())
75239957aa4SEugene Zhulenev       return rewriter.notifyMatchFailure(
75339957aa4SEugene Zhulenev           op, "operation is not inside the async coroutine function");
75439957aa4SEugene Zhulenev 
75539957aa4SEugene Zhulenev     Location loc = op->getLoc();
756f81f8808Syijiagu     CoroMachinery &coro = funcCoro->getSecond();
75739957aa4SEugene Zhulenev 
75839957aa4SEugene Zhulenev     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
75939957aa4SEugene Zhulenev     rewriter.setInsertionPointToEnd(cont->getPrevNode());
760ace01605SRiver Riddle     rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
76139957aa4SEugene Zhulenev                                       /*trueDest=*/cont,
76239957aa4SEugene Zhulenev                                       /*trueArgs=*/ArrayRef<Value>(),
76339957aa4SEugene Zhulenev                                       /*falseDest=*/setupSetErrorBlock(coro),
76439957aa4SEugene Zhulenev                                       /*falseArgs=*/ArrayRef<Value>());
76539957aa4SEugene Zhulenev     rewriter.eraseOp(op);
76639957aa4SEugene Zhulenev 
76739957aa4SEugene Zhulenev     return success();
76839957aa4SEugene Zhulenev   }
76939957aa4SEugene Zhulenev 
77039957aa4SEugene Zhulenev private:
77198b13979SMehdi Amini   FuncCoroMapPtr coros;
77239957aa4SEugene Zhulenev };
77339957aa4SEugene Zhulenev 
77439957aa4SEugene Zhulenev //===----------------------------------------------------------------------===//
runOnOperation()77525f80e16SEugene Zhulenev void AsyncToAsyncRuntimePass::runOnOperation() {
77625f80e16SEugene Zhulenev   ModuleOp module = getOperation();
77725f80e16SEugene Zhulenev   SymbolTable symbolTable(module);
77825f80e16SEugene Zhulenev 
779f81f8808Syijiagu   // Functions with coroutine CFG setups, which are results of outlining
7806cca6b9aSyijiagu   // `async.execute` body regions
7816cca6b9aSyijiagu   FuncCoroMapPtr coros =
7826cca6b9aSyijiagu       std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
78325f80e16SEugene Zhulenev 
78425f80e16SEugene Zhulenev   module.walk([&](ExecuteOp execute) {
7856cca6b9aSyijiagu     coros->insert(outlineExecuteOp(symbolTable, execute));
78625f80e16SEugene Zhulenev   });
78725f80e16SEugene Zhulenev 
78825f80e16SEugene Zhulenev   LLVM_DEBUG({
7896cca6b9aSyijiagu     llvm::dbgs() << "Outlined " << coros->size()
79025f80e16SEugene Zhulenev                  << " functions built from async.execute operations\n";
79125f80e16SEugene Zhulenev   });
79225f80e16SEugene Zhulenev 
793de7a4e53SEugene Zhulenev   // Returns true if operation is inside the coroutine.
794de7a4e53SEugene Zhulenev   auto isInCoroutine = [&](Operation *op) -> bool {
79558ceae95SRiver Riddle     auto parentFunc = op->getParentOfType<func::FuncOp>();
7966cca6b9aSyijiagu     return coros->find(parentFunc) != coros->end();
797de7a4e53SEugene Zhulenev   };
798de7a4e53SEugene Zhulenev 
79925f80e16SEugene Zhulenev   // Lower async operations to async.runtime operations.
80025f80e16SEugene Zhulenev   MLIRContext *ctx = module->getContext();
801dc4e913bSChris Lattner   RewritePatternSet asyncPatterns(ctx);
80225f80e16SEugene Zhulenev 
803de7a4e53SEugene Zhulenev   // Conversion to async runtime augments original CFG with the coroutine CFG,
804de7a4e53SEugene Zhulenev   // and we have to make sure that structured control flow operations with async
805de7a4e53SEugene Zhulenev   // operations in nested regions will be converted to branch-based control flow
806de7a4e53SEugene Zhulenev   // before we add the coroutine basic blocks.
807ace01605SRiver Riddle   populateSCFToControlFlowConversionPatterns(asyncPatterns);
808de7a4e53SEugene Zhulenev 
80925f80e16SEugene Zhulenev   // Async lowering does not use type converter because it must preserve all
81025f80e16SEugene Zhulenev   // types for async.runtime operations.
811dc4e913bSChris Lattner   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812f81f8808Syijiagu 
8136cca6b9aSyijiagu   asyncPatterns
8146cca6b9aSyijiagu       .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
8156cca6b9aSyijiagu           ctx, coros, /*should_lower_blocking_wait=*/true);
81625f80e16SEugene Zhulenev 
81739957aa4SEugene Zhulenev   // Lower assertions to conditional branches into error blocks.
8186cca6b9aSyijiagu   asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
81939957aa4SEugene Zhulenev 
82025f80e16SEugene Zhulenev   // All high level async operations must be lowered to the runtime operations.
82125f80e16SEugene Zhulenev   ConversionTarget runtimeTarget(*ctx);
822f81f8808Syijiagu   runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
82325f80e16SEugene Zhulenev   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
8246cca6b9aSyijiagu   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
82525f80e16SEugene Zhulenev 
826de7a4e53SEugene Zhulenev   // Decide if structured control flow has to be lowered to branch-based CFG.
827de7a4e53SEugene Zhulenev   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828de7a4e53SEugene Zhulenev     auto walkResult = op->walk([&](Operation *nested) {
829de7a4e53SEugene Zhulenev       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
830de7a4e53SEugene Zhulenev       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
831de7a4e53SEugene Zhulenev                                               : WalkResult::advance();
832de7a4e53SEugene Zhulenev     });
833de7a4e53SEugene Zhulenev     return !walkResult.wasInterrupted();
834de7a4e53SEugene Zhulenev   });
835ace01605SRiver Riddle   runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
83623aa5a74SRiver Riddle                            func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837de7a4e53SEugene Zhulenev 
8388f23fac4SEugene Zhulenev   // Assertions must be converted to runtime errors inside async functions.
839ace01605SRiver Riddle   runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840ace01605SRiver Riddle       [&](cf::AssertOp op) -> bool {
84158ceae95SRiver Riddle         auto func = op->getParentOfType<func::FuncOp>();
84269ffd49cSKazu Hirata         return !coros->contains(func);
8438f23fac4SEugene Zhulenev       });
84439957aa4SEugene Zhulenev 
84525f80e16SEugene Zhulenev   if (failed(applyPartialConversion(module, runtimeTarget,
84625f80e16SEugene Zhulenev                                     std::move(asyncPatterns)))) {
84725f80e16SEugene Zhulenev     signalPassFailure();
84825f80e16SEugene Zhulenev     return;
84925f80e16SEugene Zhulenev   }
85025f80e16SEugene Zhulenev }
85125f80e16SEugene Zhulenev 
8526cca6b9aSyijiagu //===----------------------------------------------------------------------===//
populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet & patterns,ConversionTarget & target)8536cca6b9aSyijiagu void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
8546cca6b9aSyijiagu     RewritePatternSet &patterns, ConversionTarget &target) {
8556cca6b9aSyijiagu   // Functions with coroutine CFG setups, which are results of converting
8566cca6b9aSyijiagu   // async.func.
8576cca6b9aSyijiagu   FuncCoroMapPtr coros =
8586cca6b9aSyijiagu       std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
8596cca6b9aSyijiagu   MLIRContext *ctx = patterns.getContext();
8606cca6b9aSyijiagu   // Lower async.func to func.func with coroutine cfg.
8616cca6b9aSyijiagu   patterns.add<AsyncCallOpLowering>(ctx);
8626cca6b9aSyijiagu   patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
8636cca6b9aSyijiagu 
8646cca6b9aSyijiagu   patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
8656cca6b9aSyijiagu       ctx, coros, /*should_lower_blocking_wait=*/false);
8666cca6b9aSyijiagu   patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
8676cca6b9aSyijiagu 
8686cca6b9aSyijiagu   target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
8696cca6b9aSyijiagu       [coros](Operation *op) {
870a5ddd920Syijiagu         auto exec = op->getParentOfType<ExecuteOp>();
8716cca6b9aSyijiagu         auto func = op->getParentOfType<func::FuncOp>();
87269ffd49cSKazu Hirata         return exec || !coros->contains(func);
8736cca6b9aSyijiagu       });
8746cca6b9aSyijiagu }
8756cca6b9aSyijiagu 
runOnOperation()8766cca6b9aSyijiagu void AsyncFuncToAsyncRuntimePass::runOnOperation() {
8776cca6b9aSyijiagu   ModuleOp module = getOperation();
8786cca6b9aSyijiagu 
8796cca6b9aSyijiagu   // Lower async operations to async.runtime operations.
8806cca6b9aSyijiagu   MLIRContext *ctx = module->getContext();
8816cca6b9aSyijiagu   RewritePatternSet asyncPatterns(ctx);
8826cca6b9aSyijiagu   ConversionTarget runtimeTarget(*ctx);
8836cca6b9aSyijiagu 
8846cca6b9aSyijiagu   // Lower async.func to func.func with coroutine cfg.
8856cca6b9aSyijiagu   populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
8866cca6b9aSyijiagu                                                     runtimeTarget);
8876cca6b9aSyijiagu 
8886cca6b9aSyijiagu   runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
8896cca6b9aSyijiagu   runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
8906cca6b9aSyijiagu 
8916cca6b9aSyijiagu   runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
8926cca6b9aSyijiagu                            cf::BranchOp, cf::CondBranchOp>();
8936cca6b9aSyijiagu 
8946cca6b9aSyijiagu   if (failed(applyPartialConversion(module, runtimeTarget,
8956cca6b9aSyijiagu                                     std::move(asyncPatterns)))) {
8966cca6b9aSyijiagu     signalPassFailure();
8976cca6b9aSyijiagu     return;
8986cca6b9aSyijiagu   }
8996cca6b9aSyijiagu }
9006cca6b9aSyijiagu 
createAsyncToAsyncRuntimePass()90125f80e16SEugene Zhulenev std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
90225f80e16SEugene Zhulenev   return std::make_unique<AsyncToAsyncRuntimePass>();
90325f80e16SEugene Zhulenev }
9046cca6b9aSyijiagu 
9056cca6b9aSyijiagu std::unique_ptr<OperationPass<ModuleOp>>
createAsyncFuncToAsyncRuntimePass()9066cca6b9aSyijiagu mlir::createAsyncFuncToAsyncRuntimePass() {
9076cca6b9aSyijiagu   return std::make_unique<AsyncFuncToAsyncRuntimePass>();
9086cca6b9aSyijiagu }
909