1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements gradual lowering of `gpu.subgroup_reduce` ops. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/GPU/Transforms/Passes.h" 16 #include "mlir/Dialect/GPU/Utils/GPUUtils.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/Location.h" 20 #include "mlir/IR/PatternMatch.h" 21 #include "mlir/IR/TypeUtilities.h" 22 #include "llvm/Support/FormatVariadic.h" 23 #include "llvm/Support/MathExtras.h" 24 #include <cassert> 25 #include <cstdint> 26 27 using namespace mlir; 28 29 namespace { 30 31 /// Example, assumes `maxShuffleBitwidth` equal to 32: 32 /// ``` 33 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16> 34 /// ==> 35 /// %v0 = arith.constant dense<0.0> : vector<3xf16> 36 /// %e0 = vector.extract_strided_slice %x 37 /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32> 38 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16> 39 /// %v1 = vector.insert_strided_slice %r0, %v0 40 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32> 41 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16> 42 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16 43 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16> 44 /// ``` 45 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { 46 BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth, 47 PatternBenefit benefit) 48 : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) { 49 } 50 51 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 52 PatternRewriter &rewriter) const override { 53 auto vecTy = dyn_cast<VectorType>(op.getType()); 54 if (!vecTy || vecTy.getNumElements() < 2) 55 return rewriter.notifyMatchFailure(op, "not a multi-element reduction"); 56 57 assert(vecTy.getRank() == 1 && "Unexpected vector type"); 58 assert(!vecTy.isScalable() && "Unexpected vector type"); 59 60 Type elemTy = vecTy.getElementType(); 61 unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); 62 if (elemBitwidth >= maxShuffleBitwidth) 63 return rewriter.notifyMatchFailure( 64 op, llvm::formatv("element type too large ({0}), cannot break down " 65 "into vectors of bitwidth {1} or less", 66 elemBitwidth, maxShuffleBitwidth)); 67 68 unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth; 69 assert(elementsPerShuffle >= 1); 70 71 unsigned numNewReductions = 72 llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle); 73 assert(numNewReductions >= 1); 74 if (numNewReductions == 1) 75 return rewriter.notifyMatchFailure(op, "nothing to break down"); 76 77 Location loc = op.getLoc(); 78 Value res = 79 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy)); 80 81 for (unsigned i = 0; i != numNewReductions; ++i) { 82 int64_t startIdx = i * elementsPerShuffle; 83 int64_t endIdx = 84 std::min(startIdx + elementsPerShuffle, vecTy.getNumElements()); 85 int64_t numElems = endIdx - startIdx; 86 87 Value extracted; 88 if (numElems == 1) { 89 extracted = 90 rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx); 91 } else { 92 extracted = rewriter.create<vector::ExtractStridedSliceOp>( 93 loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems, 94 /*strides=*/1); 95 } 96 97 Value reduce = rewriter.create<gpu::SubgroupReduceOp>( 98 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), 99 op.getClusterStride()); 100 if (numElems == 1) { 101 res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx); 102 continue; 103 } 104 105 res = rewriter.create<vector::InsertStridedSliceOp>( 106 loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); 107 } 108 109 rewriter.replaceOp(op, res); 110 return success(); 111 } 112 113 private: 114 unsigned maxShuffleBitwidth = 0; 115 }; 116 117 /// Example: 118 /// ``` 119 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32> 120 /// ==> 121 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32> 122 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32 123 /// %a = vector.broadcast %r0 : f32 to vector<1xf32> 124 /// ``` 125 struct ScalarizeSingleElementReduce final 126 : OpRewritePattern<gpu::SubgroupReduceOp> { 127 using OpRewritePattern::OpRewritePattern; 128 129 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 130 PatternRewriter &rewriter) const override { 131 auto vecTy = dyn_cast<VectorType>(op.getType()); 132 if (!vecTy || vecTy.getNumElements() != 1) 133 return rewriter.notifyMatchFailure(op, "not a single-element reduction"); 134 135 assert(vecTy.getRank() == 1 && "Unexpected vector type"); 136 assert(!vecTy.isScalable() && "Unexpected vector type"); 137 Location loc = op.getLoc(); 138 Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0); 139 Value reduce = rewriter.create<gpu::SubgroupReduceOp>( 140 loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), 141 op.getClusterStride()); 142 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce); 143 return success(); 144 } 145 }; 146 147 struct ClusterInfo { 148 unsigned clusterStride; 149 unsigned clusterSize; 150 unsigned subgroupSize; 151 }; 152 153 static FailureOr<ClusterInfo> 154 getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) { 155 assert(llvm::isPowerOf2_32(subgroupSize)); 156 157 std::optional<uint32_t> clusterSize = op.getClusterSize(); 158 assert(!clusterSize || 159 llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this. 160 if (clusterSize && *clusterSize > subgroupSize) 161 return op.emitOpError() 162 << "cluster size " << *clusterSize 163 << " is greater than subgroup size " << subgroupSize; 164 unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize); 165 166 auto clusterStride = op.getClusterStride(); 167 assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this. 168 if (clusterStride >= subgroupSize) 169 return op.emitOpError() 170 << "cluster stride " << clusterStride 171 << " is not less than subgroup size " << subgroupSize; 172 173 return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize}; 174 } 175 176 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` 177 /// and `unpackFn` to convert to the native shuffle type and to the reduction 178 /// type, respectively. For example, with `input` of type `f16`, `packFn` could 179 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` 180 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that 181 /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of 182 /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for 183 /// lanes within a cluster, reducing all lanes in each cluster in parallel. 184 Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc, 185 Value input, gpu::AllReduceOperation mode, 186 const ClusterInfo &ci, 187 function_ref<Value(Value)> packFn, 188 function_ref<Value(Value)> unpackFn) { 189 // Lane value always stays in the original type. We use it to perform arith 190 // reductions. 191 Value laneVal = input; 192 // Parallel reduction using butterfly shuffles. 193 for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize; 194 i <<= 1) { 195 Value shuffled = builder 196 .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, 197 /*width=*/ci.subgroupSize, 198 /*mode=*/gpu::ShuffleMode::XOR) 199 .getShuffleResult(); 200 laneVal = vector::makeArithReduction(builder, loc, 201 gpu::convertReductionKind(mode), 202 laneVal, unpackFn(shuffled)); 203 assert(laneVal.getType() == input.getType()); 204 } 205 206 return laneVal; 207 } 208 209 /// Lowers scalar gpu subgroup reductions to a series of shuffles. 210 struct ScalarSubgroupReduceToShuffles final 211 : OpRewritePattern<gpu::SubgroupReduceOp> { 212 ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, 213 unsigned shuffleBitwidth, bool matchClustered, 214 PatternBenefit benefit) 215 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), 216 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} 217 218 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 219 PatternRewriter &rewriter) const override { 220 if (op.getClusterSize().has_value() != matchClustered) { 221 return rewriter.notifyMatchFailure( 222 op, llvm::formatv("op is {0}clustered but pattern is configured to " 223 "only match {1}clustered ops", 224 matchClustered ? "non-" : "", 225 matchClustered ? "" : "non-")); 226 } 227 228 auto ci = getAndValidateClusterInfo(op, subgroupSize); 229 if (failed(ci)) 230 return failure(); 231 232 Type valueTy = op.getType(); 233 unsigned elemBitwidth = 234 getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); 235 if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) 236 return rewriter.notifyMatchFailure( 237 op, "value type is not a compatible scalar"); 238 239 Location loc = op.getLoc(); 240 // Since this is already a native shuffle scalar, no packing is necessary. 241 if (elemBitwidth == shuffleBitwidth) { 242 auto identityFn = [](Value v) { return v; }; 243 rewriter.replaceOp(op, createSubgroupShuffleReduction( 244 rewriter, loc, op.getValue(), op.getOp(), *ci, 245 identityFn, identityFn)); 246 return success(); 247 } 248 249 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); 250 auto equivIntType = rewriter.getIntegerType(elemBitwidth); 251 auto packFn = [loc, &rewriter, equivIntType, 252 shuffleIntType](Value unpackedVal) -> Value { 253 auto asInt = 254 rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal); 255 return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt); 256 }; 257 auto unpackFn = [loc, &rewriter, equivIntType, 258 valueTy](Value packedVal) -> Value { 259 auto asInt = 260 rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal); 261 return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt); 262 }; 263 264 rewriter.replaceOp( 265 op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(), 266 op.getOp(), *ci, packFn, unpackFn)); 267 return success(); 268 } 269 270 private: 271 unsigned subgroupSize = 0; 272 unsigned shuffleBitwidth = 0; 273 bool matchClustered = false; 274 }; 275 276 /// Lowers vector gpu subgroup reductions to a series of shuffles. 277 struct VectorSubgroupReduceToShuffles final 278 : OpRewritePattern<gpu::SubgroupReduceOp> { 279 VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, 280 unsigned shuffleBitwidth, bool matchClustered, 281 PatternBenefit benefit) 282 : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), 283 shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} 284 285 LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 286 PatternRewriter &rewriter) const override { 287 if (op.getClusterSize().has_value() != matchClustered) { 288 return rewriter.notifyMatchFailure( 289 op, llvm::formatv("op is {0}clustered but pattern is configured to " 290 "only match {1}clustered ops", 291 matchClustered ? "non-" : "", 292 matchClustered ? "" : "non-")); 293 } 294 295 auto ci = getAndValidateClusterInfo(op, subgroupSize); 296 if (failed(ci)) 297 return failure(); 298 299 auto vecTy = dyn_cast<VectorType>(op.getType()); 300 if (!vecTy) 301 return rewriter.notifyMatchFailure(op, "value type is not a vector"); 302 303 unsigned vecBitwidth = 304 vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); 305 if (vecBitwidth > shuffleBitwidth) 306 return rewriter.notifyMatchFailure( 307 op, 308 llvm::formatv("vector type bitwidth too large ({0}), cannot lower " 309 "to shuffles of size {1}", 310 vecBitwidth, shuffleBitwidth)); 311 312 unsigned elementsPerShuffle = 313 shuffleBitwidth / vecTy.getElementTypeBitWidth(); 314 if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) 315 return rewriter.notifyMatchFailure( 316 op, "shuffle bitwidth is not a multiple of the element bitwidth"); 317 318 Location loc = op.getLoc(); 319 320 // If the reduced type is smaller than the native shuffle size, extend it, 321 // perform the shuffles, and extract at the end. 322 auto extendedVecTy = VectorType::get( 323 static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType()); 324 Value extendedInput = op.getValue(); 325 if (vecBitwidth < shuffleBitwidth) { 326 auto zero = rewriter.create<arith::ConstantOp>( 327 loc, rewriter.getZeroAttr(extendedVecTy)); 328 extendedInput = rewriter.create<vector::InsertStridedSliceOp>( 329 loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); 330 } 331 332 auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); 333 auto shuffleVecType = VectorType::get(1, shuffleIntType); 334 335 auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { 336 auto asIntVec = 337 rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal); 338 return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0); 339 }; 340 auto unpackFn = [loc, &rewriter, shuffleVecType, 341 extendedVecTy](Value packedVal) -> Value { 342 auto asIntVec = 343 rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal); 344 return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec); 345 }; 346 347 Value res = createSubgroupShuffleReduction( 348 rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn); 349 350 if (vecBitwidth < shuffleBitwidth) { 351 res = rewriter.create<vector::ExtractStridedSliceOp>( 352 loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), 353 /*strides=*/1); 354 } 355 356 rewriter.replaceOp(op, res); 357 return success(); 358 } 359 360 private: 361 unsigned subgroupSize = 0; 362 unsigned shuffleBitwidth = 0; 363 bool matchClustered = false; 364 }; 365 } // namespace 366 367 void mlir::populateGpuBreakDownSubgroupReducePatterns( 368 RewritePatternSet &patterns, unsigned maxShuffleBitwidth, 369 PatternBenefit benefit) { 370 patterns.add<BreakDownSubgroupReduce>(patterns.getContext(), 371 maxShuffleBitwidth, benefit); 372 patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit); 373 } 374 375 void mlir::populateGpuLowerSubgroupReduceToShufflePatterns( 376 RewritePatternSet &patterns, unsigned subgroupSize, 377 unsigned shuffleBitwidth, PatternBenefit benefit) { 378 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( 379 patterns.getContext(), subgroupSize, shuffleBitwidth, 380 /*matchClustered=*/false, benefit); 381 } 382 383 void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns( 384 RewritePatternSet &patterns, unsigned subgroupSize, 385 unsigned shuffleBitwidth, PatternBenefit benefit) { 386 patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( 387 patterns.getContext(), subgroupSize, shuffleBitwidth, 388 /*matchClustered=*/true, benefit); 389 } 390