xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp (revision 984b800a036fc61ccb129a8da7592af9cadc94dd)
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Transforms/Passes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 struct GpuAllReduceRewriter {
29   using AccumulatorFactory = std::function<Value(Value, Value)>;
30 
31   GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32                        PatternRewriter &rewriter)
33       : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34         loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
35         indexType(IndexType::get(reduceOp.getContext())),
36         int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
37 
38   /// Creates an all_reduce across the workgroup.
39   ///
40   /// First reduce the elements within a subgroup. The first invocation of each
41   /// subgroup writes the intermediate result to workgroup memory. After
42   /// synchronizing the workgroup, the first subgroup reduces the values from
43   /// workgroup memory. The result is broadcasted to all invocations through
44   /// workgroup memory.
45   ///
46   ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
47   ///     cf.cond_br %is_first_lane, ^then1, ^continue1
48   ///   ^then1:
49   ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50   ///     cf.br ^continue1
51   ///   ^continue1:
52   ///     gpu.barrier
53   ///     %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54   ///     cf.cond_br %is_valid_subgroup, ^then2, ^continue2
55   ///   ^then2:
56   ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
57   ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
58   ///     store %all_reduce, %workgroup_buffer[%zero]
59   ///     llvm.br ^continue2
60   ///   ^continue2:
61   ///     gpu.barrier
62   ///     %result = load %workgroup_buffer[%zero]
63   ///     return %result
64   ///
65   void rewrite() {
66     rewriter.setInsertionPoint(reduceOp);
67 
68     // Compute linear invocation index and workgroup size.
69     Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70     Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71     Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72     Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73     Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74     Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75     Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76     Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77     Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78     Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79     Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80     Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
81 
82     // Compute lane id (invocation id withing the subgroup).
83     Value subgroupMask =
84         create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
85     Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
86     Value isFirstLane =
87         create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88                               create<arith::ConstantIntOp>(0, int32Type));
89 
90     Value numThreadsWithSmallerSubgroupId =
91         create<arith::SubIOp>(invocationIdx, laneId);
92     // The number of active invocations starting from the current subgroup.
93     // The consumers do not require the value to be clamped to the size of the
94     // subgroup.
95     Value activeWidth =
96         create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
97 
98     // Create factory for op which accumulates to values.
99     AccumulatorFactory accumFactory = getFactory();
100     assert(accumFactory && "failed to create accumulator factory");
101 
102     // Reduce elements within each subgroup to produce the intermediate results.
103     Value subgroupReduce = createSubgroupReduce(
104         activeWidth, laneId, reduceOp.getValue(), accumFactory);
105 
106     // Add workgroup buffer to parent function for intermediate result.
107     Value buffer = createWorkgroupBuffer();
108 
109     // Write the intermediate results to workgroup memory, using the first lane
110     // of each subgroup.
111     createPredicatedBlock(isFirstLane, [&] {
112       Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113       Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114       create<memref::StoreOp>(subgroupReduce, buffer, index);
115     });
116     create<gpu::BarrierOp>();
117 
118     // Compute number of active subgroups.
119     Value biasedBlockSize =
120         create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121     Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122     Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123                                                   invocationIdx, numSubgroups);
124 
125     // Use the first numSubgroups invocations to reduce the intermediate results
126     // from workgroup memory. The final result is written to workgroup memory
127     // again.
128     Value zero = create<arith::ConstantIndexOp>(0);
129     createPredicatedBlock(isValidSubgroup, [&] {
130       Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131       Value value = create<memref::LoadOp>(valueType, buffer, index);
132       Value result =
133           createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134       create<memref::StoreOp>(result, buffer, zero);
135     });
136 
137     // Synchronize workgroup and load result from workgroup memory.
138     create<gpu::BarrierOp>();
139     Value result = create<memref::LoadOp>(valueType, buffer, zero);
140 
141     rewriter.replaceOp(reduceOp, result);
142   }
143 
144 private:
145   // Shortcut to create an op from rewriter using loc as the first argument.
146   template <typename T, typename... Args>
147   T create(Args... args) {
148     return rewriter.create<T>(loc, std::forward<Args>(args)...);
149   }
150 
151   // Creates dimension op of type T, with the result casted to int32.
152   template <typename T>
153   Value getDimOp(gpu::Dimension dimension) {
154     Value dim = create<T>(indexType, dimension);
155     return create<arith::IndexCastOp>(int32Type, dim);
156   }
157 
158   /// Adds type to funcOp's workgroup attributions.
159   Value createWorkgroupBuffer() {
160     // TODO: Pick a proper location for the attribution.
161     int workgroupMemoryAddressSpace =
162         gpu::GPUDialect::getWorkgroupAddressSpace();
163     auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
164                                       workgroupMemoryAddressSpace);
165     return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
166   }
167 
168   /// Returns an accumulator factory using either the op attribute or the body
169   /// region.
170   AccumulatorFactory getFactory() {
171     auto &body = reduceOp.getBody();
172     if (!body.empty())
173       return getFactory(body);
174     auto opAttr = reduceOp.getOp();
175     if (opAttr)
176       return getFactory(*opAttr);
177     return AccumulatorFactory();
178   }
179 
180   /// Returns an accumulator factory that clones the body. The body's entry
181   /// block is expected to have 2 arguments. The gpu.yield return the
182   /// accumulated value of the same type.
183   AccumulatorFactory getFactory(Region &body) {
184     return AccumulatorFactory([&](Value lhs, Value rhs) {
185       Block *block = rewriter.getInsertionBlock();
186       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
187 
188       // Insert accumulator body between split block.
189       BlockAndValueMapping mapping;
190       mapping.map(body.getArgument(0), lhs);
191       mapping.map(body.getArgument(1), rhs);
192       rewriter.cloneRegionBefore(body, *split->getParent(),
193                                  split->getIterator(), mapping);
194 
195       // Add branch before inserted body, into body.
196       block = block->getNextNode();
197       create<cf::BranchOp>(block, ValueRange());
198 
199       // Replace all gpu.yield ops with branch out of body.
200       for (; block != split; block = block->getNextNode()) {
201         Operation *terminator = block->getTerminator();
202         if (!isa<gpu::YieldOp>(terminator))
203           continue;
204         rewriter.setInsertionPointToEnd(block);
205         rewriter.replaceOpWithNewOp<cf::BranchOp>(
206             terminator, split, ValueRange(terminator->getOperand(0)));
207       }
208 
209       // Return accumulator result.
210       rewriter.setInsertionPointToStart(split);
211       return split->addArgument(lhs.getType(), lhs.getLoc());
212     });
213   }
214 
215   /// Returns an accumulator factory that creates an op specified by opName.
216   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217     bool isFloatingPoint = valueType.isa<FloatType>();
218     switch (opName) {
219     case gpu::AllReduceOperation::ADD:
220       return isFloatingPoint ? getFactory<arith::AddFOp>()
221                              : getFactory<arith::AddIOp>();
222     case gpu::AllReduceOperation::MUL:
223       return isFloatingPoint ? getFactory<arith::MulFOp>()
224                              : getFactory<arith::MulIOp>();
225     case gpu::AllReduceOperation::AND:
226       return getFactory<arith::AndIOp>();
227     case gpu::AllReduceOperation::OR:
228       return getFactory<arith::OrIOp>();
229     case gpu::AllReduceOperation::XOR:
230       return getFactory<arith::XOrIOp>();
231     case gpu::AllReduceOperation::MAX:
232       return isFloatingPoint
233                  ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
234                                  arith::CmpFPredicate::UGT>()
235                  : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
236                                  arith::CmpIPredicate::ugt>();
237     case gpu::AllReduceOperation::MIN:
238       return isFloatingPoint
239                  ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
240                                  arith::CmpFPredicate::ULT>()
241                  : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
242                                  arith::CmpIPredicate::ult>();
243     }
244     llvm_unreachable("unknown GPU AllReduceOperation");
245   }
246 
247   /// Returns an accumulator factory that creates an op of type T.
248   template <typename T>
249   AccumulatorFactory getFactory() {
250     return [&](Value lhs, Value rhs) {
251       return create<T>(lhs.getType(), lhs, rhs);
252     };
253   }
254 
255   /// Returns an accumulator for comparison such as min, max. T is the type
256   /// of the compare op.
257   template <typename T, typename PredicateEnum, PredicateEnum predicate>
258   AccumulatorFactory getCmpFactory() const {
259     return [&](Value lhs, Value rhs) {
260       Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
261       return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
262     };
263   }
264 
265   /// Creates an if-block skeleton and calls the two factories to generate the
266   /// ops in the `then` and `else` block..
267   ///
268   ///     llvm.cond_br %condition, ^then, ^continue
269   ///   ^then:
270   ///     %then_operands = `thenOpsFactory()`
271   ///     llvm.br ^continue(%then_operands)
272   ///   ^else:
273   ///     %else_operands = `elseOpsFactory()`
274   ///     llvm.br ^continue(%else_operands)
275   ///   ^continue(%block_operands):
276   ///
277   template <typename ThenOpsFactory, typename ElseOpsFactory>
278   void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
279                 ElseOpsFactory &&elseOpsFactory) {
280     Block *currentBlock = rewriter.getInsertionBlock();
281     auto currentPoint = rewriter.getInsertionPoint();
282 
283     Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
284     Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
285     Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
286 
287     rewriter.setInsertionPointToEnd(currentBlock);
288     create<cf::CondBranchOp>(condition, thenBlock,
289                              /*trueOperands=*/ArrayRef<Value>(), elseBlock,
290                              /*falseOperands=*/ArrayRef<Value>());
291 
292     rewriter.setInsertionPointToStart(thenBlock);
293     auto thenOperands = thenOpsFactory();
294     create<cf::BranchOp>(continueBlock, thenOperands);
295 
296     rewriter.setInsertionPointToStart(elseBlock);
297     auto elseOperands = elseOpsFactory();
298     create<cf::BranchOp>(continueBlock, elseOperands);
299 
300     assert(thenOperands.size() == elseOperands.size());
301     rewriter.setInsertionPointToStart(continueBlock);
302     for (auto operand : thenOperands)
303       continueBlock->addArgument(operand.getType(), operand.getLoc());
304   }
305 
306   /// Shortcut for createIf with empty else block and no block operands.
307   template <typename Factory>
308   void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
309     static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
310                   "predicatedOpsFactory should not return any value");
311     createIf(
312         condition,
313         [&] {
314           predicatedOpsFactory();
315           return ArrayRef<Value>();
316         },
317         [&] { return ArrayRef<Value>(); });
318   }
319 
320   /// Creates a reduction across the first activeWidth lanes of a subgroup, or
321   /// the entire subgroup if activeWidth is larger than the subgroup width.
322   /// The first lane returns the result, all others return values are undefined.
323   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
324                              AccumulatorFactory &accumFactory) {
325     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
326     Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
327                                                     activeWidth, subgroupSize);
328     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
329 
330     createIf(
331         isPartialSubgroup,
332         // Generate reduction over a (potentially) partial subgroup.
333         [&] {
334           Value value = operand;
335           // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
336           // lane is within the active range. The accumulated value is available
337           // in the first lane.
338           for (int i = 1; i < kSubgroupSize; i <<= 1) {
339             Value offset = create<arith::ConstantIntOp>(i, int32Type);
340             auto shuffleOp = create<gpu::ShuffleOp>(
341                 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
342             // Skip the accumulation if the shuffle op read from a lane outside
343             // of the active range.
344             createIf(
345                 shuffleOp.getResult(1),
346                 [&] {
347                   return SmallVector<Value, 1>{
348                       accumFactory(value, shuffleOp.getResult(0))};
349                 },
350                 [&] { return llvm::ArrayRef(value); });
351             value = rewriter.getInsertionBlock()->getArgument(0);
352           }
353           return SmallVector<Value, 1>{value};
354         },
355         // Generate a reduction over the entire subgroup. This is a
356         // specialization of the above reduction with unconditional
357         // accumulation.
358         [&] {
359           Value value = operand;
360           for (int i = 1; i < kSubgroupSize; i <<= 1) {
361             Value offset = create<arith::ConstantIntOp>(i, int32Type);
362             auto shuffleOp =
363                 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
364                                        gpu::ShuffleMode::XOR);
365             value = accumFactory(value, shuffleOp.getResult(0));
366           }
367           return SmallVector<Value, 1>{value};
368         });
369     return rewriter.getInsertionBlock()->getArgument(0);
370   }
371 
372   /// Returns value divided by the subgroup size (i.e. 32).
373   Value getDivideBySubgroupSize(Value value) {
374     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
375     return create<arith::DivSIOp>(int32Type, value, subgroupSize);
376   }
377 
378   gpu::GPUFuncOp funcOp;
379   gpu::AllReduceOp reduceOp;
380   PatternRewriter &rewriter;
381 
382   Location loc;
383   Type valueType;
384   Type indexType;
385   IntegerType int32Type;
386 
387   static constexpr int kSubgroupSize = 32;
388 };
389 
390 struct GpuAllReduceConversion : public RewritePattern {
391   explicit GpuAllReduceConversion(MLIRContext *context)
392       : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
393 
394   LogicalResult matchAndRewrite(Operation *op,
395                                 PatternRewriter &rewriter) const override {
396     auto funcOp = cast<gpu::GPUFuncOp>(op);
397 
398     SmallVector<gpu::AllReduceOp> reduceOps;
399     auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
400       if (!reduceOp.getUniform())
401         return WalkResult::interrupt();
402 
403       reduceOps.emplace_back(reduceOp);
404       return WalkResult::advance();
405     };
406 
407     if (funcOp.walk(callback).wasInterrupted())
408       return rewriter.notifyMatchFailure(
409           op, "Non uniform reductions are not supported yet.");
410 
411     for (gpu::AllReduceOp reduceOp : reduceOps)
412       GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
413 
414     return success();
415   }
416 };
417 } // namespace
418 
419 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
420   patterns.add<GpuAllReduceConversion>(patterns.getContext());
421 }
422