xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
12bc4c3e9SNicolas Vasilache //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===//
22bc4c3e9SNicolas Vasilache //
32bc4c3e9SNicolas Vasilache /// Part of the LLVM Project, under the Apache License v2.0 with LLVM
42bc4c3e9SNicolas Vasilache /// Exceptions. See https://llvm.org/LICENSE.txt for license information.
52bc4c3e9SNicolas Vasilache /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62bc4c3e9SNicolas Vasilache //
72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
82bc4c3e9SNicolas Vasilache //
92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the
102bc4c3e9SNicolas Vasilache // 'vector.multi_reduction' operation.
112bc4c3e9SNicolas Vasilache //
122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
132bc4c3e9SNicolas Vasilache 
142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h"
158d6469b0Sxiaoleis-nv #include "mlir/Dialect/Func/IR/FuncOps.h"
162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
178d6469b0Sxiaoleis-nv #include "mlir/Dialect/Vector/Transforms/Passes.h"
182bc4c3e9SNicolas Vasilache #include "mlir/IR/Builders.h"
192bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h"
208d6469b0Sxiaoleis-nv #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
218d6469b0Sxiaoleis-nv 
228d6469b0Sxiaoleis-nv namespace mlir {
238d6469b0Sxiaoleis-nv namespace vector {
248d6469b0Sxiaoleis-nv #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION
258d6469b0Sxiaoleis-nv #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
268d6469b0Sxiaoleis-nv } // namespace vector
278d6469b0Sxiaoleis-nv } // namespace mlir
282bc4c3e9SNicolas Vasilache 
292bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-multi-reduction"
302bc4c3e9SNicolas Vasilache 
312bc4c3e9SNicolas Vasilache using namespace mlir;
322bc4c3e9SNicolas Vasilache 
332bc4c3e9SNicolas Vasilache namespace {
342bc4c3e9SNicolas Vasilache /// This file implements the following transformations as composable atomic
352bc4c3e9SNicolas Vasilache /// patterns.
362bc4c3e9SNicolas Vasilache 
372bc4c3e9SNicolas Vasilache /// Converts vector.multi_reduction into inner-most/outer-most reduction form
382bc4c3e9SNicolas Vasilache /// by using vector.transpose
392bc4c3e9SNicolas Vasilache class InnerOuterDimReductionConversion
402bc4c3e9SNicolas Vasilache     : public OpRewritePattern<vector::MultiDimReductionOp> {
412bc4c3e9SNicolas Vasilache public:
422bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
432bc4c3e9SNicolas Vasilache 
442bc4c3e9SNicolas Vasilache   explicit InnerOuterDimReductionConversion(
452bc4c3e9SNicolas Vasilache       MLIRContext *context, vector::VectorMultiReductionLowering options,
462bc4c3e9SNicolas Vasilache       PatternBenefit benefit = 1)
472bc4c3e9SNicolas Vasilache       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
482bc4c3e9SNicolas Vasilache         useInnerDimsForReduction(
492bc4c3e9SNicolas Vasilache             options == vector::VectorMultiReductionLowering::InnerReduction) {}
502bc4c3e9SNicolas Vasilache 
512bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
522bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
532bc4c3e9SNicolas Vasilache     // Vector mask setup.
542bc4c3e9SNicolas Vasilache     OpBuilder::InsertionGuard guard(rewriter);
552bc4c3e9SNicolas Vasilache     auto maskableOp =
562bc4c3e9SNicolas Vasilache         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
572bc4c3e9SNicolas Vasilache     Operation *rootOp;
582bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
592bc4c3e9SNicolas Vasilache       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
602bc4c3e9SNicolas Vasilache       rootOp = maskableOp.getMaskingOp();
612bc4c3e9SNicolas Vasilache     } else {
622bc4c3e9SNicolas Vasilache       rootOp = multiReductionOp;
632bc4c3e9SNicolas Vasilache     }
642bc4c3e9SNicolas Vasilache 
652bc4c3e9SNicolas Vasilache     auto src = multiReductionOp.getSource();
662bc4c3e9SNicolas Vasilache     auto loc = multiReductionOp.getLoc();
672bc4c3e9SNicolas Vasilache     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
682bc4c3e9SNicolas Vasilache 
692bc4c3e9SNicolas Vasilache     // Separate reduction and parallel dims
705f26497dSBenjamin Maxwell     ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
712bc4c3e9SNicolas Vasilache     llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
722bc4c3e9SNicolas Vasilache                                                   reductionDims.end());
732bc4c3e9SNicolas Vasilache     int64_t reductionSize = reductionDims.size();
742bc4c3e9SNicolas Vasilache     SmallVector<int64_t, 4> parallelDims;
752bc4c3e9SNicolas Vasilache     for (int64_t i = 0; i < srcRank; ++i)
762bc4c3e9SNicolas Vasilache       if (!reductionDimsSet.contains(i))
772bc4c3e9SNicolas Vasilache         parallelDims.push_back(i);
782bc4c3e9SNicolas Vasilache 
792bc4c3e9SNicolas Vasilache     // Add transpose only if inner-most/outer-most dimensions are not parallel
802bc4c3e9SNicolas Vasilache     // and there are parallel dims.
812bc4c3e9SNicolas Vasilache     if (parallelDims.empty())
822bc4c3e9SNicolas Vasilache       return failure();
832bc4c3e9SNicolas Vasilache     if (useInnerDimsForReduction &&
842bc4c3e9SNicolas Vasilache         (parallelDims ==
852bc4c3e9SNicolas Vasilache          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
862bc4c3e9SNicolas Vasilache       return failure();
872bc4c3e9SNicolas Vasilache 
882bc4c3e9SNicolas Vasilache     if (!useInnerDimsForReduction &&
892bc4c3e9SNicolas Vasilache         (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
902bc4c3e9SNicolas Vasilache                              reductionDims.size(),
912bc4c3e9SNicolas Vasilache                              parallelDims.size() + reductionDims.size()))))
922bc4c3e9SNicolas Vasilache       return failure();
932bc4c3e9SNicolas Vasilache 
942bc4c3e9SNicolas Vasilache     SmallVector<int64_t, 4> indices;
952bc4c3e9SNicolas Vasilache     if (useInnerDimsForReduction) {
962bc4c3e9SNicolas Vasilache       indices.append(parallelDims.begin(), parallelDims.end());
972bc4c3e9SNicolas Vasilache       indices.append(reductionDims.begin(), reductionDims.end());
982bc4c3e9SNicolas Vasilache     } else {
992bc4c3e9SNicolas Vasilache       indices.append(reductionDims.begin(), reductionDims.end());
1002bc4c3e9SNicolas Vasilache       indices.append(parallelDims.begin(), parallelDims.end());
1012bc4c3e9SNicolas Vasilache     }
1022bc4c3e9SNicolas Vasilache 
1032bc4c3e9SNicolas Vasilache     // If masked, transpose the original mask.
1042bc4c3e9SNicolas Vasilache     Value transposedMask;
1052bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
1062bc4c3e9SNicolas Vasilache       transposedMask = rewriter.create<vector::TransposeOp>(
1072bc4c3e9SNicolas Vasilache           loc, maskableOp.getMaskingOp().getMask(), indices);
1082bc4c3e9SNicolas Vasilache     }
1092bc4c3e9SNicolas Vasilache 
1102bc4c3e9SNicolas Vasilache     // Transpose reduction source.
1112bc4c3e9SNicolas Vasilache     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
1122bc4c3e9SNicolas Vasilache     SmallVector<bool> reductionMask(srcRank, false);
1132bc4c3e9SNicolas Vasilache     for (int i = 0; i < reductionSize; ++i) {
1142bc4c3e9SNicolas Vasilache       if (useInnerDimsForReduction)
1152bc4c3e9SNicolas Vasilache         reductionMask[srcRank - i - 1] = true;
1162bc4c3e9SNicolas Vasilache       else
1172bc4c3e9SNicolas Vasilache         reductionMask[i] = true;
1182bc4c3e9SNicolas Vasilache     }
1192bc4c3e9SNicolas Vasilache 
1202bc4c3e9SNicolas Vasilache     Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
1212bc4c3e9SNicolas Vasilache         multiReductionOp.getLoc(), transposeOp.getResult(),
1222bc4c3e9SNicolas Vasilache         multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
1232bc4c3e9SNicolas Vasilache     newMultiRedOp =
1242bc4c3e9SNicolas Vasilache         mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
1252bc4c3e9SNicolas Vasilache 
1262bc4c3e9SNicolas Vasilache     rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
1272bc4c3e9SNicolas Vasilache     return success();
1282bc4c3e9SNicolas Vasilache   }
1292bc4c3e9SNicolas Vasilache 
1302bc4c3e9SNicolas Vasilache private:
1312bc4c3e9SNicolas Vasilache   const bool useInnerDimsForReduction;
1322bc4c3e9SNicolas Vasilache };
1332bc4c3e9SNicolas Vasilache 
1342bc4c3e9SNicolas Vasilache /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
1352bc4c3e9SNicolas Vasilache /// dimensions are either inner most or outer most.
1362bc4c3e9SNicolas Vasilache class ReduceMultiDimReductionRank
1372bc4c3e9SNicolas Vasilache     : public OpRewritePattern<vector::MultiDimReductionOp> {
1382bc4c3e9SNicolas Vasilache public:
1392bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
1402bc4c3e9SNicolas Vasilache 
1412bc4c3e9SNicolas Vasilache   explicit ReduceMultiDimReductionRank(
1422bc4c3e9SNicolas Vasilache       MLIRContext *context, vector::VectorMultiReductionLowering options,
1432bc4c3e9SNicolas Vasilache       PatternBenefit benefit = 1)
1442bc4c3e9SNicolas Vasilache       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
1452bc4c3e9SNicolas Vasilache         useInnerDimsForReduction(
1462bc4c3e9SNicolas Vasilache             options == vector::VectorMultiReductionLowering::InnerReduction) {}
1472bc4c3e9SNicolas Vasilache 
1482bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
1492bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
1502bc4c3e9SNicolas Vasilache     // Vector mask setup.
1512bc4c3e9SNicolas Vasilache     OpBuilder::InsertionGuard guard(rewriter);
1522bc4c3e9SNicolas Vasilache     auto maskableOp =
1532bc4c3e9SNicolas Vasilache         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
1542bc4c3e9SNicolas Vasilache     Operation *rootOp;
1552bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
1562bc4c3e9SNicolas Vasilache       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1572bc4c3e9SNicolas Vasilache       rootOp = maskableOp.getMaskingOp();
1582bc4c3e9SNicolas Vasilache     } else {
1592bc4c3e9SNicolas Vasilache       rootOp = multiReductionOp;
1602bc4c3e9SNicolas Vasilache     }
1612bc4c3e9SNicolas Vasilache 
1622bc4c3e9SNicolas Vasilache     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
1632bc4c3e9SNicolas Vasilache     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
1645c581720SAndrzej Warzynski     auto srcScalableDims =
1655c581720SAndrzej Warzynski         multiReductionOp.getSourceVectorType().getScalableDims();
1662bc4c3e9SNicolas Vasilache     auto loc = multiReductionOp.getLoc();
1672bc4c3e9SNicolas Vasilache 
1682bc4c3e9SNicolas Vasilache     // If rank less than 2, nothing to do.
1692bc4c3e9SNicolas Vasilache     if (srcRank < 2)
1702bc4c3e9SNicolas Vasilache       return failure();
1712bc4c3e9SNicolas Vasilache 
1725c581720SAndrzej Warzynski     // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g.
1735c581720SAndrzej Warzynski     // `vscale * vscale` that's currently not modelled.
1745c581720SAndrzej Warzynski     if (llvm::count(srcScalableDims, true) > 1)
1755c581720SAndrzej Warzynski       return failure();
1765c581720SAndrzej Warzynski 
1772bc4c3e9SNicolas Vasilache     // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
1782bc4c3e9SNicolas Vasilache     SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
1792bc4c3e9SNicolas Vasilache     if (srcRank == 2 && reductionMask.front() != reductionMask.back())
1802bc4c3e9SNicolas Vasilache       return failure();
1812bc4c3e9SNicolas Vasilache 
1822bc4c3e9SNicolas Vasilache     // 1. Separate reduction and parallel dims.
1832bc4c3e9SNicolas Vasilache     SmallVector<int64_t, 4> parallelDims, parallelShapes;
1845c581720SAndrzej Warzynski     SmallVector<bool, 4> parallelScalableDims;
1852bc4c3e9SNicolas Vasilache     SmallVector<int64_t, 4> reductionDims, reductionShapes;
1865c581720SAndrzej Warzynski     bool isReductionDimScalable = false;
1872bc4c3e9SNicolas Vasilache     for (const auto &it : llvm::enumerate(reductionMask)) {
1882bc4c3e9SNicolas Vasilache       int64_t i = it.index();
1892bc4c3e9SNicolas Vasilache       bool isReduction = it.value();
1902bc4c3e9SNicolas Vasilache       if (isReduction) {
1912bc4c3e9SNicolas Vasilache         reductionDims.push_back(i);
1922bc4c3e9SNicolas Vasilache         reductionShapes.push_back(srcShape[i]);
1935c581720SAndrzej Warzynski         isReductionDimScalable |= srcScalableDims[i];
1942bc4c3e9SNicolas Vasilache       } else {
1952bc4c3e9SNicolas Vasilache         parallelDims.push_back(i);
1962bc4c3e9SNicolas Vasilache         parallelShapes.push_back(srcShape[i]);
1975c581720SAndrzej Warzynski         parallelScalableDims.push_back(srcScalableDims[i]);
1982bc4c3e9SNicolas Vasilache       }
1992bc4c3e9SNicolas Vasilache     }
2002bc4c3e9SNicolas Vasilache 
2012bc4c3e9SNicolas Vasilache     // 2. Compute flattened parallel and reduction sizes.
2022bc4c3e9SNicolas Vasilache     int flattenedParallelDim = 0;
2032bc4c3e9SNicolas Vasilache     int flattenedReductionDim = 0;
2042bc4c3e9SNicolas Vasilache     if (!parallelShapes.empty()) {
2052bc4c3e9SNicolas Vasilache       flattenedParallelDim = 1;
2062bc4c3e9SNicolas Vasilache       for (auto d : parallelShapes)
2072bc4c3e9SNicolas Vasilache         flattenedParallelDim *= d;
2082bc4c3e9SNicolas Vasilache     }
2092bc4c3e9SNicolas Vasilache     if (!reductionShapes.empty()) {
2102bc4c3e9SNicolas Vasilache       flattenedReductionDim = 1;
2112bc4c3e9SNicolas Vasilache       for (auto d : reductionShapes)
2122bc4c3e9SNicolas Vasilache         flattenedReductionDim *= d;
2132bc4c3e9SNicolas Vasilache     }
2142bc4c3e9SNicolas Vasilache     // We must at least have some parallel or some reduction.
2152bc4c3e9SNicolas Vasilache     assert((flattenedParallelDim || flattenedReductionDim) &&
2162bc4c3e9SNicolas Vasilache            "expected at least one parallel or reduction dim");
2172bc4c3e9SNicolas Vasilache 
2182bc4c3e9SNicolas Vasilache     // 3. Fail if reduction/parallel dims are not contiguous.
2192bc4c3e9SNicolas Vasilache     // Check parallelDims are exactly [0 .. size).
2202bc4c3e9SNicolas Vasilache     int64_t counter = 0;
2212bc4c3e9SNicolas Vasilache     if (useInnerDimsForReduction &&
2222bc4c3e9SNicolas Vasilache         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
2232bc4c3e9SNicolas Vasilache       return failure();
2242bc4c3e9SNicolas Vasilache     // Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
2252bc4c3e9SNicolas Vasilache     counter = reductionDims.size();
2262bc4c3e9SNicolas Vasilache     if (!useInnerDimsForReduction &&
2272bc4c3e9SNicolas Vasilache         llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
2282bc4c3e9SNicolas Vasilache       return failure();
2292bc4c3e9SNicolas Vasilache 
2302bc4c3e9SNicolas Vasilache     // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
2312bc4c3e9SNicolas Vasilache     // a single parallel (resp. reduction) dim.
2322bc4c3e9SNicolas Vasilache     SmallVector<bool, 2> mask;
2335c581720SAndrzej Warzynski     SmallVector<bool, 2> scalableDims;
2342bc4c3e9SNicolas Vasilache     SmallVector<int64_t, 2> vectorShape;
2355c581720SAndrzej Warzynski     bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true);
2362bc4c3e9SNicolas Vasilache     if (flattenedParallelDim) {
2372bc4c3e9SNicolas Vasilache       mask.push_back(false);
2382bc4c3e9SNicolas Vasilache       vectorShape.push_back(flattenedParallelDim);
2395c581720SAndrzej Warzynski       scalableDims.push_back(isParallelDimScalable);
2402bc4c3e9SNicolas Vasilache     }
2412bc4c3e9SNicolas Vasilache     if (flattenedReductionDim) {
2422bc4c3e9SNicolas Vasilache       mask.push_back(true);
2432bc4c3e9SNicolas Vasilache       vectorShape.push_back(flattenedReductionDim);
2445c581720SAndrzej Warzynski       scalableDims.push_back(isReductionDimScalable);
2452bc4c3e9SNicolas Vasilache     }
2462bc4c3e9SNicolas Vasilache     if (!useInnerDimsForReduction && vectorShape.size() == 2) {
2472bc4c3e9SNicolas Vasilache       std::swap(mask.front(), mask.back());
2482bc4c3e9SNicolas Vasilache       std::swap(vectorShape.front(), vectorShape.back());
2495c581720SAndrzej Warzynski       std::swap(scalableDims.front(), scalableDims.back());
2502bc4c3e9SNicolas Vasilache     }
2512bc4c3e9SNicolas Vasilache 
2522bc4c3e9SNicolas Vasilache     Value newVectorMask;
2532bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
2542bc4c3e9SNicolas Vasilache       Value vectorMask = maskableOp.getMaskingOp().getMask();
2552bc4c3e9SNicolas Vasilache       auto maskCastedType = VectorType::get(
2562bc4c3e9SNicolas Vasilache           vectorShape,
257c1fa60b4STres Popp           llvm::cast<VectorType>(vectorMask.getType()).getElementType());
2582bc4c3e9SNicolas Vasilache       newVectorMask =
2592bc4c3e9SNicolas Vasilache           rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
2602bc4c3e9SNicolas Vasilache     }
2612bc4c3e9SNicolas Vasilache 
2622bc4c3e9SNicolas Vasilache     auto castedType = VectorType::get(
2635c581720SAndrzej Warzynski         vectorShape, multiReductionOp.getSourceVectorType().getElementType(),
2645c581720SAndrzej Warzynski         scalableDims);
2652bc4c3e9SNicolas Vasilache     Value cast = rewriter.create<vector::ShapeCastOp>(
2662bc4c3e9SNicolas Vasilache         loc, castedType, multiReductionOp.getSource());
2672bc4c3e9SNicolas Vasilache 
2682bc4c3e9SNicolas Vasilache     Value acc = multiReductionOp.getAcc();
2692bc4c3e9SNicolas Vasilache     if (flattenedParallelDim) {
2702bc4c3e9SNicolas Vasilache       auto accType = VectorType::get(
2712bc4c3e9SNicolas Vasilache           {flattenedParallelDim},
2725c581720SAndrzej Warzynski           multiReductionOp.getSourceVectorType().getElementType(),
2735c581720SAndrzej Warzynski           /*scalableDims=*/{isParallelDimScalable});
2742bc4c3e9SNicolas Vasilache       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
2752bc4c3e9SNicolas Vasilache     }
2762bc4c3e9SNicolas Vasilache     // 6. Creates the flattened form of vector.multi_reduction with inner/outer
2772bc4c3e9SNicolas Vasilache     // most dim as reduction.
2782bc4c3e9SNicolas Vasilache     Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
2792bc4c3e9SNicolas Vasilache         loc, cast, acc, mask, multiReductionOp.getKind());
2802bc4c3e9SNicolas Vasilache     newMultiDimRedOp =
2812bc4c3e9SNicolas Vasilache         mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
2822bc4c3e9SNicolas Vasilache 
2832bc4c3e9SNicolas Vasilache     // 7. If there are no parallel shapes, the result is a scalar.
2842bc4c3e9SNicolas Vasilache     // TODO: support 0-d vectors when available.
2852bc4c3e9SNicolas Vasilache     if (parallelShapes.empty()) {
2862bc4c3e9SNicolas Vasilache       rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
2872bc4c3e9SNicolas Vasilache       return success();
2882bc4c3e9SNicolas Vasilache     }
2892bc4c3e9SNicolas Vasilache 
2902bc4c3e9SNicolas Vasilache     // 8. Creates shape cast for the output n-D -> 2-D.
2912bc4c3e9SNicolas Vasilache     VectorType outputCastedType = VectorType::get(
2925c581720SAndrzej Warzynski         parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
2935c581720SAndrzej Warzynski         parallelScalableDims);
2942bc4c3e9SNicolas Vasilache     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2952bc4c3e9SNicolas Vasilache         rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
2962bc4c3e9SNicolas Vasilache     return success();
2972bc4c3e9SNicolas Vasilache   }
2982bc4c3e9SNicolas Vasilache 
2992bc4c3e9SNicolas Vasilache private:
3002bc4c3e9SNicolas Vasilache   const bool useInnerDimsForReduction;
3012bc4c3e9SNicolas Vasilache };
3022bc4c3e9SNicolas Vasilache 
3032bc4c3e9SNicolas Vasilache /// Unrolls vector.multi_reduction with outermost reductions
3042bc4c3e9SNicolas Vasilache /// and combines results
3052bc4c3e9SNicolas Vasilache struct TwoDimMultiReductionToElementWise
3062bc4c3e9SNicolas Vasilache     : public OpRewritePattern<vector::MultiDimReductionOp> {
3072bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
3082bc4c3e9SNicolas Vasilache 
3092bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
3102bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
3112bc4c3e9SNicolas Vasilache     auto maskableOp =
3122bc4c3e9SNicolas Vasilache         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
3132bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked())
3142bc4c3e9SNicolas Vasilache       // TODO: Support masking.
3152bc4c3e9SNicolas Vasilache       return failure();
3162bc4c3e9SNicolas Vasilache 
3172bc4c3e9SNicolas Vasilache     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
3182bc4c3e9SNicolas Vasilache     // Rank-2 ["parallel", "reduce"] or bail.
3192bc4c3e9SNicolas Vasilache     if (srcRank != 2)
3202bc4c3e9SNicolas Vasilache       return failure();
3212bc4c3e9SNicolas Vasilache 
3222bc4c3e9SNicolas Vasilache     if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
3232bc4c3e9SNicolas Vasilache       return failure();
3242bc4c3e9SNicolas Vasilache 
3252bc4c3e9SNicolas Vasilache     auto loc = multiReductionOp.getLoc();
3262bc4c3e9SNicolas Vasilache     ArrayRef<int64_t> srcShape =
3272bc4c3e9SNicolas Vasilache         multiReductionOp.getSourceVectorType().getShape();
3282bc4c3e9SNicolas Vasilache 
3292bc4c3e9SNicolas Vasilache     Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
3302bc4c3e9SNicolas Vasilache     if (!elementType.isIntOrIndexOrFloat())
3312bc4c3e9SNicolas Vasilache       return failure();
3322bc4c3e9SNicolas Vasilache 
3332bc4c3e9SNicolas Vasilache     Value result = multiReductionOp.getAcc();
3342bc4c3e9SNicolas Vasilache     for (int64_t i = 0; i < srcShape[0]; i++) {
3352bc4c3e9SNicolas Vasilache       auto operand = rewriter.create<vector::ExtractOp>(
3362bc4c3e9SNicolas Vasilache           loc, multiReductionOp.getSource(), i);
3372bc4c3e9SNicolas Vasilache       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
3382bc4c3e9SNicolas Vasilache                                   operand, result);
3392bc4c3e9SNicolas Vasilache     }
3402bc4c3e9SNicolas Vasilache 
3412bc4c3e9SNicolas Vasilache     rewriter.replaceOp(multiReductionOp, result);
3422bc4c3e9SNicolas Vasilache     return success();
3432bc4c3e9SNicolas Vasilache   }
3442bc4c3e9SNicolas Vasilache };
3452bc4c3e9SNicolas Vasilache 
3462bc4c3e9SNicolas Vasilache /// Converts 2d vector.multi_reduction with inner most reduction dimension into
3472bc4c3e9SNicolas Vasilache /// a sequence of vector.reduction ops.
3482bc4c3e9SNicolas Vasilache struct TwoDimMultiReductionToReduction
3492bc4c3e9SNicolas Vasilache     : public OpRewritePattern<vector::MultiDimReductionOp> {
3502bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
3512bc4c3e9SNicolas Vasilache 
3522bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
3532bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
3542bc4c3e9SNicolas Vasilache     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
3552bc4c3e9SNicolas Vasilache     if (srcRank != 2)
3562bc4c3e9SNicolas Vasilache       return failure();
3572bc4c3e9SNicolas Vasilache 
3582bc4c3e9SNicolas Vasilache     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
3592bc4c3e9SNicolas Vasilache       return failure();
3602bc4c3e9SNicolas Vasilache 
3612bc4c3e9SNicolas Vasilache     // Vector mask setup.
3622bc4c3e9SNicolas Vasilache     OpBuilder::InsertionGuard guard(rewriter);
3632bc4c3e9SNicolas Vasilache     auto maskableOp =
3642bc4c3e9SNicolas Vasilache         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
3652bc4c3e9SNicolas Vasilache     Operation *rootOp;
3662bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
3672bc4c3e9SNicolas Vasilache       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
3682bc4c3e9SNicolas Vasilache       rootOp = maskableOp.getMaskingOp();
3692bc4c3e9SNicolas Vasilache     } else {
3702bc4c3e9SNicolas Vasilache       rootOp = multiReductionOp;
3712bc4c3e9SNicolas Vasilache     }
3722bc4c3e9SNicolas Vasilache 
3732bc4c3e9SNicolas Vasilache     auto loc = multiReductionOp.getLoc();
3742bc4c3e9SNicolas Vasilache     Value result = rewriter.create<arith::ConstantOp>(
3752bc4c3e9SNicolas Vasilache         loc, multiReductionOp.getDestType(),
3762bc4c3e9SNicolas Vasilache         rewriter.getZeroAttr(multiReductionOp.getDestType()));
3772bc4c3e9SNicolas Vasilache     int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
3782bc4c3e9SNicolas Vasilache 
3792bc4c3e9SNicolas Vasilache     for (int i = 0; i < outerDim; ++i) {
3802bc4c3e9SNicolas Vasilache       auto v = rewriter.create<vector::ExtractOp>(
3812bc4c3e9SNicolas Vasilache           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
3822bc4c3e9SNicolas Vasilache       auto acc = rewriter.create<vector::ExtractOp>(
3832bc4c3e9SNicolas Vasilache           loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
3842bc4c3e9SNicolas Vasilache       Operation *reductionOp = rewriter.create<vector::ReductionOp>(
3852bc4c3e9SNicolas Vasilache           loc, multiReductionOp.getKind(), v, acc);
3862bc4c3e9SNicolas Vasilache 
3872bc4c3e9SNicolas Vasilache       // If masked, slice the mask and mask the new reduction operation.
3882bc4c3e9SNicolas Vasilache       if (maskableOp.isMasked()) {
3892bc4c3e9SNicolas Vasilache         Value mask = rewriter.create<vector::ExtractOp>(
3902bc4c3e9SNicolas Vasilache             loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
3912bc4c3e9SNicolas Vasilache         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
3922bc4c3e9SNicolas Vasilache       }
3932bc4c3e9SNicolas Vasilache 
3948e663039SKunwar Grover       result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
3958e663039SKunwar Grover                                                  result, i);
3962bc4c3e9SNicolas Vasilache     }
3972bc4c3e9SNicolas Vasilache 
3982bc4c3e9SNicolas Vasilache     rewriter.replaceOp(rootOp, result);
3992bc4c3e9SNicolas Vasilache     return success();
4002bc4c3e9SNicolas Vasilache   }
4012bc4c3e9SNicolas Vasilache };
4022bc4c3e9SNicolas Vasilache 
4032bc4c3e9SNicolas Vasilache /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
4042bc4c3e9SNicolas Vasilache /// form with both a single parallel and reduction dimension.
4052bc4c3e9SNicolas Vasilache /// This is achieved with a simple vector.shape_cast that inserts a leading 1.
4062bc4c3e9SNicolas Vasilache /// The case with a single parallel dimension is a noop and folds away
4072bc4c3e9SNicolas Vasilache /// separately.
4082bc4c3e9SNicolas Vasilache struct OneDimMultiReductionToTwoDim
4092bc4c3e9SNicolas Vasilache     : public OpRewritePattern<vector::MultiDimReductionOp> {
4102bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
4112bc4c3e9SNicolas Vasilache 
4122bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
4132bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
4142bc4c3e9SNicolas Vasilache     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
4152bc4c3e9SNicolas Vasilache     // Rank-1 or bail.
4162bc4c3e9SNicolas Vasilache     if (srcRank != 1)
4172bc4c3e9SNicolas Vasilache       return failure();
4182bc4c3e9SNicolas Vasilache 
4192bc4c3e9SNicolas Vasilache     // Vector mask setup.
4202bc4c3e9SNicolas Vasilache     OpBuilder::InsertionGuard guard(rewriter);
4212bc4c3e9SNicolas Vasilache     auto maskableOp =
4222bc4c3e9SNicolas Vasilache         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
4232bc4c3e9SNicolas Vasilache     Operation *rootOp;
4242bc4c3e9SNicolas Vasilache     Value mask;
4252bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
4262bc4c3e9SNicolas Vasilache       rewriter.setInsertionPoint(maskableOp.getMaskingOp());
4272bc4c3e9SNicolas Vasilache       rootOp = maskableOp.getMaskingOp();
4282bc4c3e9SNicolas Vasilache       mask = maskableOp.getMaskingOp().getMask();
4292bc4c3e9SNicolas Vasilache     } else {
4302bc4c3e9SNicolas Vasilache       rootOp = multiReductionOp;
4312bc4c3e9SNicolas Vasilache     }
4322bc4c3e9SNicolas Vasilache 
4332bc4c3e9SNicolas Vasilache     auto loc = multiReductionOp.getLoc();
4342bc4c3e9SNicolas Vasilache     auto srcVectorType = multiReductionOp.getSourceVectorType();
4352bc4c3e9SNicolas Vasilache     auto srcShape = srcVectorType.getShape();
436dbcc4549SZhaoshi Zheng     auto castedType = VectorType::get(
437dbcc4549SZhaoshi Zheng         ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
438dbcc4549SZhaoshi Zheng         ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
439dbcc4549SZhaoshi Zheng 
4402bc4c3e9SNicolas Vasilache     auto accType =
4412bc4c3e9SNicolas Vasilache         VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
442c1fa60b4STres Popp     assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
4432bc4c3e9SNicolas Vasilache            "multi_reduction with a single dimension expects a scalar result");
4442bc4c3e9SNicolas Vasilache 
4452bc4c3e9SNicolas Vasilache     // If the unique dim is reduced and we insert a parallel in front, we need a
4462bc4c3e9SNicolas Vasilache     // {false, true} mask.
4472bc4c3e9SNicolas Vasilache     SmallVector<bool, 2> reductionMask{false, true};
4482bc4c3e9SNicolas Vasilache 
4492bc4c3e9SNicolas Vasilache     /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
4502bc4c3e9SNicolas Vasilache     Value cast = rewriter.create<vector::ShapeCastOp>(
4512bc4c3e9SNicolas Vasilache         loc, castedType, multiReductionOp.getSource());
4522bc4c3e9SNicolas Vasilache     Value castAcc = rewriter.create<vector::BroadcastOp>(
4532bc4c3e9SNicolas Vasilache         loc, accType, multiReductionOp.getAcc());
4542bc4c3e9SNicolas Vasilache     Value castMask;
4552bc4c3e9SNicolas Vasilache     if (maskableOp.isMasked()) {
456dbcc4549SZhaoshi Zheng       auto maskType = llvm::cast<VectorType>(mask.getType());
457dbcc4549SZhaoshi Zheng       auto castMaskType = VectorType::get(
458dbcc4549SZhaoshi Zheng           ArrayRef<int64_t>{1, maskType.getShape().back()},
459dbcc4549SZhaoshi Zheng           maskType.getElementType(),
460dbcc4549SZhaoshi Zheng           ArrayRef<bool>{false, maskType.getScalableDims().back()});
4612bc4c3e9SNicolas Vasilache       castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
4622bc4c3e9SNicolas Vasilache     }
4632bc4c3e9SNicolas Vasilache 
4642bc4c3e9SNicolas Vasilache     Operation *newOp = rewriter.create<vector::MultiDimReductionOp>(
4652bc4c3e9SNicolas Vasilache         loc, cast, castAcc, reductionMask, multiReductionOp.getKind());
4662bc4c3e9SNicolas Vasilache     newOp = vector::maskOperation(rewriter, newOp, castMask);
4672bc4c3e9SNicolas Vasilache 
4682bc4c3e9SNicolas Vasilache     rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
4692bc4c3e9SNicolas Vasilache                                                    ArrayRef<int64_t>{0});
4702bc4c3e9SNicolas Vasilache     return success();
4712bc4c3e9SNicolas Vasilache   }
4722bc4c3e9SNicolas Vasilache };
4738d6469b0Sxiaoleis-nv 
4748d6469b0Sxiaoleis-nv struct LowerVectorMultiReductionPass
4758d6469b0Sxiaoleis-nv     : public vector::impl::LowerVectorMultiReductionBase<
4768d6469b0Sxiaoleis-nv           LowerVectorMultiReductionPass> {
4778d6469b0Sxiaoleis-nv   LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) {
4788d6469b0Sxiaoleis-nv     this->loweringStrategy = option;
4798d6469b0Sxiaoleis-nv   }
4808d6469b0Sxiaoleis-nv 
4818d6469b0Sxiaoleis-nv   void runOnOperation() override {
4828d6469b0Sxiaoleis-nv     Operation *op = getOperation();
4838d6469b0Sxiaoleis-nv     MLIRContext *context = op->getContext();
4848d6469b0Sxiaoleis-nv 
4858d6469b0Sxiaoleis-nv     RewritePatternSet loweringPatterns(context);
4868d6469b0Sxiaoleis-nv     populateVectorMultiReductionLoweringPatterns(loweringPatterns,
4878d6469b0Sxiaoleis-nv                                                  this->loweringStrategy);
4888d6469b0Sxiaoleis-nv 
489*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
4908d6469b0Sxiaoleis-nv       signalPassFailure();
4918d6469b0Sxiaoleis-nv   }
4928d6469b0Sxiaoleis-nv 
4938d6469b0Sxiaoleis-nv   void getDependentDialects(DialectRegistry &registry) const override {
4948d6469b0Sxiaoleis-nv     registry.insert<vector::VectorDialect>();
4958d6469b0Sxiaoleis-nv   }
4968d6469b0Sxiaoleis-nv };
4978d6469b0Sxiaoleis-nv 
4982bc4c3e9SNicolas Vasilache } // namespace
4992bc4c3e9SNicolas Vasilache 
5002bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorMultiReductionLoweringPatterns(
5012bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, VectorMultiReductionLowering options,
5022bc4c3e9SNicolas Vasilache     PatternBenefit benefit) {
5032bc4c3e9SNicolas Vasilache   patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
5042bc4c3e9SNicolas Vasilache       patterns.getContext(), options, benefit);
5052bc4c3e9SNicolas Vasilache   patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
5062bc4c3e9SNicolas Vasilache   if (options == VectorMultiReductionLowering ::InnerReduction)
5072bc4c3e9SNicolas Vasilache     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
5082bc4c3e9SNicolas Vasilache                                                   benefit);
5092bc4c3e9SNicolas Vasilache   else
5102bc4c3e9SNicolas Vasilache     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(),
5112bc4c3e9SNicolas Vasilache                                                     benefit);
5122bc4c3e9SNicolas Vasilache }
5138d6469b0Sxiaoleis-nv 
5148d6469b0Sxiaoleis-nv std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
5158d6469b0Sxiaoleis-nv     vector::VectorMultiReductionLowering option) {
5168d6469b0Sxiaoleis-nv   return std::make_unique<LowerVectorMultiReductionPass>(option);
5178d6469b0Sxiaoleis-nv }
518