xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp (revision 9f74e6e6157bc4d63a28385c7c0a50506bb8a737)
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Transforms/Passes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/Support/ErrorHandling.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 static vector::CombiningKind
31 convertReductionKind(gpu::AllReduceOperation mode) {
32   switch (mode) {
33 #define MAP_CASE(X)                                                            \
34   case gpu::AllReduceOperation::X:                                             \
35     return vector::CombiningKind::X
36 
37     MAP_CASE(ADD);
38     MAP_CASE(MUL);
39     MAP_CASE(MINUI);
40     MAP_CASE(MINSI);
41     MAP_CASE(MINF);
42     MAP_CASE(MAXSI);
43     MAP_CASE(MAXUI);
44     MAP_CASE(MAXF);
45     MAP_CASE(AND);
46     MAP_CASE(OR);
47     MAP_CASE(XOR);
48     MAP_CASE(MINIMUMF);
49     MAP_CASE(MAXIMUMF);
50 
51 #undef MAP_CASE
52   }
53 
54   llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
55 }
56 
57 struct GpuAllReduceRewriter {
58   using AccumulatorFactory = std::function<Value(Value, Value)>;
59 
60   GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
61                        PatternRewriter &rewriter)
62       : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
63         loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
64         indexType(IndexType::get(reduceOp.getContext())),
65         int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
66 
67   /// Creates an all_reduce across the workgroup.
68   ///
69   /// First reduce the elements within a subgroup. The first invocation of each
70   /// subgroup writes the intermediate result to workgroup memory. After
71   /// synchronizing the workgroup, the first subgroup reduces the values from
72   /// workgroup memory. The result is broadcasted to all invocations through
73   /// workgroup memory.
74   ///
75   ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
76   ///     cf.cond_br %is_first_lane, ^then1, ^continue1
77   ///   ^then1:
78   ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
79   ///     cf.br ^continue1
80   ///   ^continue1:
81   ///     gpu.barrier
82   ///     %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
83   ///     cf.cond_br %is_valid_subgroup, ^then2, ^continue2
84   ///   ^then2:
85   ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
86   ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
87   ///     store %all_reduce, %workgroup_buffer[%zero]
88   ///     llvm.br ^continue2
89   ///   ^continue2:
90   ///     gpu.barrier
91   ///     %result = load %workgroup_buffer[%zero]
92   ///     return %result
93   ///
94   void rewrite() {
95     rewriter.setInsertionPoint(reduceOp);
96 
97     // Compute linear invocation index and workgroup size.
98     Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
99     Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
100     Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
101     Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
102     Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
103     Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
104     Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
105     Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
106     Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
107     Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
108     Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
109     Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
110 
111     // Compute lane id (invocation id withing the subgroup).
112     Value subgroupMask =
113         create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
114     Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
115     Value isFirstLane =
116         create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
117                               create<arith::ConstantIntOp>(0, int32Type));
118 
119     Value numThreadsWithSmallerSubgroupId =
120         create<arith::SubIOp>(invocationIdx, laneId);
121     // The number of active invocations starting from the current subgroup.
122     // The consumers do not require the value to be clamped to the size of the
123     // subgroup.
124     Value activeWidth =
125         create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
126 
127     // Create factory for op which accumulates to values.
128     AccumulatorFactory accumFactory = getFactory();
129     assert(accumFactory && "failed to create accumulator factory");
130 
131     // Reduce elements within each subgroup to produce the intermediate results.
132     Value subgroupReduce = createSubgroupReduce(
133         activeWidth, laneId, reduceOp.getValue(), accumFactory);
134 
135     // Add workgroup buffer to parent function for intermediate result.
136     Value buffer = createWorkgroupBuffer();
137 
138     // Write the intermediate results to workgroup memory, using the first lane
139     // of each subgroup.
140     createPredicatedBlock(isFirstLane, [&] {
141       Value subgroupId = getDivideBySubgroupSize(invocationIdx);
142       Value index = create<arith::IndexCastOp>(indexType, subgroupId);
143       create<memref::StoreOp>(subgroupReduce, buffer, index);
144     });
145     create<gpu::BarrierOp>();
146 
147     // Compute number of active subgroups.
148     Value biasedBlockSize =
149         create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
150     Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
151     Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
152                                                   invocationIdx, numSubgroups);
153 
154     // Use the first numSubgroups invocations to reduce the intermediate results
155     // from workgroup memory. The final result is written to workgroup memory
156     // again.
157     Value zero = create<arith::ConstantIndexOp>(0);
158     createPredicatedBlock(isValidSubgroup, [&] {
159       Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
160       Value value = create<memref::LoadOp>(valueType, buffer, index);
161       Value result =
162           createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
163       create<memref::StoreOp>(result, buffer, zero);
164     });
165 
166     // Synchronize workgroup and load result from workgroup memory.
167     create<gpu::BarrierOp>();
168     Value result = create<memref::LoadOp>(valueType, buffer, zero);
169 
170     rewriter.replaceOp(reduceOp, result);
171   }
172 
173 private:
174   // Shortcut to create an op from rewriter using loc as the first argument.
175   template <typename T, typename... Args>
176   T create(Args... args) {
177     return rewriter.create<T>(loc, std::forward<Args>(args)...);
178   }
179 
180   // Creates dimension op of type T, with the result casted to int32.
181   template <typename T>
182   Value getDimOp(gpu::Dimension dimension) {
183     Value dim = create<T>(indexType, dimension);
184     return create<arith::IndexCastOp>(int32Type, dim);
185   }
186 
187   /// Adds type to funcOp's workgroup attributions.
188   Value createWorkgroupBuffer() {
189     // TODO: Pick a proper location for the attribution.
190     auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
191         funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
192     auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
193                                       workgroupMemoryAddressSpace);
194     return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
195   }
196 
197   /// Returns an accumulator factory using either the op attribute or the body
198   /// region.
199   AccumulatorFactory getFactory() {
200     auto &body = reduceOp.getBody();
201     if (!body.empty())
202       return getFactory(body);
203     auto opAttr = reduceOp.getOp();
204     if (opAttr)
205       return getFactory(*opAttr);
206     return AccumulatorFactory();
207   }
208 
209   /// Returns an accumulator factory that clones the body. The body's entry
210   /// block is expected to have 2 arguments. The gpu.yield return the
211   /// accumulated value of the same type.
212   AccumulatorFactory getFactory(Region &body) {
213     return [&body, this](Value lhs, Value rhs) -> Value {
214       Block *block = rewriter.getInsertionBlock();
215       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
216 
217       // Insert accumulator body between split block.
218       IRMapping mapping;
219       mapping.map(body.getArgument(0), lhs);
220       mapping.map(body.getArgument(1), rhs);
221       rewriter.cloneRegionBefore(body, *split->getParent(),
222                                  split->getIterator(), mapping);
223 
224       // Add branch before inserted body, into body.
225       block = block->getNextNode();
226       create<cf::BranchOp>(block, ValueRange());
227 
228       // Replace all gpu.yield ops with branch out of body.
229       for (; block != split; block = block->getNextNode()) {
230         Operation *terminator = block->getTerminator();
231         if (!isa<gpu::YieldOp>(terminator))
232           continue;
233         rewriter.setInsertionPointToEnd(block);
234         rewriter.replaceOpWithNewOp<cf::BranchOp>(
235             terminator, split, ValueRange(terminator->getOperand(0)));
236       }
237 
238       // Return accumulator result.
239       rewriter.setInsertionPointToStart(split);
240       return split->addArgument(lhs.getType(), lhs.getLoc());
241     };
242   }
243 
244   /// Returns an accumulator factory that creates an op specified by opName.
245   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
246     return [opName, this](Value lhs, Value rhs) {
247       return vector::makeArithReduction(rewriter, loc,
248                                         convertReductionKind(opName), lhs, rhs);
249     };
250   }
251 
252   /// Creates an if-block skeleton and calls the two factories to generate the
253   /// ops in the `then` and `else` block..
254   ///
255   ///     llvm.cond_br %condition, ^then, ^continue
256   ///   ^then:
257   ///     %then_operands = `thenOpsFactory()`
258   ///     llvm.br ^continue(%then_operands)
259   ///   ^else:
260   ///     %else_operands = `elseOpsFactory()`
261   ///     llvm.br ^continue(%else_operands)
262   ///   ^continue(%block_operands):
263   ///
264   template <typename ThenOpsFactory, typename ElseOpsFactory>
265   void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
266                 ElseOpsFactory &&elseOpsFactory) {
267     Block *currentBlock = rewriter.getInsertionBlock();
268     auto currentPoint = rewriter.getInsertionPoint();
269 
270     Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
271     Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
272     Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
273 
274     rewriter.setInsertionPointToEnd(currentBlock);
275     create<cf::CondBranchOp>(condition, thenBlock,
276                              /*trueOperands=*/ArrayRef<Value>(), elseBlock,
277                              /*falseOperands=*/ArrayRef<Value>());
278 
279     rewriter.setInsertionPointToStart(thenBlock);
280     auto thenOperands = thenOpsFactory();
281     create<cf::BranchOp>(continueBlock, thenOperands);
282 
283     rewriter.setInsertionPointToStart(elseBlock);
284     auto elseOperands = elseOpsFactory();
285     create<cf::BranchOp>(continueBlock, elseOperands);
286 
287     assert(thenOperands.size() == elseOperands.size());
288     rewriter.setInsertionPointToStart(continueBlock);
289     for (auto operand : thenOperands)
290       continueBlock->addArgument(operand.getType(), operand.getLoc());
291   }
292 
293   /// Shortcut for createIf with empty else block and no block operands.
294   template <typename Factory>
295   void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
296     static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
297                   "predicatedOpsFactory should not return any value");
298     createIf(
299         condition,
300         [&] {
301           predicatedOpsFactory();
302           return ArrayRef<Value>();
303         },
304         [&] { return ArrayRef<Value>(); });
305   }
306 
307   /// Creates a reduction across the first activeWidth lanes of a subgroup, or
308   /// the entire subgroup if activeWidth is larger than the subgroup width.
309   /// The first lane returns the result, all others return values are undefined.
310   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
311                              AccumulatorFactory &accumFactory) {
312     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
313     Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
314                                                     activeWidth, subgroupSize);
315     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
316 
317     createIf(
318         isPartialSubgroup,
319         // Generate reduction over a (potentially) partial subgroup.
320         [&] {
321           Value value = operand;
322           // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
323           // lane is within the active range. The accumulated value is available
324           // in the first lane.
325           for (int i = 1; i < kSubgroupSize; i <<= 1) {
326             Value offset = create<arith::ConstantIntOp>(i, int32Type);
327             auto shuffleOp = create<gpu::ShuffleOp>(
328                 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
329             // Skip the accumulation if the shuffle op read from a lane outside
330             // of the active range.
331             createIf(
332                 shuffleOp.getResult(1),
333                 [&] {
334                   return SmallVector<Value, 1>{
335                       accumFactory(value, shuffleOp.getResult(0))};
336                 },
337                 [&] { return llvm::ArrayRef(value); });
338             value = rewriter.getInsertionBlock()->getArgument(0);
339           }
340           return SmallVector<Value, 1>{value};
341         },
342         // Generate a reduction over the entire subgroup. This is a
343         // specialization of the above reduction with unconditional
344         // accumulation.
345         [&] {
346           Value value = operand;
347           for (int i = 1; i < kSubgroupSize; i <<= 1) {
348             Value offset = create<arith::ConstantIntOp>(i, int32Type);
349             auto shuffleOp =
350                 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
351                                        gpu::ShuffleMode::XOR);
352             value = accumFactory(value, shuffleOp.getResult(0));
353           }
354           return SmallVector<Value, 1>{value};
355         });
356     return rewriter.getInsertionBlock()->getArgument(0);
357   }
358 
359   /// Returns value divided by the subgroup size (i.e. 32).
360   Value getDivideBySubgroupSize(Value value) {
361     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
362     return create<arith::DivSIOp>(int32Type, value, subgroupSize);
363   }
364 
365   gpu::GPUFuncOp funcOp;
366   gpu::AllReduceOp reduceOp;
367   PatternRewriter &rewriter;
368 
369   Location loc;
370   Type valueType;
371   Type indexType;
372   IntegerType int32Type;
373 
374   static constexpr int kSubgroupSize = 32;
375 };
376 
377 struct GpuAllReduceRewrite : public RewritePattern {
378   explicit GpuAllReduceRewrite(MLIRContext *context)
379       : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
380 
381   LogicalResult matchAndRewrite(Operation *op,
382                                 PatternRewriter &rewriter) const override {
383     auto funcOp = cast<gpu::GPUFuncOp>(op);
384 
385     SmallVector<gpu::AllReduceOp> reduceOps;
386     auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
387       if (!reduceOp.getUniform())
388         return WalkResult::interrupt();
389 
390       reduceOps.emplace_back(reduceOp);
391       return WalkResult::advance();
392     };
393 
394     if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
395       return rewriter.notifyMatchFailure(
396           op, "Non uniform reductions are not supported yet.");
397 
398     for (gpu::AllReduceOp reduceOp : reduceOps)
399       GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
400 
401     return success();
402   }
403 };
404 } // namespace
405 
406 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
407   patterns.add<GpuAllReduceRewrite>(patterns.getContext());
408 }
409