xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp (revision dc4e913be9c3d1c37f66348d4b5047a107499b53)
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/GPU/Passes.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 struct GpuAllReduceRewriter {
28   using AccumulatorFactory = std::function<Value(Value, Value)>;
29 
30   GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_,
31                        PatternRewriter &rewriter_)
32       : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
33         loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
34         indexType(IndexType::get(reduceOp.getContext())),
35         int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
36 
37   /// Creates an all_reduce across the workgroup.
38   ///
39   /// First reduce the elements within a subgroup. The first invocation of each
40   /// subgroup writes the intermediate result to workgroup memory. After
41   /// synchronizing the workgroup, the first subgroup reduces the values from
42   /// workgroup memory. The result is broadcasted to all invocations through
43   /// workgroup memory.
44   ///
45   ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
46   ///     cond_br %is_first_lane, ^then1, ^continue1
47   ///   ^then1:
48   ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
49   ///     br ^continue1
50   ///   ^continue1:
51   ///     gpu.barrier
52   ///     %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups
53   ///     cond_br %is_valid_subgroup, ^then2, ^continue2
54   ///   ^then2:
55   ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
56   ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
57   ///     store %all_reduce, %workgroup_buffer[%zero]
58   ///     llvm.br ^continue2
59   ///   ^continue2:
60   ///     gpu.barrier
61   ///     %result = load %workgroup_buffer[%zero]
62   ///     return %result
63   ///
64   void rewrite() {
65     rewriter.setInsertionPoint(reduceOp);
66 
67     // Compute linear invocation index and workgroup size.
68     Value dimX = getDimOp<gpu::BlockDimOp>("x");
69     Value dimY = getDimOp<gpu::BlockDimOp>("y");
70     Value dimZ = getDimOp<gpu::BlockDimOp>("z");
71     Value tidX = getDimOp<gpu::ThreadIdOp>("x");
72     Value tidY = getDimOp<gpu::ThreadIdOp>("y");
73     Value tidZ = getDimOp<gpu::ThreadIdOp>("z");
74     Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY);
75     Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY);
76     Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX);
77     Value tmp4 = create<MulIOp>(int32Type, dimX, dimY);
78     Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX);
79     Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ);
80 
81     // Compute lane id (invocation id withing the subgroup).
82     Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type);
83     Value laneId = create<AndOp>(invocationIdx, subgroupMask);
84     Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId,
85                                        create<ConstantIntOp>(0, int32Type));
86 
87     Value numThreadsWithSmallerSubgroupId =
88         create<SubIOp>(invocationIdx, laneId);
89     // The number of active invocations starting from the current subgroup.
90     // The consumers do not require the value to be clamped to the size of the
91     // subgroup.
92     Value activeWidth =
93         create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
94 
95     // Create factory for op which accumulates to values.
96     AccumulatorFactory accumFactory = getFactory();
97     assert(accumFactory && "failed to create accumulator factory");
98 
99     // Reduce elements within each subgroup to produce the intermediate results.
100     Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
101                                                 reduceOp.value(), accumFactory);
102 
103     // Add workgroup buffer to parent function for intermediate result.
104     Value buffer = createWorkgroupBuffer();
105 
106     // Write the intermediate results to workgroup memory, using the first lane
107     // of each subgroup.
108     createPredicatedBlock(isFirstLane, [&] {
109       Value subgroupId = getDivideBySubgroupSize(invocationIdx);
110       Value index = create<IndexCastOp>(indexType, subgroupId);
111       create<memref::StoreOp>(subgroupReduce, buffer, index);
112     });
113     create<gpu::BarrierOp>();
114 
115     // Compute number of active subgroups.
116     Value biasedBlockSize =
117         create<AddIOp>(int32Type, workgroupSize, subgroupMask);
118     Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
119     Value isValidSubgroup =
120         create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups);
121 
122     // Use the first numSubgroups invocations to reduce the intermediate results
123     // from workgroup memory. The final result is written to workgroup memory
124     // again.
125     Value zero = create<ConstantIndexOp>(0);
126     createPredicatedBlock(isValidSubgroup, [&] {
127       Value index = create<IndexCastOp>(indexType, invocationIdx);
128       Value value = create<memref::LoadOp>(valueType, buffer, index);
129       Value result =
130           createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
131       create<memref::StoreOp>(result, buffer, zero);
132     });
133 
134     // Synchronize workgroup and load result from workgroup memory.
135     create<gpu::BarrierOp>();
136     Value result = create<memref::LoadOp>(valueType, buffer, zero);
137 
138     rewriter.replaceOp(reduceOp, result);
139   }
140 
141 private:
142   // Shortcut to create an op from rewriter using loc as the first argument.
143   template <typename T, typename... Args>
144   T create(Args... args) {
145     return rewriter.create<T>(loc, std::forward<Args>(args)...);
146   }
147 
148   // Creates dimension op of type T, with the result casted to int32.
149   template <typename T>
150   Value getDimOp(StringRef dimension) {
151     Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
152     return create<IndexCastOp>(int32Type, dim);
153   }
154 
155   /// Adds type to funcOp's workgroup attributions.
156   Value createWorkgroupBuffer() {
157     int workgroupMemoryAddressSpace =
158         gpu::GPUDialect::getWorkgroupAddressSpace();
159     auto bufferType =
160         MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{},
161                         workgroupMemoryAddressSpace);
162     return funcOp.addWorkgroupAttribution(bufferType);
163   }
164 
165   /// Returns an accumulator factory using either the op attribute or the body
166   /// region.
167   AccumulatorFactory getFactory() {
168     auto &body = reduceOp.body();
169     if (!body.empty())
170       return getFactory(body);
171     auto opAttr = reduceOp.op();
172     if (opAttr)
173       return getFactory(*opAttr);
174     return AccumulatorFactory();
175   }
176 
177   /// Returns an accumulator factory that clones the body. The body's entry
178   /// block is expected to have 2 arguments. The gpu.yield return the
179   /// accumulated value of the same type.
180   AccumulatorFactory getFactory(Region &body) {
181     return AccumulatorFactory([&](Value lhs, Value rhs) {
182       Block *block = rewriter.getInsertionBlock();
183       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
184 
185       // Insert accumulator body between split block.
186       BlockAndValueMapping mapping;
187       mapping.map(body.getArgument(0), lhs);
188       mapping.map(body.getArgument(1), rhs);
189       rewriter.cloneRegionBefore(body, *split->getParent(),
190                                  split->getIterator(), mapping);
191 
192       // Add branch before inserted body, into body.
193       block = block->getNextNode();
194       create<BranchOp>(block, ValueRange());
195 
196       // Replace all gpu.yield ops with branch out of body.
197       for (; block != split; block = block->getNextNode()) {
198         Operation *terminator = block->getTerminator();
199         if (!isa<gpu::YieldOp>(terminator))
200           continue;
201         rewriter.setInsertionPointToEnd(block);
202         rewriter.replaceOpWithNewOp<BranchOp>(
203             terminator, split, ValueRange(terminator->getOperand(0)));
204       }
205 
206       // Return accumulator result.
207       rewriter.setInsertionPointToStart(split);
208       return split->addArgument(lhs.getType());
209     });
210   }
211 
212   /// Returns an accumulator factory that creates an op specified by opName.
213   AccumulatorFactory getFactory(StringRef opName) {
214     bool isFloatingPoint = valueType.isa<FloatType>();
215     if (opName == "add")
216       return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
217     if (opName == "mul")
218       return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
219     if (opName == "and") {
220       return getFactory<AndOp>();
221     }
222     if (opName == "or") {
223       return getFactory<OrOp>();
224     }
225     if (opName == "xor") {
226       return getFactory<XOrOp>();
227     }
228     if (opName == "max") {
229       return isFloatingPoint
230                  ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
231                  : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
232     }
233     if (opName == "min") {
234       return isFloatingPoint
235                  ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
236                  : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
237     }
238     return AccumulatorFactory();
239   }
240 
241   /// Returns an accumulator factory that creates an op of type T.
242   template <typename T>
243   AccumulatorFactory getFactory() {
244     return [&](Value lhs, Value rhs) {
245       return create<T>(lhs.getType(), lhs, rhs);
246     };
247   }
248 
249   /// Returns an accumulator for comparison such as min, max. T is the type
250   /// of the compare op.
251   template <typename T, typename PredicateEnum, PredicateEnum predicate>
252   AccumulatorFactory getCmpFactory() const {
253     return [&](Value lhs, Value rhs) {
254       Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
255       return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
256     };
257   }
258 
259   /// Creates an if-block skeleton and calls the two factories to generate the
260   /// ops in the `then` and `else` block..
261   ///
262   ///     llvm.cond_br %condition, ^then, ^continue
263   ///   ^then:
264   ///     %then_operands = `thenOpsFactory()`
265   ///     llvm.br ^continue(%then_operands)
266   ///   ^else:
267   ///     %else_operands = `elseOpsFactory()`
268   ///     llvm.br ^continue(%else_operands)
269   ///   ^continue(%block_operands):
270   ///
271   template <typename ThenOpsFactory, typename ElseOpsFactory>
272   void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
273                 ElseOpsFactory &&elseOpsFactory) {
274     Block *currentBlock = rewriter.getInsertionBlock();
275     auto currentPoint = rewriter.getInsertionPoint();
276 
277     Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
278     Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
279     Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
280 
281     rewriter.setInsertionPointToEnd(currentBlock);
282     create<CondBranchOp>(condition, thenBlock,
283                          /*trueOperands=*/ArrayRef<Value>(), elseBlock,
284                          /*falseOperands=*/ArrayRef<Value>());
285 
286     rewriter.setInsertionPointToStart(thenBlock);
287     auto thenOperands = thenOpsFactory();
288     create<BranchOp>(continueBlock, thenOperands);
289 
290     rewriter.setInsertionPointToStart(elseBlock);
291     auto elseOperands = elseOpsFactory();
292     create<BranchOp>(continueBlock, elseOperands);
293 
294     assert(thenOperands.size() == elseOperands.size());
295     rewriter.setInsertionPointToStart(continueBlock);
296     for (auto operand : thenOperands)
297       continueBlock->addArgument(operand.getType());
298   }
299 
300   /// Shortcut for createIf with empty else block and no block operands.
301   template <typename Factory>
302   void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
303     static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
304                   "predicatedOpsFactory should not return any value");
305     createIf(
306         condition,
307         [&] {
308           predicatedOpsFactory();
309           return ArrayRef<Value>();
310         },
311         [&] { return ArrayRef<Value>(); });
312   }
313 
314   /// Creates a reduction across the first activeWidth lanes of a subgroup, or
315   /// the entire subgroup if activeWidth is larger than the subgroup width.
316   /// The first lane returns the result, all others return values are undefined.
317   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
318                              AccumulatorFactory &accumFactory) {
319     Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
320     Value isPartialSubgroup =
321         create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize);
322     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
323     auto xorAttr = rewriter.getStringAttr("xor");
324 
325     createIf(
326         isPartialSubgroup,
327         // Generate reduction over a (potentially) partial subgroup.
328         [&] {
329           Value value = operand;
330           // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
331           // lane is within the active range. The accumulated value is available
332           // in the first lane.
333           for (int i = 1; i < kSubgroupSize; i <<= 1) {
334             Value offset = create<ConstantIntOp>(i, int32Type);
335             auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
336                                                     activeWidth, xorAttr);
337             // Skip the accumulation if the shuffle op read from a lane outside
338             // of the active range.
339             createIf(
340                 shuffleOp.getResult(1),
341                 [&] {
342                   return SmallVector<Value, 1>{
343                       accumFactory(value, shuffleOp.getResult(0))};
344                 },
345                 [&] { return llvm::makeArrayRef(value); });
346             value = rewriter.getInsertionBlock()->getArgument(0);
347           }
348           return SmallVector<Value, 1>{value};
349         },
350         // Generate a reduction over the entire subgroup. This is a
351         // specialization of the above reduction with unconditional
352         // accumulation.
353         [&] {
354           Value value = operand;
355           for (int i = 1; i < kSubgroupSize; i <<= 1) {
356             Value offset = create<ConstantIntOp>(i, int32Type);
357             auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
358                                                     subgroupSize, xorAttr);
359             value = accumFactory(value, shuffleOp.getResult(0));
360           }
361           return SmallVector<Value, 1>{value};
362         });
363     return rewriter.getInsertionBlock()->getArgument(0);
364   }
365 
366   /// Returns value divided by the subgroup size (i.e. 32).
367   Value getDivideBySubgroupSize(Value value) {
368     Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
369     return create<SignedDivIOp>(int32Type, value, subgroupSize);
370   }
371 
372   gpu::GPUFuncOp funcOp;
373   gpu::AllReduceOp reduceOp;
374   PatternRewriter &rewriter;
375 
376   Location loc;
377   Type valueType;
378   Type indexType;
379   Type int32Type;
380 
381   static constexpr int kSubgroupSize = 32;
382 };
383 
384 struct GpuAllReduceConversion : public RewritePattern {
385   explicit GpuAllReduceConversion(MLIRContext *context)
386       : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
387 
388   LogicalResult matchAndRewrite(Operation *op,
389                                 PatternRewriter &rewriter) const override {
390     auto funcOp = cast<gpu::GPUFuncOp>(op);
391     auto callback = [&](gpu::AllReduceOp reduceOp) {
392       GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
393       // Performing a rewrite invalidates the walk iterator. Report interrupt
394       // so that we can start a new walk until all all_reduce ops are replaced.
395       return WalkResult::interrupt();
396     };
397     while (funcOp.walk(callback).wasInterrupted()) {
398     }
399     return success();
400   }
401 };
402 } // namespace
403 
404 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
405   patterns.add<GpuAllReduceConversion>(patterns.getContext());
406 }
407