12bc4c3e9SNicolas Vasilache //===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// 22bc4c3e9SNicolas Vasilache // 32bc4c3e9SNicolas Vasilache /// Part of the LLVM Project, under the Apache License v2.0 with LLVM 42bc4c3e9SNicolas Vasilache /// Exceptions. See https://llvm.org/LICENSE.txt for license information. 52bc4c3e9SNicolas Vasilache /// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62bc4c3e9SNicolas Vasilache // 72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 82bc4c3e9SNicolas Vasilache // 92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the 102bc4c3e9SNicolas Vasilache // 'vector.multi_reduction' operation. 112bc4c3e9SNicolas Vasilache // 122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 132bc4c3e9SNicolas Vasilache 142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h" 158d6469b0Sxiaoleis-nv #include "mlir/Dialect/Func/IR/FuncOps.h" 162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 178d6469b0Sxiaoleis-nv #include "mlir/Dialect/Vector/Transforms/Passes.h" 182bc4c3e9SNicolas Vasilache #include "mlir/IR/Builders.h" 192bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h" 208d6469b0Sxiaoleis-nv #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 218d6469b0Sxiaoleis-nv 228d6469b0Sxiaoleis-nv namespace mlir { 238d6469b0Sxiaoleis-nv namespace vector { 248d6469b0Sxiaoleis-nv #define GEN_PASS_DEF_LOWERVECTORMULTIREDUCTION 258d6469b0Sxiaoleis-nv #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" 268d6469b0Sxiaoleis-nv } // namespace vector 278d6469b0Sxiaoleis-nv } // namespace mlir 282bc4c3e9SNicolas Vasilache 292bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-multi-reduction" 302bc4c3e9SNicolas Vasilache 312bc4c3e9SNicolas Vasilache using namespace mlir; 322bc4c3e9SNicolas Vasilache 332bc4c3e9SNicolas Vasilache namespace { 342bc4c3e9SNicolas Vasilache /// This file implements the following transformations as composable atomic 352bc4c3e9SNicolas Vasilache /// patterns. 362bc4c3e9SNicolas Vasilache 372bc4c3e9SNicolas Vasilache /// Converts vector.multi_reduction into inner-most/outer-most reduction form 382bc4c3e9SNicolas Vasilache /// by using vector.transpose 392bc4c3e9SNicolas Vasilache class InnerOuterDimReductionConversion 402bc4c3e9SNicolas Vasilache : public OpRewritePattern<vector::MultiDimReductionOp> { 412bc4c3e9SNicolas Vasilache public: 422bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 432bc4c3e9SNicolas Vasilache 442bc4c3e9SNicolas Vasilache explicit InnerOuterDimReductionConversion( 452bc4c3e9SNicolas Vasilache MLIRContext *context, vector::VectorMultiReductionLowering options, 462bc4c3e9SNicolas Vasilache PatternBenefit benefit = 1) 472bc4c3e9SNicolas Vasilache : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), 482bc4c3e9SNicolas Vasilache useInnerDimsForReduction( 492bc4c3e9SNicolas Vasilache options == vector::VectorMultiReductionLowering::InnerReduction) {} 502bc4c3e9SNicolas Vasilache 512bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 522bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 532bc4c3e9SNicolas Vasilache // Vector mask setup. 542bc4c3e9SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 552bc4c3e9SNicolas Vasilache auto maskableOp = 562bc4c3e9SNicolas Vasilache cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 572bc4c3e9SNicolas Vasilache Operation *rootOp; 582bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 592bc4c3e9SNicolas Vasilache rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 602bc4c3e9SNicolas Vasilache rootOp = maskableOp.getMaskingOp(); 612bc4c3e9SNicolas Vasilache } else { 622bc4c3e9SNicolas Vasilache rootOp = multiReductionOp; 632bc4c3e9SNicolas Vasilache } 642bc4c3e9SNicolas Vasilache 652bc4c3e9SNicolas Vasilache auto src = multiReductionOp.getSource(); 662bc4c3e9SNicolas Vasilache auto loc = multiReductionOp.getLoc(); 672bc4c3e9SNicolas Vasilache auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 682bc4c3e9SNicolas Vasilache 692bc4c3e9SNicolas Vasilache // Separate reduction and parallel dims 705f26497dSBenjamin Maxwell ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims(); 712bc4c3e9SNicolas Vasilache llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(), 722bc4c3e9SNicolas Vasilache reductionDims.end()); 732bc4c3e9SNicolas Vasilache int64_t reductionSize = reductionDims.size(); 742bc4c3e9SNicolas Vasilache SmallVector<int64_t, 4> parallelDims; 752bc4c3e9SNicolas Vasilache for (int64_t i = 0; i < srcRank; ++i) 762bc4c3e9SNicolas Vasilache if (!reductionDimsSet.contains(i)) 772bc4c3e9SNicolas Vasilache parallelDims.push_back(i); 782bc4c3e9SNicolas Vasilache 792bc4c3e9SNicolas Vasilache // Add transpose only if inner-most/outer-most dimensions are not parallel 802bc4c3e9SNicolas Vasilache // and there are parallel dims. 812bc4c3e9SNicolas Vasilache if (parallelDims.empty()) 822bc4c3e9SNicolas Vasilache return failure(); 832bc4c3e9SNicolas Vasilache if (useInnerDimsForReduction && 842bc4c3e9SNicolas Vasilache (parallelDims == 852bc4c3e9SNicolas Vasilache llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size())))) 862bc4c3e9SNicolas Vasilache return failure(); 872bc4c3e9SNicolas Vasilache 882bc4c3e9SNicolas Vasilache if (!useInnerDimsForReduction && 892bc4c3e9SNicolas Vasilache (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>( 902bc4c3e9SNicolas Vasilache reductionDims.size(), 912bc4c3e9SNicolas Vasilache parallelDims.size() + reductionDims.size())))) 922bc4c3e9SNicolas Vasilache return failure(); 932bc4c3e9SNicolas Vasilache 942bc4c3e9SNicolas Vasilache SmallVector<int64_t, 4> indices; 952bc4c3e9SNicolas Vasilache if (useInnerDimsForReduction) { 962bc4c3e9SNicolas Vasilache indices.append(parallelDims.begin(), parallelDims.end()); 972bc4c3e9SNicolas Vasilache indices.append(reductionDims.begin(), reductionDims.end()); 982bc4c3e9SNicolas Vasilache } else { 992bc4c3e9SNicolas Vasilache indices.append(reductionDims.begin(), reductionDims.end()); 1002bc4c3e9SNicolas Vasilache indices.append(parallelDims.begin(), parallelDims.end()); 1012bc4c3e9SNicolas Vasilache } 1022bc4c3e9SNicolas Vasilache 1032bc4c3e9SNicolas Vasilache // If masked, transpose the original mask. 1042bc4c3e9SNicolas Vasilache Value transposedMask; 1052bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 1062bc4c3e9SNicolas Vasilache transposedMask = rewriter.create<vector::TransposeOp>( 1072bc4c3e9SNicolas Vasilache loc, maskableOp.getMaskingOp().getMask(), indices); 1082bc4c3e9SNicolas Vasilache } 1092bc4c3e9SNicolas Vasilache 1102bc4c3e9SNicolas Vasilache // Transpose reduction source. 1112bc4c3e9SNicolas Vasilache auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices); 1122bc4c3e9SNicolas Vasilache SmallVector<bool> reductionMask(srcRank, false); 1132bc4c3e9SNicolas Vasilache for (int i = 0; i < reductionSize; ++i) { 1142bc4c3e9SNicolas Vasilache if (useInnerDimsForReduction) 1152bc4c3e9SNicolas Vasilache reductionMask[srcRank - i - 1] = true; 1162bc4c3e9SNicolas Vasilache else 1172bc4c3e9SNicolas Vasilache reductionMask[i] = true; 1182bc4c3e9SNicolas Vasilache } 1192bc4c3e9SNicolas Vasilache 1202bc4c3e9SNicolas Vasilache Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>( 1212bc4c3e9SNicolas Vasilache multiReductionOp.getLoc(), transposeOp.getResult(), 1222bc4c3e9SNicolas Vasilache multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); 1232bc4c3e9SNicolas Vasilache newMultiRedOp = 1242bc4c3e9SNicolas Vasilache mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); 1252bc4c3e9SNicolas Vasilache 1262bc4c3e9SNicolas Vasilache rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0)); 1272bc4c3e9SNicolas Vasilache return success(); 1282bc4c3e9SNicolas Vasilache } 1292bc4c3e9SNicolas Vasilache 1302bc4c3e9SNicolas Vasilache private: 1312bc4c3e9SNicolas Vasilache const bool useInnerDimsForReduction; 1322bc4c3e9SNicolas Vasilache }; 1332bc4c3e9SNicolas Vasilache 1342bc4c3e9SNicolas Vasilache /// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction 1352bc4c3e9SNicolas Vasilache /// dimensions are either inner most or outer most. 1362bc4c3e9SNicolas Vasilache class ReduceMultiDimReductionRank 1372bc4c3e9SNicolas Vasilache : public OpRewritePattern<vector::MultiDimReductionOp> { 1382bc4c3e9SNicolas Vasilache public: 1392bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 1402bc4c3e9SNicolas Vasilache 1412bc4c3e9SNicolas Vasilache explicit ReduceMultiDimReductionRank( 1422bc4c3e9SNicolas Vasilache MLIRContext *context, vector::VectorMultiReductionLowering options, 1432bc4c3e9SNicolas Vasilache PatternBenefit benefit = 1) 1442bc4c3e9SNicolas Vasilache : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), 1452bc4c3e9SNicolas Vasilache useInnerDimsForReduction( 1462bc4c3e9SNicolas Vasilache options == vector::VectorMultiReductionLowering::InnerReduction) {} 1472bc4c3e9SNicolas Vasilache 1482bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 1492bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 1502bc4c3e9SNicolas Vasilache // Vector mask setup. 1512bc4c3e9SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 1522bc4c3e9SNicolas Vasilache auto maskableOp = 1532bc4c3e9SNicolas Vasilache cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 1542bc4c3e9SNicolas Vasilache Operation *rootOp; 1552bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 1562bc4c3e9SNicolas Vasilache rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 1572bc4c3e9SNicolas Vasilache rootOp = maskableOp.getMaskingOp(); 1582bc4c3e9SNicolas Vasilache } else { 1592bc4c3e9SNicolas Vasilache rootOp = multiReductionOp; 1602bc4c3e9SNicolas Vasilache } 1612bc4c3e9SNicolas Vasilache 1622bc4c3e9SNicolas Vasilache auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 1632bc4c3e9SNicolas Vasilache auto srcShape = multiReductionOp.getSourceVectorType().getShape(); 1645c581720SAndrzej Warzynski auto srcScalableDims = 1655c581720SAndrzej Warzynski multiReductionOp.getSourceVectorType().getScalableDims(); 1662bc4c3e9SNicolas Vasilache auto loc = multiReductionOp.getLoc(); 1672bc4c3e9SNicolas Vasilache 1682bc4c3e9SNicolas Vasilache // If rank less than 2, nothing to do. 1692bc4c3e9SNicolas Vasilache if (srcRank < 2) 1702bc4c3e9SNicolas Vasilache return failure(); 1712bc4c3e9SNicolas Vasilache 1725c581720SAndrzej Warzynski // Allow only 1 scalable dimensions. Otherwise we could end-up with e.g. 1735c581720SAndrzej Warzynski // `vscale * vscale` that's currently not modelled. 1745c581720SAndrzej Warzynski if (llvm::count(srcScalableDims, true) > 1) 1755c581720SAndrzej Warzynski return failure(); 1765c581720SAndrzej Warzynski 1772bc4c3e9SNicolas Vasilache // If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail. 1782bc4c3e9SNicolas Vasilache SmallVector<bool> reductionMask = multiReductionOp.getReductionMask(); 1792bc4c3e9SNicolas Vasilache if (srcRank == 2 && reductionMask.front() != reductionMask.back()) 1802bc4c3e9SNicolas Vasilache return failure(); 1812bc4c3e9SNicolas Vasilache 1822bc4c3e9SNicolas Vasilache // 1. Separate reduction and parallel dims. 1832bc4c3e9SNicolas Vasilache SmallVector<int64_t, 4> parallelDims, parallelShapes; 1845c581720SAndrzej Warzynski SmallVector<bool, 4> parallelScalableDims; 1852bc4c3e9SNicolas Vasilache SmallVector<int64_t, 4> reductionDims, reductionShapes; 1865c581720SAndrzej Warzynski bool isReductionDimScalable = false; 1872bc4c3e9SNicolas Vasilache for (const auto &it : llvm::enumerate(reductionMask)) { 1882bc4c3e9SNicolas Vasilache int64_t i = it.index(); 1892bc4c3e9SNicolas Vasilache bool isReduction = it.value(); 1902bc4c3e9SNicolas Vasilache if (isReduction) { 1912bc4c3e9SNicolas Vasilache reductionDims.push_back(i); 1922bc4c3e9SNicolas Vasilache reductionShapes.push_back(srcShape[i]); 1935c581720SAndrzej Warzynski isReductionDimScalable |= srcScalableDims[i]; 1942bc4c3e9SNicolas Vasilache } else { 1952bc4c3e9SNicolas Vasilache parallelDims.push_back(i); 1962bc4c3e9SNicolas Vasilache parallelShapes.push_back(srcShape[i]); 1975c581720SAndrzej Warzynski parallelScalableDims.push_back(srcScalableDims[i]); 1982bc4c3e9SNicolas Vasilache } 1992bc4c3e9SNicolas Vasilache } 2002bc4c3e9SNicolas Vasilache 2012bc4c3e9SNicolas Vasilache // 2. Compute flattened parallel and reduction sizes. 2022bc4c3e9SNicolas Vasilache int flattenedParallelDim = 0; 2032bc4c3e9SNicolas Vasilache int flattenedReductionDim = 0; 2042bc4c3e9SNicolas Vasilache if (!parallelShapes.empty()) { 2052bc4c3e9SNicolas Vasilache flattenedParallelDim = 1; 2062bc4c3e9SNicolas Vasilache for (auto d : parallelShapes) 2072bc4c3e9SNicolas Vasilache flattenedParallelDim *= d; 2082bc4c3e9SNicolas Vasilache } 2092bc4c3e9SNicolas Vasilache if (!reductionShapes.empty()) { 2102bc4c3e9SNicolas Vasilache flattenedReductionDim = 1; 2112bc4c3e9SNicolas Vasilache for (auto d : reductionShapes) 2122bc4c3e9SNicolas Vasilache flattenedReductionDim *= d; 2132bc4c3e9SNicolas Vasilache } 2142bc4c3e9SNicolas Vasilache // We must at least have some parallel or some reduction. 2152bc4c3e9SNicolas Vasilache assert((flattenedParallelDim || flattenedReductionDim) && 2162bc4c3e9SNicolas Vasilache "expected at least one parallel or reduction dim"); 2172bc4c3e9SNicolas Vasilache 2182bc4c3e9SNicolas Vasilache // 3. Fail if reduction/parallel dims are not contiguous. 2192bc4c3e9SNicolas Vasilache // Check parallelDims are exactly [0 .. size). 2202bc4c3e9SNicolas Vasilache int64_t counter = 0; 2212bc4c3e9SNicolas Vasilache if (useInnerDimsForReduction && 2222bc4c3e9SNicolas Vasilache llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 2232bc4c3e9SNicolas Vasilache return failure(); 2242bc4c3e9SNicolas Vasilache // Check parallelDims are exactly {reductionDims.size()} + [0 .. size). 2252bc4c3e9SNicolas Vasilache counter = reductionDims.size(); 2262bc4c3e9SNicolas Vasilache if (!useInnerDimsForReduction && 2272bc4c3e9SNicolas Vasilache llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; })) 2282bc4c3e9SNicolas Vasilache return failure(); 2292bc4c3e9SNicolas Vasilache 2302bc4c3e9SNicolas Vasilache // 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into 2312bc4c3e9SNicolas Vasilache // a single parallel (resp. reduction) dim. 2322bc4c3e9SNicolas Vasilache SmallVector<bool, 2> mask; 2335c581720SAndrzej Warzynski SmallVector<bool, 2> scalableDims; 2342bc4c3e9SNicolas Vasilache SmallVector<int64_t, 2> vectorShape; 2355c581720SAndrzej Warzynski bool isParallelDimScalable = llvm::is_contained(parallelScalableDims, true); 2362bc4c3e9SNicolas Vasilache if (flattenedParallelDim) { 2372bc4c3e9SNicolas Vasilache mask.push_back(false); 2382bc4c3e9SNicolas Vasilache vectorShape.push_back(flattenedParallelDim); 2395c581720SAndrzej Warzynski scalableDims.push_back(isParallelDimScalable); 2402bc4c3e9SNicolas Vasilache } 2412bc4c3e9SNicolas Vasilache if (flattenedReductionDim) { 2422bc4c3e9SNicolas Vasilache mask.push_back(true); 2432bc4c3e9SNicolas Vasilache vectorShape.push_back(flattenedReductionDim); 2445c581720SAndrzej Warzynski scalableDims.push_back(isReductionDimScalable); 2452bc4c3e9SNicolas Vasilache } 2462bc4c3e9SNicolas Vasilache if (!useInnerDimsForReduction && vectorShape.size() == 2) { 2472bc4c3e9SNicolas Vasilache std::swap(mask.front(), mask.back()); 2482bc4c3e9SNicolas Vasilache std::swap(vectorShape.front(), vectorShape.back()); 2495c581720SAndrzej Warzynski std::swap(scalableDims.front(), scalableDims.back()); 2502bc4c3e9SNicolas Vasilache } 2512bc4c3e9SNicolas Vasilache 2522bc4c3e9SNicolas Vasilache Value newVectorMask; 2532bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 2542bc4c3e9SNicolas Vasilache Value vectorMask = maskableOp.getMaskingOp().getMask(); 2552bc4c3e9SNicolas Vasilache auto maskCastedType = VectorType::get( 2562bc4c3e9SNicolas Vasilache vectorShape, 257c1fa60b4STres Popp llvm::cast<VectorType>(vectorMask.getType()).getElementType()); 2582bc4c3e9SNicolas Vasilache newVectorMask = 2592bc4c3e9SNicolas Vasilache rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask); 2602bc4c3e9SNicolas Vasilache } 2612bc4c3e9SNicolas Vasilache 2622bc4c3e9SNicolas Vasilache auto castedType = VectorType::get( 2635c581720SAndrzej Warzynski vectorShape, multiReductionOp.getSourceVectorType().getElementType(), 2645c581720SAndrzej Warzynski scalableDims); 2652bc4c3e9SNicolas Vasilache Value cast = rewriter.create<vector::ShapeCastOp>( 2662bc4c3e9SNicolas Vasilache loc, castedType, multiReductionOp.getSource()); 2672bc4c3e9SNicolas Vasilache 2682bc4c3e9SNicolas Vasilache Value acc = multiReductionOp.getAcc(); 2692bc4c3e9SNicolas Vasilache if (flattenedParallelDim) { 2702bc4c3e9SNicolas Vasilache auto accType = VectorType::get( 2712bc4c3e9SNicolas Vasilache {flattenedParallelDim}, 2725c581720SAndrzej Warzynski multiReductionOp.getSourceVectorType().getElementType(), 2735c581720SAndrzej Warzynski /*scalableDims=*/{isParallelDimScalable}); 2742bc4c3e9SNicolas Vasilache acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc); 2752bc4c3e9SNicolas Vasilache } 2762bc4c3e9SNicolas Vasilache // 6. Creates the flattened form of vector.multi_reduction with inner/outer 2772bc4c3e9SNicolas Vasilache // most dim as reduction. 2782bc4c3e9SNicolas Vasilache Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>( 2792bc4c3e9SNicolas Vasilache loc, cast, acc, mask, multiReductionOp.getKind()); 2802bc4c3e9SNicolas Vasilache newMultiDimRedOp = 2812bc4c3e9SNicolas Vasilache mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); 2822bc4c3e9SNicolas Vasilache 2832bc4c3e9SNicolas Vasilache // 7. If there are no parallel shapes, the result is a scalar. 2842bc4c3e9SNicolas Vasilache // TODO: support 0-d vectors when available. 2852bc4c3e9SNicolas Vasilache if (parallelShapes.empty()) { 2862bc4c3e9SNicolas Vasilache rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0)); 2872bc4c3e9SNicolas Vasilache return success(); 2882bc4c3e9SNicolas Vasilache } 2892bc4c3e9SNicolas Vasilache 2902bc4c3e9SNicolas Vasilache // 8. Creates shape cast for the output n-D -> 2-D. 2912bc4c3e9SNicolas Vasilache VectorType outputCastedType = VectorType::get( 2925c581720SAndrzej Warzynski parallelShapes, multiReductionOp.getSourceVectorType().getElementType(), 2935c581720SAndrzej Warzynski parallelScalableDims); 2942bc4c3e9SNicolas Vasilache rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 2952bc4c3e9SNicolas Vasilache rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); 2962bc4c3e9SNicolas Vasilache return success(); 2972bc4c3e9SNicolas Vasilache } 2982bc4c3e9SNicolas Vasilache 2992bc4c3e9SNicolas Vasilache private: 3002bc4c3e9SNicolas Vasilache const bool useInnerDimsForReduction; 3012bc4c3e9SNicolas Vasilache }; 3022bc4c3e9SNicolas Vasilache 3032bc4c3e9SNicolas Vasilache /// Unrolls vector.multi_reduction with outermost reductions 3042bc4c3e9SNicolas Vasilache /// and combines results 3052bc4c3e9SNicolas Vasilache struct TwoDimMultiReductionToElementWise 3062bc4c3e9SNicolas Vasilache : public OpRewritePattern<vector::MultiDimReductionOp> { 3072bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 3082bc4c3e9SNicolas Vasilache 3092bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 3102bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 3112bc4c3e9SNicolas Vasilache auto maskableOp = 3122bc4c3e9SNicolas Vasilache cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 3132bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) 3142bc4c3e9SNicolas Vasilache // TODO: Support masking. 3152bc4c3e9SNicolas Vasilache return failure(); 3162bc4c3e9SNicolas Vasilache 3172bc4c3e9SNicolas Vasilache auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 3182bc4c3e9SNicolas Vasilache // Rank-2 ["parallel", "reduce"] or bail. 3192bc4c3e9SNicolas Vasilache if (srcRank != 2) 3202bc4c3e9SNicolas Vasilache return failure(); 3212bc4c3e9SNicolas Vasilache 3222bc4c3e9SNicolas Vasilache if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0)) 3232bc4c3e9SNicolas Vasilache return failure(); 3242bc4c3e9SNicolas Vasilache 3252bc4c3e9SNicolas Vasilache auto loc = multiReductionOp.getLoc(); 3262bc4c3e9SNicolas Vasilache ArrayRef<int64_t> srcShape = 3272bc4c3e9SNicolas Vasilache multiReductionOp.getSourceVectorType().getShape(); 3282bc4c3e9SNicolas Vasilache 3292bc4c3e9SNicolas Vasilache Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); 3302bc4c3e9SNicolas Vasilache if (!elementType.isIntOrIndexOrFloat()) 3312bc4c3e9SNicolas Vasilache return failure(); 3322bc4c3e9SNicolas Vasilache 3332bc4c3e9SNicolas Vasilache Value result = multiReductionOp.getAcc(); 3342bc4c3e9SNicolas Vasilache for (int64_t i = 0; i < srcShape[0]; i++) { 3352bc4c3e9SNicolas Vasilache auto operand = rewriter.create<vector::ExtractOp>( 3362bc4c3e9SNicolas Vasilache loc, multiReductionOp.getSource(), i); 3372bc4c3e9SNicolas Vasilache result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), 3382bc4c3e9SNicolas Vasilache operand, result); 3392bc4c3e9SNicolas Vasilache } 3402bc4c3e9SNicolas Vasilache 3412bc4c3e9SNicolas Vasilache rewriter.replaceOp(multiReductionOp, result); 3422bc4c3e9SNicolas Vasilache return success(); 3432bc4c3e9SNicolas Vasilache } 3442bc4c3e9SNicolas Vasilache }; 3452bc4c3e9SNicolas Vasilache 3462bc4c3e9SNicolas Vasilache /// Converts 2d vector.multi_reduction with inner most reduction dimension into 3472bc4c3e9SNicolas Vasilache /// a sequence of vector.reduction ops. 3482bc4c3e9SNicolas Vasilache struct TwoDimMultiReductionToReduction 3492bc4c3e9SNicolas Vasilache : public OpRewritePattern<vector::MultiDimReductionOp> { 3502bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 3512bc4c3e9SNicolas Vasilache 3522bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 3532bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 3542bc4c3e9SNicolas Vasilache auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 3552bc4c3e9SNicolas Vasilache if (srcRank != 2) 3562bc4c3e9SNicolas Vasilache return failure(); 3572bc4c3e9SNicolas Vasilache 3582bc4c3e9SNicolas Vasilache if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) 3592bc4c3e9SNicolas Vasilache return failure(); 3602bc4c3e9SNicolas Vasilache 3612bc4c3e9SNicolas Vasilache // Vector mask setup. 3622bc4c3e9SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 3632bc4c3e9SNicolas Vasilache auto maskableOp = 3642bc4c3e9SNicolas Vasilache cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 3652bc4c3e9SNicolas Vasilache Operation *rootOp; 3662bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 3672bc4c3e9SNicolas Vasilache rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 3682bc4c3e9SNicolas Vasilache rootOp = maskableOp.getMaskingOp(); 3692bc4c3e9SNicolas Vasilache } else { 3702bc4c3e9SNicolas Vasilache rootOp = multiReductionOp; 3712bc4c3e9SNicolas Vasilache } 3722bc4c3e9SNicolas Vasilache 3732bc4c3e9SNicolas Vasilache auto loc = multiReductionOp.getLoc(); 3742bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 3752bc4c3e9SNicolas Vasilache loc, multiReductionOp.getDestType(), 3762bc4c3e9SNicolas Vasilache rewriter.getZeroAttr(multiReductionOp.getDestType())); 3772bc4c3e9SNicolas Vasilache int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; 3782bc4c3e9SNicolas Vasilache 3792bc4c3e9SNicolas Vasilache for (int i = 0; i < outerDim; ++i) { 3802bc4c3e9SNicolas Vasilache auto v = rewriter.create<vector::ExtractOp>( 3812bc4c3e9SNicolas Vasilache loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i}); 3822bc4c3e9SNicolas Vasilache auto acc = rewriter.create<vector::ExtractOp>( 3832bc4c3e9SNicolas Vasilache loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i}); 3842bc4c3e9SNicolas Vasilache Operation *reductionOp = rewriter.create<vector::ReductionOp>( 3852bc4c3e9SNicolas Vasilache loc, multiReductionOp.getKind(), v, acc); 3862bc4c3e9SNicolas Vasilache 3872bc4c3e9SNicolas Vasilache // If masked, slice the mask and mask the new reduction operation. 3882bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 3892bc4c3e9SNicolas Vasilache Value mask = rewriter.create<vector::ExtractOp>( 3902bc4c3e9SNicolas Vasilache loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i}); 3912bc4c3e9SNicolas Vasilache reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); 3922bc4c3e9SNicolas Vasilache } 3932bc4c3e9SNicolas Vasilache 3948e663039SKunwar Grover result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0), 3958e663039SKunwar Grover result, i); 3962bc4c3e9SNicolas Vasilache } 3972bc4c3e9SNicolas Vasilache 3982bc4c3e9SNicolas Vasilache rewriter.replaceOp(rootOp, result); 3992bc4c3e9SNicolas Vasilache return success(); 4002bc4c3e9SNicolas Vasilache } 4012bc4c3e9SNicolas Vasilache }; 4022bc4c3e9SNicolas Vasilache 4032bc4c3e9SNicolas Vasilache /// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d 4042bc4c3e9SNicolas Vasilache /// form with both a single parallel and reduction dimension. 4052bc4c3e9SNicolas Vasilache /// This is achieved with a simple vector.shape_cast that inserts a leading 1. 4062bc4c3e9SNicolas Vasilache /// The case with a single parallel dimension is a noop and folds away 4072bc4c3e9SNicolas Vasilache /// separately. 4082bc4c3e9SNicolas Vasilache struct OneDimMultiReductionToTwoDim 4092bc4c3e9SNicolas Vasilache : public OpRewritePattern<vector::MultiDimReductionOp> { 4102bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 4112bc4c3e9SNicolas Vasilache 4122bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, 4132bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 4142bc4c3e9SNicolas Vasilache auto srcRank = multiReductionOp.getSourceVectorType().getRank(); 4152bc4c3e9SNicolas Vasilache // Rank-1 or bail. 4162bc4c3e9SNicolas Vasilache if (srcRank != 1) 4172bc4c3e9SNicolas Vasilache return failure(); 4182bc4c3e9SNicolas Vasilache 4192bc4c3e9SNicolas Vasilache // Vector mask setup. 4202bc4c3e9SNicolas Vasilache OpBuilder::InsertionGuard guard(rewriter); 4212bc4c3e9SNicolas Vasilache auto maskableOp = 4222bc4c3e9SNicolas Vasilache cast<vector::MaskableOpInterface>(multiReductionOp.getOperation()); 4232bc4c3e9SNicolas Vasilache Operation *rootOp; 4242bc4c3e9SNicolas Vasilache Value mask; 4252bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 4262bc4c3e9SNicolas Vasilache rewriter.setInsertionPoint(maskableOp.getMaskingOp()); 4272bc4c3e9SNicolas Vasilache rootOp = maskableOp.getMaskingOp(); 4282bc4c3e9SNicolas Vasilache mask = maskableOp.getMaskingOp().getMask(); 4292bc4c3e9SNicolas Vasilache } else { 4302bc4c3e9SNicolas Vasilache rootOp = multiReductionOp; 4312bc4c3e9SNicolas Vasilache } 4322bc4c3e9SNicolas Vasilache 4332bc4c3e9SNicolas Vasilache auto loc = multiReductionOp.getLoc(); 4342bc4c3e9SNicolas Vasilache auto srcVectorType = multiReductionOp.getSourceVectorType(); 4352bc4c3e9SNicolas Vasilache auto srcShape = srcVectorType.getShape(); 436dbcc4549SZhaoshi Zheng auto castedType = VectorType::get( 437dbcc4549SZhaoshi Zheng ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(), 438dbcc4549SZhaoshi Zheng ArrayRef<bool>{false, srcVectorType.getScalableDims().back()}); 439dbcc4549SZhaoshi Zheng 4402bc4c3e9SNicolas Vasilache auto accType = 4412bc4c3e9SNicolas Vasilache VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); 442c1fa60b4STres Popp assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) && 4432bc4c3e9SNicolas Vasilache "multi_reduction with a single dimension expects a scalar result"); 4442bc4c3e9SNicolas Vasilache 4452bc4c3e9SNicolas Vasilache // If the unique dim is reduced and we insert a parallel in front, we need a 4462bc4c3e9SNicolas Vasilache // {false, true} mask. 4472bc4c3e9SNicolas Vasilache SmallVector<bool, 2> reductionMask{false, true}; 4482bc4c3e9SNicolas Vasilache 4492bc4c3e9SNicolas Vasilache /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) 4502bc4c3e9SNicolas Vasilache Value cast = rewriter.create<vector::ShapeCastOp>( 4512bc4c3e9SNicolas Vasilache loc, castedType, multiReductionOp.getSource()); 4522bc4c3e9SNicolas Vasilache Value castAcc = rewriter.create<vector::BroadcastOp>( 4532bc4c3e9SNicolas Vasilache loc, accType, multiReductionOp.getAcc()); 4542bc4c3e9SNicolas Vasilache Value castMask; 4552bc4c3e9SNicolas Vasilache if (maskableOp.isMasked()) { 456dbcc4549SZhaoshi Zheng auto maskType = llvm::cast<VectorType>(mask.getType()); 457dbcc4549SZhaoshi Zheng auto castMaskType = VectorType::get( 458dbcc4549SZhaoshi Zheng ArrayRef<int64_t>{1, maskType.getShape().back()}, 459dbcc4549SZhaoshi Zheng maskType.getElementType(), 460dbcc4549SZhaoshi Zheng ArrayRef<bool>{false, maskType.getScalableDims().back()}); 4612bc4c3e9SNicolas Vasilache castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask); 4622bc4c3e9SNicolas Vasilache } 4632bc4c3e9SNicolas Vasilache 4642bc4c3e9SNicolas Vasilache Operation *newOp = rewriter.create<vector::MultiDimReductionOp>( 4652bc4c3e9SNicolas Vasilache loc, cast, castAcc, reductionMask, multiReductionOp.getKind()); 4662bc4c3e9SNicolas Vasilache newOp = vector::maskOperation(rewriter, newOp, castMask); 4672bc4c3e9SNicolas Vasilache 4682bc4c3e9SNicolas Vasilache rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0), 4692bc4c3e9SNicolas Vasilache ArrayRef<int64_t>{0}); 4702bc4c3e9SNicolas Vasilache return success(); 4712bc4c3e9SNicolas Vasilache } 4722bc4c3e9SNicolas Vasilache }; 4738d6469b0Sxiaoleis-nv 4748d6469b0Sxiaoleis-nv struct LowerVectorMultiReductionPass 4758d6469b0Sxiaoleis-nv : public vector::impl::LowerVectorMultiReductionBase< 4768d6469b0Sxiaoleis-nv LowerVectorMultiReductionPass> { 4778d6469b0Sxiaoleis-nv LowerVectorMultiReductionPass(vector::VectorMultiReductionLowering option) { 4788d6469b0Sxiaoleis-nv this->loweringStrategy = option; 4798d6469b0Sxiaoleis-nv } 4808d6469b0Sxiaoleis-nv 4818d6469b0Sxiaoleis-nv void runOnOperation() override { 4828d6469b0Sxiaoleis-nv Operation *op = getOperation(); 4838d6469b0Sxiaoleis-nv MLIRContext *context = op->getContext(); 4848d6469b0Sxiaoleis-nv 4858d6469b0Sxiaoleis-nv RewritePatternSet loweringPatterns(context); 4868d6469b0Sxiaoleis-nv populateVectorMultiReductionLoweringPatterns(loweringPatterns, 4878d6469b0Sxiaoleis-nv this->loweringStrategy); 4888d6469b0Sxiaoleis-nv 489*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) 4908d6469b0Sxiaoleis-nv signalPassFailure(); 4918d6469b0Sxiaoleis-nv } 4928d6469b0Sxiaoleis-nv 4938d6469b0Sxiaoleis-nv void getDependentDialects(DialectRegistry ®istry) const override { 4948d6469b0Sxiaoleis-nv registry.insert<vector::VectorDialect>(); 4958d6469b0Sxiaoleis-nv } 4968d6469b0Sxiaoleis-nv }; 4978d6469b0Sxiaoleis-nv 4982bc4c3e9SNicolas Vasilache } // namespace 4992bc4c3e9SNicolas Vasilache 5002bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorMultiReductionLoweringPatterns( 5012bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, VectorMultiReductionLowering options, 5022bc4c3e9SNicolas Vasilache PatternBenefit benefit) { 5032bc4c3e9SNicolas Vasilache patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>( 5042bc4c3e9SNicolas Vasilache patterns.getContext(), options, benefit); 5052bc4c3e9SNicolas Vasilache patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit); 5062bc4c3e9SNicolas Vasilache if (options == VectorMultiReductionLowering ::InnerReduction) 5072bc4c3e9SNicolas Vasilache patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(), 5082bc4c3e9SNicolas Vasilache benefit); 5092bc4c3e9SNicolas Vasilache else 5102bc4c3e9SNicolas Vasilache patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext(), 5112bc4c3e9SNicolas Vasilache benefit); 5122bc4c3e9SNicolas Vasilache } 5138d6469b0Sxiaoleis-nv 5148d6469b0Sxiaoleis-nv std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass( 5158d6469b0Sxiaoleis-nv vector::VectorMultiReductionLowering option) { 5168d6469b0Sxiaoleis-nv return std::make_unique<LowerVectorMultiReductionPass>(option); 5178d6469b0Sxiaoleis-nv } 518