xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===- AsyncRegionRewriter.cpp - Implementation of GPU async rewriters ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect pattern rewriters that make GPU op
10 // within a region execute asynchronously.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Async/IR/Async.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Transforms/Passes.h"
18 #include "mlir/Dialect/GPU/Transforms/Utils.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/SymbolTable.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Transforms/RegionUtils.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace mlir;
28 namespace {
29 class GpuAsyncRegionPass : public GpuAsyncRegionPassBase<GpuAsyncRegionPass> {
30   struct ThreadTokenCallback;
31   struct DeferWaitCallback;
32   struct SingleTokenUseCallback;
33   void runOnOperation() override;
34 };
35 } // namespace
36 
37 static bool isTerminator(Operation *op) {
38   return op->mightHaveTrait<OpTrait::IsTerminator>();
39 }
40 static bool hasSideEffects(Operation *op) {
41   return !MemoryEffectOpInterface::hasNoEffect(op);
42 }
43 
44 // Region walk callback which makes GPU ops implementing the AsyncOpInterface
45 // execute asynchronously.
46 struct GpuAsyncRegionPass::ThreadTokenCallback {
47   ThreadTokenCallback(MLIRContext &context) : builder(&context) {}
48 
49   WalkResult operator()(Block *block) {
50     for (Operation &op : make_early_inc_range(*block)) {
51       if (failed(visit(&op)))
52         return WalkResult::interrupt();
53     }
54     return WalkResult::advance();
55   }
56 
57 private:
58   // If `op` implements the AsyncOpInterface, insert a `gpu.wait async` to
59   // create a current token (unless it already exists), and 'thread' that token
60   // through the `op` so that it executes asynchronously.
61   //
62   // If `op` is a terminator or an op with side-effects, insert a `gpu.wait` to
63   // host-synchronize execution. A `!gpu.async.token` will therefore only be
64   // used inside of its block and GPU execution will always synchronize with
65   // the host at block boundaries.
66   LogicalResult visit(Operation *op) {
67     if (isa<gpu::LaunchOp>(op))
68       return op->emitOpError("replace with gpu.launch_func first");
69     if (auto waitOp = llvm::dyn_cast<gpu::WaitOp>(op)) {
70       if (currentToken)
71         waitOp.addAsyncDependency(currentToken);
72       currentToken = waitOp.asyncToken();
73       return success();
74     }
75     builder.setInsertionPoint(op);
76     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(op))
77       return rewriteAsyncOp(asyncOp); // Replace GPU op with async version.
78     if (!currentToken)
79       return success();
80     // Insert host synchronization before terminator or op with side effects.
81     if (isTerminator(op) || hasSideEffects(op))
82       currentToken = createWaitOp(op->getLoc(), Type(), {currentToken});
83     return success();
84   }
85 
86   // Replaces asyncOp with a clone that returns a token.
87   LogicalResult rewriteAsyncOp(gpu::AsyncOpInterface asyncOp) {
88     auto *op = asyncOp.getOperation();
89     auto tokenType = builder.getType<gpu::AsyncTokenType>();
90 
91     // If there is no current token, insert a `gpu.wait async` without
92     // dependencies to create one.
93     if (!currentToken)
94       currentToken = createWaitOp(op->getLoc(), tokenType, {});
95     asyncOp.addAsyncDependency(currentToken);
96 
97     // Return early if op returns a token already.
98     currentToken = asyncOp.getAsyncToken();
99     if (currentToken)
100       return success();
101 
102     // Clone the op to return a token in addition to the other results.
103     SmallVector<Type, 1> resultTypes;
104     resultTypes.reserve(1 + op->getNumResults());
105     copy(op->getResultTypes(), std::back_inserter(resultTypes));
106     resultTypes.push_back(tokenType);
107     auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes,
108                                     op->getOperands(), op->getAttrDictionary(),
109                                     op->getSuccessors(), op->getNumRegions());
110 
111     // Clone regions into new op.
112     BlockAndValueMapping mapping;
113     for (auto pair : llvm::zip_first(op->getRegions(), newOp->getRegions()))
114       std::get<0>(pair).cloneInto(&std::get<1>(pair), mapping);
115 
116     // Replace the op with the async clone.
117     auto results = newOp->getResults();
118     currentToken = results.back();
119     builder.insert(newOp);
120     op->replaceAllUsesWith(results.drop_back());
121     op->erase();
122 
123     return success();
124   }
125 
126   Value createWaitOp(Location loc, Type resultType, ValueRange operands) {
127     return builder.create<gpu::WaitOp>(loc, resultType, operands).asyncToken();
128   }
129 
130   OpBuilder builder;
131 
132   // The token that represents the current asynchronous dependency. It's valid
133   // range starts with a `gpu.wait async` op, and ends with a `gpu.wait` op.
134   // In between, each gpu::AsyncOpInterface depends on the current token and
135   // produces the new one.
136   Value currentToken = {};
137 };
138 
139 /// Erases `executeOp` and returns a clone with additional `results`.
140 async::ExecuteOp addExecuteResults(async::ExecuteOp executeOp,
141                                    ValueRange results) {
142   // Add values to async.yield op.
143   Operation *yieldOp = executeOp.getBody()->getTerminator();
144   yieldOp->insertOperands(yieldOp->getNumOperands(), results);
145 
146   // Construct new result type list with additional types.
147   SmallVector<Type, 2> resultTypes;
148   resultTypes.reserve(executeOp.getNumResults() + results.size());
149   transform(executeOp.getResultTypes(), std::back_inserter(resultTypes),
150             [](Type type) {
151               // Extract value type from !async.value.
152               if (auto valueType = type.dyn_cast<async::ValueType>())
153                 return valueType.getValueType();
154               assert(type.isa<async::TokenType>() && "expected token type");
155               return type;
156             });
157   transform(results, std::back_inserter(resultTypes),
158             [](Value value) { return value.getType(); });
159 
160   // Clone executeOp with the extra results.
161   OpBuilder builder(executeOp);
162   auto newOp = builder.create<async::ExecuteOp>(
163       executeOp.getLoc(), TypeRange{resultTypes}.drop_front() /*drop token*/,
164       executeOp.dependencies(), executeOp.operands());
165   BlockAndValueMapping mapper;
166   newOp.getRegion().getBlocks().clear();
167   executeOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
168 
169   // Replace executeOp with cloned one.
170   executeOp.getOperation()->replaceAllUsesWith(
171       newOp.getResults().drop_back(results.size()));
172   executeOp.erase();
173 
174   return newOp;
175 }
176 
177 // Callback for `async.execute` ops which tries to push the contained
178 // synchronous `gpu.wait` op to the dependencies of the `async.execute`.
179 struct GpuAsyncRegionPass::DeferWaitCallback {
180   // If the `executeOp`s token is used only in `async.execute` or `async.await`
181   // ops, add the region's last `gpu.wait` op to the worklist if it is
182   // synchronous and is the last op with side effects.
183   void operator()(async::ExecuteOp executeOp) {
184     if (!areAllUsersExecuteOrAwait(executeOp.token()))
185       return;
186     // async.execute's region is currently restricted to one block.
187     for (auto &op : llvm::reverse(executeOp.getBody()->without_terminator())) {
188       if (auto waitOp = dyn_cast<gpu::WaitOp>(op)) {
189         if (!waitOp.asyncToken())
190           worklist.push_back(waitOp);
191         return;
192       }
193       if (hasSideEffects(&op))
194         return;
195     }
196   }
197 
198   // The destructor performs the actual rewrite work.
199   ~DeferWaitCallback() {
200     for (size_t i = 0; i < worklist.size(); ++i) {
201       auto waitOp = worklist[i];
202       auto executeOp = waitOp->getParentOfType<async::ExecuteOp>();
203 
204       // Erase `gpu.wait` and return async dependencies from execute op instead.
205       SmallVector<Value, 4> dependencies = waitOp.asyncDependencies();
206       waitOp.erase();
207       executeOp = addExecuteResults(executeOp, dependencies);
208 
209       // Add the async dependency to each user of the `async.execute` token.
210       auto asyncTokens = executeOp.getResults().take_back(dependencies.size());
211       SmallVector<Operation *, 4> users(executeOp.token().user_begin(),
212                                         executeOp.token().user_end());
213       for (Operation *user : users)
214         addAsyncDependencyAfter(asyncTokens, user);
215     }
216   }
217 
218 private:
219   // Returns whether all token users are either 'async.execute' or 'async.await'
220   // ops. This is used as a requirement for pushing 'gpu.wait' ops from a
221   // 'async.execute' body to it's users. Specifically, we do not allow
222   // terminator users, because it could mean that the `async.execute` is inside
223   // control flow code.
224   static bool areAllUsersExecuteOrAwait(Value token) {
225     return !token.use_empty() &&
226            llvm::all_of(token.getUsers(), [](Operation *user) {
227              return isa<async::ExecuteOp, async::AwaitOp>(user);
228            });
229   }
230 
231   // Add the `asyncToken` as dependency as needed after `op`.
232   void addAsyncDependencyAfter(ValueRange asyncTokens, Operation *op) {
233     OpBuilder builder(op->getContext());
234     auto loc = op->getLoc();
235 
236     Block::iterator it;
237     SmallVector<Value, 1> tokens;
238     tokens.reserve(asyncTokens.size());
239     TypeSwitch<Operation *>(op)
240         .Case<async::AwaitOp>([&](auto awaitOp) {
241           // Add async.await ops to wait for the !gpu.async.tokens.
242           builder.setInsertionPointAfter(op);
243           for (auto asyncToken : asyncTokens)
244             tokens.push_back(
245                 builder.create<async::AwaitOp>(loc, asyncToken).result());
246           // Set `it` after the inserted async.await ops.
247           it = builder.getInsertionPoint();
248         })
249         .Case<async::ExecuteOp>([&](auto executeOp) {
250           // Set `it` to the beginning of the region and add asyncTokens to the
251           // async.execute operands.
252           it = executeOp.getBody()->begin();
253           executeOp.operandsMutable().append(asyncTokens);
254           SmallVector<Type, 1> tokenTypes(
255               asyncTokens.size(), builder.getType<gpu::AsyncTokenType>());
256           SmallVector<Location, 1> tokenLocs(asyncTokens.size(),
257                                              executeOp.getLoc());
258           copy(executeOp.getBody()->addArguments(tokenTypes, tokenLocs),
259                std::back_inserter(tokens));
260         });
261 
262     // Advance `it` to terminator or op with side-effects.
263     it = std::find_if(it, Block::iterator(), [](Operation &op) {
264       return isTerminator(&op) || hasSideEffects(&op);
265     });
266 
267     // If `op` implements the AsyncOpInterface, add `token` to the list of async
268     // dependencies.
269     if (auto asyncOp = dyn_cast<gpu::AsyncOpInterface>(*it)) {
270       for (auto token : tokens)
271         asyncOp.addAsyncDependency(token);
272       return;
273     }
274 
275     // Otherwise, insert a gpu.wait before 'it'.
276     builder.setInsertionPoint(it->getBlock(), it);
277     auto waitOp = builder.create<gpu::WaitOp>(loc, Type{}, tokens);
278 
279     // If the new waitOp is at the end of an async.execute region, add it to the
280     // worklist. 'operator()(executeOp)' would do the same, but this is faster.
281     auto executeOp = dyn_cast<async::ExecuteOp>(it->getParentOp());
282     if (executeOp && areAllUsersExecuteOrAwait(executeOp.token()) &&
283         !it->getNextNode())
284       worklist.push_back(waitOp);
285   }
286 
287   SmallVector<gpu::WaitOp, 8> worklist;
288 };
289 
290 // Callback for `async.execute` ops which repeats !gpu.async.token results
291 // so that each of them is only used once.
292 struct GpuAsyncRegionPass::SingleTokenUseCallback {
293   void operator()(async::ExecuteOp executeOp) {
294     // Extract !gpu.async.token results which have multiple uses.
295     auto multiUseResults =
296         llvm::make_filter_range(executeOp.results(), [](OpResult result) {
297           if (result.use_empty() || result.hasOneUse())
298             return false;
299           auto valueType = result.getType().dyn_cast<async::ValueType>();
300           return valueType &&
301                  valueType.getValueType().isa<gpu::AsyncTokenType>();
302         });
303     if (multiUseResults.empty())
304       return;
305 
306     // Indices within !async.execute results (i.e. without the async.token).
307     SmallVector<int, 4> indices;
308     transform(multiUseResults, std::back_inserter(indices),
309               [](OpResult result) {
310                 return result.getResultNumber() - 1; // Index without token.
311               });
312 
313     for (auto index : indices) {
314       assert(!executeOp.results()[index].getUses().empty());
315       // Repeat async.yield token result, one for each use after the first one.
316       auto uses = llvm::drop_begin(executeOp.results()[index].getUses());
317       auto count = std::distance(uses.begin(), uses.end());
318       auto yieldOp = cast<async::YieldOp>(executeOp.getBody()->getTerminator());
319       SmallVector<Value, 4> operands(count, yieldOp.getOperand(index));
320       executeOp = addExecuteResults(executeOp, operands);
321       // Update 'uses' to refer to the new executeOp.
322       uses = llvm::drop_begin(executeOp.results()[index].getUses());
323       auto results = executeOp.results().take_back(count);
324       for (auto pair : llvm::zip(uses, results))
325         std::get<0>(pair).set(std::get<1>(pair));
326     }
327   }
328 };
329 
330 // Replaces synchronous GPU ops in the op's region with asynchronous ones and
331 // inserts the necessary synchronization (as gpu.wait ops). Assumes sequential
332 // execution semantics and that no GPU ops are asynchronous yet.
333 void GpuAsyncRegionPass::runOnOperation() {
334   if (getOperation()->walk(ThreadTokenCallback(getContext())).wasInterrupted())
335     return signalPassFailure();
336 
337   // Collect gpu.wait ops that we can move out of async.execute regions.
338   getOperation().getRegion().walk(DeferWaitCallback());
339   // Makes each !gpu.async.token returned from async.execute op have single use.
340   getOperation().getRegion().walk(SingleTokenUseCallback());
341 }
342 
343 std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGpuAsyncRegionPass() {
344   return std::make_unique<GpuAsyncRegionPass>();
345 }
346