xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp (revision bc29fc937c6cb4a210f80c93c79fc6ed97c801f8)
1db7129a0SChristian Sigg //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2db7129a0SChristian Sigg //
3db7129a0SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4db7129a0SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5db7129a0SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6db7129a0SChristian Sigg //
7db7129a0SChristian Sigg //===----------------------------------------------------------------------===//
8db7129a0SChristian Sigg //
9db7129a0SChristian Sigg // This file implements the GPU dialect pattern rewriters that make GPU op
10db7129a0SChristian Sigg // within a region execute asynchronously.
11db7129a0SChristian Sigg //
12db7129a0SChristian Sigg //===----------------------------------------------------------------------===//
13db7129a0SChristian Sigg 
14039b969bSMichele Scuttari #include "mlir/Dialect/GPU/Transforms/Passes.h"
1567d0d7acSMichele Scuttari 
1667d0d7acSMichele Scuttari #include "mlir/Dialect/Async/IR/Async.h"
1767d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1867d0d7acSMichele Scuttari #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19*bc29fc93SPetr Kurapov #include "mlir/Dialect/GPU/Utils/GPUUtils.h"
20db7129a0SChristian Sigg #include "mlir/IR/Builders.h"
214d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
22db7129a0SChristian Sigg #include "mlir/IR/PatternMatch.h"
23db7129a0SChristian Sigg #include "mlir/IR/SymbolTable.h"
24fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h"
25db7129a0SChristian Sigg #include "mlir/Support/LLVM.h"
26db7129a0SChristian Sigg #include "mlir/Transforms/RegionUtils.h"
27d9adde5aSChristian Sigg #include "llvm/ADT/TypeSwitch.h"
28db7129a0SChristian Sigg 
2967d0d7acSMichele Scuttari namespace mlir {
3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_GPUASYNCREGIONPASS
3167d0d7acSMichele Scuttari #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
3267d0d7acSMichele Scuttari } // namespace mlir
3367d0d7acSMichele Scuttari 
34db7129a0SChristian Sigg using namespace mlir;
3567d0d7acSMichele Scuttari 
36db7129a0SChristian Sigg namespace {
3767d0d7acSMichele Scuttari class GpuAsyncRegionPass
3867d0d7acSMichele Scuttari     : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
39d9adde5aSChristian Sigg   struct ThreadTokenCallback;
40d9adde5aSChristian Sigg   struct DeferWaitCallback;
41f03826f8SChristian Sigg   struct SingleTokenUseCallback;
4241574554SRiver Riddle   void runOnOperation() override;
43db7129a0SChristian Sigg };
44db7129a0SChristian Sigg } // namespace
45db7129a0SChristian Sigg 
46fe7c0d90SRiver Riddle static bool isTerminator(Operation *op) {
47fe7c0d90SRiver Riddle   return op->mightHaveTrait<OpTrait::IsTerminator>();
48fe7c0d90SRiver Riddle }
49fc367dfaSMahesh Ravishankar static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
50d9adde5aSChristian Sigg 
51db7129a0SChristian Sigg // Region walk callback which makes GPU ops implementing the AsyncOpInterface
52db7129a0SChristian Sigg // execute asynchronously.
53d9adde5aSChristian Sigg struct GpuAsyncRegionPass::ThreadTokenCallback {
54d9adde5aSChristian Sigg   ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
55d9adde5aSChristian Sigg 
560b21371eSChristian Sigg   WalkResult operator()(Block *block) {
570b21371eSChristian Sigg     for (Operation &op : make_early_inc_range(*block)) {
580b21371eSChristian Sigg       if (failed(visit(&op)))
590b21371eSChristian Sigg         return WalkResult::interrupt();
600b21371eSChristian Sigg     }
610b21371eSChristian Sigg     return WalkResult::advance();
620b21371eSChristian Sigg   }
630b21371eSChristian Sigg 
640b21371eSChristian Sigg private:
65db7129a0SChristian Sigg   // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
66db7129a0SChristian Sigg   // create a current token (unless it already exists), and 'thread' that token
67db7129a0SChristian Sigg   // through the `op` so that it executes asynchronously.
68db7129a0SChristian Sigg   //
69db7129a0SChristian Sigg   // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
70d9adde5aSChristian Sigg   // host-synchronize execution. A `!gpu.async.token` will therefore only be
71d9adde5aSChristian Sigg   // used inside of its block and GPU execution will always synchronize with
72d9adde5aSChristian Sigg   // the host at block boundaries.
730b21371eSChristian Sigg   LogicalResult visit(Operation *op) {
74db7129a0SChristian Sigg     if (isa<gpu::LaunchOp>(op))
75db7129a0SChristian Sigg       return op->emitOpError("replace with gpu.launch_func first");
760b21371eSChristian Sigg     if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
770b21371eSChristian Sigg       if (currentToken)
780b21371eSChristian Sigg         waitOp.addAsyncDependency(currentToken);
7910c04f46SRiver Riddle       currentToken = waitOp.getAsyncToken();
800b21371eSChristian Sigg       return success();
810b21371eSChristian Sigg     }
82db7129a0SChristian Sigg     builder.setInsertionPoint(op);
83db7129a0SChristian Sigg     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
84db7129a0SChristian Sigg       return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
85db7129a0SChristian Sigg     if (!currentToken)
86db7129a0SChristian Sigg       return success();
87db7129a0SChristian Sigg     // Insert host synchronization before terminator or op with side effects.
88d9adde5aSChristian Sigg     if (isTerminator(op) || hasSideEffects(op))
89db7129a0SChristian Sigg       currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
90db7129a0SChristian Sigg     return success();
91db7129a0SChristian Sigg   }
92db7129a0SChristian Sigg 
93db7129a0SChristian Sigg   // Replaces asyncOp with a clone that returns a token.
94db7129a0SChristian Sigg   LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
95db7129a0SChristian Sigg     auto *op = asyncOp.getOperation();
964c372a35SChristian Sigg     auto tokenType = builder.getType<gpu::AsyncTokenType>();
974c372a35SChristian Sigg 
98db7129a0SChristian Sigg     // If there is no current token, insert a `gpu.wait async` without
99db7129a0SChristian Sigg     // dependencies to create one.
100db7129a0SChristian Sigg     if (!currentToken)
101db7129a0SChristian Sigg       currentToken = createWaitOp(op->getLoc(), tokenType, {});
102db7129a0SChristian Sigg     asyncOp.addAsyncDependency(currentToken);
103db7129a0SChristian Sigg 
1040b21371eSChristian Sigg     // Return early if op returns a token already.
1050b21371eSChristian Sigg     currentToken = asyncOp.getAsyncToken();
1060b21371eSChristian Sigg     if (currentToken)
1070b21371eSChristian Sigg       return success();
1080b21371eSChristian Sigg 
109db7129a0SChristian Sigg     // Clone the op to return a token in addition to the other results.
110a79b26dbSChristian Sigg     SmallVector<Type, 1> resultTypes;
111db7129a0SChristian Sigg     resultTypes.reserve(1 + op->getNumResults());
112db7129a0SChristian Sigg     copy(op->getResultTypes(), std::back_inserter(resultTypes));
113a79b26dbSChristian Sigg     resultTypes.push_back(tokenType);
114bbe5bf17SMehdi Amini     auto *newOp = Operation::create(
115bbe5bf17SMehdi Amini         op->getLoc(), op->getName(), resultTypes, op->getOperands(),
116bbe5bf17SMehdi Amini         op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
117a0d019fcSChristian Sigg         op->getSuccessors(), op->getNumRegions());
118a0d019fcSChristian Sigg 
119a0d019fcSChristian Sigg     // Clone regions into new op.
1204d67b278SJeff Niu     IRMapping mapping;
121a0d019fcSChristian Sigg     for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
122a0d019fcSChristian Sigg       std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
123db7129a0SChristian Sigg 
124db7129a0SChristian Sigg     // Replace the op with the async clone.
125db7129a0SChristian Sigg     auto results = newOp->getResults();
126a79b26dbSChristian Sigg     currentToken = results.back();
127db7129a0SChristian Sigg     builder.insert(newOp);
128a79b26dbSChristian Sigg     op->replaceAllUsesWith(results.drop_back());
129db7129a0SChristian Sigg     op->erase();
130db7129a0SChristian Sigg 
131db7129a0SChristian Sigg     return success();
132db7129a0SChristian Sigg   }
133db7129a0SChristian Sigg 
134db7129a0SChristian Sigg   Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
13510c04f46SRiver Riddle     return builder.create<gpu::WaitOp>(loc, resultType, operands)
13610c04f46SRiver Riddle         .getAsyncToken();
137db7129a0SChristian Sigg   }
138db7129a0SChristian Sigg 
139db7129a0SChristian Sigg   OpBuilder builder;
1404c372a35SChristian Sigg 
141db7129a0SChristian Sigg   // The token that represents the current asynchronous dependency. It's valid
142db7129a0SChristian Sigg   // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
143db7129a0SChristian Sigg   // In between, each gpu::AsyncOpInterface depends on the current token and
144db7129a0SChristian Sigg   // produces the new one.
145db7129a0SChristian Sigg   Value currentToken = {};
146db7129a0SChristian Sigg };
147db7129a0SChristian Sigg 
148f03826f8SChristian Sigg /// Erases `executeOp` and returns a clone with additional `results`.
149f03826f8SChristian Sigg async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
150f03826f8SChristian Sigg                                    ValueRange results) {
151f03826f8SChristian Sigg   // Add values to async.yield op.
152f03826f8SChristian Sigg   Operation *yieldOp = executeOp.getBody()->getTerminator();
153f03826f8SChristian Sigg   yieldOp->insertOperands(yieldOp->getNumOperands(), results);
154f03826f8SChristian Sigg 
155f03826f8SChristian Sigg   // Construct new result type list with additional types.
156f03826f8SChristian Sigg   SmallVector<Type, 2> resultTypes;
157f03826f8SChristian Sigg   resultTypes.reserve(executeOp.getNumResults() + results.size());
158f03826f8SChristian Sigg   transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
159f03826f8SChristian Sigg             [](Type type) {
160f03826f8SChristian Sigg               // Extract value type from !async.value.
1615550c821STres Popp               if (auto valueType = dyn_cast<async::ValueType>(type))
162f03826f8SChristian Sigg                 return valueType.getValueType();
1635550c821STres Popp               assert(isa<async::TokenType>(type) && "expected token type");
164f03826f8SChristian Sigg               return type;
165f03826f8SChristian Sigg             });
166f03826f8SChristian Sigg   transform(results, std::back_inserter(resultTypes),
167f03826f8SChristian Sigg             [](Value value) { return value.getType(); });
168f03826f8SChristian Sigg 
169f03826f8SChristian Sigg   // Clone executeOp with the extra results.
170f03826f8SChristian Sigg   OpBuilder builder(executeOp);
171f03826f8SChristian Sigg   auto newOp = builder.create<async::ExecuteOp>(
172f03826f8SChristian Sigg       executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
173a5aa7836SRiver Riddle       executeOp.getDependencies(), executeOp.getBodyOperands());
1744d67b278SJeff Niu   IRMapping mapper;
175f03826f8SChristian Sigg   newOp.getRegion().getBlocks().clear();
176f03826f8SChristian Sigg   executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
177f03826f8SChristian Sigg 
178f03826f8SChristian Sigg   // Replace executeOp with cloned one.
179f03826f8SChristian Sigg   executeOp.getOperation()->replaceAllUsesWith(
180f03826f8SChristian Sigg       newOp.getResults().drop_back(results.size()));
181f03826f8SChristian Sigg   executeOp.erase();
182f03826f8SChristian Sigg 
183f03826f8SChristian Sigg   return newOp;
184f03826f8SChristian Sigg }
185f03826f8SChristian Sigg 
186d9adde5aSChristian Sigg // Callback for `async.execute` ops which tries to push the contained
187d9adde5aSChristian Sigg // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
188d9adde5aSChristian Sigg struct GpuAsyncRegionPass::DeferWaitCallback {
189d9adde5aSChristian Sigg   // If the `executeOp`s token is used only in `async.execute` or `async.await`
190d9adde5aSChristian Sigg   // ops, add the region's last `gpu.wait` op to the worklist if it is
191d9adde5aSChristian Sigg   // synchronous and is the last op with side effects.
192d9adde5aSChristian Sigg   void operator()(async::ExecuteOp executeOp) {
193a5aa7836SRiver Riddle     if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
194d9adde5aSChristian Sigg       return;
195d9adde5aSChristian Sigg     // async.execute's region is currently restricted to one block.
196d9adde5aSChristian Sigg     for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
197d9adde5aSChristian Sigg       if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
19810c04f46SRiver Riddle         if (!waitOp.getAsyncToken())
199d9adde5aSChristian Sigg           worklist.push_back(waitOp);
200d9adde5aSChristian Sigg         return;
201d9adde5aSChristian Sigg       }
202d9adde5aSChristian Sigg       if (hasSideEffects(&op))
203d9adde5aSChristian Sigg         return;
204d9adde5aSChristian Sigg     }
205d9adde5aSChristian Sigg   }
206d9adde5aSChristian Sigg 
207d9adde5aSChristian Sigg   // The destructor performs the actual rewrite work.
208d9adde5aSChristian Sigg   ~DeferWaitCallback() {
209d9adde5aSChristian Sigg     for (size_t i = 0; i < worklist.size(); ++i) {
210d9adde5aSChristian Sigg       auto waitOp = worklist[i];
2110bf4a82aSChristian Sigg       auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
212d9adde5aSChristian Sigg 
213f03826f8SChristian Sigg       // Erase `gpu.wait` and return async dependencies from execute op instead.
214a5aa7836SRiver Riddle       SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
215d9adde5aSChristian Sigg       waitOp.erase();
216f03826f8SChristian Sigg       executeOp = addExecuteResults(executeOp, dependencies);
217d9adde5aSChristian Sigg 
218d9adde5aSChristian Sigg       // Add the async dependency to each user of the `async.execute` token.
219f03826f8SChristian Sigg       auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
220a5aa7836SRiver Riddle       SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
221a5aa7836SRiver Riddle                                         executeOp.getToken().user_end());
2226e1ac68aSVitaly Buka       for (Operation *user : users)
223d9adde5aSChristian Sigg         addAsyncDependencyAfter(asyncTokens, user);
224d9adde5aSChristian Sigg     }
225d9adde5aSChristian Sigg   }
226d9adde5aSChristian Sigg 
227d9adde5aSChristian Sigg private:
228d9adde5aSChristian Sigg   // Returns whether all token users are either 'async.execute' or 'async.await'
229d9adde5aSChristian Sigg   // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
230d9adde5aSChristian Sigg   // 'async.execute' body to it's users. Specifically, we do not allow
231d9adde5aSChristian Sigg   // terminator users, because it could mean that the `async.execute` is inside
232d9adde5aSChristian Sigg   // control flow code.
233d9adde5aSChristian Sigg   static bool areAllUsersExecuteOrAwait(Value token) {
234f03826f8SChristian Sigg     return !token.use_empty() &&
235971b8525SJakub Kuderski            llvm::all_of(token.getUsers(),
236971b8525SJakub Kuderski                         llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
237d9adde5aSChristian Sigg   }
238d9adde5aSChristian Sigg 
239d9adde5aSChristian Sigg   // Add the `asyncToken` as dependency as needed after `op`.
240d9adde5aSChristian Sigg   void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
241d9adde5aSChristian Sigg     OpBuilder builder(op->getContext());
242d9adde5aSChristian Sigg     auto loc = op->getLoc();
243d9adde5aSChristian Sigg 
244d9adde5aSChristian Sigg     Block::iterator it;
245d9adde5aSChristian Sigg     SmallVector<Value, 1> tokens;
246d9adde5aSChristian Sigg     tokens.reserve(asyncTokens.size());
247d9adde5aSChristian Sigg     TypeSwitch<Operation *>(op)
248d9adde5aSChristian Sigg         .Case<async::AwaitOp>([&](auto awaitOp) {
249d9adde5aSChristian Sigg           // Add async.await ops to wait for the !gpu.async.tokens.
250d9adde5aSChristian Sigg           builder.setInsertionPointAfter(op);
251d9adde5aSChristian Sigg           for (auto asyncToken : asyncTokens)
252d9adde5aSChristian Sigg             tokens.push_back(
253a5aa7836SRiver Riddle                 builder.create<async::AwaitOp>(loc, asyncToken).getResult());
254d9adde5aSChristian Sigg           // Set `it` after the inserted async.await ops.
255d9adde5aSChristian Sigg           it = builder.getInsertionPoint();
256d9adde5aSChristian Sigg         })
257d9adde5aSChristian Sigg         .Case<async::ExecuteOp>([&](auto executeOp) {
258d9adde5aSChristian Sigg           // Set `it` to the beginning of the region and add asyncTokens to the
259d9adde5aSChristian Sigg           // async.execute operands.
260d9adde5aSChristian Sigg           it = executeOp.getBody()->begin();
261a5aa7836SRiver Riddle           executeOp.getBodyOperandsMutable().append(asyncTokens);
262d9adde5aSChristian Sigg           SmallVector<Type, 1> tokenTypes(
263d9adde5aSChristian Sigg               asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
264e084679fSRiver Riddle           SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
265e084679fSRiver Riddle                                              executeOp.getLoc());
266e084679fSRiver Riddle           copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
267d9adde5aSChristian Sigg                std::back_inserter(tokens));
268d9adde5aSChristian Sigg         });
269d9adde5aSChristian Sigg 
270d9adde5aSChristian Sigg     // Advance `it` to terminator or op with side-effects.
271d9adde5aSChristian Sigg     it = std::find_if(it, Block::iterator(), [](Operation &op) {
272d9adde5aSChristian Sigg       return isTerminator(&op) || hasSideEffects(&op);
273d9adde5aSChristian Sigg     });
274d9adde5aSChristian Sigg 
275d9adde5aSChristian Sigg     // If `op` implements the AsyncOpInterface, add `token` to the list of async
276d9adde5aSChristian Sigg     // dependencies.
277d9adde5aSChristian Sigg     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
278d9adde5aSChristian Sigg       for (auto token : tokens)
279d9adde5aSChristian Sigg         asyncOp.addAsyncDependency(token);
280d9adde5aSChristian Sigg       return;
281d9adde5aSChristian Sigg     }
282d9adde5aSChristian Sigg 
283d9adde5aSChristian Sigg     // Otherwise, insert a gpu.wait before 'it'.
284d9adde5aSChristian Sigg     builder.setInsertionPoint(it->getBlock(), it);
285d9adde5aSChristian Sigg     auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
286d9adde5aSChristian Sigg 
287d9adde5aSChristian Sigg     // If the new waitOp is at the end of an async.execute region, add it to the
288d9adde5aSChristian Sigg     // worklist. 'operator()(executeOp)' would do the same, but this is faster.
289d9adde5aSChristian Sigg     auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
290a5aa7836SRiver Riddle     if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
291d9adde5aSChristian Sigg         !it->getNextNode())
292d9adde5aSChristian Sigg       worklist.push_back(waitOp);
293d9adde5aSChristian Sigg   }
294d9adde5aSChristian Sigg 
295d9adde5aSChristian Sigg   SmallVector<gpu::WaitOp, 8> worklist;
296d9adde5aSChristian Sigg };
297d9adde5aSChristian Sigg 
298f03826f8SChristian Sigg // Callback for `async.execute` ops which repeats !gpu.async.token results
299f03826f8SChristian Sigg // so that each of them is only used once.
300f03826f8SChristian Sigg struct GpuAsyncRegionPass::SingleTokenUseCallback {
301f03826f8SChristian Sigg   void operator()(async::ExecuteOp executeOp) {
302f03826f8SChristian Sigg     // Extract !gpu.async.token results which have multiple uses.
303a5aa7836SRiver Riddle     auto multiUseResults = llvm::make_filter_range(
304a5aa7836SRiver Riddle         executeOp.getBodyResults(), [](OpResult result) {
305f03826f8SChristian Sigg           if (result.use_empty() || result.hasOneUse())
306f03826f8SChristian Sigg             return false;
3075550c821STres Popp           auto valueType = dyn_cast<async::ValueType>(result.getType());
308f03826f8SChristian Sigg           return valueType &&
3095550c821STres Popp                  isa<gpu::AsyncTokenType>(valueType.getValueType());
310f03826f8SChristian Sigg         });
311f03826f8SChristian Sigg     if (multiUseResults.empty())
312f03826f8SChristian Sigg       return;
313f03826f8SChristian Sigg 
314f03826f8SChristian Sigg     // Indices within !async.execute results (i.e. without the async.token).
315f03826f8SChristian Sigg     SmallVector<int, 4> indices;
316f03826f8SChristian Sigg     transform(multiUseResults, std::back_inserter(indices),
317f03826f8SChristian Sigg               [](OpResult result) {
318f03826f8SChristian Sigg                 return result.getResultNumber() - 1; // Index without token.
319f03826f8SChristian Sigg               });
320f03826f8SChristian Sigg 
321f03826f8SChristian Sigg     for (auto index : indices) {
322a5aa7836SRiver Riddle       assert(!executeOp.getBodyResults()[index].getUses().empty());
323f03826f8SChristian Sigg       // Repeat async.yield token result, one for each use after the first one.
324a5aa7836SRiver Riddle       auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
325f03826f8SChristian Sigg       auto count = std::distance(uses.begin(), uses.end());
326f03826f8SChristian Sigg       auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
327f03826f8SChristian Sigg       SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
328f03826f8SChristian Sigg       executeOp = addExecuteResults(executeOp, operands);
329f03826f8SChristian Sigg       // Update 'uses' to refer to the new executeOp.
330a5aa7836SRiver Riddle       uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
331a5aa7836SRiver Riddle       auto results = executeOp.getBodyResults().take_back(count);
332f03826f8SChristian Sigg       for (auto pair : llvm::zip(uses, results))
333f03826f8SChristian Sigg         std::get<0>(pair).set(std::get<1>(pair));
334f03826f8SChristian Sigg     }
335f03826f8SChristian Sigg   }
336f03826f8SChristian Sigg };
337f03826f8SChristian Sigg 
338db7129a0SChristian Sigg // Replaces synchronous GPU ops in the op's region with asynchronous ones and
339db7129a0SChristian Sigg // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
340db7129a0SChristian Sigg // execution semantics and that no GPU ops are asynchronous yet.
34141574554SRiver Riddle void GpuAsyncRegionPass::runOnOperation() {
34241574554SRiver Riddle   if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
343db7129a0SChristian Sigg     return signalPassFailure();
344d9adde5aSChristian Sigg 
345a79b26dbSChristian Sigg   // Collect gpu.wait ops that we can move out of async.execute regions.
34641574554SRiver Riddle   getOperation().getRegion().walk(DeferWaitCallback());
347f03826f8SChristian Sigg   // Makes each !gpu.async.token returned from async.execute op have single use.
34841574554SRiver Riddle   getOperation().getRegion().walk(SingleTokenUseCallback());
349db7129a0SChristian Sigg }
350db7129a0SChristian Sigg 
35158ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
352db7129a0SChristian Sigg   return std::make_unique<GpuAsyncRegionPass>();
353db7129a0SChristian Sigg }
354