xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp (revision bc29fc937c6cb4a210f80c93c79fc6ed97c801f8)
1 //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements gradual lowering of `gpu.subgroup_reduce` ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/GPU/Transforms/Passes.h"
16 #include "mlir/Dialect/GPU/Utils/GPUUtils.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
24 #include <cassert>
25 #include <cstdint>
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Example, assumes `maxShuffleBitwidth` equal to 32:
32 /// ```
33 /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16>
34 ///  ==>
35 /// %v0 = arith.constant dense<0.0> : vector<3xf16>
36 /// %e0 = vector.extract_strided_slice %x
37 ///   {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32>
38 /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16>
39 /// %v1 = vector.insert_strided_slice %r0, %v0
40 ///   {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32>
41 /// %e1 = vector.extract %x[2] : f16 from vector<2xf16>
42 /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16
43 /// %a  = vector.insert %r1, %v1[2] : f16 into vector<3xf16>
44 /// ```
45 struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
46   BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth,
47                           PatternBenefit benefit)
48       : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) {
49   }
50 
51   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
52                                 PatternRewriter &rewriter) const override {
53     auto vecTy = dyn_cast<VectorType>(op.getType());
54     if (!vecTy || vecTy.getNumElements() < 2)
55       return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
56 
57     assert(vecTy.getRank() == 1 && "Unexpected vector type");
58     assert(!vecTy.isScalable() && "Unexpected vector type");
59 
60     Type elemTy = vecTy.getElementType();
61     unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
62     if (elemBitwidth >= maxShuffleBitwidth)
63       return rewriter.notifyMatchFailure(
64           op, llvm::formatv("element type too large ({0}), cannot break down "
65                             "into vectors of bitwidth {1} or less",
66                             elemBitwidth, maxShuffleBitwidth));
67 
68     unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth;
69     assert(elementsPerShuffle >= 1);
70 
71     unsigned numNewReductions =
72         llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle);
73     assert(numNewReductions >= 1);
74     if (numNewReductions == 1)
75       return rewriter.notifyMatchFailure(op, "nothing to break down");
76 
77     Location loc = op.getLoc();
78     Value res =
79         rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(vecTy));
80 
81     for (unsigned i = 0; i != numNewReductions; ++i) {
82       int64_t startIdx = i * elementsPerShuffle;
83       int64_t endIdx =
84           std::min(startIdx + elementsPerShuffle, vecTy.getNumElements());
85       int64_t numElems = endIdx - startIdx;
86 
87       Value extracted;
88       if (numElems == 1) {
89         extracted =
90             rewriter.create<vector::ExtractOp>(loc, op.getValue(), startIdx);
91       } else {
92         extracted = rewriter.create<vector::ExtractStridedSliceOp>(
93             loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems,
94             /*strides=*/1);
95       }
96 
97       Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
98           loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
99           op.getClusterStride());
100       if (numElems == 1) {
101         res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
102         continue;
103       }
104 
105       res = rewriter.create<vector::InsertStridedSliceOp>(
106           loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1);
107     }
108 
109     rewriter.replaceOp(op, res);
110     return success();
111   }
112 
113 private:
114   unsigned maxShuffleBitwidth = 0;
115 };
116 
117 /// Example:
118 /// ```
119 /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32>
120 ///  ==>
121 /// %e0 = vector.extract %x[0] : f32 from vector<1xf32>
122 /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32
123 /// %a = vector.broadcast %r0 : f32 to vector<1xf32>
124 /// ```
125 struct ScalarizeSingleElementReduce final
126     : OpRewritePattern<gpu::SubgroupReduceOp> {
127   using OpRewritePattern::OpRewritePattern;
128 
129   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
130                                 PatternRewriter &rewriter) const override {
131     auto vecTy = dyn_cast<VectorType>(op.getType());
132     if (!vecTy || vecTy.getNumElements() != 1)
133       return rewriter.notifyMatchFailure(op, "not a single-element reduction");
134 
135     assert(vecTy.getRank() == 1 && "Unexpected vector type");
136     assert(!vecTy.isScalable() && "Unexpected vector type");
137     Location loc = op.getLoc();
138     Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
139     Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
140         loc, extracted, op.getOp(), op.getUniform(), op.getClusterSize(),
141         op.getClusterStride());
142     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
143     return success();
144   }
145 };
146 
147 struct ClusterInfo {
148   unsigned clusterStride;
149   unsigned clusterSize;
150   unsigned subgroupSize;
151 };
152 
153 static FailureOr<ClusterInfo>
154 getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) {
155   assert(llvm::isPowerOf2_32(subgroupSize));
156 
157   std::optional<uint32_t> clusterSize = op.getClusterSize();
158   assert(!clusterSize ||
159          llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this.
160   if (clusterSize && *clusterSize > subgroupSize)
161     return op.emitOpError()
162            << "cluster size " << *clusterSize
163            << " is greater than subgroup size " << subgroupSize;
164   unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
165 
166   auto clusterStride = op.getClusterStride();
167   assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this.
168   if (clusterStride >= subgroupSize)
169     return op.emitOpError()
170            << "cluster stride " << clusterStride
171            << " is not less than subgroup size " << subgroupSize;
172 
173   return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize};
174 }
175 
176 /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
177 /// and `unpackFn` to convert to the native shuffle type and to the reduction
178 /// type, respectively. For example, with `input` of type `f16`, `packFn` could
179 /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
180 /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
181 /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
182 /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for
183 /// lanes within a cluster, reducing all lanes in each cluster in parallel.
184 Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc,
185                                      Value input, gpu::AllReduceOperation mode,
186                                      const ClusterInfo &ci,
187                                      function_ref<Value(Value)> packFn,
188                                      function_ref<Value(Value)> unpackFn) {
189   // Lane value always stays in the original type. We use it to perform arith
190   // reductions.
191   Value laneVal = input;
192   // Parallel reduction using butterfly shuffles.
193   for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize;
194        i <<= 1) {
195     Value shuffled = builder
196                          .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
197                                                  /*width=*/ci.subgroupSize,
198                                                  /*mode=*/gpu::ShuffleMode::XOR)
199                          .getShuffleResult();
200     laneVal = vector::makeArithReduction(builder, loc,
201                                          gpu::convertReductionKind(mode),
202                                          laneVal, unpackFn(shuffled));
203     assert(laneVal.getType() == input.getType());
204   }
205 
206   return laneVal;
207 }
208 
209 /// Lowers scalar gpu subgroup reductions to a series of shuffles.
210 struct ScalarSubgroupReduceToShuffles final
211     : OpRewritePattern<gpu::SubgroupReduceOp> {
212   ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
213                                  unsigned shuffleBitwidth, bool matchClustered,
214                                  PatternBenefit benefit)
215       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
216         shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
217 
218   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
219                                 PatternRewriter &rewriter) const override {
220     if (op.getClusterSize().has_value() != matchClustered) {
221       return rewriter.notifyMatchFailure(
222           op, llvm::formatv("op is {0}clustered but pattern is configured to "
223                             "only match {1}clustered ops",
224                             matchClustered ? "non-" : "",
225                             matchClustered ? "" : "non-"));
226     }
227 
228     auto ci = getAndValidateClusterInfo(op, subgroupSize);
229     if (failed(ci))
230       return failure();
231 
232     Type valueTy = op.getType();
233     unsigned elemBitwidth =
234         getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
235     if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
236       return rewriter.notifyMatchFailure(
237           op, "value type is not a compatible scalar");
238 
239     Location loc = op.getLoc();
240     // Since this is already a native shuffle scalar, no packing is necessary.
241     if (elemBitwidth == shuffleBitwidth) {
242       auto identityFn = [](Value v) { return v; };
243       rewriter.replaceOp(op, createSubgroupShuffleReduction(
244                                  rewriter, loc, op.getValue(), op.getOp(), *ci,
245                                  identityFn, identityFn));
246       return success();
247     }
248 
249     auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
250     auto equivIntType = rewriter.getIntegerType(elemBitwidth);
251     auto packFn = [loc, &rewriter, equivIntType,
252                    shuffleIntType](Value unpackedVal) -> Value {
253       auto asInt =
254           rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
255       return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
256     };
257     auto unpackFn = [loc, &rewriter, equivIntType,
258                      valueTy](Value packedVal) -> Value {
259       auto asInt =
260           rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
261       return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
262     };
263 
264     rewriter.replaceOp(
265         op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
266                                            op.getOp(), *ci, packFn, unpackFn));
267     return success();
268   }
269 
270 private:
271   unsigned subgroupSize = 0;
272   unsigned shuffleBitwidth = 0;
273   bool matchClustered = false;
274 };
275 
276 /// Lowers vector gpu subgroup reductions to a series of shuffles.
277 struct VectorSubgroupReduceToShuffles final
278     : OpRewritePattern<gpu::SubgroupReduceOp> {
279   VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
280                                  unsigned shuffleBitwidth, bool matchClustered,
281                                  PatternBenefit benefit)
282       : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
283         shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {}
284 
285   LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
286                                 PatternRewriter &rewriter) const override {
287     if (op.getClusterSize().has_value() != matchClustered) {
288       return rewriter.notifyMatchFailure(
289           op, llvm::formatv("op is {0}clustered but pattern is configured to "
290                             "only match {1}clustered ops",
291                             matchClustered ? "non-" : "",
292                             matchClustered ? "" : "non-"));
293     }
294 
295     auto ci = getAndValidateClusterInfo(op, subgroupSize);
296     if (failed(ci))
297       return failure();
298 
299     auto vecTy = dyn_cast<VectorType>(op.getType());
300     if (!vecTy)
301       return rewriter.notifyMatchFailure(op, "value type is not a vector");
302 
303     unsigned vecBitwidth =
304         vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
305     if (vecBitwidth > shuffleBitwidth)
306       return rewriter.notifyMatchFailure(
307           op,
308           llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
309                         "to shuffles of size {1}",
310                         vecBitwidth, shuffleBitwidth));
311 
312     unsigned elementsPerShuffle =
313         shuffleBitwidth / vecTy.getElementTypeBitWidth();
314     if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
315       return rewriter.notifyMatchFailure(
316           op, "shuffle bitwidth is not a multiple of the element bitwidth");
317 
318     Location loc = op.getLoc();
319 
320     // If the reduced type is smaller than the native shuffle size, extend it,
321     // perform the shuffles, and extract at the end.
322     auto extendedVecTy = VectorType::get(
323         static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
324     Value extendedInput = op.getValue();
325     if (vecBitwidth < shuffleBitwidth) {
326       auto zero = rewriter.create<arith::ConstantOp>(
327           loc, rewriter.getZeroAttr(extendedVecTy));
328       extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
329           loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
330     }
331 
332     auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
333     auto shuffleVecType = VectorType::get(1, shuffleIntType);
334 
335     auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
336       auto asIntVec =
337           rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
338       return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
339     };
340     auto unpackFn = [loc, &rewriter, shuffleVecType,
341                      extendedVecTy](Value packedVal) -> Value {
342       auto asIntVec =
343           rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
344       return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
345     };
346 
347     Value res = createSubgroupShuffleReduction(
348         rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn);
349 
350     if (vecBitwidth < shuffleBitwidth) {
351       res = rewriter.create<vector::ExtractStridedSliceOp>(
352           loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
353           /*strides=*/1);
354     }
355 
356     rewriter.replaceOp(op, res);
357     return success();
358   }
359 
360 private:
361   unsigned subgroupSize = 0;
362   unsigned shuffleBitwidth = 0;
363   bool matchClustered = false;
364 };
365 } // namespace
366 
367 void mlir::populateGpuBreakDownSubgroupReducePatterns(
368     RewritePatternSet &patterns, unsigned maxShuffleBitwidth,
369     PatternBenefit benefit) {
370   patterns.add<BreakDownSubgroupReduce>(patterns.getContext(),
371                                         maxShuffleBitwidth, benefit);
372   patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
373 }
374 
375 void mlir::populateGpuLowerSubgroupReduceToShufflePatterns(
376     RewritePatternSet &patterns, unsigned subgroupSize,
377     unsigned shuffleBitwidth, PatternBenefit benefit) {
378   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
379       patterns.getContext(), subgroupSize, shuffleBitwidth,
380       /*matchClustered=*/false, benefit);
381 }
382 
383 void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns(
384     RewritePatternSet &patterns, unsigned subgroupSize,
385     unsigned shuffleBitwidth, PatternBenefit benefit) {
386   patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
387       patterns.getContext(), subgroupSize, shuffleBitwidth,
388       /*matchClustered=*/true, benefit);
389 }
390