1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Transforms/Passes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/IRMapping.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/Support/ErrorHandling.h"
25
26 using namespace mlir;
27
28 namespace {
29
30 struct GpuAllReduceRewriter {
31 using AccumulatorFactory = std::function<Value(Value, Value)>;
32
GpuAllReduceRewriter__anon813c75b20111::GpuAllReduceRewriter33 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
34 PatternRewriter &rewriter)
35 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
36 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
37 indexType(IndexType::get(reduceOp.getContext())),
38 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
39
40 /// Creates an all_reduce across the workgroup.
41 ///
42 /// First reduce the elements within a subgroup. The first invocation of each
43 /// subgroup writes the intermediate result to workgroup memory. After
44 /// synchronizing the workgroup, the first subgroup reduces the values from
45 /// workgroup memory. The result is broadcasted to all invocations through
46 /// workgroup memory.
47 ///
48 /// %subgroup_reduce = `createSubgroupReduce(%operand)`
49 /// cf.cond_br %is_first_lane, ^then1, ^continue1
50 /// ^then1:
51 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
52 /// cf.br ^continue1
53 /// ^continue1:
54 /// gpu.barrier
55 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
56 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
57 /// ^then2:
58 /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
59 /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
60 /// store %all_reduce, %workgroup_buffer[%zero]
61 /// llvm.br ^continue2
62 /// ^continue2:
63 /// gpu.barrier
64 /// %result = load %workgroup_buffer[%zero]
65 /// return %result
66 ///
rewrite__anon813c75b20111::GpuAllReduceRewriter67 void rewrite() {
68 rewriter.setInsertionPoint(reduceOp);
69
70 // Compute linear invocation index and workgroup size.
71 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
72 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
73 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
74 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
75 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
76 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
77 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
78 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
79 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
80 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
81 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
82 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
83
84 // Compute lane id (invocation id withing the subgroup).
85 Value subgroupMask =
86 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
87 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
88 Value isFirstLane =
89 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
90 create<arith::ConstantIntOp>(0, int32Type));
91
92 Value numThreadsWithSmallerSubgroupId =
93 create<arith::SubIOp>(invocationIdx, laneId);
94 // The number of active invocations starting from the current subgroup.
95 // The consumers do not require the value to be clamped to the size of the
96 // subgroup.
97 Value activeWidth =
98 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
99
100 // Create factory for op which accumulates to values.
101 AccumulatorFactory accumFactory = getFactory();
102 assert(accumFactory && "failed to create accumulator factory");
103
104 // Reduce elements within each subgroup to produce the intermediate results.
105 Value subgroupReduce = createSubgroupReduce(
106 activeWidth, laneId, reduceOp.getValue(), accumFactory);
107
108 // Add workgroup buffer to parent function for intermediate result.
109 Value buffer = createWorkgroupBuffer();
110
111 // Write the intermediate results to workgroup memory, using the first lane
112 // of each subgroup.
113 createPredicatedBlock(isFirstLane, [&] {
114 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
115 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
116 create<memref::StoreOp>(subgroupReduce, buffer, index);
117 });
118 create<gpu::BarrierOp>();
119
120 // Compute number of active subgroups.
121 Value biasedBlockSize =
122 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
123 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
124 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
125 invocationIdx, numSubgroups);
126
127 // Use the first numSubgroups invocations to reduce the intermediate results
128 // from workgroup memory. The final result is written to workgroup memory
129 // again.
130 Value zero = create<arith::ConstantIndexOp>(0);
131 createPredicatedBlock(isValidSubgroup, [&] {
132 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
133 Value value = create<memref::LoadOp>(valueType, buffer, index);
134 Value result =
135 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
136 create<memref::StoreOp>(result, buffer, zero);
137 });
138
139 // Synchronize workgroup and load result from workgroup memory.
140 create<gpu::BarrierOp>();
141 Value result = create<memref::LoadOp>(valueType, buffer, zero);
142
143 rewriter.replaceOp(reduceOp, result);
144 }
145
146 private:
147 // Shortcut to create an op from rewriter using loc as the first argument.
148 template <typename T, typename... Args>
create__anon813c75b20111::GpuAllReduceRewriter149 T create(Args... args) {
150 return rewriter.create<T>(loc, std::forward<Args>(args)...);
151 }
152
153 // Creates dimension op of type T, with the result casted to int32.
154 template <typename T>
getDimOp__anon813c75b20111::GpuAllReduceRewriter155 Value getDimOp(gpu::Dimension dimension) {
156 Value dim = create<T>(indexType, dimension);
157 return create<arith::IndexCastOp>(int32Type, dim);
158 }
159
160 /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon813c75b20111::GpuAllReduceRewriter161 Value createWorkgroupBuffer() {
162 // TODO: Pick a proper location for the attribution.
163 auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
164 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
165 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
166 workgroupMemoryAddressSpace);
167 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
168 }
169
170 /// Returns an accumulator factory using either the op attribute or the body
171 /// region.
getFactory__anon813c75b20111::GpuAllReduceRewriter172 AccumulatorFactory getFactory() {
173 auto &body = reduceOp.getBody();
174 if (!body.empty())
175 return getFactory(body);
176 auto opAttr = reduceOp.getOp();
177 if (opAttr)
178 return getFactory(*opAttr);
179 return AccumulatorFactory();
180 }
181
182 /// Returns an accumulator factory that clones the body. The body's entry
183 /// block is expected to have 2 arguments. The gpu.yield return the
184 /// accumulated value of the same type.
getFactory__anon813c75b20111::GpuAllReduceRewriter185 AccumulatorFactory getFactory(Region &body) {
186 return [&body, this](Value lhs, Value rhs) -> Value {
187 Block *block = rewriter.getInsertionBlock();
188 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
189
190 // Insert accumulator body between split block.
191 IRMapping mapping;
192 mapping.map(body.getArgument(0), lhs);
193 mapping.map(body.getArgument(1), rhs);
194 rewriter.cloneRegionBefore(body, *split->getParent(),
195 split->getIterator(), mapping);
196
197 // Add branch before inserted body, into body.
198 block = block->getNextNode();
199 create<cf::BranchOp>(block, ValueRange());
200
201 // Replace all gpu.yield ops with branch out of body.
202 for (; block != split; block = block->getNextNode()) {
203 Operation *terminator = block->getTerminator();
204 if (!isa<gpu::YieldOp>(terminator))
205 continue;
206 rewriter.setInsertionPointToEnd(block);
207 rewriter.replaceOpWithNewOp<cf::BranchOp>(
208 terminator, split, ValueRange(terminator->getOperand(0)));
209 }
210
211 // Return accumulator result.
212 rewriter.setInsertionPointToStart(split);
213 return split->addArgument(lhs.getType(), lhs.getLoc());
214 };
215 }
216
217 /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon813c75b20111::GpuAllReduceRewriter218 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
219 return [opName, this](Value lhs, Value rhs) {
220 return vector::makeArithReduction(rewriter, loc,
221 convertReductionKind(opName), lhs, rhs);
222 };
223 }
224
225 /// Creates an if-block skeleton and calls the two factories to generate the
226 /// ops in the `then` and `else` block..
227 ///
228 /// llvm.cond_br %condition, ^then, ^continue
229 /// ^then:
230 /// %then_operands = `thenOpsFactory()`
231 /// llvm.br ^continue(%then_operands)
232 /// ^else:
233 /// %else_operands = `elseOpsFactory()`
234 /// llvm.br ^continue(%else_operands)
235 /// ^continue(%block_operands):
236 ///
237 template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon813c75b20111::GpuAllReduceRewriter238 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
239 ElseOpsFactory &&elseOpsFactory) {
240 Block *currentBlock = rewriter.getInsertionBlock();
241 auto currentPoint = rewriter.getInsertionPoint();
242
243 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
244 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
245 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
246
247 rewriter.setInsertionPointToEnd(currentBlock);
248 create<cf::CondBranchOp>(condition, thenBlock,
249 /*trueOperands=*/ArrayRef<Value>(), elseBlock,
250 /*falseOperands=*/ArrayRef<Value>());
251
252 rewriter.setInsertionPointToStart(thenBlock);
253 auto thenOperands = thenOpsFactory();
254 create<cf::BranchOp>(continueBlock, thenOperands);
255
256 rewriter.setInsertionPointToStart(elseBlock);
257 auto elseOperands = elseOpsFactory();
258 create<cf::BranchOp>(continueBlock, elseOperands);
259
260 assert(thenOperands.size() == elseOperands.size());
261 rewriter.setInsertionPointToStart(continueBlock);
262 for (auto operand : thenOperands)
263 continueBlock->addArgument(operand.getType(), operand.getLoc());
264 }
265
266 /// Shortcut for createIf with empty else block and no block operands.
267 template <typename Factory>
createPredicatedBlock__anon813c75b20111::GpuAllReduceRewriter268 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
269 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
270 "predicatedOpsFactory should not return any value");
271 createIf(
272 condition,
273 [&] {
274 predicatedOpsFactory();
275 return ArrayRef<Value>();
276 },
277 [&] { return ArrayRef<Value>(); });
278 }
279
280 /// Creates a reduction across the first activeWidth lanes of a subgroup, or
281 /// the entire subgroup if activeWidth is larger than the subgroup width.
282 /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon813c75b20111::GpuAllReduceRewriter283 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
284 AccumulatorFactory &accumFactory) {
285 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
286 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
287 activeWidth, subgroupSize);
288 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
289
290 createIf(
291 isPartialSubgroup,
292 // Generate reduction over a (potentially) partial subgroup.
293 [&] {
294 Value value = operand;
295 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
296 // lane is within the active range. The accumulated value is available
297 // in the first lane.
298 for (int i = 1; i < kSubgroupSize; i <<= 1) {
299 Value offset = create<arith::ConstantIntOp>(i, int32Type);
300 auto shuffleOp = create<gpu::ShuffleOp>(
301 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
302 // Skip the accumulation if the shuffle op read from a lane outside
303 // of the active range.
304 createIf(
305 shuffleOp.getResult(1),
306 [&] {
307 return SmallVector<Value, 1>{
308 accumFactory(value, shuffleOp.getResult(0))};
309 },
310 [&] { return llvm::ArrayRef(value); });
311 value = rewriter.getInsertionBlock()->getArgument(0);
312 }
313 return SmallVector<Value, 1>{value};
314 },
315 // Generate a reduction over the entire subgroup. This is a
316 // specialization of the above reduction with unconditional
317 // accumulation.
318 [&] {
319 Value value = operand;
320 for (int i = 1; i < kSubgroupSize; i <<= 1) {
321 Value offset = create<arith::ConstantIntOp>(i, int32Type);
322 auto shuffleOp =
323 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
324 gpu::ShuffleMode::XOR);
325 value = accumFactory(value, shuffleOp.getResult(0));
326 }
327 return SmallVector<Value, 1>{value};
328 });
329 return rewriter.getInsertionBlock()->getArgument(0);
330 }
331
332 /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon813c75b20111::GpuAllReduceRewriter333 Value getDivideBySubgroupSize(Value value) {
334 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
335 return create<arith::DivSIOp>(int32Type, value, subgroupSize);
336 }
337
338 gpu::GPUFuncOp funcOp;
339 gpu::AllReduceOp reduceOp;
340 PatternRewriter &rewriter;
341
342 Location loc;
343 Type valueType;
344 Type indexType;
345 IntegerType int32Type;
346
347 static constexpr int kSubgroupSize = 32;
348 };
349
350 struct GpuAllReduceRewrite : public RewritePattern {
GpuAllReduceRewrite__anon813c75b20111::GpuAllReduceRewrite351 explicit GpuAllReduceRewrite(MLIRContext *context)
352 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
353
matchAndRewrite__anon813c75b20111::GpuAllReduceRewrite354 LogicalResult matchAndRewrite(Operation *op,
355 PatternRewriter &rewriter) const override {
356 auto funcOp = cast<gpu::GPUFuncOp>(op);
357
358 SmallVector<gpu::AllReduceOp> reduceOps;
359 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
360 if (!reduceOp.getUniform())
361 return WalkResult::interrupt();
362
363 reduceOps.emplace_back(reduceOp);
364 return WalkResult::advance();
365 };
366
367 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
368 return rewriter.notifyMatchFailure(
369 op, "Non uniform reductions are not supported yet.");
370
371 for (gpu::AllReduceOp reduceOp : reduceOps)
372 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
373
374 return success();
375 }
376 };
377 } // namespace
378
populateGpuAllReducePatterns(RewritePatternSet & patterns)379 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
380 patterns.add<GpuAllReduceRewrite>(patterns.getContext());
381 }
382