12af186f9SJakub Kuderski //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===// 22af186f9SJakub Kuderski // 32af186f9SJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42af186f9SJakub Kuderski // See https://llvm.org/LICENSE.txt for license information. 52af186f9SJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62af186f9SJakub Kuderski // 72af186f9SJakub Kuderski //===----------------------------------------------------------------------===// 82af186f9SJakub Kuderski // 92af186f9SJakub Kuderski // Implements gradual lowering of `gpu.subgroup_reduce` ops. 102af186f9SJakub Kuderski // 112af186f9SJakub Kuderski //===----------------------------------------------------------------------===// 122af186f9SJakub Kuderski 132af186f9SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 142af186f9SJakub Kuderski #include "mlir/Dialect/GPU/IR/GPUDialect.h" 152af186f9SJakub Kuderski #include "mlir/Dialect/GPU/Transforms/Passes.h" 16*bc29fc93SPetr Kurapov #include "mlir/Dialect/GPU/Utils/GPUUtils.h" 172af186f9SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h" 18c0345b46SJakub Kuderski #include "mlir/IR/BuiltinTypes.h" 192af186f9SJakub Kuderski #include "mlir/IR/Location.h" 202af186f9SJakub Kuderski #include "mlir/IR/PatternMatch.h" 21c0345b46SJakub Kuderski #include "mlir/IR/TypeUtilities.h" 222af186f9SJakub Kuderski #include "llvm/Support/FormatVariadic.h" 232af186f9SJakub Kuderski #include "llvm/Support/MathExtras.h" 242af186f9SJakub Kuderski #include <cassert> 25c0345b46SJakub Kuderski #include <cstdint> 262af186f9SJakub Kuderski 272af186f9SJakub Kuderski using namespace mlir; 282af186f9SJakub Kuderski 292af186f9SJakub Kuderski namespace { 302af186f9SJakub Kuderski 312af186f9SJakub Kuderski /// Example, assumes `maxShuffleBitwidth` equal to 32: 322af186f9SJakub Kuderski /// ``` 332af186f9SJakub Kuderski /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16> 342af186f9SJakub Kuderski /// ==> 352af186f9SJakub Kuderski /// %v0 = arith.constant dense<0.0> : vector<3xf16> 362af186f9SJakub Kuderski /// %e0 = vector.extract_strided_slice %x 372af186f9SJakub Kuderski /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32> 382af186f9SJakub Kuderski /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16> 392af186f9SJakub Kuderski /// %v1 = vector.insert_strided_slice %r0, %v0 402af186f9SJakub Kuderski /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32> 412af186f9SJakub Kuderski /// %e1 = vector.extract %x[2] : f16 from vector<2xf16> 422af186f9SJakub Kuderski /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16 432af186f9SJakub Kuderski /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16> 442af186f9SJakub Kuderski /// ``` 452af186f9SJakub Kuderski struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { 462af186f9SJakub Kuderski BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth, 472af186f9SJakub Kuderski PatternBenefit benefit) 482af186f9SJakub Kuderski : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) { 492af186f9SJakub Kuderski } 502af186f9SJakub Kuderski 512af186f9SJakub Kuderski LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 522af186f9SJakub Kuderski PatternRewriter &rewriter) const override { 532af186f9SJakub Kuderski auto vecTy = dyn_cast<VectorType>(op.getType()); 542af186f9SJakub Kuderski if (!vecTy || vecTy.getNumElements() < 2) 552af186f9SJakub Kuderski return rewriter.notifyMatchFailure(op, "not a multi-element reduction"); 562af186f9SJakub Kuderski 572af186f9SJakub Kuderski assert(vecTy.getRank() == 1 && "Unexpected vector type"); 582af186f9SJakub Kuderski assert(!vecTy.isScalable() && "Unexpected vector type"); 592af186f9SJakub Kuderski 602af186f9SJakub Kuderski Type elemTy = vecTy.getElementType(); 612af186f9SJakub Kuderski unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); 622af186f9SJakub Kuderski if (elemBitwidth >= maxShuffleBitwidth) 632af186f9SJakub Kuderski return rewriter.notifyMatchFailure( 64c0345b46SJakub Kuderski op, llvm::formatv("element type too large ({0}), cannot break down " 652af186f9SJakub Kuderski "into vectors of bitwidth {1} or less", 662af186f9SJakub Kuderski elemBitwidth, maxShuffleBitwidth)); 672af186f9SJakub Kuderski 682af186f9SJakub Kuderski unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth; 692af186f9SJakub Kuderski assert(elementsPerShuffle >= 1); 702af186f9SJakub Kuderski 712af186f9SJakub Kuderski unsigned numNewReductions = 722af186f9SJakub Kuderski llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle); 732af186f9SJakub Kuderski assert(numNewReductions >= 1); 742af186f9SJakub Kuderski if (numNewReductions == 1) 752af186f9SJakub Kuderski return rewriter.notifyMatchFailure(op, "nothing to break down"); 762af186f9SJakub Kuderski 772af186f9SJakub Kuderski Location loc = op.getLoc(); 782af186f9SJakub Kuderski Value res = 792af186f9SJakub Kuderski rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy)); 802af186f9SJakub Kuderski 812af186f9SJakub Kuderski for (unsigned i = 0; i != numNewReductions; ++i) { 822af186f9SJakub Kuderski int64_t startIdx = i * elementsPerShuffle; 832af186f9SJakub Kuderski int64_t endIdx = 842af186f9SJakub Kuderski std::min(startIdx + elementsPerShuffle, vecTy.getNumElements()); 852af186f9SJakub Kuderski int64_t numElems = endIdx - startIdx; 862af186f9SJakub Kuderski 872af186f9SJakub Kuderski Value extracted; 882af186f9SJakub Kuderski if (numElems == 1) { 892af186f9SJakub Kuderski extracted = 902af186f9SJakub Kuderski rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx); 912af186f9SJakub Kuderski } else { 922af186f9SJakub Kuderski extracted = rewriter.create<vector::ExtractStridedSliceOp>( 932af186f9SJakub Kuderski loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems, 942af186f9SJakub Kuderski /*strides=*/1); 952af186f9SJakub Kuderski } 962af186f9SJakub Kuderski 972af186f9SJakub Kuderski Value reduce = rewriter.create<gpu::SubgroupReduceOp>( 983d01f0a3SAndrea Faulds loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), 993d01f0a3SAndrea Faulds op.getClusterStride()); 1002af186f9SJakub Kuderski if (numElems == 1) { 1012af186f9SJakub Kuderski res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx); 1022af186f9SJakub Kuderski continue; 1032af186f9SJakub Kuderski } 1042af186f9SJakub Kuderski 1052af186f9SJakub Kuderski res = rewriter.create<vector::InsertStridedSliceOp>( 1062af186f9SJakub Kuderski loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); 1072af186f9SJakub Kuderski } 1082af186f9SJakub Kuderski 1092af186f9SJakub Kuderski rewriter.replaceOp(op, res); 1102af186f9SJakub Kuderski return success(); 1112af186f9SJakub Kuderski } 1122af186f9SJakub Kuderski 1132af186f9SJakub Kuderski private: 1142af186f9SJakub Kuderski unsigned maxShuffleBitwidth = 0; 1152af186f9SJakub Kuderski }; 1162af186f9SJakub Kuderski 1172af186f9SJakub Kuderski /// Example: 1182af186f9SJakub Kuderski /// ``` 1192af186f9SJakub Kuderski /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32> 1202af186f9SJakub Kuderski /// ==> 1212af186f9SJakub Kuderski /// %e0 = vector.extract %x[0] : f32 from vector<1xf32> 1222af186f9SJakub Kuderski /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32 1232af186f9SJakub Kuderski /// %a = vector.broadcast %r0 : f32 to vector<1xf32> 1242af186f9SJakub Kuderski /// ``` 1252af186f9SJakub Kuderski struct ScalarizeSingleElementReduce final 1262af186f9SJakub Kuderski : OpRewritePattern<gpu::SubgroupReduceOp> { 1272af186f9SJakub Kuderski using OpRewritePattern::OpRewritePattern; 1282af186f9SJakub Kuderski 1292af186f9SJakub Kuderski LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 1302af186f9SJakub Kuderski PatternRewriter &rewriter) const override { 1312af186f9SJakub Kuderski auto vecTy = dyn_cast<VectorType>(op.getType()); 1322af186f9SJakub Kuderski if (!vecTy || vecTy.getNumElements() != 1) 1332af186f9SJakub Kuderski return rewriter.notifyMatchFailure(op, "not a single-element reduction"); 1342af186f9SJakub Kuderski 1352af186f9SJakub Kuderski assert(vecTy.getRank() == 1 && "Unexpected vector type"); 1362af186f9SJakub Kuderski assert(!vecTy.isScalable() && "Unexpected vector type"); 1372af186f9SJakub Kuderski Location loc = op.getLoc(); 1382af186f9SJakub Kuderski Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0); 1392af186f9SJakub Kuderski Value reduce = rewriter.create<gpu::SubgroupReduceOp>( 1403d01f0a3SAndrea Faulds loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(), 1413d01f0a3SAndrea Faulds op.getClusterStride()); 1422af186f9SJakub Kuderski rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce); 1432af186f9SJakub Kuderski return success(); 1442af186f9SJakub Kuderski } 1452af186f9SJakub Kuderski }; 1462af186f9SJakub Kuderski 1473d01f0a3SAndrea Faulds struct ClusterInfo { 1483d01f0a3SAndrea Faulds unsigned clusterStride; 1493d01f0a3SAndrea Faulds unsigned clusterSize; 1503d01f0a3SAndrea Faulds unsigned subgroupSize; 1513d01f0a3SAndrea Faulds }; 1523d01f0a3SAndrea Faulds 1533d01f0a3SAndrea Faulds static FailureOr<ClusterInfo> 1543d01f0a3SAndrea Faulds getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) { 1553d01f0a3SAndrea Faulds assert(llvm::isPowerOf2_32(subgroupSize)); 1563d01f0a3SAndrea Faulds 1573d01f0a3SAndrea Faulds std::optional<uint32_t> clusterSize = op.getClusterSize(); 1583d01f0a3SAndrea Faulds assert(!clusterSize || 1593d01f0a3SAndrea Faulds llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this. 1603d01f0a3SAndrea Faulds if (clusterSize && *clusterSize > subgroupSize) 1613d01f0a3SAndrea Faulds return op.emitOpError() 1623d01f0a3SAndrea Faulds << "cluster size " << *clusterSize 1633d01f0a3SAndrea Faulds << " is greater than subgroup size " << subgroupSize; 1643d01f0a3SAndrea Faulds unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize); 1653d01f0a3SAndrea Faulds 1663d01f0a3SAndrea Faulds auto clusterStride = op.getClusterStride(); 1673d01f0a3SAndrea Faulds assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this. 1683d01f0a3SAndrea Faulds if (clusterStride >= subgroupSize) 1693d01f0a3SAndrea Faulds return op.emitOpError() 1703d01f0a3SAndrea Faulds << "cluster stride " << clusterStride 1713d01f0a3SAndrea Faulds << " is not less than subgroup size " << subgroupSize; 1723d01f0a3SAndrea Faulds 1733d01f0a3SAndrea Faulds return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize}; 1743d01f0a3SAndrea Faulds } 1753d01f0a3SAndrea Faulds 176c0345b46SJakub Kuderski /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` 177c0345b46SJakub Kuderski /// and `unpackFn` to convert to the native shuffle type and to the reduction 178c0345b46SJakub Kuderski /// type, respectively. For example, with `input` of type `f16`, `packFn` could 179c0345b46SJakub Kuderski /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` 180c0345b46SJakub Kuderski /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that 1817aa22f01SAndrea Faulds /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of 1823d01f0a3SAndrea Faulds /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for 1833d01f0a3SAndrea Faulds /// lanes within a cluster, reducing all lanes in each cluster in parallel. 1843d01f0a3SAndrea Faulds Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc, 1853d01f0a3SAndrea Faulds Value input, gpu::AllReduceOperation mode, 1863d01f0a3SAndrea Faulds const ClusterInfo &ci, 1873d01f0a3SAndrea Faulds function_ref<Value(Value)> packFn, 1883d01f0a3SAndrea Faulds function_ref<Value(Value)> unpackFn) { 189c0345b46SJakub Kuderski // Lane value always stays in the original type. We use it to perform arith 190c0345b46SJakub Kuderski // reductions. 191c0345b46SJakub Kuderski Value laneVal = input; 192c0345b46SJakub Kuderski // Parallel reduction using butterfly shuffles. 1933d01f0a3SAndrea Faulds for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize; 1943d01f0a3SAndrea Faulds i <<= 1) { 195c0345b46SJakub Kuderski Value shuffled = builder 196c0345b46SJakub Kuderski .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, 1973d01f0a3SAndrea Faulds /*width=*/ci.subgroupSize, 198c0345b46SJakub Kuderski /*mode=*/gpu::ShuffleMode::XOR) 199c0345b46SJakub Kuderski .getShuffleResult(); 200c0345b46SJakub Kuderski laneVal = vector::makeArithReduction(builder, loc, 201c0345b46SJakub Kuderski gpu::convertReductionKind(mode), 202c0345b46SJakub Kuderski laneVal, unpackFn(shuffled)); 203c0345b46SJakub Kuderski assert(laneVal.getType() == input.getType()); 204c0345b46SJakub Kuderski } 205c0345b46SJakub Kuderski 206c0345b46SJakub Kuderski return laneVal; 207c0345b46SJakub Kuderski } 208c0345b46SJakub Kuderski 209c0345b46SJakub Kuderski /// Lowers scalar gpu subgroup reductions to a series of shuffles. 210c0345b46SJakub Kuderski struct ScalarSubgroupReduceToShuffles final 211c0345b46SJakub Kuderski : OpRewritePattern<gpu::SubgroupReduceOp> { 212c0345b46SJakub Kuderski ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, 213a800ffacSAndrea Faulds unsigned shuffleBitwidth, bool matchClustered, 214c0345b46SJakub Kuderski PatternBenefit benefit) 215c0345b46SJakub Kuderski : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), 216a800ffacSAndrea Faulds shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} 217c0345b46SJakub Kuderski 218c0345b46SJakub Kuderski LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 219c0345b46SJakub Kuderski PatternRewriter &rewriter) const override { 220a800ffacSAndrea Faulds if (op.getClusterSize().has_value() != matchClustered) { 221a800ffacSAndrea Faulds return rewriter.notifyMatchFailure( 222a800ffacSAndrea Faulds op, llvm::formatv("op is {0}clustered but pattern is configured to " 223a800ffacSAndrea Faulds "only match {1}clustered ops", 224a800ffacSAndrea Faulds matchClustered ? "non-" : "", 225a800ffacSAndrea Faulds matchClustered ? "" : "non-")); 226a800ffacSAndrea Faulds } 227a800ffacSAndrea Faulds 2283d01f0a3SAndrea Faulds auto ci = getAndValidateClusterInfo(op, subgroupSize); 2293d01f0a3SAndrea Faulds if (failed(ci)) 2303d01f0a3SAndrea Faulds return failure(); 2317aa22f01SAndrea Faulds 232c0345b46SJakub Kuderski Type valueTy = op.getType(); 233c0345b46SJakub Kuderski unsigned elemBitwidth = 234c0345b46SJakub Kuderski getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); 235c0345b46SJakub Kuderski if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) 236c0345b46SJakub Kuderski return rewriter.notifyMatchFailure( 237c0345b46SJakub Kuderski op, "value type is not a compatible scalar"); 238c0345b46SJakub Kuderski 239c0345b46SJakub Kuderski Location loc = op.getLoc(); 240c0345b46SJakub Kuderski // Since this is already a native shuffle scalar, no packing is necessary. 241c0345b46SJakub Kuderski if (elemBitwidth == shuffleBitwidth) { 242c0345b46SJakub Kuderski auto identityFn = [](Value v) { return v; }; 243c0345b46SJakub Kuderski rewriter.replaceOp(op, createSubgroupShuffleReduction( 2443d01f0a3SAndrea Faulds rewriter, loc, op.getValue(), op.getOp(), *ci, 2453d01f0a3SAndrea Faulds identityFn, identityFn)); 246c0345b46SJakub Kuderski return success(); 247c0345b46SJakub Kuderski } 248c0345b46SJakub Kuderski 249c0345b46SJakub Kuderski auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); 250c0345b46SJakub Kuderski auto equivIntType = rewriter.getIntegerType(elemBitwidth); 251c0345b46SJakub Kuderski auto packFn = [loc, &rewriter, equivIntType, 252c0345b46SJakub Kuderski shuffleIntType](Value unpackedVal) -> Value { 253c0345b46SJakub Kuderski auto asInt = 254c0345b46SJakub Kuderski rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal); 255c0345b46SJakub Kuderski return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt); 256c0345b46SJakub Kuderski }; 257c0345b46SJakub Kuderski auto unpackFn = [loc, &rewriter, equivIntType, 258c0345b46SJakub Kuderski valueTy](Value packedVal) -> Value { 259c0345b46SJakub Kuderski auto asInt = 260c0345b46SJakub Kuderski rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal); 261c0345b46SJakub Kuderski return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt); 262c0345b46SJakub Kuderski }; 263c0345b46SJakub Kuderski 2647aa22f01SAndrea Faulds rewriter.replaceOp( 2657aa22f01SAndrea Faulds op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(), 2663d01f0a3SAndrea Faulds op.getOp(), *ci, packFn, unpackFn)); 267c0345b46SJakub Kuderski return success(); 268c0345b46SJakub Kuderski } 269c0345b46SJakub Kuderski 270c0345b46SJakub Kuderski private: 271c0345b46SJakub Kuderski unsigned subgroupSize = 0; 272c0345b46SJakub Kuderski unsigned shuffleBitwidth = 0; 273a800ffacSAndrea Faulds bool matchClustered = false; 274c0345b46SJakub Kuderski }; 275c0345b46SJakub Kuderski 276c0345b46SJakub Kuderski /// Lowers vector gpu subgroup reductions to a series of shuffles. 277c0345b46SJakub Kuderski struct VectorSubgroupReduceToShuffles final 278c0345b46SJakub Kuderski : OpRewritePattern<gpu::SubgroupReduceOp> { 279c0345b46SJakub Kuderski VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, 280a800ffacSAndrea Faulds unsigned shuffleBitwidth, bool matchClustered, 281c0345b46SJakub Kuderski PatternBenefit benefit) 282c0345b46SJakub Kuderski : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), 283a800ffacSAndrea Faulds shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} 284c0345b46SJakub Kuderski 285c0345b46SJakub Kuderski LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, 286c0345b46SJakub Kuderski PatternRewriter &rewriter) const override { 287a800ffacSAndrea Faulds if (op.getClusterSize().has_value() != matchClustered) { 288a800ffacSAndrea Faulds return rewriter.notifyMatchFailure( 289a800ffacSAndrea Faulds op, llvm::formatv("op is {0}clustered but pattern is configured to " 290a800ffacSAndrea Faulds "only match {1}clustered ops", 291a800ffacSAndrea Faulds matchClustered ? "non-" : "", 292a800ffacSAndrea Faulds matchClustered ? "" : "non-")); 293a800ffacSAndrea Faulds } 294a800ffacSAndrea Faulds 2953d01f0a3SAndrea Faulds auto ci = getAndValidateClusterInfo(op, subgroupSize); 2963d01f0a3SAndrea Faulds if (failed(ci)) 2973d01f0a3SAndrea Faulds return failure(); 2987aa22f01SAndrea Faulds 299c0345b46SJakub Kuderski auto vecTy = dyn_cast<VectorType>(op.getType()); 300c0345b46SJakub Kuderski if (!vecTy) 301c0345b46SJakub Kuderski return rewriter.notifyMatchFailure(op, "value type is not a vector"); 302c0345b46SJakub Kuderski 303c0345b46SJakub Kuderski unsigned vecBitwidth = 304c0345b46SJakub Kuderski vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); 305c0345b46SJakub Kuderski if (vecBitwidth > shuffleBitwidth) 306c0345b46SJakub Kuderski return rewriter.notifyMatchFailure( 307c0345b46SJakub Kuderski op, 308c0345b46SJakub Kuderski llvm::formatv("vector type bitwidth too large ({0}), cannot lower " 309c0345b46SJakub Kuderski "to shuffles of size {1}", 310c0345b46SJakub Kuderski vecBitwidth, shuffleBitwidth)); 311c0345b46SJakub Kuderski 312c0345b46SJakub Kuderski unsigned elementsPerShuffle = 313c0345b46SJakub Kuderski shuffleBitwidth / vecTy.getElementTypeBitWidth(); 314c0345b46SJakub Kuderski if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) 315c0345b46SJakub Kuderski return rewriter.notifyMatchFailure( 316c0345b46SJakub Kuderski op, "shuffle bitwidth is not a multiple of the element bitwidth"); 317c0345b46SJakub Kuderski 318c0345b46SJakub Kuderski Location loc = op.getLoc(); 319c0345b46SJakub Kuderski 320c0345b46SJakub Kuderski // If the reduced type is smaller than the native shuffle size, extend it, 321c0345b46SJakub Kuderski // perform the shuffles, and extract at the end. 322c0345b46SJakub Kuderski auto extendedVecTy = VectorType::get( 323c0345b46SJakub Kuderski static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType()); 324c0345b46SJakub Kuderski Value extendedInput = op.getValue(); 325c0345b46SJakub Kuderski if (vecBitwidth < shuffleBitwidth) { 326c0345b46SJakub Kuderski auto zero = rewriter.create<arith::ConstantOp>( 327c0345b46SJakub Kuderski loc, rewriter.getZeroAttr(extendedVecTy)); 328c0345b46SJakub Kuderski extendedInput = rewriter.create<vector::InsertStridedSliceOp>( 329c0345b46SJakub Kuderski loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); 330c0345b46SJakub Kuderski } 331c0345b46SJakub Kuderski 332c0345b46SJakub Kuderski auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); 333c0345b46SJakub Kuderski auto shuffleVecType = VectorType::get(1, shuffleIntType); 334c0345b46SJakub Kuderski 335c0345b46SJakub Kuderski auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { 336c0345b46SJakub Kuderski auto asIntVec = 337c0345b46SJakub Kuderski rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal); 338c0345b46SJakub Kuderski return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0); 339c0345b46SJakub Kuderski }; 340c0345b46SJakub Kuderski auto unpackFn = [loc, &rewriter, shuffleVecType, 341c0345b46SJakub Kuderski extendedVecTy](Value packedVal) -> Value { 342c0345b46SJakub Kuderski auto asIntVec = 343c0345b46SJakub Kuderski rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal); 344c0345b46SJakub Kuderski return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec); 345c0345b46SJakub Kuderski }; 346c0345b46SJakub Kuderski 3473d01f0a3SAndrea Faulds Value res = createSubgroupShuffleReduction( 3483d01f0a3SAndrea Faulds rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn); 349c0345b46SJakub Kuderski 350c0345b46SJakub Kuderski if (vecBitwidth < shuffleBitwidth) { 351c0345b46SJakub Kuderski res = rewriter.create<vector::ExtractStridedSliceOp>( 352c0345b46SJakub Kuderski loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), 353c0345b46SJakub Kuderski /*strides=*/1); 354c0345b46SJakub Kuderski } 355c0345b46SJakub Kuderski 356c0345b46SJakub Kuderski rewriter.replaceOp(op, res); 357c0345b46SJakub Kuderski return success(); 358c0345b46SJakub Kuderski } 359c0345b46SJakub Kuderski 360c0345b46SJakub Kuderski private: 361c0345b46SJakub Kuderski unsigned subgroupSize = 0; 362c0345b46SJakub Kuderski unsigned shuffleBitwidth = 0; 363a800ffacSAndrea Faulds bool matchClustered = false; 364c0345b46SJakub Kuderski }; 3652af186f9SJakub Kuderski } // namespace 3662af186f9SJakub Kuderski 367fd26f844SAndrea Faulds void mlir::populateGpuBreakDownSubgroupReducePatterns( 3682af186f9SJakub Kuderski RewritePatternSet &patterns, unsigned maxShuffleBitwidth, 3692af186f9SJakub Kuderski PatternBenefit benefit) { 3702af186f9SJakub Kuderski patterns.add<BreakDownSubgroupReduce>(patterns.getContext(), 3712af186f9SJakub Kuderski maxShuffleBitwidth, benefit); 3722af186f9SJakub Kuderski patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit); 3732af186f9SJakub Kuderski } 374c0345b46SJakub Kuderski 375fd26f844SAndrea Faulds void mlir::populateGpuLowerSubgroupReduceToShufflePatterns( 376c0345b46SJakub Kuderski RewritePatternSet &patterns, unsigned subgroupSize, 377c0345b46SJakub Kuderski unsigned shuffleBitwidth, PatternBenefit benefit) { 378c0345b46SJakub Kuderski patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( 379a800ffacSAndrea Faulds patterns.getContext(), subgroupSize, shuffleBitwidth, 380a800ffacSAndrea Faulds /*matchClustered=*/false, benefit); 381a800ffacSAndrea Faulds } 382a800ffacSAndrea Faulds 383a800ffacSAndrea Faulds void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns( 384a800ffacSAndrea Faulds RewritePatternSet &patterns, unsigned subgroupSize, 385a800ffacSAndrea Faulds unsigned shuffleBitwidth, PatternBenefit benefit) { 386a800ffacSAndrea Faulds patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( 387a800ffacSAndrea Faulds patterns.getContext(), subgroupSize, shuffleBitwidth, 388a800ffacSAndrea Faulds /*matchClustered=*/true, benefit); 389c0345b46SJakub Kuderski } 390