xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
2 //
3 /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.multi_reduction' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 #include "mlir/Dialect/Vector/Transforms/Passes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 namespace mlir {
23 namespace vector {
24 #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
25 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
26 } // namespace vector
27 } // namespace mlir
28 
29 #define DEBUG_TYPE "vector-multi-reduction"
30 
31 using namespace mlir;
32 
33 namespace {
34 /// This file implements the following transformations as composable atomic
35 /// patterns.
36 
37 /// Converts vector.multi_reduction into inner-most/outer-most reduction form
38 /// by using vector.transpose
39 class InnerOuterDimReductionConversion
40     : public OpRewritePattern<vector::MultiDimReductionOp> {
41 public:
42   using OpRewritePattern::OpRewritePattern;
43 
44   explicit InnerOuterDimReductionConversion(
45       MLIRContext *context, vector::VectorMultiReductionLowering options,
46       PatternBenefit benefit = 1)
47       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
48         useInnerDimsForReduction(
49             options == vector::VectorMultiReductionLowering::InnerReduction) {}
50 
51   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
52                                 PatternRewriter &rewriter) const override {
53     // Vector mask setup.
54     OpBuilder::InsertionGuard guard(rewriter);
55     auto maskableOp =
56         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
57     Operation *rootOp;
58     if (maskableOp.isMasked()) {
59       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
60       rootOp = maskableOp.getMaskingOp();
61     } else {
62       rootOp = multiReductionOp;
63     }
64 
65     auto src = multiReductionOp.getSource();
66     auto loc = multiReductionOp.getLoc();
67     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
68 
69     // Separate reduction and parallel dims
70     ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
71     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
72                                                   reductionDims.end());
73     int64_t reductionSize = reductionDims.size();
74     SmallVector<int64_t, 4> parallelDims;
75     for (int64_t i = 0; i < srcRank; ++i)
76       if (!reductionDimsSet.contains(i))
77         parallelDims.push_back(i);
78 
79     // Add transpose only if inner-most/outer-most dimensions are not parallel
80     // and there are parallel dims.
81     if (parallelDims.empty())
82       return failure();
83     if (useInnerDimsForReduction &&
84         (parallelDims ==
85          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
86       return failure();
87 
88     if (!useInnerDimsForReduction &&
89         (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
90                              reductionDims.size(),
91                              parallelDims.size() + reductionDims.size()))))
92       return failure();
93 
94     SmallVector<int64_t, 4> indices;
95     if (useInnerDimsForReduction) {
96       indices.append(parallelDims.begin(), parallelDims.end());
97       indices.append(reductionDims.begin(), reductionDims.end());
98     } else {
99       indices.append(reductionDims.begin(), reductionDims.end());
100       indices.append(parallelDims.begin(), parallelDims.end());
101     }
102 
103     // If masked, transpose the original mask.
104     Value transposedMask;
105     if (maskableOp.isMasked()) {
106       transposedMask = rewriter.create<vector::TransposeOp>(
107           loc, maskableOp.getMaskingOp().getMask(), indices);
108     }
109 
110     // Transpose reduction source.
111     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
112     SmallVector<bool> reductionMask(srcRank, false);
113     for (int i = 0; i < reductionSize; ++i) {
114       if (useInnerDimsForReduction)
115         reductionMask[srcRank - i - 1] = true;
116       else
117         reductionMask[i] = true;
118     }
119 
120     Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
121         multiReductionOp.getLoc(), transposeOp.getResult(),
122         multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
123     newMultiRedOp =
124         mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
125 
126     rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
127     return success();
128   }
129 
130 private:
131   const bool useInnerDimsForReduction;
132 };
133 
134 /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
135 /// dimensions are either inner most or outer most.
136 class ReduceMultiDimReductionRank
137     : public OpRewritePattern<vector::MultiDimReductionOp> {
138 public:
139   using OpRewritePattern::OpRewritePattern;
140 
141   explicit ReduceMultiDimReductionRank(
142       MLIRContext *context, vector::VectorMultiReductionLowering options,
143       PatternBenefit benefit = 1)
144       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
145         useInnerDimsForReduction(
146             options == vector::VectorMultiReductionLowering::InnerReduction) {}
147 
148   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
149                                 PatternRewriter &rewriter) const override {
150     // Vector mask setup.
151     OpBuilder::InsertionGuard guard(rewriter);
152     auto maskableOp =
153         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
154     Operation *rootOp;
155     if (maskableOp.isMasked()) {
156       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
157       rootOp = maskableOp.getMaskingOp();
158     } else {
159       rootOp = multiReductionOp;
160     }
161 
162     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
163     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
164     auto srcScalableDims =
165         multiReductionOp.getSourceVectorType().getScalableDims();
166     auto loc = multiReductionOp.getLoc();
167 
168     // If rank less than 2, nothing to do.
169     if (srcRank < 2)
170       return failure();
171 
172     // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
173     // `vscale * vscale` that's currently not modelled.
174     if (llvm::count(srcScalableDims, true) > 1)
175       return failure();
176 
177     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
178     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
179     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
180       return failure();
181 
182     // 1. Separate reduction and parallel dims.
183     SmallVector<int64_t, 4> parallelDims, parallelShapes;
184     SmallVector<bool, 4> parallelScalableDims;
185     SmallVector<int64_t, 4> reductionDims, reductionShapes;
186     bool isReductionDimScalable = false;
187     for (const auto &it : llvm::enumerate(reductionMask)) {
188       int64_t i = it.index();
189       bool isReduction = it.value();
190       if (isReduction) {
191         reductionDims.push_back(i);
192         reductionShapes.push_back(srcShape[i]);
193         isReductionDimScalable |= srcScalableDims[i];
194       } else {
195         parallelDims.push_back(i);
196         parallelShapes.push_back(srcShape[i]);
197         parallelScalableDims.push_back(srcScalableDims[i]);
198       }
199     }
200 
201     // 2. Compute flattened parallel and reduction sizes.
202     int flattenedParallelDim = 0;
203     int flattenedReductionDim = 0;
204     if (!parallelShapes.empty()) {
205       flattenedParallelDim = 1;
206       for (auto d : parallelShapes)
207         flattenedParallelDim *= d;
208     }
209     if (!reductionShapes.empty()) {
210       flattenedReductionDim = 1;
211       for (auto d : reductionShapes)
212         flattenedReductionDim *= d;
213     }
214     // We must at least have some parallel or some reduction.
215     assert((flattenedParallelDim || flattenedReductionDim) &&
216            "expected at least one parallel or reduction dim");
217 
218     // 3. Fail if reduction/parallel dims are not contiguous.
219     // Check parallelDims are exactly [0 .. size).
220     int64_t counter = 0;
221     if (useInnerDimsForReduction &&
222         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
223       return failure();
224     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
225     counter = reductionDims.size();
226     if (!useInnerDimsForReduction &&
227         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
228       return failure();
229 
230     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
231     // a single parallel (resp. reduction) dim.
232     SmallVector<bool, 2> mask;
233     SmallVector<bool, 2> scalableDims;
234     SmallVector<int64_t, 2> vectorShape;
235     bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
236     if (flattenedParallelDim) {
237       mask.push_back(false);
238       vectorShape.push_back(flattenedParallelDim);
239       scalableDims.push_back(isParallelDimScalable);
240     }
241     if (flattenedReductionDim) {
242       mask.push_back(true);
243       vectorShape.push_back(flattenedReductionDim);
244       scalableDims.push_back(isReductionDimScalable);
245     }
246     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
247       std::swap(mask.front(), mask.back());
248       std::swap(vectorShape.front(), vectorShape.back());
249       std::swap(scalableDims.front(), scalableDims.back());
250     }
251 
252     Value newVectorMask;
253     if (maskableOp.isMasked()) {
254       Value vectorMask = maskableOp.getMaskingOp().getMask();
255       auto maskCastedType = VectorType::get(
256           vectorShape,
257           llvm::cast<VectorType>(vectorMask.getType()).getElementType());
258       newVectorMask =
259           rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
260     }
261 
262     auto castedType = VectorType::get(
263         vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
264         scalableDims);
265     Value cast = rewriter.create<vector::ShapeCastOp>(
266         loc, castedType, multiReductionOp.getSource());
267 
268     Value acc = multiReductionOp.getAcc();
269     if (flattenedParallelDim) {
270       auto accType = VectorType::get(
271           {flattenedParallelDim},
272           multiReductionOp.getSourceVectorType().getElementType(),
273           /*scalableDims=*/{isParallelDimScalable});
274       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
275     }
276     // 6. Creates the flattened form of vector.multi_reduction with inner/outer
277     // most dim as reduction.
278     Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
279         loc, cast, acc, mask, multiReductionOp.getKind());
280     newMultiDimRedOp =
281         mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
282 
283     // 7. If there are no parallel shapes, the result is a scalar.
284     // TODO: support 0-d vectors when available.
285     if (parallelShapes.empty()) {
286       rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
287       return success();
288     }
289 
290     // 8. Creates shape cast for the output n-D -> 2-D.
291     VectorType outputCastedType = VectorType::get(
292         parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
293         parallelScalableDims);
294     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
295         rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
296     return success();
297   }
298 
299 private:
300   const bool useInnerDimsForReduction;
301 };
302 
303 /// Unrolls vector.multi_reduction with outermost reductions
304 /// and combines results
305 struct TwoDimMultiReductionToElementWise
306     : public OpRewritePattern<vector::MultiDimReductionOp> {
307   using OpRewritePattern::OpRewritePattern;
308 
309   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
310                                 PatternRewriter &rewriter) const override {
311     auto maskableOp =
312         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
313     if (maskableOp.isMasked())
314       // TODO: Support masking.
315       return failure();
316 
317     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
318     // Rank-2 ["parallel", "reduce"] or bail.
319     if (srcRank != 2)
320       return failure();
321 
322     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
323       return failure();
324 
325     auto loc = multiReductionOp.getLoc();
326     ArrayRef<int64_t> srcShape =
327         multiReductionOp.getSourceVectorType().getShape();
328 
329     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
330     if (!elementType.isIntOrIndexOrFloat())
331       return failure();
332 
333     Value result = multiReductionOp.getAcc();
334     for (int64_t i = 0; i < srcShape[0]; i++) {
335       auto operand = rewriter.create<vector::ExtractOp>(
336           loc, multiReductionOp.getSource(), i);
337       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
338                                   operand, result);
339     }
340 
341     rewriter.replaceOp(multiReductionOp, result);
342     return success();
343   }
344 };
345 
346 /// Converts 2d vector.multi_reduction with inner most reduction dimension into
347 /// a sequence of vector.reduction ops.
348 struct TwoDimMultiReductionToReduction
349     : public OpRewritePattern<vector::MultiDimReductionOp> {
350   using OpRewritePattern::OpRewritePattern;
351 
352   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
353                                 PatternRewriter &rewriter) const override {
354     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
355     if (srcRank != 2)
356       return failure();
357 
358     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
359       return failure();
360 
361     // Vector mask setup.
362     OpBuilder::InsertionGuard guard(rewriter);
363     auto maskableOp =
364         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
365     Operation *rootOp;
366     if (maskableOp.isMasked()) {
367       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
368       rootOp = maskableOp.getMaskingOp();
369     } else {
370       rootOp = multiReductionOp;
371     }
372 
373     auto loc = multiReductionOp.getLoc();
374     Value result = rewriter.create<arith::ConstantOp>(
375         loc, multiReductionOp.getDestType(),
376         rewriter.getZeroAttr(multiReductionOp.getDestType()));
377     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
378 
379     for (int i = 0; i < outerDim; ++i) {
380       auto v = rewriter.create<vector::ExtractOp>(
381           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
382       auto acc = rewriter.create<vector::ExtractOp>(
383           loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
384       Operation *reductionOp = rewriter.create<vector::ReductionOp>(
385           loc, multiReductionOp.getKind(), v, acc);
386 
387       // If masked, slice the mask and mask the new reduction operation.
388       if (maskableOp.isMasked()) {
389         Value mask = rewriter.create<vector::ExtractOp>(
390             loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
391         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
392       }
393 
394       result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
395                                                  result, i);
396     }
397 
398     rewriter.replaceOp(rootOp, result);
399     return success();
400   }
401 };
402 
403 /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
404 /// form with both a single parallel and reduction dimension.
405 /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
406 /// The case with a single parallel dimension is a noop and folds away
407 /// separately.
408 struct OneDimMultiReductionToTwoDim
409     : public OpRewritePattern<vector::MultiDimReductionOp> {
410   using OpRewritePattern::OpRewritePattern;
411 
412   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
413                                 PatternRewriter &rewriter) const override {
414     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
415     // Rank-1 or bail.
416     if (srcRank != 1)
417       return failure();
418 
419     // Vector mask setup.
420     OpBuilder::InsertionGuard guard(rewriter);
421     auto maskableOp =
422         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
423     Operation *rootOp;
424     Value mask;
425     if (maskableOp.isMasked()) {
426       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
427       rootOp = maskableOp.getMaskingOp();
428       mask = maskableOp.getMaskingOp().getMask();
429     } else {
430       rootOp = multiReductionOp;
431     }
432 
433     auto loc = multiReductionOp.getLoc();
434     auto srcVectorType = multiReductionOp.getSourceVectorType();
435     auto srcShape = srcVectorType.getShape();
436     auto castedType = VectorType::get(
437         ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
438         ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
439 
440     auto accType =
441         VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
442     assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
443            "multi_reduction with a single dimension expects a scalar result");
444 
445     // If the unique dim is reduced and we insert a parallel in front, we need a
446     // {false, true} mask.
447     SmallVector<bool, 2> reductionMask{false, true};
448 
449     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
450     Value cast = rewriter.create<vector::ShapeCastOp>(
451         loc, castedType, multiReductionOp.getSource());
452     Value castAcc = rewriter.create<vector::BroadcastOp>(
453         loc, accType, multiReductionOp.getAcc());
454     Value castMask;
455     if (maskableOp.isMasked()) {
456       auto maskType = llvm::cast<VectorType>(mask.getType());
457       auto castMaskType = VectorType::get(
458           ArrayRef<int64_t>{1, maskType.getShape().back()},
459           maskType.getElementType(),
460           ArrayRef<bool>{false, maskType.getScalableDims().back()});
461       castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
462     }
463 
464     Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
465         loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
466     newOp = vector::maskOperation(rewriter, newOp, castMask);
467 
468     rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
469                                                    ArrayRef<int64_t>{0});
470     return success();
471   }
472 };
473 
474 struct LowerVectorMultiReductionPass
475     : public vector::impl::LowerVectorMultiReductionBase<
476           LowerVectorMultiReductionPass> {
477   LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
478     this->loweringStrategy = option;
479   }
480 
481   void runOnOperation() override {
482     Operation *op = getOperation();
483     MLIRContext *context = op->getContext();
484 
485     RewritePatternSet loweringPatterns(context);
486     populateVectorMultiReductionLoweringPatterns(loweringPatterns,
487                                                  this->loweringStrategy);
488 
489     if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
490       signalPassFailure();
491   }
492 
493   void getDependentDialects(DialectRegistry &registry) const override {
494     registry.insert<vector::VectorDialect>();
495   }
496 };
497 
498 } // namespace
499 
500 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
501     RewritePatternSet &patterns, VectorMultiReductionLowering options,
502     PatternBenefit benefit) {
503   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
504       patterns.getContext(), options, benefit);
505   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
506   if (options == VectorMultiReductionLowering ::InnerReduction)
507     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
508                                                   benefit);
509   else
510     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
511                                                     benefit);
512 }
513 
514 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
515     vector::VectorMultiReductionLowering option) {
516   return std::make_unique<LowerVectorMultiReductionPass>(option);
517 }
518