12bc4c3e9SNicolas Vasilache //===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===// 22bc4c3e9SNicolas Vasilache // 32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62bc4c3e9SNicolas Vasilache // 72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 82bc4c3e9SNicolas Vasilache // 92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the 102bc4c3e9SNicolas Vasilache // 'vector.scan' operation. 112bc4c3e9SNicolas Vasilache // 122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 132bc4c3e9SNicolas Vasilache 142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h" 162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h" 172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/Linalg.h" 182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h" 192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h" 202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h" 212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 222bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 242bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 262bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h" 272bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h" 282bc4c3e9SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 292bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h" 302bc4c3e9SNicolas Vasilache #include "mlir/IR/Matchers.h" 312bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h" 322bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h" 332bc4c3e9SNicolas Vasilache #include "mlir/Interfaces/VectorInterfaces.h" 342bc4c3e9SNicolas Vasilache 352bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-broadcast-lowering" 362bc4c3e9SNicolas Vasilache 372bc4c3e9SNicolas Vasilache using namespace mlir; 382bc4c3e9SNicolas Vasilache using namespace mlir::vector; 392bc4c3e9SNicolas Vasilache 402bc4c3e9SNicolas Vasilache /// This function checks to see if the vector combining kind 412bc4c3e9SNicolas Vasilache /// is consistent with the integer or float element type. 422bc4c3e9SNicolas Vasilache static bool isValidKind(bool isInt, vector::CombiningKind kind) { 432bc4c3e9SNicolas Vasilache using vector::CombiningKind; 442bc4c3e9SNicolas Vasilache enum class KindType { FLOAT, INT, INVALID }; 452bc4c3e9SNicolas Vasilache KindType type{KindType::INVALID}; 462bc4c3e9SNicolas Vasilache switch (kind) { 47560564f5SJakub Kuderski case CombiningKind::MINNUMF: 484a831250SDaniil Dudkin case CombiningKind::MINIMUMF: 49560564f5SJakub Kuderski case CombiningKind::MAXNUMF: 504a831250SDaniil Dudkin case CombiningKind::MAXIMUMF: 512bc4c3e9SNicolas Vasilache type = KindType::FLOAT; 522bc4c3e9SNicolas Vasilache break; 532bc4c3e9SNicolas Vasilache case CombiningKind::MINUI: 542bc4c3e9SNicolas Vasilache case CombiningKind::MINSI: 552bc4c3e9SNicolas Vasilache case CombiningKind::MAXUI: 562bc4c3e9SNicolas Vasilache case CombiningKind::MAXSI: 572bc4c3e9SNicolas Vasilache case CombiningKind::AND: 582bc4c3e9SNicolas Vasilache case CombiningKind::OR: 592bc4c3e9SNicolas Vasilache case CombiningKind::XOR: 602bc4c3e9SNicolas Vasilache type = KindType::INT; 612bc4c3e9SNicolas Vasilache break; 622bc4c3e9SNicolas Vasilache case CombiningKind::ADD: 632bc4c3e9SNicolas Vasilache case CombiningKind::MUL: 642bc4c3e9SNicolas Vasilache type = isInt ? KindType::INT : KindType::FLOAT; 652bc4c3e9SNicolas Vasilache break; 662bc4c3e9SNicolas Vasilache } 672bc4c3e9SNicolas Vasilache bool isValidIntKind = (type == KindType::INT) && isInt; 682bc4c3e9SNicolas Vasilache bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); 692bc4c3e9SNicolas Vasilache return (isValidIntKind || isValidFloatKind); 702bc4c3e9SNicolas Vasilache } 712bc4c3e9SNicolas Vasilache 722bc4c3e9SNicolas Vasilache namespace { 732bc4c3e9SNicolas Vasilache /// Convert vector.scan op into arith ops and vector.insert_strided_slice / 742bc4c3e9SNicolas Vasilache /// vector.extract_strided_slice. 752bc4c3e9SNicolas Vasilache /// 762bc4c3e9SNicolas Vasilache /// Example: 772bc4c3e9SNicolas Vasilache /// 782bc4c3e9SNicolas Vasilache /// ``` 792bc4c3e9SNicolas Vasilache /// %0:2 = vector.scan <add>, %arg0, %arg1 802bc4c3e9SNicolas Vasilache /// {inclusive = true, reduction_dim = 1} : 812bc4c3e9SNicolas Vasilache /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) 822bc4c3e9SNicolas Vasilache /// ``` 832bc4c3e9SNicolas Vasilache /// 842bc4c3e9SNicolas Vasilache /// is converted to: 852bc4c3e9SNicolas Vasilache /// 862bc4c3e9SNicolas Vasilache /// ``` 872bc4c3e9SNicolas Vasilache /// %cst = arith.constant dense<0> : vector<2x3xi32> 882bc4c3e9SNicolas Vasilache /// %0 = vector.extract_strided_slice %arg0 892bc4c3e9SNicolas Vasilache /// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} 902bc4c3e9SNicolas Vasilache /// : vector<2x3xi32> to vector<2x1xi32> 912bc4c3e9SNicolas Vasilache /// %1 = vector.insert_strided_slice %0, %cst 922bc4c3e9SNicolas Vasilache /// {offsets = [0, 0], strides = [1, 1]} 932bc4c3e9SNicolas Vasilache /// : vector<2x1xi32> into vector<2x3xi32> 942bc4c3e9SNicolas Vasilache /// %2 = vector.extract_strided_slice %arg0 952bc4c3e9SNicolas Vasilache /// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} 962bc4c3e9SNicolas Vasilache /// : vector<2x3xi32> to vector<2x1xi32> 972bc4c3e9SNicolas Vasilache /// %3 = arith.muli %0, %2 : vector<2x1xi32> 982bc4c3e9SNicolas Vasilache /// %4 = vector.insert_strided_slice %3, %1 992bc4c3e9SNicolas Vasilache /// {offsets = [0, 1], strides = [1, 1]} 1002bc4c3e9SNicolas Vasilache /// : vector<2x1xi32> into vector<2x3xi32> 1012bc4c3e9SNicolas Vasilache /// %5 = vector.extract_strided_slice %arg0 1022bc4c3e9SNicolas Vasilache /// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} 1032bc4c3e9SNicolas Vasilache /// : vector<2x3xi32> to vector<2x1xi32> 1042bc4c3e9SNicolas Vasilache /// %6 = arith.muli %3, %5 : vector<2x1xi32> 1052bc4c3e9SNicolas Vasilache /// %7 = vector.insert_strided_slice %6, %4 1062bc4c3e9SNicolas Vasilache /// {offsets = [0, 2], strides = [1, 1]} 1072bc4c3e9SNicolas Vasilache /// : vector<2x1xi32> into vector<2x3xi32> 1082bc4c3e9SNicolas Vasilache /// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> 1092bc4c3e9SNicolas Vasilache /// return %7, %8 : vector<2x3xi32>, vector<2xi32> 1102bc4c3e9SNicolas Vasilache /// ``` 1112bc4c3e9SNicolas Vasilache struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { 1122bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 1132bc4c3e9SNicolas Vasilache 1142bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::ScanOp scanOp, 1152bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 1162bc4c3e9SNicolas Vasilache auto loc = scanOp.getLoc(); 1172bc4c3e9SNicolas Vasilache VectorType destType = scanOp.getDestType(); 1182bc4c3e9SNicolas Vasilache ArrayRef<int64_t> destShape = destType.getShape(); 1192bc4c3e9SNicolas Vasilache auto elType = destType.getElementType(); 1202bc4c3e9SNicolas Vasilache bool isInt = elType.isIntOrIndex(); 1212bc4c3e9SNicolas Vasilache if (!isValidKind(isInt, scanOp.getKind())) 1222bc4c3e9SNicolas Vasilache return failure(); 1232bc4c3e9SNicolas Vasilache 1242bc4c3e9SNicolas Vasilache VectorType resType = VectorType::get(destShape, elType); 1252bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 1262bc4c3e9SNicolas Vasilache loc, resType, rewriter.getZeroAttr(resType)); 1272bc4c3e9SNicolas Vasilache int64_t reductionDim = scanOp.getReductionDim(); 1282bc4c3e9SNicolas Vasilache bool inclusive = scanOp.getInclusive(); 1292bc4c3e9SNicolas Vasilache int64_t destRank = destType.getRank(); 1302bc4c3e9SNicolas Vasilache VectorType initialValueType = scanOp.getInitialValueType(); 1312bc4c3e9SNicolas Vasilache int64_t initialValueRank = initialValueType.getRank(); 1322bc4c3e9SNicolas Vasilache 133*5262865aSKazu Hirata SmallVector<int64_t> reductionShape(destShape); 1342bc4c3e9SNicolas Vasilache reductionShape[reductionDim] = 1; 1352bc4c3e9SNicolas Vasilache VectorType reductionType = VectorType::get(reductionShape, elType); 1362bc4c3e9SNicolas Vasilache SmallVector<int64_t> offsets(destRank, 0); 1372bc4c3e9SNicolas Vasilache SmallVector<int64_t> strides(destRank, 1); 138*5262865aSKazu Hirata SmallVector<int64_t> sizes(destShape); 1392bc4c3e9SNicolas Vasilache sizes[reductionDim] = 1; 1402bc4c3e9SNicolas Vasilache ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); 1412bc4c3e9SNicolas Vasilache ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); 1422bc4c3e9SNicolas Vasilache 1432bc4c3e9SNicolas Vasilache Value lastOutput, lastInput; 1442bc4c3e9SNicolas Vasilache for (int i = 0; i < destShape[reductionDim]; i++) { 1452bc4c3e9SNicolas Vasilache offsets[reductionDim] = i; 1462bc4c3e9SNicolas Vasilache ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); 1472bc4c3e9SNicolas Vasilache Value input = rewriter.create<vector::ExtractStridedSliceOp>( 1482bc4c3e9SNicolas Vasilache loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, 1492bc4c3e9SNicolas Vasilache scanStrides); 1502bc4c3e9SNicolas Vasilache Value output; 1512bc4c3e9SNicolas Vasilache if (i == 0) { 1522bc4c3e9SNicolas Vasilache if (inclusive) { 1532bc4c3e9SNicolas Vasilache output = input; 1542bc4c3e9SNicolas Vasilache } else { 1552bc4c3e9SNicolas Vasilache if (initialValueRank == 0) { 1562bc4c3e9SNicolas Vasilache // ShapeCastOp cannot handle 0-D vectors 1572bc4c3e9SNicolas Vasilache output = rewriter.create<vector::BroadcastOp>( 1582bc4c3e9SNicolas Vasilache loc, input.getType(), scanOp.getInitialValue()); 1592bc4c3e9SNicolas Vasilache } else { 1602bc4c3e9SNicolas Vasilache output = rewriter.create<vector::ShapeCastOp>( 1612bc4c3e9SNicolas Vasilache loc, input.getType(), scanOp.getInitialValue()); 1622bc4c3e9SNicolas Vasilache } 1632bc4c3e9SNicolas Vasilache } 1642bc4c3e9SNicolas Vasilache } else { 1652bc4c3e9SNicolas Vasilache Value y = inclusive ? input : lastInput; 1669f74e6e6SJakub Kuderski output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(), 1679f74e6e6SJakub Kuderski lastOutput, y); 1682bc4c3e9SNicolas Vasilache } 1692bc4c3e9SNicolas Vasilache result = rewriter.create<vector::InsertStridedSliceOp>( 1702bc4c3e9SNicolas Vasilache loc, output, result, offsets, strides); 1712bc4c3e9SNicolas Vasilache lastOutput = output; 1722bc4c3e9SNicolas Vasilache lastInput = input; 1732bc4c3e9SNicolas Vasilache } 1742bc4c3e9SNicolas Vasilache 1752bc4c3e9SNicolas Vasilache Value reduction; 1762bc4c3e9SNicolas Vasilache if (initialValueRank == 0) { 1772bc4c3e9SNicolas Vasilache Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0); 1782bc4c3e9SNicolas Vasilache reduction = 1792bc4c3e9SNicolas Vasilache rewriter.create<vector::BroadcastOp>(loc, initialValueType, v); 1802bc4c3e9SNicolas Vasilache } else { 1812bc4c3e9SNicolas Vasilache reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType, 1822bc4c3e9SNicolas Vasilache lastOutput); 1832bc4c3e9SNicolas Vasilache } 1842bc4c3e9SNicolas Vasilache 1852bc4c3e9SNicolas Vasilache rewriter.replaceOp(scanOp, {result, reduction}); 1862bc4c3e9SNicolas Vasilache return success(); 1872bc4c3e9SNicolas Vasilache } 1882bc4c3e9SNicolas Vasilache }; 1892bc4c3e9SNicolas Vasilache } // namespace 1902bc4c3e9SNicolas Vasilache 1912bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorScanLoweringPatterns( 1922bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 1932bc4c3e9SNicolas Vasilache patterns.add<ScanToArithOps>(patterns.getContext(), benefit); 1942bc4c3e9SNicolas Vasilache } 195