xref: /llvm-project/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp (revision ea2d9383a23ca17b9240ad64c2adc5f2b5a73dc0)
1 //===- AsyncToAsyncRuntime.cpp - Lower from Async to Async Runtime --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering from high level async operations to async.coro
10 // and async.runtime operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <utility>
15 
16 #include "mlir/Dialect/Async/Passes.h"
17 
18 #include "PassDetail.h"
19 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Async/IR/Async.h"
22 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "mlir/Transforms/RegionUtils.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
32 #include <optional>
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
36 #define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
37 #include "mlir/Dialect/Async/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 using namespace mlir::async;
42 
43 #define DEBUG_TYPE "async-to-async-runtime"
44 // Prefix for functions outlined from `async.execute` op regions.
45 static constexpr const char kAsyncFnPrefix[] = "async_execute_fn";
46 
47 namespace {
48 
49 class AsyncToAsyncRuntimePass
50     : public impl::AsyncToAsyncRuntimeBase<AsyncToAsyncRuntimePass> {
51 public:
52   AsyncToAsyncRuntimePass() = default;
53   void runOnOperation() override;
54 };
55 
56 } // namespace
57 
58 namespace {
59 
60 class AsyncFuncToAsyncRuntimePass
61     : public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
62 public:
63   AsyncFuncToAsyncRuntimePass() = default;
64   void runOnOperation() override;
65 };
66 
67 } // namespace
68 
69 /// Function targeted for coroutine transformation has two additional blocks at
70 /// the end: coroutine cleanup and coroutine suspension.
71 ///
72 /// async.await op lowering additionaly creates a resume block for each
73 /// operation to enable non-blocking waiting via coroutine suspension.
74 namespace {
75 struct CoroMachinery {
76   func::FuncOp func;
77 
78   // Async function returns an optional token, followed by some async values
79   //
80   //  async.func @foo() -> !async.value<T> {
81   //    %cst = arith.constant 42.0 : T
82   //    return %cst: T
83   //  }
84   // Async execute region returns a completion token, and an async value for
85   // each yielded value.
86   //
87   //   %token, %result = async.execute -> !async.value<T> {
88   //     %0 = arith.constant ... : T
89   //     async.yield %0 : T
90   //   }
91   std::optional<Value> asyncToken;          // returned completion token
92   llvm::SmallVector<Value, 4> returnValues; // returned async values
93 
94   Value coroHandle; // coroutine handle (!async.coro.getHandle value)
95   Block *entry;     // coroutine entry block
96   std::optional<Block *> setError; // set returned values to error state
97   Block *cleanup;                  // coroutine cleanup block
98 
99   // Coroutine cleanup block for destroy after the coroutine is resumed,
100   //   e.g. async.coro.suspend state, [suspend], [resume], [destroy]
101   //
102   // This cleanup block is a duplicate of the cleanup block followed by the
103   // resume block. The purpose of having a duplicate cleanup block for destroy
104   // is to make the CFG clear so that the control flow analysis won't confuse.
105   //
106   // The overall structure of the lowered CFG can be the following,
107   //
108   //     Entry (calling async.coro.suspend)
109   //       |                \
110   //     Resume           Destroy (duplicate of Cleanup)
111   //       |                 |
112   //     Cleanup             |
113   //       |                 /
114   //      End (ends the corontine)
115   //
116   // If there is resume-specific cleanup logic, it can go into the Cleanup
117   // block but not the destroy block. Otherwise, it can fail block dominance
118   // check.
119   Block *cleanupForDestroy;
120   Block *suspend; // coroutine suspension block
121 };
122 } // namespace
123 
124 using FuncCoroMapPtr =
125     std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
126 
127 /// Utility to partially update the regular function CFG to the coroutine CFG
128 /// compatible with LLVM coroutines switched-resume lowering using
129 /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
130 /// that branches into preexisting entry block. Also inserts trailing blocks.
131 ///
132 /// The result types of the passed `func` start with an optional `async.token`
133 /// and be continued with some number of `async.value`s.
134 ///
135 /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html
136 ///
137 ///  - `entry` block sets up the coroutine.
138 ///  - `set_error` block sets completion token and async values state to error.
139 ///  - `cleanup` block cleans up the coroutine state.
140 ///  - `suspend block after the @llvm.coro.end() defines what value will be
141 ///    returned to the initial caller of a coroutine. Everything before the
142 ///    @llvm.coro.end() will be executed at every suspension point.
143 ///
144 /// Coroutine structure (only the important bits):
145 ///
146 ///   func @some_fn(<function-arguments>) -> (!async.token, !async.value<T>)
147 ///   {
148 ///     ^entry(<function-arguments>):
149 ///       %token = <async token> : !async.token    // create async runtime token
150 ///       %value = <async value> : !async.value<T> // create async value
151 ///       %id = async.coro.getId                   // create a coroutine id
152 ///       %hdl = async.coro.begin %id              // create a coroutine handle
153 ///       cf.br ^preexisting_entry_block
154 ///
155 ///     /*  preexisting blocks modified to branch to the cleanup block */
156 ///
157 ///     ^set_error: // this block created lazily only if needed (see code below)
158 ///       async.runtime.set_error %token : !async.token
159 ///       async.runtime.set_error %value : !async.value<T>
160 ///       cf.br ^cleanup
161 ///
162 ///     ^cleanup:
163 ///       async.coro.free %hdl // delete the coroutine state
164 ///       cf.br ^suspend
165 ///
166 ///     ^suspend:
167 ///       async.coro.end %hdl // marks the end of a coroutine
168 ///       return %token, %value : !async.token, !async.value<T>
169 ///   }
170 ///
setupCoroMachinery(func::FuncOp func)171 static CoroMachinery setupCoroMachinery(func::FuncOp func) {
172   assert(!func.getBlocks().empty() && "Function must have an entry block");
173 
174   MLIRContext *ctx = func.getContext();
175   Block *entryBlock = &func.getBlocks().front();
176   Block *originalEntryBlock =
177       entryBlock->splitBlock(entryBlock->getOperations().begin());
178   auto builder = ImplicitLocOpBuilder::atBlockBegin(func->getLoc(), entryBlock);
179 
180   // ------------------------------------------------------------------------ //
181   // Allocate async token/values that we will return from a ramp function.
182   // ------------------------------------------------------------------------ //
183 
184   // We treat TokenType as state update marker to represent side-effects of
185   // async computations
186   bool isStateful = isa<TokenType>(func.getResultTypes().front());
187 
188   std::optional<Value> retToken;
189   if (isStateful)
190     retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx)));
191 
192   llvm::SmallVector<Value, 4> retValues;
193   ArrayRef<Type> resValueTypes =
194       isStateful ? func.getResultTypes().drop_front() : func.getResultTypes();
195   for (auto resType : resValueTypes)
196     retValues.emplace_back(
197         builder.create<RuntimeCreateOp>(resType).getResult());
198 
199   // ------------------------------------------------------------------------ //
200   // Initialize coroutine: get coroutine id and coroutine handle.
201   // ------------------------------------------------------------------------ //
202   auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx));
203   auto coroHdlOp =
204       builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId());
205   builder.create<cf::BranchOp>(originalEntryBlock);
206 
207   Block *cleanupBlock = func.addBlock();
208   Block *cleanupBlockForDestroy = func.addBlock();
209   Block *suspendBlock = func.addBlock();
210 
211   // ------------------------------------------------------------------------ //
212   // Coroutine cleanup blocks: deallocate coroutine frame, free the memory.
213   // ------------------------------------------------------------------------ //
214   auto buildCleanupBlock = [&](Block *cb) {
215     builder.setInsertionPointToStart(cb);
216     builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle());
217 
218     // Branch into the suspend block.
219     builder.create<cf::BranchOp>(suspendBlock);
220   };
221   buildCleanupBlock(cleanupBlock);
222   buildCleanupBlock(cleanupBlockForDestroy);
223 
224   // ------------------------------------------------------------------------ //
225   // Coroutine suspend block: mark the end of a coroutine and return allocated
226   // async token.
227   // ------------------------------------------------------------------------ //
228   builder.setInsertionPointToStart(suspendBlock);
229 
230   // Mark the end of a coroutine: async.coro.end
231   builder.create<CoroEndOp>(coroHdlOp.getHandle());
232 
233   // Return created optional `async.token` and `async.values` from the suspend
234   // block. This will be the return value of a coroutine ramp function.
235   SmallVector<Value, 4> ret;
236   if (retToken)
237     ret.push_back(*retToken);
238   ret.insert(ret.end(), retValues.begin(), retValues.end());
239   builder.create<func::ReturnOp>(ret);
240 
241   // `async.await` op lowering will create resume blocks for async
242   // continuations, and will conditionally branch to cleanup or suspend blocks.
243 
244   // The switch-resumed API based coroutine should be marked with
245   // presplitcoroutine attribute to mark the function as a coroutine.
246   func->setAttr("passthrough", builder.getArrayAttr(
247                                    StringAttr::get(ctx, "presplitcoroutine")));
248 
249   CoroMachinery machinery;
250   machinery.func = func;
251   machinery.asyncToken = retToken;
252   machinery.returnValues = retValues;
253   machinery.coroHandle = coroHdlOp.getHandle();
254   machinery.entry = entryBlock;
255   machinery.setError = std::nullopt; // created lazily only if needed
256   machinery.cleanup = cleanupBlock;
257   machinery.cleanupForDestroy = cleanupBlockForDestroy;
258   machinery.suspend = suspendBlock;
259   return machinery;
260 }
261 
262 // Lazily creates `set_error` block only if it is required for lowering to the
263 // runtime operations (see for example lowering of assert operation).
setupSetErrorBlock(CoroMachinery & coro)264 static Block *setupSetErrorBlock(CoroMachinery &coro) {
265   if (coro.setError)
266     return *coro.setError;
267 
268   coro.setError = coro.func.addBlock();
269   (*coro.setError)->moveBefore(coro.cleanup);
270 
271   auto builder =
272       ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), *coro.setError);
273 
274   // Coroutine set_error block: set error on token and all returned values.
275   if (coro.asyncToken)
276     builder.create<RuntimeSetErrorOp>(*coro.asyncToken);
277 
278   for (Value retValue : coro.returnValues)
279     builder.create<RuntimeSetErrorOp>(retValue);
280 
281   // Branch into the cleanup block.
282   builder.create<cf::BranchOp>(coro.cleanup);
283 
284   return *coro.setError;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // async.execute op outlining to the coroutine functions.
289 //===----------------------------------------------------------------------===//
290 
291 /// Outline the body region attached to the `async.execute` op into a standalone
292 /// function.
293 ///
294 /// Note that this is not reversible transformation.
295 static std::pair<func::FuncOp, CoroMachinery>
outlineExecuteOp(SymbolTable & symbolTable,ExecuteOp execute)296 outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
297   ModuleOp module = execute->getParentOfType<ModuleOp>();
298 
299   MLIRContext *ctx = module.getContext();
300   Location loc = execute.getLoc();
301 
302   // Make sure that all constants will be inside the outlined async function to
303   // reduce the number of function arguments.
304   cloneConstantsIntoTheRegion(execute.getBodyRegion());
305 
306   // Collect all outlined function inputs.
307   SetVector<mlir::Value> functionInputs(execute.getDependencies().begin(),
308                                         execute.getDependencies().end());
309   functionInputs.insert(execute.getBodyOperands().begin(),
310                         execute.getBodyOperands().end());
311   getUsedValuesDefinedAbove(execute.getBodyRegion(), functionInputs);
312 
313   // Collect types for the outlined function inputs and outputs.
314   auto typesRange = llvm::map_range(
315       functionInputs, [](Value value) { return value.getType(); });
316   SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
317   auto outputTypes = execute.getResultTypes();
318 
319   auto funcType = FunctionType::get(ctx, inputTypes, outputTypes);
320   auto funcAttrs = ArrayRef<NamedAttribute>();
321 
322   // TODO: Derive outlined function name from the parent FuncOp (support
323   // multiple nested async.execute operations).
324   func::FuncOp func =
325       func::FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
326   symbolTable.insert(func);
327 
328   SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
329   auto builder = ImplicitLocOpBuilder::atBlockBegin(loc, func.addEntryBlock());
330 
331   // Prepare for coroutine conversion by creating the body of the function.
332   {
333     size_t numDependencies = execute.getDependencies().size();
334     size_t numOperands = execute.getBodyOperands().size();
335 
336     // Await on all dependencies before starting to execute the body region.
337     for (size_t i = 0; i < numDependencies; ++i)
338       builder.create<AwaitOp>(func.getArgument(i));
339 
340     // Await on all async value operands and unwrap the payload.
341     SmallVector<Value, 4> unwrappedOperands(numOperands);
342     for (size_t i = 0; i < numOperands; ++i) {
343       Value operand = func.getArgument(numDependencies + i);
344       unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult();
345     }
346 
347     // Map from function inputs defined above the execute op to the function
348     // arguments.
349     IRMapping valueMapping;
350     valueMapping.map(functionInputs, func.getArguments());
351     valueMapping.map(execute.getBodyRegion().getArguments(), unwrappedOperands);
352 
353     // Clone all operations from the execute operation body into the outlined
354     // function body.
355     for (Operation &op : execute.getBodyRegion().getOps())
356       builder.clone(op, valueMapping);
357   }
358 
359   // Adding entry/cleanup/suspend blocks.
360   CoroMachinery coro = setupCoroMachinery(func);
361 
362   // Suspend async function at the end of an entry block, and resume it using
363   // Async resume operation (execution will be resumed in a thread managed by
364   // the async runtime).
365   {
366     cf::BranchOp branch = cast<cf::BranchOp>(coro.entry->getTerminator());
367     builder.setInsertionPointToEnd(coro.entry);
368 
369     // Save the coroutine state: async.coro.save
370     auto coroSaveOp =
371         builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
372 
373     // Pass coroutine to the runtime to be resumed on a runtime managed
374     // thread.
375     builder.create<RuntimeResumeOp>(coro.coroHandle);
376 
377     // Add async.coro.suspend as a suspended block terminator.
378     builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend,
379                                   branch.getDest(), coro.cleanupForDestroy);
380 
381     branch.erase();
382   }
383 
384   // Replace the original `async.execute` with a call to outlined function.
385   {
386     ImplicitLocOpBuilder callBuilder(loc, execute);
387     auto callOutlinedFunc = callBuilder.create<func::CallOp>(
388         func.getName(), execute.getResultTypes(), functionInputs.getArrayRef());
389     execute.replaceAllUsesWith(callOutlinedFunc.getResults());
390     execute.erase();
391   }
392 
393   return {func, coro};
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // Convert async.create_group operation to async.runtime.create_group
398 //===----------------------------------------------------------------------===//
399 
400 namespace {
401 class CreateGroupOpLowering : public OpConversionPattern<CreateGroupOp> {
402 public:
403   using OpConversionPattern::OpConversionPattern;
404 
405   LogicalResult
matchAndRewrite(CreateGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const406   matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor,
407                   ConversionPatternRewriter &rewriter) const override {
408     rewriter.replaceOpWithNewOp<RuntimeCreateGroupOp>(
409         op, GroupType::get(op->getContext()), adaptor.getOperands());
410     return success();
411   }
412 };
413 } // namespace
414 
415 //===----------------------------------------------------------------------===//
416 // Convert async.add_to_group operation to async.runtime.add_to_group.
417 //===----------------------------------------------------------------------===//
418 
419 namespace {
420 class AddToGroupOpLowering : public OpConversionPattern<AddToGroupOp> {
421 public:
422   using OpConversionPattern::OpConversionPattern;
423 
424   LogicalResult
matchAndRewrite(AddToGroupOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const425   matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor,
426                   ConversionPatternRewriter &rewriter) const override {
427     rewriter.replaceOpWithNewOp<RuntimeAddToGroupOp>(
428         op, rewriter.getIndexType(), adaptor.getOperands());
429     return success();
430   }
431 };
432 } // namespace
433 
434 //===----------------------------------------------------------------------===//
435 // Convert async.func, async.return and async.call operations to non-blocking
436 // operations based on llvm coroutine
437 //===----------------------------------------------------------------------===//
438 
439 namespace {
440 
441 //===----------------------------------------------------------------------===//
442 // Convert async.func operation to func.func
443 //===----------------------------------------------------------------------===//
444 
445 class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
446 public:
AsyncFuncOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)447   AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
448       : OpConversionPattern<async::FuncOp>(ctx), coros(std::move(coros)) {}
449 
450   LogicalResult
matchAndRewrite(async::FuncOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const451   matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
452                   ConversionPatternRewriter &rewriter) const override {
453     Location loc = op->getLoc();
454 
455     auto newFuncOp =
456         rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType());
457 
458     SymbolTable::setSymbolVisibility(newFuncOp,
459                                      SymbolTable::getSymbolVisibility(op));
460     // Copy over all attributes other than the name.
461     for (const auto &namedAttr : op->getAttrs()) {
462       if (namedAttr.getName() != SymbolTable::getSymbolAttrName())
463         newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
464     }
465 
466     rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
467                                 newFuncOp.end());
468 
469     CoroMachinery coro = setupCoroMachinery(newFuncOp);
470     (*coros)[newFuncOp] = coro;
471     // no initial suspend, we should hot-start
472 
473     rewriter.eraseOp(op);
474     return success();
475   }
476 
477 private:
478   FuncCoroMapPtr coros;
479 };
480 
481 //===----------------------------------------------------------------------===//
482 // Convert async.call operation to func.call
483 //===----------------------------------------------------------------------===//
484 
485 class AsyncCallOpLowering : public OpConversionPattern<async::CallOp> {
486 public:
AsyncCallOpLowering(MLIRContext * ctx)487   AsyncCallOpLowering(MLIRContext *ctx)
488       : OpConversionPattern<async::CallOp>(ctx) {}
489 
490   LogicalResult
matchAndRewrite(async::CallOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const491   matchAndRewrite(async::CallOp op, OpAdaptor adaptor,
492                   ConversionPatternRewriter &rewriter) const override {
493     rewriter.replaceOpWithNewOp<func::CallOp>(
494         op, op.getCallee(), op.getResultTypes(), op.getOperands());
495     return success();
496   }
497 };
498 
499 //===----------------------------------------------------------------------===//
500 // Convert async.return operation to async.runtime operations.
501 //===----------------------------------------------------------------------===//
502 
503 class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
504 public:
AsyncReturnOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)505   AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
506       : OpConversionPattern<async::ReturnOp>(ctx), coros(std::move(coros)) {}
507 
508   LogicalResult
matchAndRewrite(async::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const509   matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
510                   ConversionPatternRewriter &rewriter) const override {
511     auto func = op->template getParentOfType<func::FuncOp>();
512     auto funcCoro = coros->find(func);
513     if (funcCoro == coros->end())
514       return rewriter.notifyMatchFailure(
515           op, "operation is not inside the async coroutine function");
516 
517     Location loc = op->getLoc();
518     const CoroMachinery &coro = funcCoro->getSecond();
519     rewriter.setInsertionPointAfter(op);
520 
521     // Store return values into the async values storage and switch async
522     // values state to available.
523     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
524       Value returnValue = std::get<0>(tuple);
525       Value asyncValue = std::get<1>(tuple);
526       rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue);
527       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
528     }
529 
530     if (coro.asyncToken)
531       // Switch the coroutine completion token to available state.
532       rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
533 
534     rewriter.eraseOp(op);
535     rewriter.create<cf::BranchOp>(loc, coro.cleanup);
536     return success();
537   }
538 
539 private:
540   FuncCoroMapPtr coros;
541 };
542 } // namespace
543 
544 //===----------------------------------------------------------------------===//
545 // Convert async.await and async.await_all operations to the async.runtime.await
546 // or async.runtime.await_and_resume operations.
547 //===----------------------------------------------------------------------===//
548 
549 namespace {
550 template <typename AwaitType, typename AwaitableType>
551 class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
552   using AwaitAdaptor = typename AwaitType::Adaptor;
553 
554 public:
AwaitOpLoweringBase(MLIRContext * ctx,FuncCoroMapPtr coros,bool shouldLowerBlockingWait)555   AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
556                       bool shouldLowerBlockingWait)
557       : OpConversionPattern<AwaitType>(ctx), coros(std::move(coros)),
558         shouldLowerBlockingWait(shouldLowerBlockingWait) {}
559 
560   LogicalResult
matchAndRewrite(AwaitType op,typename AwaitType::Adaptor adaptor,ConversionPatternRewriter & rewriter) const561   matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
562                   ConversionPatternRewriter &rewriter) const override {
563     // We can only await on one the `AwaitableType` (for `await` it can be
564     // a `token` or a `value`, for `await_all` it must be a `group`).
565     if (!isa<AwaitableType>(op.getOperand().getType()))
566       return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
567 
568     // Check if await operation is inside the coroutine function.
569     auto func = op->template getParentOfType<func::FuncOp>();
570     auto funcCoro = coros->find(func);
571     const bool isInCoroutine = funcCoro != coros->end();
572 
573     Location loc = op->getLoc();
574     Value operand = adaptor.getOperand();
575 
576     Type i1 = rewriter.getI1Type();
577 
578     // Delay lowering to block wait in case await op is inside async.execute
579     if (!isInCoroutine && !shouldLowerBlockingWait)
580       return failure();
581 
582     // Inside regular functions we use the blocking wait operation to wait for
583     // the async object (token, value or group) to become available.
584     if (!isInCoroutine) {
585       ImplicitLocOpBuilder builder(loc, rewriter);
586       builder.create<RuntimeAwaitOp>(loc, operand);
587 
588       // Assert that the awaited operands is not in the error state.
589       Value isError = builder.create<RuntimeIsErrorOp>(i1, operand);
590       Value notError = builder.create<arith::XOrIOp>(
591           isError, builder.create<arith::ConstantOp>(
592                        loc, i1, builder.getIntegerAttr(i1, 1)));
593 
594       builder.create<cf::AssertOp>(notError,
595                                    "Awaited async operand is in error state");
596     }
597 
598     // Inside the coroutine we convert await operation into coroutine suspension
599     // point, and resume execution asynchronously.
600     if (isInCoroutine) {
601       CoroMachinery &coro = funcCoro->getSecond();
602       Block *suspended = op->getBlock();
603 
604       ImplicitLocOpBuilder builder(loc, rewriter);
605       MLIRContext *ctx = op->getContext();
606 
607       // Save the coroutine state and resume on a runtime managed thread when
608       // the operand becomes available.
609       auto coroSaveOp =
610           builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle);
611       builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle);
612 
613       // Split the entry block before the await operation.
614       Block *resume = rewriter.splitBlock(suspended, Block::iterator(op));
615 
616       // Add async.coro.suspend as a suspended block terminator.
617       builder.setInsertionPointToEnd(suspended);
618       builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume,
619                                     coro.cleanupForDestroy);
620 
621       // Split the resume block into error checking and continuation.
622       Block *continuation = rewriter.splitBlock(resume, Block::iterator(op));
623 
624       // Check if the awaited value is in the error state.
625       builder.setInsertionPointToStart(resume);
626       auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand);
627       builder.create<cf::CondBranchOp>(isError,
628                                        /*trueDest=*/setupSetErrorBlock(coro),
629                                        /*trueArgs=*/ArrayRef<Value>(),
630                                        /*falseDest=*/continuation,
631                                        /*falseArgs=*/ArrayRef<Value>());
632 
633       // Make sure that replacement value will be constructed in the
634       // continuation block.
635       rewriter.setInsertionPointToStart(continuation);
636     }
637 
638     // Erase or replace the await operation with the new value.
639     if (Value replaceWith = getReplacementValue(op, operand, rewriter))
640       rewriter.replaceOp(op, replaceWith);
641     else
642       rewriter.eraseOp(op);
643 
644     return success();
645   }
646 
getReplacementValue(AwaitType op,Value operand,ConversionPatternRewriter & rewriter) const647   virtual Value getReplacementValue(AwaitType op, Value operand,
648                                     ConversionPatternRewriter &rewriter) const {
649     return Value();
650   }
651 
652 private:
653   FuncCoroMapPtr coros;
654   bool shouldLowerBlockingWait;
655 };
656 
657 /// Lowering for `async.await` with a token operand.
658 class AwaitTokenOpLowering : public AwaitOpLoweringBase<AwaitOp, TokenType> {
659   using Base = AwaitOpLoweringBase<AwaitOp, TokenType>;
660 
661 public:
662   using Base::Base;
663 };
664 
665 /// Lowering for `async.await` with a value operand.
666 class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
667   using Base = AwaitOpLoweringBase<AwaitOp, ValueType>;
668 
669 public:
670   using Base::Base;
671 
672   Value
getReplacementValue(AwaitOp op,Value operand,ConversionPatternRewriter & rewriter) const673   getReplacementValue(AwaitOp op, Value operand,
674                       ConversionPatternRewriter &rewriter) const override {
675     // Load from the async value storage.
676     auto valueType = cast<ValueType>(operand.getType()).getValueType();
677     return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand);
678   }
679 };
680 
681 /// Lowering for `async.await_all` operation.
682 class AwaitAllOpLowering : public AwaitOpLoweringBase<AwaitAllOp, GroupType> {
683   using Base = AwaitOpLoweringBase<AwaitAllOp, GroupType>;
684 
685 public:
686   using Base::Base;
687 };
688 
689 } // namespace
690 
691 //===----------------------------------------------------------------------===//
692 // Convert async.yield operation to async.runtime operations.
693 //===----------------------------------------------------------------------===//
694 
695 class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
696 public:
YieldOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)697   YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
698       : OpConversionPattern<async::YieldOp>(ctx), coros(std::move(coros)) {}
699 
700   LogicalResult
matchAndRewrite(async::YieldOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const701   matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
702                   ConversionPatternRewriter &rewriter) const override {
703     // Check if yield operation is inside the async coroutine function.
704     auto func = op->template getParentOfType<func::FuncOp>();
705     auto funcCoro = coros->find(func);
706     if (funcCoro == coros->end())
707       return rewriter.notifyMatchFailure(
708           op, "operation is not inside the async coroutine function");
709 
710     Location loc = op->getLoc();
711     const CoroMachinery &coro = funcCoro->getSecond();
712 
713     // Store yielded values into the async values storage and switch async
714     // values state to available.
715     for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) {
716       Value yieldValue = std::get<0>(tuple);
717       Value asyncValue = std::get<1>(tuple);
718       rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue);
719       rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue);
720     }
721 
722     if (coro.asyncToken)
723       // Switch the coroutine completion token to available state.
724       rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken);
725 
726     rewriter.eraseOp(op);
727     rewriter.create<cf::BranchOp>(loc, coro.cleanup);
728 
729     return success();
730   }
731 
732 private:
733   FuncCoroMapPtr coros;
734 };
735 
736 //===----------------------------------------------------------------------===//
737 // Convert cf.assert operation to cf.cond_br into `set_error` block.
738 //===----------------------------------------------------------------------===//
739 
740 class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
741 public:
AssertOpLowering(MLIRContext * ctx,FuncCoroMapPtr coros)742   AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
743       : OpConversionPattern<cf::AssertOp>(ctx), coros(std::move(coros)) {}
744 
745   LogicalResult
matchAndRewrite(cf::AssertOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const746   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
747                   ConversionPatternRewriter &rewriter) const override {
748     // Check if assert operation is inside the async coroutine function.
749     auto func = op->template getParentOfType<func::FuncOp>();
750     auto funcCoro = coros->find(func);
751     if (funcCoro == coros->end())
752       return rewriter.notifyMatchFailure(
753           op, "operation is not inside the async coroutine function");
754 
755     Location loc = op->getLoc();
756     CoroMachinery &coro = funcCoro->getSecond();
757 
758     Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op));
759     rewriter.setInsertionPointToEnd(cont->getPrevNode());
760     rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(),
761                                       /*trueDest=*/cont,
762                                       /*trueArgs=*/ArrayRef<Value>(),
763                                       /*falseDest=*/setupSetErrorBlock(coro),
764                                       /*falseArgs=*/ArrayRef<Value>());
765     rewriter.eraseOp(op);
766 
767     return success();
768   }
769 
770 private:
771   FuncCoroMapPtr coros;
772 };
773 
774 //===----------------------------------------------------------------------===//
runOnOperation()775 void AsyncToAsyncRuntimePass::runOnOperation() {
776   ModuleOp module = getOperation();
777   SymbolTable symbolTable(module);
778 
779   // Functions with coroutine CFG setups, which are results of outlining
780   // `async.execute` body regions
781   FuncCoroMapPtr coros =
782       std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
783 
784   module.walk([&](ExecuteOp execute) {
785     coros->insert(outlineExecuteOp(symbolTable, execute));
786   });
787 
788   LLVM_DEBUG({
789     llvm::dbgs() << "Outlined " << coros->size()
790                  << " functions built from async.execute operations\n";
791   });
792 
793   // Returns true if operation is inside the coroutine.
794   auto isInCoroutine = [&](Operation *op) -> bool {
795     auto parentFunc = op->getParentOfType<func::FuncOp>();
796     return coros->find(parentFunc) != coros->end();
797   };
798 
799   // Lower async operations to async.runtime operations.
800   MLIRContext *ctx = module->getContext();
801   RewritePatternSet asyncPatterns(ctx);
802 
803   // Conversion to async runtime augments original CFG with the coroutine CFG,
804   // and we have to make sure that structured control flow operations with async
805   // operations in nested regions will be converted to branch-based control flow
806   // before we add the coroutine basic blocks.
807   populateSCFToControlFlowConversionPatterns(asyncPatterns);
808 
809   // Async lowering does not use type converter because it must preserve all
810   // types for async.runtime operations.
811   asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
812 
813   asyncPatterns
814       .add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
815           ctx, coros, /*should_lower_blocking_wait=*/true);
816 
817   // Lower assertions to conditional branches into error blocks.
818   asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
819 
820   // All high level async operations must be lowered to the runtime operations.
821   ConversionTarget runtimeTarget(*ctx);
822   runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
823   runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
824   runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
825 
826   // Decide if structured control flow has to be lowered to branch-based CFG.
827   runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
828     auto walkResult = op->walk([&](Operation *nested) {
829       bool isAsync = isa<async::AsyncDialect>(nested->getDialect());
830       return isAsync && isInCoroutine(nested) ? WalkResult::interrupt()
831                                               : WalkResult::advance();
832     });
833     return !walkResult.wasInterrupted();
834   });
835   runtimeTarget.addLegalOp<cf::AssertOp, arith::XOrIOp, arith::ConstantOp,
836                            func::ConstantOp, cf::BranchOp, cf::CondBranchOp>();
837 
838   // Assertions must be converted to runtime errors inside async functions.
839   runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
840       [&](cf::AssertOp op) -> bool {
841         auto func = op->getParentOfType<func::FuncOp>();
842         return !coros->contains(func);
843       });
844 
845   if (failed(applyPartialConversion(module, runtimeTarget,
846                                     std::move(asyncPatterns)))) {
847     signalPassFailure();
848     return;
849   }
850 }
851 
852 //===----------------------------------------------------------------------===//
populateAsyncFuncToAsyncRuntimeConversionPatterns(RewritePatternSet & patterns,ConversionTarget & target)853 void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
854     RewritePatternSet &patterns, ConversionTarget &target) {
855   // Functions with coroutine CFG setups, which are results of converting
856   // async.func.
857   FuncCoroMapPtr coros =
858       std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
859   MLIRContext *ctx = patterns.getContext();
860   // Lower async.func to func.func with coroutine cfg.
861   patterns.add<AsyncCallOpLowering>(ctx);
862   patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
863 
864   patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
865       ctx, coros, /*should_lower_blocking_wait=*/false);
866   patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
867 
868   target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
869       [coros](Operation *op) {
870         auto exec = op->getParentOfType<ExecuteOp>();
871         auto func = op->getParentOfType<func::FuncOp>();
872         return exec || !coros->contains(func);
873       });
874 }
875 
runOnOperation()876 void AsyncFuncToAsyncRuntimePass::runOnOperation() {
877   ModuleOp module = getOperation();
878 
879   // Lower async operations to async.runtime operations.
880   MLIRContext *ctx = module->getContext();
881   RewritePatternSet asyncPatterns(ctx);
882   ConversionTarget runtimeTarget(*ctx);
883 
884   // Lower async.func to func.func with coroutine cfg.
885   populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
886                                                     runtimeTarget);
887 
888   runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
889   runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
890 
891   runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
892                            cf::BranchOp, cf::CondBranchOp>();
893 
894   if (failed(applyPartialConversion(module, runtimeTarget,
895                                     std::move(asyncPatterns)))) {
896     signalPassFailure();
897     return;
898   }
899 }
900 
createAsyncToAsyncRuntimePass()901 std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
902   return std::make_unique<AsyncToAsyncRuntimePass>();
903 }
904 
905 std::unique_ptr<OperationPass<ModuleOp>>
createAsyncFuncToAsyncRuntimePass()906 mlir::createAsyncFuncToAsyncRuntimePass() {
907   return std::make_unique<AsyncFuncToAsyncRuntimePass>();
908 }
909