xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp (revision bc29fc937c6cb4a210f80c93c79fc6ed97c801f8)
1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Async/IR/Async.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19 #include "mlir/Dialect/GPU/Utils/GPUUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Interfaces/SideEffectInterfaces.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/RegionUtils.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_GPUASYNCREGIONPASS
31 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 class GpuAsyncRegionPass
38     : public impl::GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
39   struct ThreadTokenCallback;
40   struct DeferWaitCallback;
41   struct SingleTokenUseCallback;
42   void runOnOperation() override;
43 };
44 } // namespace
45 
46 static bool isTerminator(Operation *op) {
47   return op->mightHaveTrait<OpTrait::IsTerminator>();
48 }
49 static bool hasSideEffects(Operation *op) { return !isMemoryEffectFree(op); }
50 
51 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
52 // execute asynchronously.
53 struct GpuAsyncRegionPass::ThreadTokenCallback {
54   ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
55 
56   WalkResult operator()(Block *block) {
57     for (Operation &op : make_early_inc_range(*block)) {
58       if (failed(visit(&op)))
59         return WalkResult::interrupt();
60     }
61     return WalkResult::advance();
62   }
63 
64 private:
65   // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
66   // create a current token (unless it already exists), and 'thread' that token
67   // through the `op` so that it executes asynchronously.
68   //
69   // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
70   // host-synchronize execution. A `!gpu.async.token` will therefore only be
71   // used inside of its block and GPU execution will always synchronize with
72   // the host at block boundaries.
73   LogicalResult visit(Operation *op) {
74     if (isa<gpu::LaunchOp>(op))
75       return op->emitOpError("replace with gpu.launch_func first");
76     if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
77       if (currentToken)
78         waitOp.addAsyncDependency(currentToken);
79       currentToken = waitOp.getAsyncToken();
80       return success();
81     }
82     builder.setInsertionPoint(op);
83     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
84       return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
85     if (!currentToken)
86       return success();
87     // Insert host synchronization before terminator or op with side effects.
88     if (isTerminator(op) || hasSideEffects(op))
89       currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
90     return success();
91   }
92 
93   // Replaces asyncOp with a clone that returns a token.
94   LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
95     auto *op = asyncOp.getOperation();
96     auto tokenType = builder.getType<gpu::AsyncTokenType>();
97 
98     // If there is no current token, insert a `gpu.wait async` without
99     // dependencies to create one.
100     if (!currentToken)
101       currentToken = createWaitOp(op->getLoc(), tokenType, {});
102     asyncOp.addAsyncDependency(currentToken);
103 
104     // Return early if op returns a token already.
105     currentToken = asyncOp.getAsyncToken();
106     if (currentToken)
107       return success();
108 
109     // Clone the op to return a token in addition to the other results.
110     SmallVector<Type, 1> resultTypes;
111     resultTypes.reserve(1 + op->getNumResults());
112     copy(op->getResultTypes(), std::back_inserter(resultTypes));
113     resultTypes.push_back(tokenType);
114     auto *newOp = Operation::create(
115         op->getLoc(), op->getName(), resultTypes, op->getOperands(),
116         op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
117         op->getSuccessors(), op->getNumRegions());
118 
119     // Clone regions into new op.
120     IRMapping mapping;
121     for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
122       std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
123 
124     // Replace the op with the async clone.
125     auto results = newOp->getResults();
126     currentToken = results.back();
127     builder.insert(newOp);
128     op->replaceAllUsesWith(results.drop_back());
129     op->erase();
130 
131     return success();
132   }
133 
134   Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
135     return builder.create<gpu::WaitOp>(loc, resultType, operands)
136         .getAsyncToken();
137   }
138 
139   OpBuilder builder;
140 
141   // The token that represents the current asynchronous dependency. It's valid
142   // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
143   // In between, each gpu::AsyncOpInterface depends on the current token and
144   // produces the new one.
145   Value currentToken = {};
146 };
147 
148 /// Erases `executeOp` and returns a clone with additional `results`.
149 async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
150                                    ValueRange results) {
151   // Add values to async.yield op.
152   Operation *yieldOp = executeOp.getBody()->getTerminator();
153   yieldOp->insertOperands(yieldOp->getNumOperands(), results);
154 
155   // Construct new result type list with additional types.
156   SmallVector<Type, 2> resultTypes;
157   resultTypes.reserve(executeOp.getNumResults() + results.size());
158   transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
159             [](Type type) {
160               // Extract value type from !async.value.
161               if (auto valueType = dyn_cast<async::ValueType>(type))
162                 return valueType.getValueType();
163               assert(isa<async::TokenType>(type) && "expected token type");
164               return type;
165             });
166   transform(results, std::back_inserter(resultTypes),
167             [](Value value) { return value.getType(); });
168 
169   // Clone executeOp with the extra results.
170   OpBuilder builder(executeOp);
171   auto newOp = builder.create<async::ExecuteOp>(
172       executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
173       executeOp.getDependencies(), executeOp.getBodyOperands());
174   IRMapping mapper;
175   newOp.getRegion().getBlocks().clear();
176   executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
177 
178   // Replace executeOp with cloned one.
179   executeOp.getOperation()->replaceAllUsesWith(
180       newOp.getResults().drop_back(results.size()));
181   executeOp.erase();
182 
183   return newOp;
184 }
185 
186 // Callback for `async.execute` ops which tries to push the contained
187 // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
188 struct GpuAsyncRegionPass::DeferWaitCallback {
189   // If the `executeOp`s token is used only in `async.execute` or `async.await`
190   // ops, add the region's last `gpu.wait` op to the worklist if it is
191   // synchronous and is the last op with side effects.
192   void operator()(async::ExecuteOp executeOp) {
193     if (!areAllUsersExecuteOrAwait(executeOp.getToken()))
194       return;
195     // async.execute's region is currently restricted to one block.
196     for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
197       if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
198         if (!waitOp.getAsyncToken())
199           worklist.push_back(waitOp);
200         return;
201       }
202       if (hasSideEffects(&op))
203         return;
204     }
205   }
206 
207   // The destructor performs the actual rewrite work.
208   ~DeferWaitCallback() {
209     for (size_t i = 0; i < worklist.size(); ++i) {
210       auto waitOp = worklist[i];
211       auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
212 
213       // Erase `gpu.wait` and return async dependencies from execute op instead.
214       SmallVector<Value, 4> dependencies = waitOp.getAsyncDependencies();
215       waitOp.erase();
216       executeOp = addExecuteResults(executeOp, dependencies);
217 
218       // Add the async dependency to each user of the `async.execute` token.
219       auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
220       SmallVector<Operation *, 4> users(executeOp.getToken().user_begin(),
221                                         executeOp.getToken().user_end());
222       for (Operation *user : users)
223         addAsyncDependencyAfter(asyncTokens, user);
224     }
225   }
226 
227 private:
228   // Returns whether all token users are either 'async.execute' or 'async.await'
229   // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
230   // 'async.execute' body to it's users. Specifically, we do not allow
231   // terminator users, because it could mean that the `async.execute` is inside
232   // control flow code.
233   static bool areAllUsersExecuteOrAwait(Value token) {
234     return !token.use_empty() &&
235            llvm::all_of(token.getUsers(),
236                         llvm::IsaPred<async::ExecuteOp, async::AwaitOp>);
237   }
238 
239   // Add the `asyncToken` as dependency as needed after `op`.
240   void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
241     OpBuilder builder(op->getContext());
242     auto loc = op->getLoc();
243 
244     Block::iterator it;
245     SmallVector<Value, 1> tokens;
246     tokens.reserve(asyncTokens.size());
247     TypeSwitch<Operation *>(op)
248         .Case<async::AwaitOp>([&](auto awaitOp) {
249           // Add async.await ops to wait for the !gpu.async.tokens.
250           builder.setInsertionPointAfter(op);
251           for (auto asyncToken : asyncTokens)
252             tokens.push_back(
253                 builder.create<async::AwaitOp>(loc, asyncToken).getResult());
254           // Set `it` after the inserted async.await ops.
255           it = builder.getInsertionPoint();
256         })
257         .Case<async::ExecuteOp>([&](auto executeOp) {
258           // Set `it` to the beginning of the region and add asyncTokens to the
259           // async.execute operands.
260           it = executeOp.getBody()->begin();
261           executeOp.getBodyOperandsMutable().append(asyncTokens);
262           SmallVector<Type, 1> tokenTypes(
263               asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
264           SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
265                                              executeOp.getLoc());
266           copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
267                std::back_inserter(tokens));
268         });
269 
270     // Advance `it` to terminator or op with side-effects.
271     it = std::find_if(it, Block::iterator(), [](Operation &op) {
272       return isTerminator(&op) || hasSideEffects(&op);
273     });
274 
275     // If `op` implements the AsyncOpInterface, add `token` to the list of async
276     // dependencies.
277     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
278       for (auto token : tokens)
279         asyncOp.addAsyncDependency(token);
280       return;
281     }
282 
283     // Otherwise, insert a gpu.wait before 'it'.
284     builder.setInsertionPoint(it->getBlock(), it);
285     auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
286 
287     // If the new waitOp is at the end of an async.execute region, add it to the
288     // worklist. 'operator()(executeOp)' would do the same, but this is faster.
289     auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
290     if (executeOp && areAllUsersExecuteOrAwait(executeOp.getToken()) &&
291         !it->getNextNode())
292       worklist.push_back(waitOp);
293   }
294 
295   SmallVector<gpu::WaitOp, 8> worklist;
296 };
297 
298 // Callback for `async.execute` ops which repeats !gpu.async.token results
299 // so that each of them is only used once.
300 struct GpuAsyncRegionPass::SingleTokenUseCallback {
301   void operator()(async::ExecuteOp executeOp) {
302     // Extract !gpu.async.token results which have multiple uses.
303     auto multiUseResults = llvm::make_filter_range(
304         executeOp.getBodyResults(), [](OpResult result) {
305           if (result.use_empty() || result.hasOneUse())
306             return false;
307           auto valueType = dyn_cast<async::ValueType>(result.getType());
308           return valueType &&
309                  isa<gpu::AsyncTokenType>(valueType.getValueType());
310         });
311     if (multiUseResults.empty())
312       return;
313 
314     // Indices within !async.execute results (i.e. without the async.token).
315     SmallVector<int, 4> indices;
316     transform(multiUseResults, std::back_inserter(indices),
317               [](OpResult result) {
318                 return result.getResultNumber() - 1; // Index without token.
319               });
320 
321     for (auto index : indices) {
322       assert(!executeOp.getBodyResults()[index].getUses().empty());
323       // Repeat async.yield token result, one for each use after the first one.
324       auto uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
325       auto count = std::distance(uses.begin(), uses.end());
326       auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
327       SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
328       executeOp = addExecuteResults(executeOp, operands);
329       // Update 'uses' to refer to the new executeOp.
330       uses = llvm::drop_begin(executeOp.getBodyResults()[index].getUses());
331       auto results = executeOp.getBodyResults().take_back(count);
332       for (auto pair : llvm::zip(uses, results))
333         std::get<0>(pair).set(std::get<1>(pair));
334     }
335   }
336 };
337 
338 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
339 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
340 // execution semantics and that no GPU ops are asynchronous yet.
341 void GpuAsyncRegionPass::runOnOperation() {
342   if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
343     return signalPassFailure();
344 
345   // Collect gpu.wait ops that we can move out of async.execute regions.
346   getOperation().getRegion().walk(DeferWaitCallback());
347   // Makes each !gpu.async.token returned from async.execute op have single use.
348   getOperation().getRegion().walk(SingleTokenUseCallback());
349 }
350 
351 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
352   return std::make_unique<GpuAsyncRegionPass>();
353 }
354