xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (revision bc29fc937c6cb4a210f80c93c79fc6ed97c801f8)
12af186f9SJakub Kuderski //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
22af186f9SJakub Kuderski //
32af186f9SJakub Kuderski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42af186f9SJakub Kuderski // See https://llvm.org/LICENSE.txt for license information.
52af186f9SJakub Kuderski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62af186f9SJakub Kuderski //
72af186f9SJakub Kuderski //===----------------------------------------------------------------------===//
82af186f9SJakub Kuderski //
92af186f9SJakub Kuderski // Implements gradual lowering of `gpu.subgroup_reduce` ops.
102af186f9SJakub Kuderski //
112af186f9SJakub Kuderski //===----------------------------------------------------------------------===//
122af186f9SJakub Kuderski 
132af186f9SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
142af186f9SJakub Kuderski #include "mlir/Dialect/GPU/IR/GPUDialect.h"
152af186f9SJakub Kuderski #include "mlir/Dialect/GPU/Transforms/Passes.h"
16*bc29fc93SPetr Kurapov #include "mlir/Dialect/GPU/Utils/GPUUtils.h"
172af186f9SJakub Kuderski #include "mlir/Dialect/Vector/IR/VectorOps.h"
18c0345b46SJakub Kuderski #include "mlir/IR/BuiltinTypes.h"
192af186f9SJakub Kuderski #include "mlir/IR/Location.h"
202af186f9SJakub Kuderski #include "mlir/IR/PatternMatch.h"
21c0345b46SJakub Kuderski #include "mlir/IR/TypeUtilities.h"
222af186f9SJakub Kuderski #include "llvm/Support/FormatVariadic.h"
232af186f9SJakub Kuderski #include "llvm/Support/MathExtras.h"
242af186f9SJakub Kuderski #include <cassert>
25c0345b46SJakub Kuderski #include <cstdint>
262af186f9SJakub Kuderski 
272af186f9SJakub Kuderski using namespace mlir;
282af186f9SJakub Kuderski 
292af186f9SJakub Kuderski namespace {
302af186f9SJakub Kuderski 
312af186f9SJakub Kuderski /// Example, assumes `maxShuffleBitwidth` equal to 32:
322af186f9SJakub Kuderski /// ```
332af186f9SJakub Kuderski /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
342af186f9SJakub Kuderski ///  ==>
352af186f9SJakub Kuderski /// %v0 = arith.constant dense<0.0> : vector<3xf16>
362af186f9SJakub Kuderski /// %e0 = vector.extract_strided_slice %x
372af186f9SJakub Kuderski ///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
382af186f9SJakub Kuderski /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
392af186f9SJakub Kuderski /// %v1 = vector.insert_strided_slice %r0, %v0
402af186f9SJakub Kuderski ///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
412af186f9SJakub Kuderski /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
422af186f9SJakub Kuderski /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
432af186f9SJakub Kuderski /// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
442af186f9SJakub Kuderski /// ```
452af186f9SJakub Kuderski struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
462af186f9SJakub Kuderski   BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
472af186f9SJakub Kuderski                           PatternBenefit benefit)
482af186f9SJakub Kuderski       : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
492af186f9SJakub Kuderski   }
502af186f9SJakub Kuderski 
512af186f9SJakub Kuderski   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
522af186f9SJakub Kuderski                                 PatternRewriter &rewriter) const override {
532af186f9SJakub Kuderski     auto vecTy = dyn_cast<VectorType>(op.getType());
542af186f9SJakub Kuderski     if (!vecTy || vecTy.getNumElements() < 2)
552af186f9SJakub Kuderski       return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
562af186f9SJakub Kuderski 
572af186f9SJakub Kuderski     assert(vecTy.getRank() == 1 && "Unexpected vector type");
582af186f9SJakub Kuderski     assert(!vecTy.isScalable() && "Unexpected vector type");
592af186f9SJakub Kuderski 
602af186f9SJakub Kuderski     Type elemTy = vecTy.getElementType();
612af186f9SJakub Kuderski     unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
622af186f9SJakub Kuderski     if (elemBitwidth >= maxShuffleBitwidth)
632af186f9SJakub Kuderski       return rewriter.notifyMatchFailure(
64c0345b46SJakub Kuderski           op, llvm::formatv("element type too large ({0}), cannot break down "
652af186f9SJakub Kuderski                             "into vectors of bitwidth {1} or less",
662af186f9SJakub Kuderski                             elemBitwidth, maxShuffleBitwidth));
672af186f9SJakub Kuderski 
682af186f9SJakub Kuderski     unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
692af186f9SJakub Kuderski     assert(elementsPerShuffle >= 1);
702af186f9SJakub Kuderski 
712af186f9SJakub Kuderski     unsigned numNewReductions =
722af186f9SJakub Kuderski         llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
732af186f9SJakub Kuderski     assert(numNewReductions >= 1);
742af186f9SJakub Kuderski     if (numNewReductions == 1)
752af186f9SJakub Kuderski       return rewriter.notifyMatchFailure(op, "nothing to break down");
762af186f9SJakub Kuderski 
772af186f9SJakub Kuderski     Location loc = op.getLoc();
782af186f9SJakub Kuderski     Value res =
792af186f9SJakub Kuderski         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
802af186f9SJakub Kuderski 
812af186f9SJakub Kuderski     for (unsigned i = 0; i != numNewReductions; ++i) {
822af186f9SJakub Kuderski       int64_t startIdx = i * elementsPerShuffle;
832af186f9SJakub Kuderski       int64_t endIdx =
842af186f9SJakub Kuderski           std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
852af186f9SJakub Kuderski       int64_t numElems = endIdx - startIdx;
862af186f9SJakub Kuderski 
872af186f9SJakub Kuderski       Value extracted;
882af186f9SJakub Kuderski       if (numElems == 1) {
892af186f9SJakub Kuderski         extracted =
902af186f9SJakub Kuderski             rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
912af186f9SJakub Kuderski       } else {
922af186f9SJakub Kuderski         extracted = rewriter.create<vector::ExtractStridedSliceOp>(
932af186f9SJakub Kuderski             loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
942af186f9SJakub Kuderski             /*strides=*/1);
952af186f9SJakub Kuderski       }
962af186f9SJakub Kuderski 
972af186f9SJakub Kuderski       Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
983d01f0a3SAndrea Faulds           loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
993d01f0a3SAndrea Faulds           op.getClusterStride());
1002af186f9SJakub Kuderski       if (numElems == 1) {
1012af186f9SJakub Kuderski         res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
1022af186f9SJakub Kuderski         continue;
1032af186f9SJakub Kuderski       }
1042af186f9SJakub Kuderski 
1052af186f9SJakub Kuderski       res = rewriter.create<vector::InsertStridedSliceOp>(
1062af186f9SJakub Kuderski           loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
1072af186f9SJakub Kuderski     }
1082af186f9SJakub Kuderski 
1092af186f9SJakub Kuderski     rewriter.replaceOp(op, res);
1102af186f9SJakub Kuderski     return success();
1112af186f9SJakub Kuderski   }
1122af186f9SJakub Kuderski 
1132af186f9SJakub Kuderski private:
1142af186f9SJakub Kuderski   unsigned maxShuffleBitwidth = 0;
1152af186f9SJakub Kuderski };
1162af186f9SJakub Kuderski 
1172af186f9SJakub Kuderski /// Example:
1182af186f9SJakub Kuderski /// ```
1192af186f9SJakub Kuderski /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
1202af186f9SJakub Kuderski ///  ==>
1212af186f9SJakub Kuderski /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
1222af186f9SJakub Kuderski /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
1232af186f9SJakub Kuderski /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
1242af186f9SJakub Kuderski /// ```
1252af186f9SJakub Kuderski struct ScalarizeSingleElementReduce final
1262af186f9SJakub Kuderski     : OpRewritePattern<gpu::SubgroupReduceOp> {
1272af186f9SJakub Kuderski   using OpRewritePattern::OpRewritePattern;
1282af186f9SJakub Kuderski 
1292af186f9SJakub Kuderski   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
1302af186f9SJakub Kuderski                                 PatternRewriter &rewriter) const override {
1312af186f9SJakub Kuderski     auto vecTy = dyn_cast<VectorType>(op.getType());
1322af186f9SJakub Kuderski     if (!vecTy || vecTy.getNumElements() != 1)
1332af186f9SJakub Kuderski       return rewriter.notifyMatchFailure(op, "not a single-element reduction");
1342af186f9SJakub Kuderski 
1352af186f9SJakub Kuderski     assert(vecTy.getRank() == 1 && "Unexpected vector type");
1362af186f9SJakub Kuderski     assert(!vecTy.isScalable() && "Unexpected vector type");
1372af186f9SJakub Kuderski     Location loc = op.getLoc();
1382af186f9SJakub Kuderski     Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
1392af186f9SJakub Kuderski     Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
1403d01f0a3SAndrea Faulds         loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
1413d01f0a3SAndrea Faulds         op.getClusterStride());
1422af186f9SJakub Kuderski     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
1432af186f9SJakub Kuderski     return success();
1442af186f9SJakub Kuderski   }
1452af186f9SJakub Kuderski };
1462af186f9SJakub Kuderski 
1473d01f0a3SAndrea Faulds struct ClusterInfo {
1483d01f0a3SAndrea Faulds   unsigned clusterStride;
1493d01f0a3SAndrea Faulds   unsigned clusterSize;
1503d01f0a3SAndrea Faulds   unsigned subgroupSize;
1513d01f0a3SAndrea Faulds };
1523d01f0a3SAndrea Faulds 
1533d01f0a3SAndrea Faulds static FailureOr<ClusterInfo>
1543d01f0a3SAndrea Faulds getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
1553d01f0a3SAndrea Faulds   assert(llvm::isPowerOf2_32(subgroupSize));
1563d01f0a3SAndrea Faulds 
1573d01f0a3SAndrea Faulds   std::optional<uint32_t> clusterSize = op.getClusterSize();
1583d01f0a3SAndrea Faulds   assert(!clusterSize ||
1593d01f0a3SAndrea Faulds          llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
1603d01f0a3SAndrea Faulds   if (clusterSize && *clusterSize > subgroupSize)
1613d01f0a3SAndrea Faulds     return op.emitOpError()
1623d01f0a3SAndrea Faulds            << "cluster size " << *clusterSize
1633d01f0a3SAndrea Faulds            << " is greater than subgroup size " << subgroupSize;
1643d01f0a3SAndrea Faulds   unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
1653d01f0a3SAndrea Faulds 
1663d01f0a3SAndrea Faulds   auto clusterStride = op.getClusterStride();
1673d01f0a3SAndrea Faulds   assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
1683d01f0a3SAndrea Faulds   if (clusterStride >= subgroupSize)
1693d01f0a3SAndrea Faulds     return op.emitOpError()
1703d01f0a3SAndrea Faulds            << "cluster stride " << clusterStride
1713d01f0a3SAndrea Faulds            << " is not less than subgroup size " << subgroupSize;
1723d01f0a3SAndrea Faulds 
1733d01f0a3SAndrea Faulds   return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
1743d01f0a3SAndrea Faulds }
1753d01f0a3SAndrea Faulds 
176c0345b46SJakub Kuderski /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
177c0345b46SJakub Kuderski /// and `unpackFn` to convert to the native shuffle type and to the reduction
178c0345b46SJakub Kuderski /// type, respectively. For example, with `input` of type `f16`, `packFn` could
179c0345b46SJakub Kuderski /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
180c0345b46SJakub Kuderski /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
1817aa22f01SAndrea Faulds /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
1823d01f0a3SAndrea Faulds /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
1833d01f0a3SAndrea Faulds /// lanes within a cluster, reducing all lanes in each cluster in parallel.
1843d01f0a3SAndrea Faulds Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
1853d01f0a3SAndrea Faulds                                      Value input, gpu::AllReduceOperation mode,
1863d01f0a3SAndrea Faulds                                      const ClusterInfo &ci,
1873d01f0a3SAndrea Faulds                                      function_ref<Value(Value)> packFn,
1883d01f0a3SAndrea Faulds                                      function_ref<Value(Value)> unpackFn) {
189c0345b46SJakub Kuderski   // Lane value always stays in the original type. We use it to perform arith
190c0345b46SJakub Kuderski   // reductions.
191c0345b46SJakub Kuderski   Value laneVal = input;
192c0345b46SJakub Kuderski   // Parallel reduction using butterfly shuffles.
1933d01f0a3SAndrea Faulds   for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
1943d01f0a3SAndrea Faulds        i <<= 1) {
195c0345b46SJakub Kuderski     Value shuffled = builder
196c0345b46SJakub Kuderski                          .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
1973d01f0a3SAndrea Faulds                                                  /*width=*/ci.subgroupSize,
198c0345b46SJakub Kuderski                                                  /*mode=*/gpu::ShuffleMode::XOR)
199c0345b46SJakub Kuderski                          .getShuffleResult();
200c0345b46SJakub Kuderski     laneVal = vector::makeArithReduction(builder, loc,
201c0345b46SJakub Kuderski                                          gpu::convertReductionKind(mode),
202c0345b46SJakub Kuderski                                          laneVal, unpackFn(shuffled));
203c0345b46SJakub Kuderski     assert(laneVal.getType() == input.getType());
204c0345b46SJakub Kuderski   }
205c0345b46SJakub Kuderski 
206c0345b46SJakub Kuderski   return laneVal;
207c0345b46SJakub Kuderski }
208c0345b46SJakub Kuderski 
209c0345b46SJakub Kuderski /// Lowers scalar gpu subgroup reductions to a series of shuffles.
210c0345b46SJakub Kuderski struct ScalarSubgroupReduceToShuffles final
211c0345b46SJakub Kuderski     : OpRewritePattern<gpu::SubgroupReduceOp> {
212c0345b46SJakub Kuderski   ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
213a800ffacSAndrea Faulds                                  unsigned shuffleBitwidth, bool matchClustered,
214c0345b46SJakub Kuderski                                  PatternBenefit benefit)
215c0345b46SJakub Kuderski       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216a800ffacSAndrea Faulds         shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
217c0345b46SJakub Kuderski 
218c0345b46SJakub Kuderski   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
219c0345b46SJakub Kuderski                                 PatternRewriter &rewriter) const override {
220a800ffacSAndrea Faulds     if (op.getClusterSize().has_value() != matchClustered) {
221a800ffacSAndrea Faulds       return rewriter.notifyMatchFailure(
222a800ffacSAndrea Faulds           op, llvm::formatv("op is {0}clustered but pattern is configured to "
223a800ffacSAndrea Faulds                             "only match {1}clustered ops",
224a800ffacSAndrea Faulds                             matchClustered ? "non-" : "",
225a800ffacSAndrea Faulds                             matchClustered ? "" : "non-"));
226a800ffacSAndrea Faulds     }
227a800ffacSAndrea Faulds 
2283d01f0a3SAndrea Faulds     auto ci = getAndValidateClusterInfo(op, subgroupSize);
2293d01f0a3SAndrea Faulds     if (failed(ci))
2303d01f0a3SAndrea Faulds       return failure();
2317aa22f01SAndrea Faulds 
232c0345b46SJakub Kuderski     Type valueTy = op.getType();
233c0345b46SJakub Kuderski     unsigned elemBitwidth =
234c0345b46SJakub Kuderski         getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
235c0345b46SJakub Kuderski     if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
236c0345b46SJakub Kuderski       return rewriter.notifyMatchFailure(
237c0345b46SJakub Kuderski           op, "value type is not a compatible scalar");
238c0345b46SJakub Kuderski 
239c0345b46SJakub Kuderski     Location loc = op.getLoc();
240c0345b46SJakub Kuderski     // Since this is already a native shuffle scalar, no packing is necessary.
241c0345b46SJakub Kuderski     if (elemBitwidth == shuffleBitwidth) {
242c0345b46SJakub Kuderski       auto identityFn = [](Value v) { return v; };
243c0345b46SJakub Kuderski       rewriter.replaceOp(op, createSubgroupShuffleReduction(
2443d01f0a3SAndrea Faulds                                  rewriter, loc, op.getValue(), op.getOp(), *ci,
2453d01f0a3SAndrea Faulds                                  identityFn, identityFn));
246c0345b46SJakub Kuderski       return success();
247c0345b46SJakub Kuderski     }
248c0345b46SJakub Kuderski 
249c0345b46SJakub Kuderski     auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
250c0345b46SJakub Kuderski     auto equivIntType = rewriter.getIntegerType(elemBitwidth);
251c0345b46SJakub Kuderski     auto packFn = [loc, &rewriter, equivIntType,
252c0345b46SJakub Kuderski                    shuffleIntType](Value unpackedVal) -> Value {
253c0345b46SJakub Kuderski       auto asInt =
254c0345b46SJakub Kuderski           rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
255c0345b46SJakub Kuderski       return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
256c0345b46SJakub Kuderski     };
257c0345b46SJakub Kuderski     auto unpackFn = [loc, &rewriter, equivIntType,
258c0345b46SJakub Kuderski                      valueTy](Value packedVal) -> Value {
259c0345b46SJakub Kuderski       auto asInt =
260c0345b46SJakub Kuderski           rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
261c0345b46SJakub Kuderski       return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
262c0345b46SJakub Kuderski     };
263c0345b46SJakub Kuderski 
2647aa22f01SAndrea Faulds     rewriter.replaceOp(
2657aa22f01SAndrea Faulds         op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
2663d01f0a3SAndrea Faulds                                            op.getOp(), *ci, packFn, unpackFn));
267c0345b46SJakub Kuderski     return success();
268c0345b46SJakub Kuderski   }
269c0345b46SJakub Kuderski 
270c0345b46SJakub Kuderski private:
271c0345b46SJakub Kuderski   unsigned subgroupSize = 0;
272c0345b46SJakub Kuderski   unsigned shuffleBitwidth = 0;
273a800ffacSAndrea Faulds   bool matchClustered = false;
274c0345b46SJakub Kuderski };
275c0345b46SJakub Kuderski 
276c0345b46SJakub Kuderski /// Lowers vector gpu subgroup reductions to a series of shuffles.
277c0345b46SJakub Kuderski struct VectorSubgroupReduceToShuffles final
278c0345b46SJakub Kuderski     : OpRewritePattern<gpu::SubgroupReduceOp> {
279c0345b46SJakub Kuderski   VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
280a800ffacSAndrea Faulds                                  unsigned shuffleBitwidth, bool matchClustered,
281c0345b46SJakub Kuderski                                  PatternBenefit benefit)
282c0345b46SJakub Kuderski       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
283a800ffacSAndrea Faulds         shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
284c0345b46SJakub Kuderski 
285c0345b46SJakub Kuderski   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
286c0345b46SJakub Kuderski                                 PatternRewriter &rewriter) const override {
287a800ffacSAndrea Faulds     if (op.getClusterSize().has_value() != matchClustered) {
288a800ffacSAndrea Faulds       return rewriter.notifyMatchFailure(
289a800ffacSAndrea Faulds           op, llvm::formatv("op is {0}clustered but pattern is configured to "
290a800ffacSAndrea Faulds                             "only match {1}clustered ops",
291a800ffacSAndrea Faulds                             matchClustered ? "non-" : "",
292a800ffacSAndrea Faulds                             matchClustered ? "" : "non-"));
293a800ffacSAndrea Faulds     }
294a800ffacSAndrea Faulds 
2953d01f0a3SAndrea Faulds     auto ci = getAndValidateClusterInfo(op, subgroupSize);
2963d01f0a3SAndrea Faulds     if (failed(ci))
2973d01f0a3SAndrea Faulds       return failure();
2987aa22f01SAndrea Faulds 
299c0345b46SJakub Kuderski     auto vecTy = dyn_cast<VectorType>(op.getType());
300c0345b46SJakub Kuderski     if (!vecTy)
301c0345b46SJakub Kuderski       return rewriter.notifyMatchFailure(op, "value type is not a vector");
302c0345b46SJakub Kuderski 
303c0345b46SJakub Kuderski     unsigned vecBitwidth =
304c0345b46SJakub Kuderski         vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
305c0345b46SJakub Kuderski     if (vecBitwidth > shuffleBitwidth)
306c0345b46SJakub Kuderski       return rewriter.notifyMatchFailure(
307c0345b46SJakub Kuderski           op,
308c0345b46SJakub Kuderski           llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
309c0345b46SJakub Kuderski                         "to shuffles of size {1}",
310c0345b46SJakub Kuderski                         vecBitwidth, shuffleBitwidth));
311c0345b46SJakub Kuderski 
312c0345b46SJakub Kuderski     unsigned elementsPerShuffle =
313c0345b46SJakub Kuderski         shuffleBitwidth / vecTy.getElementTypeBitWidth();
314c0345b46SJakub Kuderski     if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
315c0345b46SJakub Kuderski       return rewriter.notifyMatchFailure(
316c0345b46SJakub Kuderski           op, "shuffle bitwidth is not a multiple of the element bitwidth");
317c0345b46SJakub Kuderski 
318c0345b46SJakub Kuderski     Location loc = op.getLoc();
319c0345b46SJakub Kuderski 
320c0345b46SJakub Kuderski     // If the reduced type is smaller than the native shuffle size, extend it,
321c0345b46SJakub Kuderski     // perform the shuffles, and extract at the end.
322c0345b46SJakub Kuderski     auto extendedVecTy = VectorType::get(
323c0345b46SJakub Kuderski         static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
324c0345b46SJakub Kuderski     Value extendedInput = op.getValue();
325c0345b46SJakub Kuderski     if (vecBitwidth < shuffleBitwidth) {
326c0345b46SJakub Kuderski       auto zero = rewriter.create<arith::ConstantOp>(
327c0345b46SJakub Kuderski           loc, rewriter.getZeroAttr(extendedVecTy));
328c0345b46SJakub Kuderski       extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
329c0345b46SJakub Kuderski           loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
330c0345b46SJakub Kuderski     }
331c0345b46SJakub Kuderski 
332c0345b46SJakub Kuderski     auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
333c0345b46SJakub Kuderski     auto shuffleVecType = VectorType::get(1, shuffleIntType);
334c0345b46SJakub Kuderski 
335c0345b46SJakub Kuderski     auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
336c0345b46SJakub Kuderski       auto asIntVec =
337c0345b46SJakub Kuderski           rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
338c0345b46SJakub Kuderski       return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
339c0345b46SJakub Kuderski     };
340c0345b46SJakub Kuderski     auto unpackFn = [loc, &rewriter, shuffleVecType,
341c0345b46SJakub Kuderski                      extendedVecTy](Value packedVal) -> Value {
342c0345b46SJakub Kuderski       auto asIntVec =
343c0345b46SJakub Kuderski           rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
344c0345b46SJakub Kuderski       return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
345c0345b46SJakub Kuderski     };
346c0345b46SJakub Kuderski 
3473d01f0a3SAndrea Faulds     Value res = createSubgroupShuffleReduction(
3483d01f0a3SAndrea Faulds         rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
349c0345b46SJakub Kuderski 
350c0345b46SJakub Kuderski     if (vecBitwidth < shuffleBitwidth) {
351c0345b46SJakub Kuderski       res = rewriter.create<vector::ExtractStridedSliceOp>(
352c0345b46SJakub Kuderski           loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
353c0345b46SJakub Kuderski           /*strides=*/1);
354c0345b46SJakub Kuderski     }
355c0345b46SJakub Kuderski 
356c0345b46SJakub Kuderski     rewriter.replaceOp(op, res);
357c0345b46SJakub Kuderski     return success();
358c0345b46SJakub Kuderski   }
359c0345b46SJakub Kuderski 
360c0345b46SJakub Kuderski private:
361c0345b46SJakub Kuderski   unsigned subgroupSize = 0;
362c0345b46SJakub Kuderski   unsigned shuffleBitwidth = 0;
363a800ffacSAndrea Faulds   bool matchClustered = false;
364c0345b46SJakub Kuderski };
3652af186f9SJakub Kuderski } // namespace
3662af186f9SJakub Kuderski 
367fd26f844SAndrea Faulds void mlir::populateGpuBreakDownSubgroupReducePatterns(
3682af186f9SJakub Kuderski     RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
3692af186f9SJakub Kuderski     PatternBenefit benefit) {
3702af186f9SJakub Kuderski   patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
3712af186f9SJakub Kuderski                                         maxShuffleBitwidth, benefit);
3722af186f9SJakub Kuderski   patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
3732af186f9SJakub Kuderski }
374c0345b46SJakub Kuderski 
375fd26f844SAndrea Faulds void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
376c0345b46SJakub Kuderski     RewritePatternSet &patterns, unsigned subgroupSize,
377c0345b46SJakub Kuderski     unsigned shuffleBitwidth, PatternBenefit benefit) {
378c0345b46SJakub Kuderski   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
379a800ffacSAndrea Faulds       patterns.getContext(), subgroupSize, shuffleBitwidth,
380a800ffacSAndrea Faulds       /*matchClustered=*/false, benefit);
381a800ffacSAndrea Faulds }
382a800ffacSAndrea Faulds 
383a800ffacSAndrea Faulds void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
384a800ffacSAndrea Faulds     RewritePatternSet &patterns, unsigned subgroupSize,
385a800ffacSAndrea Faulds     unsigned shuffleBitwidth, PatternBenefit benefit) {
386a800ffacSAndrea Faulds   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387a800ffacSAndrea Faulds       patterns.getContext(), subgroupSize, shuffleBitwidth,
388a800ffacSAndrea Faulds       /*matchClustered=*/true, benefit);
389c0345b46SJakub Kuderski }
390