1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 16 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 17 #include "mlir/Dialect/GPU/Transforms/Passes.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace mlir; 25 26 namespace { 27 28 struct GpuAllReduceRewriter { 29 using AccumulatorFactory = std::function<Value(Value, Value)>; 30 31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, 32 PatternRewriter &rewriter) 33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), 34 loc(reduceOp.getLoc()), valueType(reduceOp.getValue().getType()), 35 indexType(IndexType::get(reduceOp.getContext())), 36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} 37 38 /// Creates an all_reduce across the workgroup. 39 /// 40 /// First reduce the elements within a subgroup. The first invocation of each 41 /// subgroup writes the intermediate result to workgroup memory. After 42 /// synchronizing the workgroup, the first subgroup reduces the values from 43 /// workgroup memory. The result is broadcasted to all invocations through 44 /// workgroup memory. 45 /// 46 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 47 /// cf.cond_br %is_first_lane, ^then1, ^continue1 48 /// ^then1: 49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 50 /// cf.br ^continue1 51 /// ^continue1: 52 /// gpu.barrier 53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups 54 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2 55 /// ^then2: 56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 58 /// store %all_reduce, %workgroup_buffer[%zero] 59 /// llvm.br ^continue2 60 /// ^continue2: 61 /// gpu.barrier 62 /// %result = load %workgroup_buffer[%zero] 63 /// return %result 64 /// 65 void rewrite() { 66 rewriter.setInsertionPoint(reduceOp); 67 68 // Compute linear invocation index and workgroup size. 69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x); 70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y); 71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z); 72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x); 73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y); 74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z); 75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); 76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); 77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); 78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); 79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); 80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); 81 82 // Compute lane id (invocation id withing the subgroup). 83 Value subgroupMask = 84 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); 85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); 86 Value isFirstLane = 87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, 88 create<arith::ConstantIntOp>(0, int32Type)); 89 90 Value numThreadsWithSmallerSubgroupId = 91 create<arith::SubIOp>(invocationIdx, laneId); 92 // The number of active invocations starting from the current subgroup. 93 // The consumers do not require the value to be clamped to the size of the 94 // subgroup. 95 Value activeWidth = 96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 97 98 // Create factory for op which accumulates to values. 99 AccumulatorFactory accumFactory = getFactory(); 100 assert(accumFactory && "failed to create accumulator factory"); 101 102 // Reduce elements within each subgroup to produce the intermediate results. 103 Value subgroupReduce = createSubgroupReduce( 104 activeWidth, laneId, reduceOp.getValue(), accumFactory); 105 106 // Add workgroup buffer to parent function for intermediate result. 107 Value buffer = createWorkgroupBuffer(); 108 109 // Write the intermediate results to workgroup memory, using the first lane 110 // of each subgroup. 111 createPredicatedBlock(isFirstLane, [&] { 112 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 113 Value index = create<arith::IndexCastOp>(indexType, subgroupId); 114 create<memref::StoreOp>(subgroupReduce, buffer, index); 115 }); 116 create<gpu::BarrierOp>(); 117 118 // Compute number of active subgroups. 119 Value biasedBlockSize = 120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); 121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 123 invocationIdx, numSubgroups); 124 125 // Use the first numSubgroups invocations to reduce the intermediate results 126 // from workgroup memory. The final result is written to workgroup memory 127 // again. 128 Value zero = create<arith::ConstantIndexOp>(0); 129 createPredicatedBlock(isValidSubgroup, [&] { 130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx); 131 Value value = create<memref::LoadOp>(valueType, buffer, index); 132 Value result = 133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 134 create<memref::StoreOp>(result, buffer, zero); 135 }); 136 137 // Synchronize workgroup and load result from workgroup memory. 138 create<gpu::BarrierOp>(); 139 Value result = create<memref::LoadOp>(valueType, buffer, zero); 140 141 rewriter.replaceOp(reduceOp, result); 142 } 143 144 private: 145 // Shortcut to create an op from rewriter using loc as the first argument. 146 template <typename T, typename... Args> 147 T create(Args... args) { 148 return rewriter.create<T>(loc, std::forward<Args>(args)...); 149 } 150 151 // Creates dimension op of type T, with the result casted to int32. 152 template <typename T> 153 Value getDimOp(gpu::Dimension dimension) { 154 Value dim = create<T>(indexType, dimension); 155 return create<arith::IndexCastOp>(int32Type, dim); 156 } 157 158 /// Adds type to funcOp's workgroup attributions. 159 Value createWorkgroupBuffer() { 160 // TODO: Pick a proper location for the attribution. 161 int workgroupMemoryAddressSpace = 162 gpu::GPUDialect::getWorkgroupAddressSpace(); 163 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, 164 workgroupMemoryAddressSpace); 165 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc()); 166 } 167 168 /// Returns an accumulator factory using either the op attribute or the body 169 /// region. 170 AccumulatorFactory getFactory() { 171 auto &body = reduceOp.getBody(); 172 if (!body.empty()) 173 return getFactory(body); 174 auto opAttr = reduceOp.getOp(); 175 if (opAttr) 176 return getFactory(*opAttr); 177 return AccumulatorFactory(); 178 } 179 180 /// Returns an accumulator factory that clones the body. The body's entry 181 /// block is expected to have 2 arguments. The gpu.yield return the 182 /// accumulated value of the same type. 183 AccumulatorFactory getFactory(Region &body) { 184 return AccumulatorFactory([&](Value lhs, Value rhs) { 185 Block *block = rewriter.getInsertionBlock(); 186 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 187 188 // Insert accumulator body between split block. 189 BlockAndValueMapping mapping; 190 mapping.map(body.getArgument(0), lhs); 191 mapping.map(body.getArgument(1), rhs); 192 rewriter.cloneRegionBefore(body, *split->getParent(), 193 split->getIterator(), mapping); 194 195 // Add branch before inserted body, into body. 196 block = block->getNextNode(); 197 create<cf::BranchOp>(block, ValueRange()); 198 199 // Replace all gpu.yield ops with branch out of body. 200 for (; block != split; block = block->getNextNode()) { 201 Operation *terminator = block->getTerminator(); 202 if (!isa<gpu::YieldOp>(terminator)) 203 continue; 204 rewriter.setInsertionPointToEnd(block); 205 rewriter.replaceOpWithNewOp<cf::BranchOp>( 206 terminator, split, ValueRange(terminator->getOperand(0))); 207 } 208 209 // Return accumulator result. 210 rewriter.setInsertionPointToStart(split); 211 return split->addArgument(lhs.getType(), lhs.getLoc()); 212 }); 213 } 214 215 /// Returns an accumulator factory that creates an op specified by opName. 216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { 217 bool isFloatingPoint = valueType.isa<FloatType>(); 218 switch (opName) { 219 case gpu::AllReduceOperation::ADD: 220 return isFloatingPoint ? getFactory<arith::AddFOp>() 221 : getFactory<arith::AddIOp>(); 222 case gpu::AllReduceOperation::MUL: 223 return isFloatingPoint ? getFactory<arith::MulFOp>() 224 : getFactory<arith::MulIOp>(); 225 case gpu::AllReduceOperation::AND: 226 return getFactory<arith::AndIOp>(); 227 case gpu::AllReduceOperation::OR: 228 return getFactory<arith::OrIOp>(); 229 case gpu::AllReduceOperation::XOR: 230 return getFactory<arith::XOrIOp>(); 231 case gpu::AllReduceOperation::MAX: 232 return isFloatingPoint 233 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 234 arith::CmpFPredicate::UGT>() 235 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 236 arith::CmpIPredicate::ugt>(); 237 case gpu::AllReduceOperation::MIN: 238 return isFloatingPoint 239 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 240 arith::CmpFPredicate::ULT>() 241 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 242 arith::CmpIPredicate::ult>(); 243 } 244 llvm_unreachable("unknown GPU AllReduceOperation"); 245 } 246 247 /// Returns an accumulator factory that creates an op of type T. 248 template <typename T> 249 AccumulatorFactory getFactory() { 250 return [&](Value lhs, Value rhs) { 251 return create<T>(lhs.getType(), lhs, rhs); 252 }; 253 } 254 255 /// Returns an accumulator for comparison such as min, max. T is the type 256 /// of the compare op. 257 template <typename T, typename PredicateEnum, PredicateEnum predicate> 258 AccumulatorFactory getCmpFactory() const { 259 return [&](Value lhs, Value rhs) { 260 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 261 return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); 262 }; 263 } 264 265 /// Creates an if-block skeleton and calls the two factories to generate the 266 /// ops in the `then` and `else` block.. 267 /// 268 /// llvm.cond_br %condition, ^then, ^continue 269 /// ^then: 270 /// %then_operands = `thenOpsFactory()` 271 /// llvm.br ^continue(%then_operands) 272 /// ^else: 273 /// %else_operands = `elseOpsFactory()` 274 /// llvm.br ^continue(%else_operands) 275 /// ^continue(%block_operands): 276 /// 277 template <typename ThenOpsFactory, typename ElseOpsFactory> 278 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 279 ElseOpsFactory &&elseOpsFactory) { 280 Block *currentBlock = rewriter.getInsertionBlock(); 281 auto currentPoint = rewriter.getInsertionPoint(); 282 283 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 284 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 285 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 286 287 rewriter.setInsertionPointToEnd(currentBlock); 288 create<cf::CondBranchOp>(condition, thenBlock, 289 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 290 /*falseOperands=*/ArrayRef<Value>()); 291 292 rewriter.setInsertionPointToStart(thenBlock); 293 auto thenOperands = thenOpsFactory(); 294 create<cf::BranchOp>(continueBlock, thenOperands); 295 296 rewriter.setInsertionPointToStart(elseBlock); 297 auto elseOperands = elseOpsFactory(); 298 create<cf::BranchOp>(continueBlock, elseOperands); 299 300 assert(thenOperands.size() == elseOperands.size()); 301 rewriter.setInsertionPointToStart(continueBlock); 302 for (auto operand : thenOperands) 303 continueBlock->addArgument(operand.getType(), operand.getLoc()); 304 } 305 306 /// Shortcut for createIf with empty else block and no block operands. 307 template <typename Factory> 308 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 309 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 310 "predicatedOpsFactory should not return any value"); 311 createIf( 312 condition, 313 [&] { 314 predicatedOpsFactory(); 315 return ArrayRef<Value>(); 316 }, 317 [&] { return ArrayRef<Value>(); }); 318 } 319 320 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 321 /// the entire subgroup if activeWidth is larger than the subgroup width. 322 /// The first lane returns the result, all others return values are undefined. 323 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 324 AccumulatorFactory &accumFactory) { 325 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 326 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 327 activeWidth, subgroupSize); 328 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 329 330 createIf( 331 isPartialSubgroup, 332 // Generate reduction over a (potentially) partial subgroup. 333 [&] { 334 Value value = operand; 335 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 336 // lane is within the active range. The accumulated value is available 337 // in the first lane. 338 for (int i = 1; i < kSubgroupSize; i <<= 1) { 339 Value offset = create<arith::ConstantIntOp>(i, int32Type); 340 auto shuffleOp = create<gpu::ShuffleOp>( 341 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR); 342 // Skip the accumulation if the shuffle op read from a lane outside 343 // of the active range. 344 createIf( 345 shuffleOp.getResult(1), 346 [&] { 347 return SmallVector<Value, 1>{ 348 accumFactory(value, shuffleOp.getResult(0))}; 349 }, 350 [&] { return llvm::ArrayRef(value); }); 351 value = rewriter.getInsertionBlock()->getArgument(0); 352 } 353 return SmallVector<Value, 1>{value}; 354 }, 355 // Generate a reduction over the entire subgroup. This is a 356 // specialization of the above reduction with unconditional 357 // accumulation. 358 [&] { 359 Value value = operand; 360 for (int i = 1; i < kSubgroupSize; i <<= 1) { 361 Value offset = create<arith::ConstantIntOp>(i, int32Type); 362 auto shuffleOp = 363 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize, 364 gpu::ShuffleMode::XOR); 365 value = accumFactory(value, shuffleOp.getResult(0)); 366 } 367 return SmallVector<Value, 1>{value}; 368 }); 369 return rewriter.getInsertionBlock()->getArgument(0); 370 } 371 372 /// Returns value divided by the subgroup size (i.e. 32). 373 Value getDivideBySubgroupSize(Value value) { 374 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 375 return create<arith::DivSIOp>(int32Type, value, subgroupSize); 376 } 377 378 gpu::GPUFuncOp funcOp; 379 gpu::AllReduceOp reduceOp; 380 PatternRewriter &rewriter; 381 382 Location loc; 383 Type valueType; 384 Type indexType; 385 IntegerType int32Type; 386 387 static constexpr int kSubgroupSize = 32; 388 }; 389 390 struct GpuAllReduceConversion : public RewritePattern { 391 explicit GpuAllReduceConversion(MLIRContext *context) 392 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 393 394 LogicalResult matchAndRewrite(Operation *op, 395 PatternRewriter &rewriter) const override { 396 auto funcOp = cast<gpu::GPUFuncOp>(op); 397 398 SmallVector<gpu::AllReduceOp> reduceOps; 399 auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult { 400 if (!reduceOp.getUniform()) 401 return WalkResult::interrupt(); 402 403 reduceOps.emplace_back(reduceOp); 404 return WalkResult::advance(); 405 }; 406 407 if (funcOp.walk(callback).wasInterrupted()) 408 return rewriter.notifyMatchFailure( 409 op, "Non uniform reductions are not supported yet."); 410 411 for (gpu::AllReduceOp reduceOp : reduceOps) 412 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 413 414 return success(); 415 } 416 }; 417 } // namespace 418 419 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { 420 patterns.add<GpuAllReduceConversion>(patterns.getContext()); 421 } 422