1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/GPU/Passes.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/IR/BlockAndValueMapping.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace mlir; 24 25 namespace { 26 27 struct GpuAllReduceRewriter { 28 using AccumulatorFactory = std::function<Value(Value, Value)>; 29 30 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, 31 PatternRewriter &rewriter_) 32 : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), 33 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 34 indexType(IndexType::get(reduceOp.getContext())), 35 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} 36 37 /// Creates an all_reduce across the workgroup. 38 /// 39 /// First reduce the elements within a subgroup. The first invocation of each 40 /// subgroup writes the intermediate result to workgroup memory. After 41 /// synchronizing the workgroup, the first subgroup reduces the values from 42 /// workgroup memory. The result is broadcasted to all invocations through 43 /// workgroup memory. 44 /// 45 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 46 /// cond_br %is_first_lane, ^then1, ^continue1 47 /// ^then1: 48 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 49 /// br ^continue1 50 /// ^continue1: 51 /// gpu.barrier 52 /// %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups 53 /// cond_br %is_valid_subgroup, ^then2, ^continue2 54 /// ^then2: 55 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 56 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 57 /// store %all_reduce, %workgroup_buffer[%zero] 58 /// llvm.br ^continue2 59 /// ^continue2: 60 /// gpu.barrier 61 /// %result = load %workgroup_buffer[%zero] 62 /// return %result 63 /// 64 void rewrite() { 65 rewriter.setInsertionPoint(reduceOp); 66 67 // Compute linear invocation index and workgroup size. 68 Value dimX = getDimOp<gpu::BlockDimOp>("x"); 69 Value dimY = getDimOp<gpu::BlockDimOp>("y"); 70 Value dimZ = getDimOp<gpu::BlockDimOp>("z"); 71 Value tidX = getDimOp<gpu::ThreadIdOp>("x"); 72 Value tidY = getDimOp<gpu::ThreadIdOp>("y"); 73 Value tidZ = getDimOp<gpu::ThreadIdOp>("z"); 74 Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY); 75 Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY); 76 Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX); 77 Value tmp4 = create<MulIOp>(int32Type, dimX, dimY); 78 Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX); 79 Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ); 80 81 // Compute lane id (invocation id withing the subgroup). 82 Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type); 83 Value laneId = create<AndOp>(invocationIdx, subgroupMask); 84 Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId, 85 create<ConstantIntOp>(0, int32Type)); 86 87 Value numThreadsWithSmallerSubgroupId = 88 create<SubIOp>(invocationIdx, laneId); 89 // The number of active invocations starting from the current subgroup. 90 // The consumers do not require the value to be clamped to the size of the 91 // subgroup. 92 Value activeWidth = 93 create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 94 95 // Create factory for op which accumulates to values. 96 AccumulatorFactory accumFactory = getFactory(); 97 assert(accumFactory && "failed to create accumulator factory"); 98 99 // Reduce elements within each subgroup to produce the intermediate results. 100 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 101 reduceOp.value(), accumFactory); 102 103 // Add workgroup buffer to parent function for intermediate result. 104 Value buffer = createWorkgroupBuffer(); 105 106 // Write the intermediate results to workgroup memory, using the first lane 107 // of each subgroup. 108 createPredicatedBlock(isFirstLane, [&] { 109 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 110 Value index = create<IndexCastOp>(indexType, subgroupId); 111 create<memref::StoreOp>(subgroupReduce, buffer, index); 112 }); 113 create<gpu::BarrierOp>(); 114 115 // Compute number of active subgroups. 116 Value biasedBlockSize = 117 create<AddIOp>(int32Type, workgroupSize, subgroupMask); 118 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 119 Value isValidSubgroup = 120 create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups); 121 122 // Use the first numSubgroups invocations to reduce the intermediate results 123 // from workgroup memory. The final result is written to workgroup memory 124 // again. 125 Value zero = create<ConstantIndexOp>(0); 126 createPredicatedBlock(isValidSubgroup, [&] { 127 Value index = create<IndexCastOp>(indexType, invocationIdx); 128 Value value = create<memref::LoadOp>(valueType, buffer, index); 129 Value result = 130 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 131 create<memref::StoreOp>(result, buffer, zero); 132 }); 133 134 // Synchronize workgroup and load result from workgroup memory. 135 create<gpu::BarrierOp>(); 136 Value result = create<memref::LoadOp>(valueType, buffer, zero); 137 138 rewriter.replaceOp(reduceOp, result); 139 } 140 141 private: 142 // Shortcut to create an op from rewriter using loc as the first argument. 143 template <typename T, typename... Args> 144 T create(Args... args) { 145 return rewriter.create<T>(loc, std::forward<Args>(args)...); 146 } 147 148 // Creates dimension op of type T, with the result casted to int32. 149 template <typename T> 150 Value getDimOp(StringRef dimension) { 151 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension)); 152 return create<IndexCastOp>(int32Type, dim); 153 } 154 155 /// Adds type to funcOp's workgroup attributions. 156 Value createWorkgroupBuffer() { 157 int workgroupMemoryAddressSpace = 158 gpu::GPUDialect::getWorkgroupAddressSpace(); 159 auto bufferType = 160 MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{}, 161 workgroupMemoryAddressSpace); 162 return funcOp.addWorkgroupAttribution(bufferType); 163 } 164 165 /// Returns an accumulator factory using either the op attribute or the body 166 /// region. 167 AccumulatorFactory getFactory() { 168 auto &body = reduceOp.body(); 169 if (!body.empty()) 170 return getFactory(body); 171 auto opAttr = reduceOp.op(); 172 if (opAttr) 173 return getFactory(*opAttr); 174 return AccumulatorFactory(); 175 } 176 177 /// Returns an accumulator factory that clones the body. The body's entry 178 /// block is expected to have 2 arguments. The gpu.yield return the 179 /// accumulated value of the same type. 180 AccumulatorFactory getFactory(Region &body) { 181 return AccumulatorFactory([&](Value lhs, Value rhs) { 182 Block *block = rewriter.getInsertionBlock(); 183 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 184 185 // Insert accumulator body between split block. 186 BlockAndValueMapping mapping; 187 mapping.map(body.getArgument(0), lhs); 188 mapping.map(body.getArgument(1), rhs); 189 rewriter.cloneRegionBefore(body, *split->getParent(), 190 split->getIterator(), mapping); 191 192 // Add branch before inserted body, into body. 193 block = block->getNextNode(); 194 create<BranchOp>(block, ValueRange()); 195 196 // Replace all gpu.yield ops with branch out of body. 197 for (; block != split; block = block->getNextNode()) { 198 Operation *terminator = block->getTerminator(); 199 if (!isa<gpu::YieldOp>(terminator)) 200 continue; 201 rewriter.setInsertionPointToEnd(block); 202 rewriter.replaceOpWithNewOp<BranchOp>( 203 terminator, split, ValueRange(terminator->getOperand(0))); 204 } 205 206 // Return accumulator result. 207 rewriter.setInsertionPointToStart(split); 208 return split->addArgument(lhs.getType()); 209 }); 210 } 211 212 /// Returns an accumulator factory that creates an op specified by opName. 213 AccumulatorFactory getFactory(StringRef opName) { 214 bool isFloatingPoint = valueType.isa<FloatType>(); 215 if (opName == "add") 216 return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>(); 217 if (opName == "mul") 218 return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>(); 219 if (opName == "and") { 220 return getFactory<AndOp>(); 221 } 222 if (opName == "or") { 223 return getFactory<OrOp>(); 224 } 225 if (opName == "xor") { 226 return getFactory<XOrOp>(); 227 } 228 if (opName == "max") { 229 return isFloatingPoint 230 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>() 231 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>(); 232 } 233 if (opName == "min") { 234 return isFloatingPoint 235 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>() 236 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>(); 237 } 238 return AccumulatorFactory(); 239 } 240 241 /// Returns an accumulator factory that creates an op of type T. 242 template <typename T> 243 AccumulatorFactory getFactory() { 244 return [&](Value lhs, Value rhs) { 245 return create<T>(lhs.getType(), lhs, rhs); 246 }; 247 } 248 249 /// Returns an accumulator for comparison such as min, max. T is the type 250 /// of the compare op. 251 template <typename T, typename PredicateEnum, PredicateEnum predicate> 252 AccumulatorFactory getCmpFactory() const { 253 return [&](Value lhs, Value rhs) { 254 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 255 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 256 }; 257 } 258 259 /// Creates an if-block skeleton and calls the two factories to generate the 260 /// ops in the `then` and `else` block.. 261 /// 262 /// llvm.cond_br %condition, ^then, ^continue 263 /// ^then: 264 /// %then_operands = `thenOpsFactory()` 265 /// llvm.br ^continue(%then_operands) 266 /// ^else: 267 /// %else_operands = `elseOpsFactory()` 268 /// llvm.br ^continue(%else_operands) 269 /// ^continue(%block_operands): 270 /// 271 template <typename ThenOpsFactory, typename ElseOpsFactory> 272 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 273 ElseOpsFactory &&elseOpsFactory) { 274 Block *currentBlock = rewriter.getInsertionBlock(); 275 auto currentPoint = rewriter.getInsertionPoint(); 276 277 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 278 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 279 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 280 281 rewriter.setInsertionPointToEnd(currentBlock); 282 create<CondBranchOp>(condition, thenBlock, 283 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 284 /*falseOperands=*/ArrayRef<Value>()); 285 286 rewriter.setInsertionPointToStart(thenBlock); 287 auto thenOperands = thenOpsFactory(); 288 create<BranchOp>(continueBlock, thenOperands); 289 290 rewriter.setInsertionPointToStart(elseBlock); 291 auto elseOperands = elseOpsFactory(); 292 create<BranchOp>(continueBlock, elseOperands); 293 294 assert(thenOperands.size() == elseOperands.size()); 295 rewriter.setInsertionPointToStart(continueBlock); 296 for (auto operand : thenOperands) 297 continueBlock->addArgument(operand.getType()); 298 } 299 300 /// Shortcut for createIf with empty else block and no block operands. 301 template <typename Factory> 302 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 303 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 304 "predicatedOpsFactory should not return any value"); 305 createIf( 306 condition, 307 [&] { 308 predicatedOpsFactory(); 309 return ArrayRef<Value>(); 310 }, 311 [&] { return ArrayRef<Value>(); }); 312 } 313 314 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 315 /// the entire subgroup if activeWidth is larger than the subgroup width. 316 /// The first lane returns the result, all others return values are undefined. 317 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 318 AccumulatorFactory &accumFactory) { 319 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 320 Value isPartialSubgroup = 321 create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize); 322 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 323 auto xorAttr = rewriter.getStringAttr("xor"); 324 325 createIf( 326 isPartialSubgroup, 327 // Generate reduction over a (potentially) partial subgroup. 328 [&] { 329 Value value = operand; 330 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 331 // lane is within the active range. The accumulated value is available 332 // in the first lane. 333 for (int i = 1; i < kSubgroupSize; i <<= 1) { 334 Value offset = create<ConstantIntOp>(i, int32Type); 335 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 336 activeWidth, xorAttr); 337 // Skip the accumulation if the shuffle op read from a lane outside 338 // of the active range. 339 createIf( 340 shuffleOp.getResult(1), 341 [&] { 342 return SmallVector<Value, 1>{ 343 accumFactory(value, shuffleOp.getResult(0))}; 344 }, 345 [&] { return llvm::makeArrayRef(value); }); 346 value = rewriter.getInsertionBlock()->getArgument(0); 347 } 348 return SmallVector<Value, 1>{value}; 349 }, 350 // Generate a reduction over the entire subgroup. This is a 351 // specialization of the above reduction with unconditional 352 // accumulation. 353 [&] { 354 Value value = operand; 355 for (int i = 1; i < kSubgroupSize; i <<= 1) { 356 Value offset = create<ConstantIntOp>(i, int32Type); 357 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 358 subgroupSize, xorAttr); 359 value = accumFactory(value, shuffleOp.getResult(0)); 360 } 361 return SmallVector<Value, 1>{value}; 362 }); 363 return rewriter.getInsertionBlock()->getArgument(0); 364 } 365 366 /// Returns value divided by the subgroup size (i.e. 32). 367 Value getDivideBySubgroupSize(Value value) { 368 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 369 return create<SignedDivIOp>(int32Type, value, subgroupSize); 370 } 371 372 gpu::GPUFuncOp funcOp; 373 gpu::AllReduceOp reduceOp; 374 PatternRewriter &rewriter; 375 376 Location loc; 377 Type valueType; 378 Type indexType; 379 Type int32Type; 380 381 static constexpr int kSubgroupSize = 32; 382 }; 383 384 struct GpuAllReduceConversion : public RewritePattern { 385 explicit GpuAllReduceConversion(MLIRContext *context) 386 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 387 388 LogicalResult matchAndRewrite(Operation *op, 389 PatternRewriter &rewriter) const override { 390 auto funcOp = cast<gpu::GPUFuncOp>(op); 391 auto callback = [&](gpu::AllReduceOp reduceOp) { 392 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 393 // Performing a rewrite invalidates the walk iterator. Report interrupt 394 // so that we can start a new walk until all all_reduce ops are replaced. 395 return WalkResult::interrupt(); 396 }; 397 while (funcOp.walk(callback).wasInterrupted()) { 398 } 399 return success(); 400 } 401 }; 402 } // namespace 403 404 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { 405 patterns.add<GpuAllReduceConversion>(patterns.getContext()); 406 } 407