xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp (revision c0345b4648608b44071bbb6b012b24bd07c3d29e)
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Transforms/Passes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/Support/ErrorHandling.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 struct GpuAllReduceRewriter {
31   using AccumulatorFactory = std::function<Value(Value, Value)>;
32 
GpuAllReduceRewriter__anon813c75b20111::GpuAllReduceRewriter33   GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
34                        PatternRewriter &rewriter)
35       : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
36         loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
37         indexType(IndexType::get(reduceOp.getContext())),
38         int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
39 
40   /// Creates an all_reduce across the workgroup.
41   ///
42   /// First reduce the elements within a subgroup. The first invocation of each
43   /// subgroup writes the intermediate result to workgroup memory. After
44   /// synchronizing the workgroup, the first subgroup reduces the values from
45   /// workgroup memory. The result is broadcasted to all invocations through
46   /// workgroup memory.
47   ///
48   ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
49   ///     cf.cond_br %is_first_lane, ^then1, ^continue1
50   ///   ^then1:
51   ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
52   ///     cf.br ^continue1
53   ///   ^continue1:
54   ///     gpu.barrier
55   ///     %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
56   ///     cf.cond_br %is_valid_subgroup, ^then2, ^continue2
57   ///   ^then2:
58   ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
59   ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
60   ///     store %all_reduce, %workgroup_buffer[%zero]
61   ///     llvm.br ^continue2
62   ///   ^continue2:
63   ///     gpu.barrier
64   ///     %result = load %workgroup_buffer[%zero]
65   ///     return %result
66   ///
rewrite__anon813c75b20111::GpuAllReduceRewriter67   void rewrite() {
68     rewriter.setInsertionPoint(reduceOp);
69 
70     // Compute linear invocation index and workgroup size.
71     Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
72     Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
73     Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
74     Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
75     Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
76     Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
77     Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
78     Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
79     Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
80     Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
81     Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
82     Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
83 
84     // Compute lane id (invocation id withing the subgroup).
85     Value subgroupMask =
86         create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
87     Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
88     Value isFirstLane =
89         create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
90                               create<arith::ConstantIntOp>(0, int32Type));
91 
92     Value numThreadsWithSmallerSubgroupId =
93         create<arith::SubIOp>(invocationIdx, laneId);
94     // The number of active invocations starting from the current subgroup.
95     // The consumers do not require the value to be clamped to the size of the
96     // subgroup.
97     Value activeWidth =
98         create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
99 
100     // Create factory for op which accumulates to values.
101     AccumulatorFactory accumFactory = getFactory();
102     assert(accumFactory && "failed to create accumulator factory");
103 
104     // Reduce elements within each subgroup to produce the intermediate results.
105     Value subgroupReduce = createSubgroupReduce(
106         activeWidth, laneId, reduceOp.getValue(), accumFactory);
107 
108     // Add workgroup buffer to parent function for intermediate result.
109     Value buffer = createWorkgroupBuffer();
110 
111     // Write the intermediate results to workgroup memory, using the first lane
112     // of each subgroup.
113     createPredicatedBlock(isFirstLane, [&] {
114       Value subgroupId = getDivideBySubgroupSize(invocationIdx);
115       Value index = create<arith::IndexCastOp>(indexType, subgroupId);
116       create<memref::StoreOp>(subgroupReduce, buffer, index);
117     });
118     create<gpu::BarrierOp>();
119 
120     // Compute number of active subgroups.
121     Value biasedBlockSize =
122         create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
123     Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
124     Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
125                                                   invocationIdx, numSubgroups);
126 
127     // Use the first numSubgroups invocations to reduce the intermediate results
128     // from workgroup memory. The final result is written to workgroup memory
129     // again.
130     Value zero = create<arith::ConstantIndexOp>(0);
131     createPredicatedBlock(isValidSubgroup, [&] {
132       Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
133       Value value = create<memref::LoadOp>(valueType, buffer, index);
134       Value result =
135           createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
136       create<memref::StoreOp>(result, buffer, zero);
137     });
138 
139     // Synchronize workgroup and load result from workgroup memory.
140     create<gpu::BarrierOp>();
141     Value result = create<memref::LoadOp>(valueType, buffer, zero);
142 
143     rewriter.replaceOp(reduceOp, result);
144   }
145 
146 private:
147   // Shortcut to create an op from rewriter using loc as the first argument.
148   template <typename T, typename... Args>
create__anon813c75b20111::GpuAllReduceRewriter149   T create(Args... args) {
150     return rewriter.create<T>(loc, std::forward<Args>(args)...);
151   }
152 
153   // Creates dimension op of type T, with the result casted to int32.
154   template <typename T>
getDimOp__anon813c75b20111::GpuAllReduceRewriter155   Value getDimOp(gpu::Dimension dimension) {
156     Value dim = create<T>(indexType, dimension);
157     return create<arith::IndexCastOp>(int32Type, dim);
158   }
159 
160   /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon813c75b20111::GpuAllReduceRewriter161   Value createWorkgroupBuffer() {
162     // TODO: Pick a proper location for the attribution.
163     auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
164         funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
165     auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
166                                       workgroupMemoryAddressSpace);
167     return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
168   }
169 
170   /// Returns an accumulator factory using either the op attribute or the body
171   /// region.
getFactory__anon813c75b20111::GpuAllReduceRewriter172   AccumulatorFactory getFactory() {
173     auto &body = reduceOp.getBody();
174     if (!body.empty())
175       return getFactory(body);
176     auto opAttr = reduceOp.getOp();
177     if (opAttr)
178       return getFactory(*opAttr);
179     return AccumulatorFactory();
180   }
181 
182   /// Returns an accumulator factory that clones the body. The body's entry
183   /// block is expected to have 2 arguments. The gpu.yield return the
184   /// accumulated value of the same type.
getFactory__anon813c75b20111::GpuAllReduceRewriter185   AccumulatorFactory getFactory(Region &body) {
186     return [&body, this](Value lhs, Value rhs) -> Value {
187       Block *block = rewriter.getInsertionBlock();
188       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
189 
190       // Insert accumulator body between split block.
191       IRMapping mapping;
192       mapping.map(body.getArgument(0), lhs);
193       mapping.map(body.getArgument(1), rhs);
194       rewriter.cloneRegionBefore(body, *split->getParent(),
195                                  split->getIterator(), mapping);
196 
197       // Add branch before inserted body, into body.
198       block = block->getNextNode();
199       create<cf::BranchOp>(block, ValueRange());
200 
201       // Replace all gpu.yield ops with branch out of body.
202       for (; block != split; block = block->getNextNode()) {
203         Operation *terminator = block->getTerminator();
204         if (!isa<gpu::YieldOp>(terminator))
205           continue;
206         rewriter.setInsertionPointToEnd(block);
207         rewriter.replaceOpWithNewOp<cf::BranchOp>(
208             terminator, split, ValueRange(terminator->getOperand(0)));
209       }
210 
211       // Return accumulator result.
212       rewriter.setInsertionPointToStart(split);
213       return split->addArgument(lhs.getType(), lhs.getLoc());
214     };
215   }
216 
217   /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon813c75b20111::GpuAllReduceRewriter218   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
219     return [opName, this](Value lhs, Value rhs) {
220       return vector::makeArithReduction(rewriter, loc,
221                                         convertReductionKind(opName), lhs, rhs);
222     };
223   }
224 
225   /// Creates an if-block skeleton and calls the two factories to generate the
226   /// ops in the `then` and `else` block..
227   ///
228   ///     llvm.cond_br %condition, ^then, ^continue
229   ///   ^then:
230   ///     %then_operands = `thenOpsFactory()`
231   ///     llvm.br ^continue(%then_operands)
232   ///   ^else:
233   ///     %else_operands = `elseOpsFactory()`
234   ///     llvm.br ^continue(%else_operands)
235   ///   ^continue(%block_operands):
236   ///
237   template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon813c75b20111::GpuAllReduceRewriter238   void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
239                 ElseOpsFactory &&elseOpsFactory) {
240     Block *currentBlock = rewriter.getInsertionBlock();
241     auto currentPoint = rewriter.getInsertionPoint();
242 
243     Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
244     Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
245     Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
246 
247     rewriter.setInsertionPointToEnd(currentBlock);
248     create<cf::CondBranchOp>(condition, thenBlock,
249                              /*trueOperands=*/ArrayRef<Value>(), elseBlock,
250                              /*falseOperands=*/ArrayRef<Value>());
251 
252     rewriter.setInsertionPointToStart(thenBlock);
253     auto thenOperands = thenOpsFactory();
254     create<cf::BranchOp>(continueBlock, thenOperands);
255 
256     rewriter.setInsertionPointToStart(elseBlock);
257     auto elseOperands = elseOpsFactory();
258     create<cf::BranchOp>(continueBlock, elseOperands);
259 
260     assert(thenOperands.size() == elseOperands.size());
261     rewriter.setInsertionPointToStart(continueBlock);
262     for (auto operand : thenOperands)
263       continueBlock->addArgument(operand.getType(), operand.getLoc());
264   }
265 
266   /// Shortcut for createIf with empty else block and no block operands.
267   template <typename Factory>
createPredicatedBlock__anon813c75b20111::GpuAllReduceRewriter268   void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
269     static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
270                   "predicatedOpsFactory should not return any value");
271     createIf(
272         condition,
273         [&] {
274           predicatedOpsFactory();
275           return ArrayRef<Value>();
276         },
277         [&] { return ArrayRef<Value>(); });
278   }
279 
280   /// Creates a reduction across the first activeWidth lanes of a subgroup, or
281   /// the entire subgroup if activeWidth is larger than the subgroup width.
282   /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon813c75b20111::GpuAllReduceRewriter283   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
284                              AccumulatorFactory &accumFactory) {
285     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
286     Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
287                                                     activeWidth, subgroupSize);
288     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
289 
290     createIf(
291         isPartialSubgroup,
292         // Generate reduction over a (potentially) partial subgroup.
293         [&] {
294           Value value = operand;
295           // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
296           // lane is within the active range. The accumulated value is available
297           // in the first lane.
298           for (int i = 1; i < kSubgroupSize; i <<= 1) {
299             Value offset = create<arith::ConstantIntOp>(i, int32Type);
300             auto shuffleOp = create<gpu::ShuffleOp>(
301                 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
302             // Skip the accumulation if the shuffle op read from a lane outside
303             // of the active range.
304             createIf(
305                 shuffleOp.getResult(1),
306                 [&] {
307                   return SmallVector<Value, 1>{
308                       accumFactory(value, shuffleOp.getResult(0))};
309                 },
310                 [&] { return llvm::ArrayRef(value); });
311             value = rewriter.getInsertionBlock()->getArgument(0);
312           }
313           return SmallVector<Value, 1>{value};
314         },
315         // Generate a reduction over the entire subgroup. This is a
316         // specialization of the above reduction with unconditional
317         // accumulation.
318         [&] {
319           Value value = operand;
320           for (int i = 1; i < kSubgroupSize; i <<= 1) {
321             Value offset = create<arith::ConstantIntOp>(i, int32Type);
322             auto shuffleOp =
323                 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
324                                        gpu::ShuffleMode::XOR);
325             value = accumFactory(value, shuffleOp.getResult(0));
326           }
327           return SmallVector<Value, 1>{value};
328         });
329     return rewriter.getInsertionBlock()->getArgument(0);
330   }
331 
332   /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon813c75b20111::GpuAllReduceRewriter333   Value getDivideBySubgroupSize(Value value) {
334     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
335     return create<arith::DivSIOp>(int32Type, value, subgroupSize);
336   }
337 
338   gpu::GPUFuncOp funcOp;
339   gpu::AllReduceOp reduceOp;
340   PatternRewriter &rewriter;
341 
342   Location loc;
343   Type valueType;
344   Type indexType;
345   IntegerType int32Type;
346 
347   static constexpr int kSubgroupSize = 32;
348 };
349 
350 struct GpuAllReduceRewrite : public RewritePattern {
GpuAllReduceRewrite__anon813c75b20111::GpuAllReduceRewrite351   explicit GpuAllReduceRewrite(MLIRContext *context)
352       : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
353 
matchAndRewrite__anon813c75b20111::GpuAllReduceRewrite354   LogicalResult matchAndRewrite(Operation *op,
355                                 PatternRewriter &rewriter) const override {
356     auto funcOp = cast<gpu::GPUFuncOp>(op);
357 
358     SmallVector<gpu::AllReduceOp> reduceOps;
359     auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
360       if (!reduceOp.getUniform())
361         return WalkResult::interrupt();
362 
363       reduceOps.emplace_back(reduceOp);
364       return WalkResult::advance();
365     };
366 
367     if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
368       return rewriter.notifyMatchFailure(
369           op, "Non uniform reductions are not supported yet.");
370 
371     for (gpu::AllReduceOp reduceOp : reduceOps)
372       GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
373 
374     return success();
375   }
376 };
377 } // namespace
378 
populateGpuAllReducePatterns(RewritePatternSet & patterns)379 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
380   patterns.add<GpuAllReduceRewrite>(patterns.getContext());
381 }
382