1db7129a0SChristian Sigg //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===// 2db7129a0SChristian Sigg // 3db7129a0SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4db7129a0SChristian Sigg // See https://llvm.org/LICENSE.txt for license information. 5db7129a0SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6db7129a0SChristian Sigg // 7db7129a0SChristian Sigg //===----------------------------------------------------------------------===// 8db7129a0SChristian Sigg // 9db7129a0SChristian Sigg // This file implements the GPU dialect pattern rewriters that make GPU op 10db7129a0SChristian Sigg // within a region execute asynchronously. 11db7129a0SChristian Sigg // 12db7129a0SChristian Sigg //===----------------------------------------------------------------------===// 13db7129a0SChristian Sigg 14039b969bSMichele Scuttari #include "mlir/Dialect/GPU/Transforms/Passes.h" 1567d0d7acSMichele Scuttari 1667d0d7acSMichele Scuttari #include "mlir/Dialect/Async/IR/Async.h" 1767d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h" 1867d0d7acSMichele Scuttari #include "mlir/Dialect/GPU/IR/GPUDialect.h" 19*bc29fc93SPetr Kurapov #include "mlir/Dialect/GPU/Utils/GPUUtils.h" 20db7129a0SChristian Sigg #include "mlir/IR/Builders.h" 214d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 22db7129a0SChristian Sigg #include "mlir/IR/PatternMatch.h" 23db7129a0SChristian Sigg #include "mlir/IR/SymbolTable.h" 24fc367dfaSMahesh Ravishankar #include "mlir/Interfaces/SideEffectInterfaces.h" 25db7129a0SChristian Sigg #include "mlir/Support/LLVM.h" 26db7129a0SChristian Sigg #include "mlir/Transforms/RegionUtils.h" 27d9adde5aSChristian Sigg #include "llvm/ADT/TypeSwitch.h" 28db7129a0SChristian Sigg 2967d0d7acSMichele Scuttari namespace mlir { 3067d0d7acSMichele Scuttari #define GEN_PASS_DEF_GPUASYNCREGIONPASS 3167d0d7acSMichele Scuttari #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" 3267d0d7acSMichele Scuttari } // namespace mlir 3367d0d7acSMichele Scuttari 34db7129a0SChristian Sigg using namespace mlir; 3567d0d7acSMichele Scuttari 36db7129a0SChristian Sigg namespace { 3767d0d7acSMichele Scuttari class GpuAsyncRegionPass 3867d0d7acSMichele Scuttari : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> { 39d9adde5aSChristian Sigg struct ThreadTokenCallback; 40d9adde5aSChristian Sigg struct DeferWaitCallback; 41f03826f8SChristian Sigg struct SingleTokenUseCallback; 4241574554SRiver Riddle void runOnOperation() override; 43db7129a0SChristian Sigg }; 44db7129a0SChristian Sigg } // namespace 45db7129a0SChristian Sigg 46fe7c0d90SRiver Riddle static bool isTerminator(Operation *op) { 47fe7c0d90SRiver Riddle return op->mightHaveTrait<OpTrait::IsTerminator>(); 48fe7c0d90SRiver Riddle } 49fc367dfaSMahesh Ravishankar static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); } 50d9adde5aSChristian Sigg 51db7129a0SChristian Sigg // Region walk callback which makes GPU ops implementing the AsyncOpInterface 52db7129a0SChristian Sigg // execute asynchronously. 53d9adde5aSChristian Sigg struct GpuAsyncRegionPass::ThreadTokenCallback { 54d9adde5aSChristian Sigg ThreadTokenCallback(MLIRContext &context) : builder(&context) {} 55d9adde5aSChristian Sigg 560b21371eSChristian Sigg WalkResult operator()(Block *block) { 570b21371eSChristian Sigg for (Operation &op : make_early_inc_range(*block)) { 580b21371eSChristian Sigg if (failed(visit(&op))) 590b21371eSChristian Sigg return WalkResult::interrupt(); 600b21371eSChristian Sigg } 610b21371eSChristian Sigg return WalkResult::advance(); 620b21371eSChristian Sigg } 630b21371eSChristian Sigg 640b21371eSChristian Sigg private: 65db7129a0SChristian Sigg // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to 66db7129a0SChristian Sigg // create a current token (unless it already exists), and 'thread' that token 67db7129a0SChristian Sigg // through the `op` so that it executes asynchronously. 68db7129a0SChristian Sigg // 69db7129a0SChristian Sigg // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to 70d9adde5aSChristian Sigg // host-synchronize execution. A `!gpu.async.token` will therefore only be 71d9adde5aSChristian Sigg // used inside of its block and GPU execution will always synchronize with 72d9adde5aSChristian Sigg // the host at block boundaries. 730b21371eSChristian Sigg LogicalResult visit(Operation *op) { 74db7129a0SChristian Sigg if (isa<gpu::LaunchOp>(op)) 75db7129a0SChristian Sigg return op->emitOpError("replace with gpu.launch_func first"); 760b21371eSChristian Sigg if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) { 770b21371eSChristian Sigg if (currentToken) 780b21371eSChristian Sigg waitOp.addAsyncDependency(currentToken); 7910c04f46SRiver Riddle currentToken = waitOp.getAsyncToken(); 800b21371eSChristian Sigg return success(); 810b21371eSChristian Sigg } 82db7129a0SChristian Sigg builder.setInsertionPoint(op); 83db7129a0SChristian Sigg if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op)) 84db7129a0SChristian Sigg return rewriteAsyncOp(asyncOp); // Replace GPU op with async version. 85db7129a0SChristian Sigg if (!currentToken) 86db7129a0SChristian Sigg return success(); 87db7129a0SChristian Sigg // Insert host synchronization before terminator or op with side effects. 88d9adde5aSChristian Sigg if (isTerminator(op) || hasSideEffects(op)) 89db7129a0SChristian Sigg currentToken = createWaitOp(op->getLoc(), Type(), {currentToken}); 90db7129a0SChristian Sigg return success(); 91db7129a0SChristian Sigg } 92db7129a0SChristian Sigg 93db7129a0SChristian Sigg // Replaces asyncOp with a clone that returns a token. 94db7129a0SChristian Sigg LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) { 95db7129a0SChristian Sigg auto *op = asyncOp.getOperation(); 964c372a35SChristian Sigg auto tokenType = builder.getType<gpu::AsyncTokenType>(); 974c372a35SChristian Sigg 98db7129a0SChristian Sigg // If there is no current token, insert a `gpu.wait async` without 99db7129a0SChristian Sigg // dependencies to create one. 100db7129a0SChristian Sigg if (!currentToken) 101db7129a0SChristian Sigg currentToken = createWaitOp(op->getLoc(), tokenType, {}); 102db7129a0SChristian Sigg asyncOp.addAsyncDependency(currentToken); 103db7129a0SChristian Sigg 1040b21371eSChristian Sigg // Return early if op returns a token already. 1050b21371eSChristian Sigg currentToken = asyncOp.getAsyncToken(); 1060b21371eSChristian Sigg if (currentToken) 1070b21371eSChristian Sigg return success(); 1080b21371eSChristian Sigg 109db7129a0SChristian Sigg // Clone the op to return a token in addition to the other results. 110a79b26dbSChristian Sigg SmallVector<Type, 1> resultTypes; 111db7129a0SChristian Sigg resultTypes.reserve(1 + op->getNumResults()); 112db7129a0SChristian Sigg copy(op->getResultTypes(), std::back_inserter(resultTypes)); 113a79b26dbSChristian Sigg resultTypes.push_back(tokenType); 114bbe5bf17SMehdi Amini auto *newOp = Operation::create( 115bbe5bf17SMehdi Amini op->getLoc(), op->getName(), resultTypes, op->getOperands(), 116bbe5bf17SMehdi Amini op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), 117a0d019fcSChristian Sigg op->getSuccessors(), op->getNumRegions()); 118a0d019fcSChristian Sigg 119a0d019fcSChristian Sigg // Clone regions into new op. 1204d67b278SJeff Niu IRMapping mapping; 121a0d019fcSChristian Sigg for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions())) 122a0d019fcSChristian Sigg std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping); 123db7129a0SChristian Sigg 124db7129a0SChristian Sigg // Replace the op with the async clone. 125db7129a0SChristian Sigg auto results = newOp->getResults(); 126a79b26dbSChristian Sigg currentToken = results.back(); 127db7129a0SChristian Sigg builder.insert(newOp); 128a79b26dbSChristian Sigg op->replaceAllUsesWith(results.drop_back()); 129db7129a0SChristian Sigg op->erase(); 130db7129a0SChristian Sigg 131db7129a0SChristian Sigg return success(); 132db7129a0SChristian Sigg } 133db7129a0SChristian Sigg 134db7129a0SChristian Sigg Value createWaitOp(Location loc, Type resultType, ValueRange operands) { 13510c04f46SRiver Riddle return builder.create<gpu::WaitOp>(loc, resultType, operands) 13610c04f46SRiver Riddle .getAsyncToken(); 137db7129a0SChristian Sigg } 138db7129a0SChristian Sigg 139db7129a0SChristian Sigg OpBuilder builder; 1404c372a35SChristian Sigg 141db7129a0SChristian Sigg // The token that represents the current asynchronous dependency. It's valid 142db7129a0SChristian Sigg // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op. 143db7129a0SChristian Sigg // In between, each gpu::AsyncOpInterface depends on the current token and 144db7129a0SChristian Sigg // produces the new one. 145db7129a0SChristian Sigg Value currentToken = {}; 146db7129a0SChristian Sigg }; 147db7129a0SChristian Sigg 148f03826f8SChristian Sigg /// Erases `executeOp` and returns a clone with additional `results`. 149f03826f8SChristian Sigg async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp, 150f03826f8SChristian Sigg ValueRange results) { 151f03826f8SChristian Sigg // Add values to async.yield op. 152f03826f8SChristian Sigg Operation *yieldOp = executeOp.getBody()->getTerminator(); 153f03826f8SChristian Sigg yieldOp->insertOperands(yieldOp->getNumOperands(), results); 154f03826f8SChristian Sigg 155f03826f8SChristian Sigg // Construct new result type list with additional types. 156f03826f8SChristian Sigg SmallVector<Type, 2> resultTypes; 157f03826f8SChristian Sigg resultTypes.reserve(executeOp.getNumResults() + results.size()); 158f03826f8SChristian Sigg transform(executeOp.getResultTypes(), std::back_inserter(resultTypes), 159f03826f8SChristian Sigg [](Type type) { 160f03826f8SChristian Sigg // Extract value type from !async.value. 1615550c821STres Popp if (auto valueType = dyn_cast<async::ValueType>(type)) 162f03826f8SChristian Sigg return valueType.getValueType(); 1635550c821STres Popp assert(isa<async::TokenType>(type) && "expected token type"); 164f03826f8SChristian Sigg return type; 165f03826f8SChristian Sigg }); 166f03826f8SChristian Sigg transform(results, std::back_inserter(resultTypes), 167f03826f8SChristian Sigg [](Value value) { return value.getType(); }); 168f03826f8SChristian Sigg 169f03826f8SChristian Sigg // Clone executeOp with the extra results. 170f03826f8SChristian Sigg OpBuilder builder(executeOp); 171f03826f8SChristian Sigg auto newOp = builder.create<async::ExecuteOp>( 172f03826f8SChristian Sigg executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/, 173a5aa7836SRiver Riddle executeOp.getDependencies(), executeOp.getBodyOperands()); 1744d67b278SJeff Niu IRMapping mapper; 175f03826f8SChristian Sigg newOp.getRegion().getBlocks().clear(); 176f03826f8SChristian Sigg executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper); 177f03826f8SChristian Sigg 178f03826f8SChristian Sigg // Replace executeOp with cloned one. 179f03826f8SChristian Sigg executeOp.getOperation()->replaceAllUsesWith( 180f03826f8SChristian Sigg newOp.getResults().drop_back(results.size())); 181f03826f8SChristian Sigg executeOp.erase(); 182f03826f8SChristian Sigg 183f03826f8SChristian Sigg return newOp; 184f03826f8SChristian Sigg } 185f03826f8SChristian Sigg 186d9adde5aSChristian Sigg // Callback for `async.execute` ops which tries to push the contained 187d9adde5aSChristian Sigg // synchronous `gpu.wait` op to the dependencies of the `async.execute`. 188d9adde5aSChristian Sigg struct GpuAsyncRegionPass::DeferWaitCallback { 189d9adde5aSChristian Sigg // If the `executeOp`s token is used only in `async.execute` or `async.await` 190d9adde5aSChristian Sigg // ops, add the region's last `gpu.wait` op to the worklist if it is 191d9adde5aSChristian Sigg // synchronous and is the last op with side effects. 192d9adde5aSChristian Sigg void operator()(async::ExecuteOp executeOp) { 193a5aa7836SRiver Riddle if (!areAllUsersExecuteOrAwait(executeOp.getToken())) 194d9adde5aSChristian Sigg return; 195d9adde5aSChristian Sigg // async.execute's region is currently restricted to one block. 196d9adde5aSChristian Sigg for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) { 197d9adde5aSChristian Sigg if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) { 19810c04f46SRiver Riddle if (!waitOp.getAsyncToken()) 199d9adde5aSChristian Sigg worklist.push_back(waitOp); 200d9adde5aSChristian Sigg return; 201d9adde5aSChristian Sigg } 202d9adde5aSChristian Sigg if (hasSideEffects(&op)) 203d9adde5aSChristian Sigg return; 204d9adde5aSChristian Sigg } 205d9adde5aSChristian Sigg } 206d9adde5aSChristian Sigg 207d9adde5aSChristian Sigg // The destructor performs the actual rewrite work. 208d9adde5aSChristian Sigg ~DeferWaitCallback() { 209d9adde5aSChristian Sigg for (size_t i = 0; i < worklist.size(); ++i) { 210d9adde5aSChristian Sigg auto waitOp = worklist[i]; 2110bf4a82aSChristian Sigg auto executeOp = waitOp->getParentOfType<async::ExecuteOp>(); 212d9adde5aSChristian Sigg 213f03826f8SChristian Sigg // Erase `gpu.wait` and return async dependencies from execute op instead. 214a5aa7836SRiver Riddle SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies(); 215d9adde5aSChristian Sigg waitOp.erase(); 216f03826f8SChristian Sigg executeOp = addExecuteResults(executeOp, dependencies); 217d9adde5aSChristian Sigg 218d9adde5aSChristian Sigg // Add the async dependency to each user of the `async.execute` token. 219f03826f8SChristian Sigg auto asyncTokens = executeOp.getResults().take_back(dependencies.size()); 220a5aa7836SRiver Riddle SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(), 221a5aa7836SRiver Riddle executeOp.getToken().user_end()); 2226e1ac68aSVitaly Buka for (Operation *user : users) 223d9adde5aSChristian Sigg addAsyncDependencyAfter(asyncTokens, user); 224d9adde5aSChristian Sigg } 225d9adde5aSChristian Sigg } 226d9adde5aSChristian Sigg 227d9adde5aSChristian Sigg private: 228d9adde5aSChristian Sigg // Returns whether all token users are either 'async.execute' or 'async.await' 229d9adde5aSChristian Sigg // ops. This is used as a requirement for pushing 'gpu.wait' ops from a 230d9adde5aSChristian Sigg // 'async.execute' body to it's users. Specifically, we do not allow 231d9adde5aSChristian Sigg // terminator users, because it could mean that the `async.execute` is inside 232d9adde5aSChristian Sigg // control flow code. 233d9adde5aSChristian Sigg static bool areAllUsersExecuteOrAwait(Value token) { 234f03826f8SChristian Sigg return !token.use_empty() && 235971b8525SJakub Kuderski llvm::all_of(token.getUsers(), 236971b8525SJakub Kuderski llvm::IsaPred<async::ExecuteOp, async::AwaitOp>); 237d9adde5aSChristian Sigg } 238d9adde5aSChristian Sigg 239d9adde5aSChristian Sigg // Add the `asyncToken` as dependency as needed after `op`. 240d9adde5aSChristian Sigg void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) { 241d9adde5aSChristian Sigg OpBuilder builder(op->getContext()); 242d9adde5aSChristian Sigg auto loc = op->getLoc(); 243d9adde5aSChristian Sigg 244d9adde5aSChristian Sigg Block::iterator it; 245d9adde5aSChristian Sigg SmallVector<Value, 1> tokens; 246d9adde5aSChristian Sigg tokens.reserve(asyncTokens.size()); 247d9adde5aSChristian Sigg TypeSwitch<Operation *>(op) 248d9adde5aSChristian Sigg .Case<async::AwaitOp>([&](auto awaitOp) { 249d9adde5aSChristian Sigg // Add async.await ops to wait for the !gpu.async.tokens. 250d9adde5aSChristian Sigg builder.setInsertionPointAfter(op); 251d9adde5aSChristian Sigg for (auto asyncToken : asyncTokens) 252d9adde5aSChristian Sigg tokens.push_back( 253a5aa7836SRiver Riddle builder.create<async::AwaitOp>(loc, asyncToken).getResult()); 254d9adde5aSChristian Sigg // Set `it` after the inserted async.await ops. 255d9adde5aSChristian Sigg it = builder.getInsertionPoint(); 256d9adde5aSChristian Sigg }) 257d9adde5aSChristian Sigg .Case<async::ExecuteOp>([&](auto executeOp) { 258d9adde5aSChristian Sigg // Set `it` to the beginning of the region and add asyncTokens to the 259d9adde5aSChristian Sigg // async.execute operands. 260d9adde5aSChristian Sigg it = executeOp.getBody()->begin(); 261a5aa7836SRiver Riddle executeOp.getBodyOperandsMutable().append(asyncTokens); 262d9adde5aSChristian Sigg SmallVector<Type, 1> tokenTypes( 263d9adde5aSChristian Sigg asyncTokens.size(), builder.getType<gpu::AsyncTokenType>()); 264e084679fSRiver Riddle SmallVector<Location, 1> tokenLocs(asyncTokens.size(), 265e084679fSRiver Riddle executeOp.getLoc()); 266e084679fSRiver Riddle copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs), 267d9adde5aSChristian Sigg std::back_inserter(tokens)); 268d9adde5aSChristian Sigg }); 269d9adde5aSChristian Sigg 270d9adde5aSChristian Sigg // Advance `it` to terminator or op with side-effects. 271d9adde5aSChristian Sigg it = std::find_if(it, Block::iterator(), [](Operation &op) { 272d9adde5aSChristian Sigg return isTerminator(&op) || hasSideEffects(&op); 273d9adde5aSChristian Sigg }); 274d9adde5aSChristian Sigg 275d9adde5aSChristian Sigg // If `op` implements the AsyncOpInterface, add `token` to the list of async 276d9adde5aSChristian Sigg // dependencies. 277d9adde5aSChristian Sigg if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) { 278d9adde5aSChristian Sigg for (auto token : tokens) 279d9adde5aSChristian Sigg asyncOp.addAsyncDependency(token); 280d9adde5aSChristian Sigg return; 281d9adde5aSChristian Sigg } 282d9adde5aSChristian Sigg 283d9adde5aSChristian Sigg // Otherwise, insert a gpu.wait before 'it'. 284d9adde5aSChristian Sigg builder.setInsertionPoint(it->getBlock(), it); 285d9adde5aSChristian Sigg auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens); 286d9adde5aSChristian Sigg 287d9adde5aSChristian Sigg // If the new waitOp is at the end of an async.execute region, add it to the 288d9adde5aSChristian Sigg // worklist. 'operator()(executeOp)' would do the same, but this is faster. 289d9adde5aSChristian Sigg auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp()); 290a5aa7836SRiver Riddle if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) && 291d9adde5aSChristian Sigg !it->getNextNode()) 292d9adde5aSChristian Sigg worklist.push_back(waitOp); 293d9adde5aSChristian Sigg } 294d9adde5aSChristian Sigg 295d9adde5aSChristian Sigg SmallVector<gpu::WaitOp, 8> worklist; 296d9adde5aSChristian Sigg }; 297d9adde5aSChristian Sigg 298f03826f8SChristian Sigg // Callback for `async.execute` ops which repeats !gpu.async.token results 299f03826f8SChristian Sigg // so that each of them is only used once. 300f03826f8SChristian Sigg struct GpuAsyncRegionPass::SingleTokenUseCallback { 301f03826f8SChristian Sigg void operator()(async::ExecuteOp executeOp) { 302f03826f8SChristian Sigg // Extract !gpu.async.token results which have multiple uses. 303a5aa7836SRiver Riddle auto multiUseResults = llvm::make_filter_range( 304a5aa7836SRiver Riddle executeOp.getBodyResults(), [](OpResult result) { 305f03826f8SChristian Sigg if (result.use_empty() || result.hasOneUse()) 306f03826f8SChristian Sigg return false; 3075550c821STres Popp auto valueType = dyn_cast<async::ValueType>(result.getType()); 308f03826f8SChristian Sigg return valueType && 3095550c821STres Popp isa<gpu::AsyncTokenType>(valueType.getValueType()); 310f03826f8SChristian Sigg }); 311f03826f8SChristian Sigg if (multiUseResults.empty()) 312f03826f8SChristian Sigg return; 313f03826f8SChristian Sigg 314f03826f8SChristian Sigg // Indices within !async.execute results (i.e. without the async.token). 315f03826f8SChristian Sigg SmallVector<int, 4> indices; 316f03826f8SChristian Sigg transform(multiUseResults, std::back_inserter(indices), 317f03826f8SChristian Sigg [](OpResult result) { 318f03826f8SChristian Sigg return result.getResultNumber() - 1; // Index without token. 319f03826f8SChristian Sigg }); 320f03826f8SChristian Sigg 321f03826f8SChristian Sigg for (auto index : indices) { 322a5aa7836SRiver Riddle assert(!executeOp.getBodyResults()[index].getUses().empty()); 323f03826f8SChristian Sigg // Repeat async.yield token result, one for each use after the first one. 324a5aa7836SRiver Riddle auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses()); 325f03826f8SChristian Sigg auto count = std::distance(uses.begin(), uses.end()); 326f03826f8SChristian Sigg auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator()); 327f03826f8SChristian Sigg SmallVector<Value, 4> operands(count, yieldOp.getOperand(index)); 328f03826f8SChristian Sigg executeOp = addExecuteResults(executeOp, operands); 329f03826f8SChristian Sigg // Update 'uses' to refer to the new executeOp. 330a5aa7836SRiver Riddle uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses()); 331a5aa7836SRiver Riddle auto results = executeOp.getBodyResults().take_back(count); 332f03826f8SChristian Sigg for (auto pair : llvm::zip(uses, results)) 333f03826f8SChristian Sigg std::get<0>(pair).set(std::get<1>(pair)); 334f03826f8SChristian Sigg } 335f03826f8SChristian Sigg } 336f03826f8SChristian Sigg }; 337f03826f8SChristian Sigg 338db7129a0SChristian Sigg // Replaces synchronous GPU ops in the op's region with asynchronous ones and 339db7129a0SChristian Sigg // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential 340db7129a0SChristian Sigg // execution semantics and that no GPU ops are asynchronous yet. 34141574554SRiver Riddle void GpuAsyncRegionPass::runOnOperation() { 34241574554SRiver Riddle if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted()) 343db7129a0SChristian Sigg return signalPassFailure(); 344d9adde5aSChristian Sigg 345a79b26dbSChristian Sigg // Collect gpu.wait ops that we can move out of async.execute regions. 34641574554SRiver Riddle getOperation().getRegion().walk(DeferWaitCallback()); 347f03826f8SChristian Sigg // Makes each !gpu.async.token returned from async.execute op have single use. 34841574554SRiver Riddle getOperation().getRegion().walk(SingleTokenUseCallback()); 349db7129a0SChristian Sigg } 350db7129a0SChristian Sigg 35158ceae95SRiver Riddle std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() { 352db7129a0SChristian Sigg return std::make_unique<GpuAsyncRegionPass>(); 353db7129a0SChristian Sigg } 354