18b2eb7c4SChristian Sigg //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
28b2eb7c4SChristian Sigg //
38b2eb7c4SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48b2eb7c4SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
58b2eb7c4SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68b2eb7c4SChristian Sigg //
78b2eb7c4SChristian Sigg //===----------------------------------------------------------------------===//
88b2eb7c4SChristian Sigg //
98b2eb7c4SChristian Sigg // This file implements in-dialect lowering of the all-reduce op to a block of
108b2eb7c4SChristian Sigg // simpler instructions.
118b2eb7c4SChristian Sigg //
128b2eb7c4SChristian Sigg //===----------------------------------------------------------------------===//
138b2eb7c4SChristian Sigg
14abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
15ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
18e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
19*9f74e6e6SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h"
208b2eb7c4SChristian Sigg #include "mlir/IR/Builders.h"
214d67b278SJeff Niu #include "mlir/IR/IRMapping.h"
228b2eb7c4SChristian Sigg #include "mlir/IR/PatternMatch.h"
238b2eb7c4SChristian Sigg #include "mlir/Pass/Pass.h"
24*9f74e6e6SJakub Kuderski #include "llvm/Support/ErrorHandling.h"
258b2eb7c4SChristian Sigg
268b2eb7c4SChristian Sigg using namespace mlir;
278b2eb7c4SChristian Sigg
288b2eb7c4SChristian Sigg namespace {
298b2eb7c4SChristian Sigg
308b2eb7c4SChristian Sigg struct GpuAllReduceRewriter {
318b2eb7c4SChristian Sigg using AccumulatorFactory = std::function<Value(Value, Value)>;
328b2eb7c4SChristian Sigg
GpuAllReduceRewriter__anon813c75b20111::GpuAllReduceRewriter3302b6fb21SMehdi Amini GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
3402b6fb21SMehdi Amini PatternRewriter &rewriter)
3502b6fb21SMehdi Amini : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
3610c04f46SRiver Riddle loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()),
378b2eb7c4SChristian Sigg indexType(IndexType::get(reduceOp.getContext())),
381b97cdf8SRiver Riddle int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
398b2eb7c4SChristian Sigg
408b2eb7c4SChristian Sigg /// Creates an all_reduce across the workgroup.
418b2eb7c4SChristian Sigg ///
428b2eb7c4SChristian Sigg /// First reduce the elements within a subgroup. The first invocation of each
438b2eb7c4SChristian Sigg /// subgroup writes the intermediate result to workgroup memory. After
448b2eb7c4SChristian Sigg /// synchronizing the workgroup, the first subgroup reduces the values from
458b2eb7c4SChristian Sigg /// workgroup memory. The result is broadcasted to all invocations through
468b2eb7c4SChristian Sigg /// workgroup memory.
478b2eb7c4SChristian Sigg ///
488b2eb7c4SChristian Sigg /// %subgroup_reduce = `createSubgroupReduce(%operand)`
49ace01605SRiver Riddle /// cf.cond_br %is_first_lane, ^then1, ^continue1
508b2eb7c4SChristian Sigg /// ^then1:
518b2eb7c4SChristian Sigg /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
52ace01605SRiver Riddle /// cf.br ^continue1
538b2eb7c4SChristian Sigg /// ^continue1:
548b2eb7c4SChristian Sigg /// gpu.barrier
55a54f4eaeSMogball /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
56ace01605SRiver Riddle /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
578b2eb7c4SChristian Sigg /// ^then2:
588b2eb7c4SChristian Sigg /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
598b2eb7c4SChristian Sigg /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
608b2eb7c4SChristian Sigg /// store %all_reduce, %workgroup_buffer[%zero]
618b2eb7c4SChristian Sigg /// llvm.br ^continue2
628b2eb7c4SChristian Sigg /// ^continue2:
638b2eb7c4SChristian Sigg /// gpu.barrier
648b2eb7c4SChristian Sigg /// %result = load %workgroup_buffer[%zero]
658b2eb7c4SChristian Sigg /// return %result
668b2eb7c4SChristian Sigg ///
rewrite__anon813c75b20111::GpuAllReduceRewriter678b2eb7c4SChristian Sigg void rewrite() {
688b2eb7c4SChristian Sigg rewriter.setInsertionPoint(reduceOp);
698b2eb7c4SChristian Sigg
708b2eb7c4SChristian Sigg // Compute linear invocation index and workgroup size.
71aae51255SMogball Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
72aae51255SMogball Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
73aae51255SMogball Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
74aae51255SMogball Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
75aae51255SMogball Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
76aae51255SMogball Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
77a54f4eaeSMogball Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
78a54f4eaeSMogball Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
79a54f4eaeSMogball Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
80a54f4eaeSMogball Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
81a54f4eaeSMogball Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
82a54f4eaeSMogball Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
838b2eb7c4SChristian Sigg
848b2eb7c4SChristian Sigg // Compute lane id (invocation id withing the subgroup).
85a54f4eaeSMogball Value subgroupMask =
86a54f4eaeSMogball create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
87a54f4eaeSMogball Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
88a54f4eaeSMogball Value isFirstLane =
89a54f4eaeSMogball create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
90a54f4eaeSMogball create<arith::ConstantIntOp>(0, int32Type));
918b2eb7c4SChristian Sigg
928b2eb7c4SChristian Sigg Value numThreadsWithSmallerSubgroupId =
93a54f4eaeSMogball create<arith::SubIOp>(invocationIdx, laneId);
948b2eb7c4SChristian Sigg // The number of active invocations starting from the current subgroup.
958b2eb7c4SChristian Sigg // The consumers do not require the value to be clamped to the size of the
968b2eb7c4SChristian Sigg // subgroup.
978b2eb7c4SChristian Sigg Value activeWidth =
98a54f4eaeSMogball create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
998b2eb7c4SChristian Sigg
1008b2eb7c4SChristian Sigg // Create factory for op which accumulates to values.
1018b2eb7c4SChristian Sigg AccumulatorFactory accumFactory = getFactory();
1028b2eb7c4SChristian Sigg assert(accumFactory && "failed to create accumulator factory");
1038b2eb7c4SChristian Sigg
1048b2eb7c4SChristian Sigg // Reduce elements within each subgroup to produce the intermediate results.
10510c04f46SRiver Riddle Value subgroupReduce = createSubgroupReduce(
10610c04f46SRiver Riddle activeWidth, laneId, reduceOp.getValue(), accumFactory);
1078b2eb7c4SChristian Sigg
1088b2eb7c4SChristian Sigg // Add workgroup buffer to parent function for intermediate result.
1098b2eb7c4SChristian Sigg Value buffer = createWorkgroupBuffer();
1108b2eb7c4SChristian Sigg
1118b2eb7c4SChristian Sigg // Write the intermediate results to workgroup memory, using the first lane
1128b2eb7c4SChristian Sigg // of each subgroup.
1138b2eb7c4SChristian Sigg createPredicatedBlock(isFirstLane, [&] {
1148b2eb7c4SChristian Sigg Value subgroupId = getDivideBySubgroupSize(invocationIdx);
115a54f4eaeSMogball Value index = create<arith::IndexCastOp>(indexType, subgroupId);
116e2310704SJulian Gross create<memref::StoreOp>(subgroupReduce, buffer, index);
1178b2eb7c4SChristian Sigg });
1188b2eb7c4SChristian Sigg create<gpu::BarrierOp>();
1198b2eb7c4SChristian Sigg
1208b2eb7c4SChristian Sigg // Compute number of active subgroups.
1218b2eb7c4SChristian Sigg Value biasedBlockSize =
122a54f4eaeSMogball create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
1238b2eb7c4SChristian Sigg Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
124a54f4eaeSMogball Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
125a54f4eaeSMogball invocationIdx, numSubgroups);
1268b2eb7c4SChristian Sigg
1278b2eb7c4SChristian Sigg // Use the first numSubgroups invocations to reduce the intermediate results
1288b2eb7c4SChristian Sigg // from workgroup memory. The final result is written to workgroup memory
1298b2eb7c4SChristian Sigg // again.
130a54f4eaeSMogball Value zero = create<arith::ConstantIndexOp>(0);
1318b2eb7c4SChristian Sigg createPredicatedBlock(isValidSubgroup, [&] {
132a54f4eaeSMogball Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
133e2310704SJulian Gross Value value = create<memref::LoadOp>(valueType, buffer, index);
1348b2eb7c4SChristian Sigg Value result =
1358b2eb7c4SChristian Sigg createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
136e2310704SJulian Gross create<memref::StoreOp>(result, buffer, zero);
1378b2eb7c4SChristian Sigg });
1388b2eb7c4SChristian Sigg
1398b2eb7c4SChristian Sigg // Synchronize workgroup and load result from workgroup memory.
1408b2eb7c4SChristian Sigg create<gpu::BarrierOp>();
141e2310704SJulian Gross Value result = create<memref::LoadOp>(valueType, buffer, zero);
1428b2eb7c4SChristian Sigg
1438b2eb7c4SChristian Sigg rewriter.replaceOp(reduceOp, result);
1448b2eb7c4SChristian Sigg }
1458b2eb7c4SChristian Sigg
1468b2eb7c4SChristian Sigg private:
1478b2eb7c4SChristian Sigg // Shortcut to create an op from rewriter using loc as the first argument.
148e2310704SJulian Gross template <typename T, typename... Args>
create__anon813c75b20111::GpuAllReduceRewriter149e2310704SJulian Gross T create(Args... args) {
1508b2eb7c4SChristian Sigg return rewriter.create<T>(loc, std::forward<Args>(args)...);
1518b2eb7c4SChristian Sigg }
1528b2eb7c4SChristian Sigg
1538b2eb7c4SChristian Sigg // Creates dimension op of type T, with the result casted to int32.
154e2310704SJulian Gross template <typename T>
getDimOp__anon813c75b20111::GpuAllReduceRewriter155aae51255SMogball Value getDimOp(gpu::Dimension dimension) {
156aae51255SMogball Value dim = create<T>(indexType, dimension);
157a54f4eaeSMogball return create<arith::IndexCastOp>(int32Type, dim);
1588b2eb7c4SChristian Sigg }
1598b2eb7c4SChristian Sigg
1608b2eb7c4SChristian Sigg /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon813c75b20111::GpuAllReduceRewriter1618b2eb7c4SChristian Sigg Value createWorkgroupBuffer() {
162e084679fSRiver Riddle // TODO: Pick a proper location for the attribution.
1636ca1a09fSChristopher Bate auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get(
1646ca1a09fSChristopher Bate funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
165e41ebbecSVladislav Vinogradov auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
1668b2eb7c4SChristian Sigg workgroupMemoryAddressSpace);
167e084679fSRiver Riddle return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
1688b2eb7c4SChristian Sigg }
1698b2eb7c4SChristian Sigg
1708b2eb7c4SChristian Sigg /// Returns an accumulator factory using either the op attribute or the body
1718b2eb7c4SChristian Sigg /// region.
getFactory__anon813c75b20111::GpuAllReduceRewriter1728b2eb7c4SChristian Sigg AccumulatorFactory getFactory() {
17310c04f46SRiver Riddle auto &body = reduceOp.getBody();
1748b2eb7c4SChristian Sigg if (!body.empty())
1758b2eb7c4SChristian Sigg return getFactory(body);
17610c04f46SRiver Riddle auto opAttr = reduceOp.getOp();
1778b2eb7c4SChristian Sigg if (opAttr)
1788b2eb7c4SChristian Sigg return getFactory(*opAttr);
1798b2eb7c4SChristian Sigg return AccumulatorFactory();
1808b2eb7c4SChristian Sigg }
1818b2eb7c4SChristian Sigg
1828b2eb7c4SChristian Sigg /// Returns an accumulator factory that clones the body. The body's entry
1838b2eb7c4SChristian Sigg /// block is expected to have 2 arguments. The gpu.yield return the
1848b2eb7c4SChristian Sigg /// accumulated value of the same type.
getFactory__anon813c75b20111::GpuAllReduceRewriter1858b2eb7c4SChristian Sigg AccumulatorFactory getFactory(Region &body) {
186*9f74e6e6SJakub Kuderski return [&body, this](Value lhs, Value rhs) -> Value {
1878b2eb7c4SChristian Sigg Block *block = rewriter.getInsertionBlock();
1888b2eb7c4SChristian Sigg Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
1898b2eb7c4SChristian Sigg
1908b2eb7c4SChristian Sigg // Insert accumulator body between split block.
1914d67b278SJeff Niu IRMapping mapping;
192e2b71610SRahul Joshi mapping.map(body.getArgument(0), lhs);
193e2b71610SRahul Joshi mapping.map(body.getArgument(1), rhs);
1948b2eb7c4SChristian Sigg rewriter.cloneRegionBefore(body, *split->getParent(),
1958b2eb7c4SChristian Sigg split->getIterator(), mapping);
1968b2eb7c4SChristian Sigg
1978b2eb7c4SChristian Sigg // Add branch before inserted body, into body.
1988b2eb7c4SChristian Sigg block = block->getNextNode();
199ace01605SRiver Riddle create<cf::BranchOp>(block, ValueRange());
2008b2eb7c4SChristian Sigg
2018b2eb7c4SChristian Sigg // Replace all gpu.yield ops with branch out of body.
2028b2eb7c4SChristian Sigg for (; block != split; block = block->getNextNode()) {
2038b2eb7c4SChristian Sigg Operation *terminator = block->getTerminator();
2048b2eb7c4SChristian Sigg if (!isa<gpu::YieldOp>(terminator))
2058b2eb7c4SChristian Sigg continue;
2068b2eb7c4SChristian Sigg rewriter.setInsertionPointToEnd(block);
207ace01605SRiver Riddle rewriter.replaceOpWithNewOp<cf::BranchOp>(
2088b2eb7c4SChristian Sigg terminator, split, ValueRange(terminator->getOperand(0)));
2098b2eb7c4SChristian Sigg }
2108b2eb7c4SChristian Sigg
2118b2eb7c4SChristian Sigg // Return accumulator result.
2128b2eb7c4SChristian Sigg rewriter.setInsertionPointToStart(split);
213e084679fSRiver Riddle return split->addArgument(lhs.getType(), lhs.getLoc());
214*9f74e6e6SJakub Kuderski };
2158b2eb7c4SChristian Sigg }
2168b2eb7c4SChristian Sigg
2178b2eb7c4SChristian Sigg /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon813c75b20111::GpuAllReduceRewriter218aae51255SMogball AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
219*9f74e6e6SJakub Kuderski return [opName, this](Value lhs, Value rhs) {
220*9f74e6e6SJakub Kuderski return vector::makeArithReduction(rewriter, loc,
221*9f74e6e6SJakub Kuderski convertReductionKind(opName), lhs, rhs);
2228b2eb7c4SChristian Sigg };
2238b2eb7c4SChristian Sigg }
2248b2eb7c4SChristian Sigg
2258b2eb7c4SChristian Sigg /// Creates an if-block skeleton and calls the two factories to generate the
2268b2eb7c4SChristian Sigg /// ops in the `then` and `else` block..
2278b2eb7c4SChristian Sigg ///
2288b2eb7c4SChristian Sigg /// llvm.cond_br %condition, ^then, ^continue
2298b2eb7c4SChristian Sigg /// ^then:
2308b2eb7c4SChristian Sigg /// %then_operands = `thenOpsFactory()`
2318b2eb7c4SChristian Sigg /// llvm.br ^continue(%then_operands)
2328b2eb7c4SChristian Sigg /// ^else:
2338b2eb7c4SChristian Sigg /// %else_operands = `elseOpsFactory()`
2348b2eb7c4SChristian Sigg /// llvm.br ^continue(%else_operands)
2358b2eb7c4SChristian Sigg /// ^continue(%block_operands):
2368b2eb7c4SChristian Sigg ///
2378b2eb7c4SChristian Sigg template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon813c75b20111::GpuAllReduceRewriter2388b2eb7c4SChristian Sigg void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
2398b2eb7c4SChristian Sigg ElseOpsFactory &&elseOpsFactory) {
2408b2eb7c4SChristian Sigg Block *currentBlock = rewriter.getInsertionBlock();
2418b2eb7c4SChristian Sigg auto currentPoint = rewriter.getInsertionPoint();
2428b2eb7c4SChristian Sigg
2438b2eb7c4SChristian Sigg Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
2448b2eb7c4SChristian Sigg Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
2458b2eb7c4SChristian Sigg Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
2468b2eb7c4SChristian Sigg
2478b2eb7c4SChristian Sigg rewriter.setInsertionPointToEnd(currentBlock);
248ace01605SRiver Riddle create<cf::CondBranchOp>(condition, thenBlock,
2498b2eb7c4SChristian Sigg /*trueOperands=*/ArrayRef<Value>(), elseBlock,
2508b2eb7c4SChristian Sigg /*falseOperands=*/ArrayRef<Value>());
2518b2eb7c4SChristian Sigg
2528b2eb7c4SChristian Sigg rewriter.setInsertionPointToStart(thenBlock);
2538b2eb7c4SChristian Sigg auto thenOperands = thenOpsFactory();
254ace01605SRiver Riddle create<cf::BranchOp>(continueBlock, thenOperands);
2558b2eb7c4SChristian Sigg
2568b2eb7c4SChristian Sigg rewriter.setInsertionPointToStart(elseBlock);
2578b2eb7c4SChristian Sigg auto elseOperands = elseOpsFactory();
258ace01605SRiver Riddle create<cf::BranchOp>(continueBlock, elseOperands);
2598b2eb7c4SChristian Sigg
2608b2eb7c4SChristian Sigg assert(thenOperands.size() == elseOperands.size());
2618b2eb7c4SChristian Sigg rewriter.setInsertionPointToStart(continueBlock);
2628b2eb7c4SChristian Sigg for (auto operand : thenOperands)
263e084679fSRiver Riddle continueBlock->addArgument(operand.getType(), operand.getLoc());
2648b2eb7c4SChristian Sigg }
2658b2eb7c4SChristian Sigg
2668b2eb7c4SChristian Sigg /// Shortcut for createIf with empty else block and no block operands.
2678b2eb7c4SChristian Sigg template <typename Factory>
createPredicatedBlock__anon813c75b20111::GpuAllReduceRewriter2688b2eb7c4SChristian Sigg void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
2698b2eb7c4SChristian Sigg static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
2708b2eb7c4SChristian Sigg "predicatedOpsFactory should not return any value");
2718b2eb7c4SChristian Sigg createIf(
2728b2eb7c4SChristian Sigg condition,
2738b2eb7c4SChristian Sigg [&] {
2748b2eb7c4SChristian Sigg predicatedOpsFactory();
2758b2eb7c4SChristian Sigg return ArrayRef<Value>();
2768b2eb7c4SChristian Sigg },
2778b2eb7c4SChristian Sigg [&] { return ArrayRef<Value>(); });
2788b2eb7c4SChristian Sigg }
2798b2eb7c4SChristian Sigg
2808b2eb7c4SChristian Sigg /// Creates a reduction across the first activeWidth lanes of a subgroup, or
2818b2eb7c4SChristian Sigg /// the entire subgroup if activeWidth is larger than the subgroup width.
2828b2eb7c4SChristian Sigg /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon813c75b20111::GpuAllReduceRewriter2838b2eb7c4SChristian Sigg Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
2848b2eb7c4SChristian Sigg AccumulatorFactory &accumFactory) {
285a54f4eaeSMogball Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
286a54f4eaeSMogball Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
287a54f4eaeSMogball activeWidth, subgroupSize);
2885fc5c7dbSBenjamin Kramer std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
2898b2eb7c4SChristian Sigg
2908b2eb7c4SChristian Sigg createIf(
2918b2eb7c4SChristian Sigg isPartialSubgroup,
2928b2eb7c4SChristian Sigg // Generate reduction over a (potentially) partial subgroup.
2938b2eb7c4SChristian Sigg [&] {
2948b2eb7c4SChristian Sigg Value value = operand;
2958b2eb7c4SChristian Sigg // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
2968b2eb7c4SChristian Sigg // lane is within the active range. The accumulated value is available
2978b2eb7c4SChristian Sigg // in the first lane.
2988b2eb7c4SChristian Sigg for (int i = 1; i < kSubgroupSize; i <<= 1) {
299a54f4eaeSMogball Value offset = create<arith::ConstantIntOp>(i, int32Type);
300aae51255SMogball auto shuffleOp = create<gpu::ShuffleOp>(
301aae51255SMogball shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
3028b2eb7c4SChristian Sigg // Skip the accumulation if the shuffle op read from a lane outside
3038b2eb7c4SChristian Sigg // of the active range.
3048b2eb7c4SChristian Sigg createIf(
3058b2eb7c4SChristian Sigg shuffleOp.getResult(1),
3068b2eb7c4SChristian Sigg [&] {
3078b2eb7c4SChristian Sigg return SmallVector<Value, 1>{
3088b2eb7c4SChristian Sigg accumFactory(value, shuffleOp.getResult(0))};
3098b2eb7c4SChristian Sigg },
310984b800aSserge-sans-paille [&] { return llvm::ArrayRef(value); });
3118b2eb7c4SChristian Sigg value = rewriter.getInsertionBlock()->getArgument(0);
3128b2eb7c4SChristian Sigg }
3138b2eb7c4SChristian Sigg return SmallVector<Value, 1>{value};
3148b2eb7c4SChristian Sigg },
3158b2eb7c4SChristian Sigg // Generate a reduction over the entire subgroup. This is a
3168b2eb7c4SChristian Sigg // specialization of the above reduction with unconditional
3178b2eb7c4SChristian Sigg // accumulation.
3188b2eb7c4SChristian Sigg [&] {
3198b2eb7c4SChristian Sigg Value value = operand;
3208b2eb7c4SChristian Sigg for (int i = 1; i < kSubgroupSize; i <<= 1) {
321a54f4eaeSMogball Value offset = create<arith::ConstantIntOp>(i, int32Type);
322aae51255SMogball auto shuffleOp =
323aae51255SMogball create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
324aae51255SMogball gpu::ShuffleMode::XOR);
3258b2eb7c4SChristian Sigg value = accumFactory(value, shuffleOp.getResult(0));
3268b2eb7c4SChristian Sigg }
3278b2eb7c4SChristian Sigg return SmallVector<Value, 1>{value};
3288b2eb7c4SChristian Sigg });
3298b2eb7c4SChristian Sigg return rewriter.getInsertionBlock()->getArgument(0);
3308b2eb7c4SChristian Sigg }
3318b2eb7c4SChristian Sigg
3328b2eb7c4SChristian Sigg /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon813c75b20111::GpuAllReduceRewriter3338b2eb7c4SChristian Sigg Value getDivideBySubgroupSize(Value value) {
334a54f4eaeSMogball Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
335a54f4eaeSMogball return create<arith::DivSIOp>(int32Type, value, subgroupSize);
3368b2eb7c4SChristian Sigg }
3378b2eb7c4SChristian Sigg
3388b2eb7c4SChristian Sigg gpu::GPUFuncOp funcOp;
3398b2eb7c4SChristian Sigg gpu::AllReduceOp reduceOp;
3408b2eb7c4SChristian Sigg PatternRewriter &rewriter;
3418b2eb7c4SChristian Sigg
3428b2eb7c4SChristian Sigg Location loc;
3438b2eb7c4SChristian Sigg Type valueType;
3448b2eb7c4SChristian Sigg Type indexType;
345a54f4eaeSMogball IntegerType int32Type;
3468b2eb7c4SChristian Sigg
3478b2eb7c4SChristian Sigg static constexpr int kSubgroupSize = 32;
3488b2eb7c4SChristian Sigg };
3498b2eb7c4SChristian Sigg
350888717e8SNicolas Vasilache struct GpuAllReduceRewrite : public RewritePattern {
GpuAllReduceRewrite__anon813c75b20111::GpuAllReduceRewrite351888717e8SNicolas Vasilache explicit GpuAllReduceRewrite(MLIRContext *context)
3528b2eb7c4SChristian Sigg : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
3538b2eb7c4SChristian Sigg
matchAndRewrite__anon813c75b20111::GpuAllReduceRewrite3543145427dSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
3558b2eb7c4SChristian Sigg PatternRewriter &rewriter) const override {
3568b2eb7c4SChristian Sigg auto funcOp = cast<gpu::GPUFuncOp>(op);
357247d8d4fSIvan Butygin
358247d8d4fSIvan Butygin SmallVector<gpu::AllReduceOp> reduceOps;
359247d8d4fSIvan Butygin auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
360247d8d4fSIvan Butygin if (!reduceOp.getUniform())
3618b2eb7c4SChristian Sigg return WalkResult::interrupt();
362247d8d4fSIvan Butygin
363247d8d4fSIvan Butygin reduceOps.emplace_back(reduceOp);
364247d8d4fSIvan Butygin return WalkResult::advance();
3658b2eb7c4SChristian Sigg };
366247d8d4fSIvan Butygin
3676ca1a09fSChristopher Bate if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty())
368247d8d4fSIvan Butygin return rewriter.notifyMatchFailure(
369247d8d4fSIvan Butygin op, "Non uniform reductions are not supported yet.");
370247d8d4fSIvan Butygin
371247d8d4fSIvan Butygin for (gpu::AllReduceOp reduceOp : reduceOps)
372247d8d4fSIvan Butygin GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
373247d8d4fSIvan Butygin
3743145427dSRiver Riddle return success();
3758b2eb7c4SChristian Sigg }
3768b2eb7c4SChristian Sigg };
3778b2eb7c4SChristian Sigg } // namespace
3788b2eb7c4SChristian Sigg
populateGpuAllReducePatterns(RewritePatternSet & patterns)379dc4e913bSChris Lattner void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
380888717e8SNicolas Vasilache patterns.add<GpuAllReduceRewrite>(patterns.getContext());
3818b2eb7c4SChristian Sigg }
382