1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/GPU/Passes.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace mlir; 25 26 namespace { 27 28 struct GpuAllReduceRewriter { 29 using AccumulatorFactory = std::function<Value(Value, Value)>; 30 31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, 32 PatternRewriter &rewriter) 33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), 34 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 35 indexType(IndexType::get(reduceOp.getContext())), 36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} 37 38 /// Creates an all_reduce across the workgroup. 39 /// 40 /// First reduce the elements within a subgroup. The first invocation of each 41 /// subgroup writes the intermediate result to workgroup memory. After 42 /// synchronizing the workgroup, the first subgroup reduces the values from 43 /// workgroup memory. The result is broadcasted to all invocations through 44 /// workgroup memory. 45 /// 46 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 47 /// cond_br %is_first_lane, ^then1, ^continue1 48 /// ^then1: 49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 50 /// br ^continue1 51 /// ^continue1: 52 /// gpu.barrier 53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups 54 /// cond_br %is_valid_subgroup, ^then2, ^continue2 55 /// ^then2: 56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 58 /// store %all_reduce, %workgroup_buffer[%zero] 59 /// llvm.br ^continue2 60 /// ^continue2: 61 /// gpu.barrier 62 /// %result = load %workgroup_buffer[%zero] 63 /// return %result 64 /// 65 void rewrite() { 66 rewriter.setInsertionPoint(reduceOp); 67 68 // Compute linear invocation index and workgroup size. 69 Value dimX = getDimOp<gpu::BlockDimOp>("x"); 70 Value dimY = getDimOp<gpu::BlockDimOp>("y"); 71 Value dimZ = getDimOp<gpu::BlockDimOp>("z"); 72 Value tidX = getDimOp<gpu::ThreadIdOp>("x"); 73 Value tidY = getDimOp<gpu::ThreadIdOp>("y"); 74 Value tidZ = getDimOp<gpu::ThreadIdOp>("z"); 75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); 76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); 77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); 78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); 79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); 80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); 81 82 // Compute lane id (invocation id withing the subgroup). 83 Value subgroupMask = 84 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); 85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); 86 Value isFirstLane = 87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, 88 create<arith::ConstantIntOp>(0, int32Type)); 89 90 Value numThreadsWithSmallerSubgroupId = 91 create<arith::SubIOp>(invocationIdx, laneId); 92 // The number of active invocations starting from the current subgroup. 93 // The consumers do not require the value to be clamped to the size of the 94 // subgroup. 95 Value activeWidth = 96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 97 98 // Create factory for op which accumulates to values. 99 AccumulatorFactory accumFactory = getFactory(); 100 assert(accumFactory && "failed to create accumulator factory"); 101 102 // Reduce elements within each subgroup to produce the intermediate results. 103 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 104 reduceOp.value(), accumFactory); 105 106 // Add workgroup buffer to parent function for intermediate result. 107 Value buffer = createWorkgroupBuffer(); 108 109 // Write the intermediate results to workgroup memory, using the first lane 110 // of each subgroup. 111 createPredicatedBlock(isFirstLane, [&] { 112 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 113 Value index = create<arith::IndexCastOp>(indexType, subgroupId); 114 create<memref::StoreOp>(subgroupReduce, buffer, index); 115 }); 116 create<gpu::BarrierOp>(); 117 118 // Compute number of active subgroups. 119 Value biasedBlockSize = 120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); 121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 123 invocationIdx, numSubgroups); 124 125 // Use the first numSubgroups invocations to reduce the intermediate results 126 // from workgroup memory. The final result is written to workgroup memory 127 // again. 128 Value zero = create<arith::ConstantIndexOp>(0); 129 createPredicatedBlock(isValidSubgroup, [&] { 130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx); 131 Value value = create<memref::LoadOp>(valueType, buffer, index); 132 Value result = 133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 134 create<memref::StoreOp>(result, buffer, zero); 135 }); 136 137 // Synchronize workgroup and load result from workgroup memory. 138 create<gpu::BarrierOp>(); 139 Value result = create<memref::LoadOp>(valueType, buffer, zero); 140 141 rewriter.replaceOp(reduceOp, result); 142 } 143 144 private: 145 // Shortcut to create an op from rewriter using loc as the first argument. 146 template <typename T, typename... Args> 147 T create(Args... args) { 148 return rewriter.create<T>(loc, std::forward<Args>(args)...); 149 } 150 151 // Creates dimension op of type T, with the result casted to int32. 152 template <typename T> 153 Value getDimOp(StringRef dimension) { 154 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension)); 155 return create<arith::IndexCastOp>(int32Type, dim); 156 } 157 158 /// Adds type to funcOp's workgroup attributions. 159 Value createWorkgroupBuffer() { 160 int workgroupMemoryAddressSpace = 161 gpu::GPUDialect::getWorkgroupAddressSpace(); 162 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, 163 workgroupMemoryAddressSpace); 164 return funcOp.addWorkgroupAttribution(bufferType); 165 } 166 167 /// Returns an accumulator factory using either the op attribute or the body 168 /// region. 169 AccumulatorFactory getFactory() { 170 auto &body = reduceOp.body(); 171 if (!body.empty()) 172 return getFactory(body); 173 auto opAttr = reduceOp.op(); 174 if (opAttr) 175 return getFactory(*opAttr); 176 return AccumulatorFactory(); 177 } 178 179 /// Returns an accumulator factory that clones the body. The body's entry 180 /// block is expected to have 2 arguments. The gpu.yield return the 181 /// accumulated value of the same type. 182 AccumulatorFactory getFactory(Region &body) { 183 return AccumulatorFactory([&](Value lhs, Value rhs) { 184 Block *block = rewriter.getInsertionBlock(); 185 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 186 187 // Insert accumulator body between split block. 188 BlockAndValueMapping mapping; 189 mapping.map(body.getArgument(0), lhs); 190 mapping.map(body.getArgument(1), rhs); 191 rewriter.cloneRegionBefore(body, *split->getParent(), 192 split->getIterator(), mapping); 193 194 // Add branch before inserted body, into body. 195 block = block->getNextNode(); 196 create<BranchOp>(block, ValueRange()); 197 198 // Replace all gpu.yield ops with branch out of body. 199 for (; block != split; block = block->getNextNode()) { 200 Operation *terminator = block->getTerminator(); 201 if (!isa<gpu::YieldOp>(terminator)) 202 continue; 203 rewriter.setInsertionPointToEnd(block); 204 rewriter.replaceOpWithNewOp<BranchOp>( 205 terminator, split, ValueRange(terminator->getOperand(0))); 206 } 207 208 // Return accumulator result. 209 rewriter.setInsertionPointToStart(split); 210 return split->addArgument(lhs.getType()); 211 }); 212 } 213 214 /// Returns an accumulator factory that creates an op specified by opName. 215 AccumulatorFactory getFactory(StringRef opName) { 216 bool isFloatingPoint = valueType.isa<FloatType>(); 217 if (opName == "add") 218 return isFloatingPoint ? getFactory<arith::AddFOp>() 219 : getFactory<arith::AddIOp>(); 220 if (opName == "mul") 221 return isFloatingPoint ? getFactory<arith::MulFOp>() 222 : getFactory<arith::MulIOp>(); 223 if (opName == "and") { 224 return getFactory<arith::AndIOp>(); 225 } 226 if (opName == "or") { 227 return getFactory<arith::OrIOp>(); 228 } 229 if (opName == "xor") { 230 return getFactory<arith::XOrIOp>(); 231 } 232 if (opName == "max") { 233 return isFloatingPoint 234 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 235 arith::CmpFPredicate::UGT>() 236 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 237 arith::CmpIPredicate::ugt>(); 238 } 239 if (opName == "min") { 240 return isFloatingPoint 241 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 242 arith::CmpFPredicate::ULT>() 243 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 244 arith::CmpIPredicate::ult>(); 245 } 246 return AccumulatorFactory(); 247 } 248 249 /// Returns an accumulator factory that creates an op of type T. 250 template <typename T> 251 AccumulatorFactory getFactory() { 252 return [&](Value lhs, Value rhs) { 253 return create<T>(lhs.getType(), lhs, rhs); 254 }; 255 } 256 257 /// Returns an accumulator for comparison such as min, max. T is the type 258 /// of the compare op. 259 template <typename T, typename PredicateEnum, PredicateEnum predicate> 260 AccumulatorFactory getCmpFactory() const { 261 return [&](Value lhs, Value rhs) { 262 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 263 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 264 }; 265 } 266 267 /// Creates an if-block skeleton and calls the two factories to generate the 268 /// ops in the `then` and `else` block.. 269 /// 270 /// llvm.cond_br %condition, ^then, ^continue 271 /// ^then: 272 /// %then_operands = `thenOpsFactory()` 273 /// llvm.br ^continue(%then_operands) 274 /// ^else: 275 /// %else_operands = `elseOpsFactory()` 276 /// llvm.br ^continue(%else_operands) 277 /// ^continue(%block_operands): 278 /// 279 template <typename ThenOpsFactory, typename ElseOpsFactory> 280 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 281 ElseOpsFactory &&elseOpsFactory) { 282 Block *currentBlock = rewriter.getInsertionBlock(); 283 auto currentPoint = rewriter.getInsertionPoint(); 284 285 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 286 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 287 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 288 289 rewriter.setInsertionPointToEnd(currentBlock); 290 create<CondBranchOp>(condition, thenBlock, 291 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 292 /*falseOperands=*/ArrayRef<Value>()); 293 294 rewriter.setInsertionPointToStart(thenBlock); 295 auto thenOperands = thenOpsFactory(); 296 create<BranchOp>(continueBlock, thenOperands); 297 298 rewriter.setInsertionPointToStart(elseBlock); 299 auto elseOperands = elseOpsFactory(); 300 create<BranchOp>(continueBlock, elseOperands); 301 302 assert(thenOperands.size() == elseOperands.size()); 303 rewriter.setInsertionPointToStart(continueBlock); 304 for (auto operand : thenOperands) 305 continueBlock->addArgument(operand.getType()); 306 } 307 308 /// Shortcut for createIf with empty else block and no block operands. 309 template <typename Factory> 310 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 311 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 312 "predicatedOpsFactory should not return any value"); 313 createIf( 314 condition, 315 [&] { 316 predicatedOpsFactory(); 317 return ArrayRef<Value>(); 318 }, 319 [&] { return ArrayRef<Value>(); }); 320 } 321 322 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 323 /// the entire subgroup if activeWidth is larger than the subgroup width. 324 /// The first lane returns the result, all others return values are undefined. 325 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 326 AccumulatorFactory &accumFactory) { 327 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 328 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 329 activeWidth, subgroupSize); 330 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 331 auto xorAttr = rewriter.getStringAttr("xor"); 332 333 createIf( 334 isPartialSubgroup, 335 // Generate reduction over a (potentially) partial subgroup. 336 [&] { 337 Value value = operand; 338 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 339 // lane is within the active range. The accumulated value is available 340 // in the first lane. 341 for (int i = 1; i < kSubgroupSize; i <<= 1) { 342 Value offset = create<arith::ConstantIntOp>(i, int32Type); 343 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 344 activeWidth, xorAttr); 345 // Skip the accumulation if the shuffle op read from a lane outside 346 // of the active range. 347 createIf( 348 shuffleOp.getResult(1), 349 [&] { 350 return SmallVector<Value, 1>{ 351 accumFactory(value, shuffleOp.getResult(0))}; 352 }, 353 [&] { return llvm::makeArrayRef(value); }); 354 value = rewriter.getInsertionBlock()->getArgument(0); 355 } 356 return SmallVector<Value, 1>{value}; 357 }, 358 // Generate a reduction over the entire subgroup. This is a 359 // specialization of the above reduction with unconditional 360 // accumulation. 361 [&] { 362 Value value = operand; 363 for (int i = 1; i < kSubgroupSize; i <<= 1) { 364 Value offset = create<arith::ConstantIntOp>(i, int32Type); 365 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 366 subgroupSize, xorAttr); 367 value = accumFactory(value, shuffleOp.getResult(0)); 368 } 369 return SmallVector<Value, 1>{value}; 370 }); 371 return rewriter.getInsertionBlock()->getArgument(0); 372 } 373 374 /// Returns value divided by the subgroup size (i.e. 32). 375 Value getDivideBySubgroupSize(Value value) { 376 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 377 return create<arith::DivSIOp>(int32Type, value, subgroupSize); 378 } 379 380 gpu::GPUFuncOp funcOp; 381 gpu::AllReduceOp reduceOp; 382 PatternRewriter &rewriter; 383 384 Location loc; 385 Type valueType; 386 Type indexType; 387 IntegerType int32Type; 388 389 static constexpr int kSubgroupSize = 32; 390 }; 391 392 struct GpuAllReduceConversion : public RewritePattern { 393 explicit GpuAllReduceConversion(MLIRContext *context) 394 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 395 396 LogicalResult matchAndRewrite(Operation *op, 397 PatternRewriter &rewriter) const override { 398 auto funcOp = cast<gpu::GPUFuncOp>(op); 399 auto callback = [&](gpu::AllReduceOp reduceOp) { 400 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 401 // Performing a rewrite invalidates the walk iterator. Report interrupt 402 // so that we can start a new walk until all all_reduce ops are replaced. 403 return WalkResult::interrupt(); 404 }; 405 while (funcOp.walk(callback).wasInterrupted()) { 406 } 407 return success(); 408 } 409 }; 410 } // namespace 411 412 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { 413 patterns.add<GpuAllReduceConversion>(patterns.getContext()); 414 } 415