1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/GPU/Transforms/Passes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/IRMapping.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Pass/Pass.h" 24 #include "llvm/Support/ErrorHandling.h" 25 26 using namespace mlir; 27 28 namespace { 29 30 static vector::CombiningKind 31 convertReductionKind(gpu::AllReduceOperation mode) { 32 switch (mode) { 33 #define MAP_CASE(X) \ 34 case gpu::AllReduceOperation::X: \ 35 return vector::CombiningKind::X 36 37 MAP_CASE(ADD); 38 MAP_CASE(MUL); 39 MAP_CASE(MINUI); 40 MAP_CASE(MINSI); 41 MAP_CASE(MINF); 42 MAP_CASE(MAXSI); 43 MAP_CASE(MAXUI); 44 MAP_CASE(MAXF); 45 MAP_CASE(AND); 46 MAP_CASE(OR); 47 MAP_CASE(XOR); 48 MAP_CASE(MINIMUMF); 49 MAP_CASE(MAXIMUMF); 50 51 #undef MAP_CASE 52 } 53 54 llvm_unreachable("Vector and GPU reduction kinds should match 1:1"); 55 } 56 57 struct GpuAllReduceRewriter { 58 using AccumulatorFactory = std::function<Value(Value, Value)>; 59 60 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, 61 PatternRewriter &rewriter) 62 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), 63 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()), 64 indexType(IndexType::get(reduceOp.getContext())), 65 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} 66 67 /// Creates an all_reduce across the workgroup. 68 /// 69 /// First reduce the elements within a subgroup. The first invocation of each 70 /// subgroup writes the intermediate result to workgroup memory. After 71 /// synchronizing the workgroup, the first subgroup reduces the values from 72 /// workgroup memory. The result is broadcasted to all invocations through 73 /// workgroup memory. 74 /// 75 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 76 /// cf.cond_br %is_first_lane, ^then1, ^continue1 77 /// ^then1: 78 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 79 /// cf.br ^continue1 80 /// ^continue1: 81 /// gpu.barrier 82 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups 83 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2 84 /// ^then2: 85 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 86 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 87 /// store %all_reduce, %workgroup_buffer[%zero] 88 /// llvm.br ^continue2 89 /// ^continue2: 90 /// gpu.barrier 91 /// %result = load %workgroup_buffer[%zero] 92 /// return %result 93 /// 94 void rewrite() { 95 rewriter.setInsertionPoint(reduceOp); 96 97 // Compute linear invocation index and workgroup size. 98 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x); 99 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y); 100 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z); 101 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x); 102 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y); 103 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z); 104 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); 105 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); 106 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); 107 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); 108 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); 109 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); 110 111 // Compute lane id (invocation id withing the subgroup). 112 Value subgroupMask = 113 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); 114 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); 115 Value isFirstLane = 116 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, 117 create<arith::ConstantIntOp>(0, int32Type)); 118 119 Value numThreadsWithSmallerSubgroupId = 120 create<arith::SubIOp>(invocationIdx, laneId); 121 // The number of active invocations starting from the current subgroup. 122 // The consumers do not require the value to be clamped to the size of the 123 // subgroup. 124 Value activeWidth = 125 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 126 127 // Create factory for op which accumulates to values. 128 AccumulatorFactory accumFactory = getFactory(); 129 assert(accumFactory && "failed to create accumulator factory"); 130 131 // Reduce elements within each subgroup to produce the intermediate results. 132 Value subgroupReduce = createSubgroupReduce( 133 activeWidth, laneId, reduceOp.getValue(), accumFactory); 134 135 // Add workgroup buffer to parent function for intermediate result. 136 Value buffer = createWorkgroupBuffer(); 137 138 // Write the intermediate results to workgroup memory, using the first lane 139 // of each subgroup. 140 createPredicatedBlock(isFirstLane, [&] { 141 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 142 Value index = create<arith::IndexCastOp>(indexType, subgroupId); 143 create<memref::StoreOp>(subgroupReduce, buffer, index); 144 }); 145 create<gpu::BarrierOp>(); 146 147 // Compute number of active subgroups. 148 Value biasedBlockSize = 149 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); 150 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 151 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 152 invocationIdx, numSubgroups); 153 154 // Use the first numSubgroups invocations to reduce the intermediate results 155 // from workgroup memory. The final result is written to workgroup memory 156 // again. 157 Value zero = create<arith::ConstantIndexOp>(0); 158 createPredicatedBlock(isValidSubgroup, [&] { 159 Value index = create<arith::IndexCastOp>(indexType, invocationIdx); 160 Value value = create<memref::LoadOp>(valueType, buffer, index); 161 Value result = 162 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 163 create<memref::StoreOp>(result, buffer, zero); 164 }); 165 166 // Synchronize workgroup and load result from workgroup memory. 167 create<gpu::BarrierOp>(); 168 Value result = create<memref::LoadOp>(valueType, buffer, zero); 169 170 rewriter.replaceOp(reduceOp, result); 171 } 172 173 private: 174 // Shortcut to create an op from rewriter using loc as the first argument. 175 template <typename T, typename... Args> 176 T create(Args... args) { 177 return rewriter.create<T>(loc, std::forward<Args>(args)...); 178 } 179 180 // Creates dimension op of type T, with the result casted to int32. 181 template <typename T> 182 Value getDimOp(gpu::Dimension dimension) { 183 Value dim = create<T>(indexType, dimension); 184 return create<arith::IndexCastOp>(int32Type, dim); 185 } 186 187 /// Adds type to funcOp's workgroup attributions. 188 Value createWorkgroupBuffer() { 189 // TODO: Pick a proper location for the attribution. 190 auto workgroupMemoryAddressSpace = gpu::AddressSpaceAttr::get( 191 funcOp->getContext(), gpu::GPUDialect::getWorkgroupAddressSpace()); 192 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, 193 workgroupMemoryAddressSpace); 194 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc()); 195 } 196 197 /// Returns an accumulator factory using either the op attribute or the body 198 /// region. 199 AccumulatorFactory getFactory() { 200 auto &body = reduceOp.getBody(); 201 if (!body.empty()) 202 return getFactory(body); 203 auto opAttr = reduceOp.getOp(); 204 if (opAttr) 205 return getFactory(*opAttr); 206 return AccumulatorFactory(); 207 } 208 209 /// Returns an accumulator factory that clones the body. The body's entry 210 /// block is expected to have 2 arguments. The gpu.yield return the 211 /// accumulated value of the same type. 212 AccumulatorFactory getFactory(Region &body) { 213 return [&body, this](Value lhs, Value rhs) -> Value { 214 Block *block = rewriter.getInsertionBlock(); 215 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 216 217 // Insert accumulator body between split block. 218 IRMapping mapping; 219 mapping.map(body.getArgument(0), lhs); 220 mapping.map(body.getArgument(1), rhs); 221 rewriter.cloneRegionBefore(body, *split->getParent(), 222 split->getIterator(), mapping); 223 224 // Add branch before inserted body, into body. 225 block = block->getNextNode(); 226 create<cf::BranchOp>(block, ValueRange()); 227 228 // Replace all gpu.yield ops with branch out of body. 229 for (; block != split; block = block->getNextNode()) { 230 Operation *terminator = block->getTerminator(); 231 if (!isa<gpu::YieldOp>(terminator)) 232 continue; 233 rewriter.setInsertionPointToEnd(block); 234 rewriter.replaceOpWithNewOp<cf::BranchOp>( 235 terminator, split, ValueRange(terminator->getOperand(0))); 236 } 237 238 // Return accumulator result. 239 rewriter.setInsertionPointToStart(split); 240 return split->addArgument(lhs.getType(), lhs.getLoc()); 241 }; 242 } 243 244 /// Returns an accumulator factory that creates an op specified by opName. 245 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { 246 return [opName, this](Value lhs, Value rhs) { 247 return vector::makeArithReduction(rewriter, loc, 248 convertReductionKind(opName), lhs, rhs); 249 }; 250 } 251 252 /// Creates an if-block skeleton and calls the two factories to generate the 253 /// ops in the `then` and `else` block.. 254 /// 255 /// llvm.cond_br %condition, ^then, ^continue 256 /// ^then: 257 /// %then_operands = `thenOpsFactory()` 258 /// llvm.br ^continue(%then_operands) 259 /// ^else: 260 /// %else_operands = `elseOpsFactory()` 261 /// llvm.br ^continue(%else_operands) 262 /// ^continue(%block_operands): 263 /// 264 template <typename ThenOpsFactory, typename ElseOpsFactory> 265 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 266 ElseOpsFactory &&elseOpsFactory) { 267 Block *currentBlock = rewriter.getInsertionBlock(); 268 auto currentPoint = rewriter.getInsertionPoint(); 269 270 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 271 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 272 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 273 274 rewriter.setInsertionPointToEnd(currentBlock); 275 create<cf::CondBranchOp>(condition, thenBlock, 276 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 277 /*falseOperands=*/ArrayRef<Value>()); 278 279 rewriter.setInsertionPointToStart(thenBlock); 280 auto thenOperands = thenOpsFactory(); 281 create<cf::BranchOp>(continueBlock, thenOperands); 282 283 rewriter.setInsertionPointToStart(elseBlock); 284 auto elseOperands = elseOpsFactory(); 285 create<cf::BranchOp>(continueBlock, elseOperands); 286 287 assert(thenOperands.size() == elseOperands.size()); 288 rewriter.setInsertionPointToStart(continueBlock); 289 for (auto operand : thenOperands) 290 continueBlock->addArgument(operand.getType(), operand.getLoc()); 291 } 292 293 /// Shortcut for createIf with empty else block and no block operands. 294 template <typename Factory> 295 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 296 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 297 "predicatedOpsFactory should not return any value"); 298 createIf( 299 condition, 300 [&] { 301 predicatedOpsFactory(); 302 return ArrayRef<Value>(); 303 }, 304 [&] { return ArrayRef<Value>(); }); 305 } 306 307 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 308 /// the entire subgroup if activeWidth is larger than the subgroup width. 309 /// The first lane returns the result, all others return values are undefined. 310 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 311 AccumulatorFactory &accumFactory) { 312 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 313 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 314 activeWidth, subgroupSize); 315 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 316 317 createIf( 318 isPartialSubgroup, 319 // Generate reduction over a (potentially) partial subgroup. 320 [&] { 321 Value value = operand; 322 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 323 // lane is within the active range. The accumulated value is available 324 // in the first lane. 325 for (int i = 1; i < kSubgroupSize; i <<= 1) { 326 Value offset = create<arith::ConstantIntOp>(i, int32Type); 327 auto shuffleOp = create<gpu::ShuffleOp>( 328 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR); 329 // Skip the accumulation if the shuffle op read from a lane outside 330 // of the active range. 331 createIf( 332 shuffleOp.getResult(1), 333 [&] { 334 return SmallVector<Value, 1>{ 335 accumFactory(value, shuffleOp.getResult(0))}; 336 }, 337 [&] { return llvm::ArrayRef(value); }); 338 value = rewriter.getInsertionBlock()->getArgument(0); 339 } 340 return SmallVector<Value, 1>{value}; 341 }, 342 // Generate a reduction over the entire subgroup. This is a 343 // specialization of the above reduction with unconditional 344 // accumulation. 345 [&] { 346 Value value = operand; 347 for (int i = 1; i < kSubgroupSize; i <<= 1) { 348 Value offset = create<arith::ConstantIntOp>(i, int32Type); 349 auto shuffleOp = 350 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize, 351 gpu::ShuffleMode::XOR); 352 value = accumFactory(value, shuffleOp.getResult(0)); 353 } 354 return SmallVector<Value, 1>{value}; 355 }); 356 return rewriter.getInsertionBlock()->getArgument(0); 357 } 358 359 /// Returns value divided by the subgroup size (i.e. 32). 360 Value getDivideBySubgroupSize(Value value) { 361 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 362 return create<arith::DivSIOp>(int32Type, value, subgroupSize); 363 } 364 365 gpu::GPUFuncOp funcOp; 366 gpu::AllReduceOp reduceOp; 367 PatternRewriter &rewriter; 368 369 Location loc; 370 Type valueType; 371 Type indexType; 372 IntegerType int32Type; 373 374 static constexpr int kSubgroupSize = 32; 375 }; 376 377 struct GpuAllReduceRewrite : public RewritePattern { 378 explicit GpuAllReduceRewrite(MLIRContext *context) 379 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 380 381 LogicalResult matchAndRewrite(Operation *op, 382 PatternRewriter &rewriter) const override { 383 auto funcOp = cast<gpu::GPUFuncOp>(op); 384 385 SmallVector<gpu::AllReduceOp> reduceOps; 386 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult { 387 if (!reduceOp.getUniform()) 388 return WalkResult::interrupt(); 389 390 reduceOps.emplace_back(reduceOp); 391 return WalkResult::advance(); 392 }; 393 394 if (funcOp.walk(callback).wasInterrupted() || reduceOps.empty()) 395 return rewriter.notifyMatchFailure( 396 op, "Non uniform reductions are not supported yet."); 397 398 for (gpu::AllReduceOp reduceOp : reduceOps) 399 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 400 401 return success(); 402 } 403 }; 404 } // namespace 405 406 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { 407 patterns.add<GpuAllReduceRewrite>(patterns.getContext()); 408 } 409