xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
12bc4c3e9SNicolas Vasilache //===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
22bc4c3e9SNicolas Vasilache //
32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62bc4c3e9SNicolas Vasilache //
72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
82bc4c3e9SNicolas Vasilache //
92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the
102bc4c3e9SNicolas Vasilache // 'vector.scan' operation.
112bc4c3e9SNicolas Vasilache //
122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
132bc4c3e9SNicolas Vasilache 
142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h"
162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h"
172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Linalg/IR/Linalg.h"
182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h"
202bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
212bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
222bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
232bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
242bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
262bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinAttributeInterfaces.h"
272bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h"
282bc4c3e9SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
292bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h"
302bc4c3e9SNicolas Vasilache #include "mlir/IR/Matchers.h"
312bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
322bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h"
332bc4c3e9SNicolas Vasilache #include "mlir/Interfaces/VectorInterfaces.h"
342bc4c3e9SNicolas Vasilache 
352bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-broadcast-lowering"
362bc4c3e9SNicolas Vasilache 
372bc4c3e9SNicolas Vasilache using namespace mlir;
382bc4c3e9SNicolas Vasilache using namespace mlir::vector;
392bc4c3e9SNicolas Vasilache 
402bc4c3e9SNicolas Vasilache /// This function checks to see if the vector combining kind
412bc4c3e9SNicolas Vasilache /// is consistent with the integer or float element type.
422bc4c3e9SNicolas Vasilache static bool isValidKind(bool isInt, vector::CombiningKind kind) {
432bc4c3e9SNicolas Vasilache   using vector::CombiningKind;
442bc4c3e9SNicolas Vasilache   enum class KindType { FLOAT, INT, INVALID };
452bc4c3e9SNicolas Vasilache   KindType type{KindType::INVALID};
462bc4c3e9SNicolas Vasilache   switch (kind) {
47560564f5SJakub Kuderski   case CombiningKind::MINNUMF:
484a831250SDaniil Dudkin   case CombiningKind::MINIMUMF:
49560564f5SJakub Kuderski   case CombiningKind::MAXNUMF:
504a831250SDaniil Dudkin   case CombiningKind::MAXIMUMF:
512bc4c3e9SNicolas Vasilache     type = KindType::FLOAT;
522bc4c3e9SNicolas Vasilache     break;
532bc4c3e9SNicolas Vasilache   case CombiningKind::MINUI:
542bc4c3e9SNicolas Vasilache   case CombiningKind::MINSI:
552bc4c3e9SNicolas Vasilache   case CombiningKind::MAXUI:
562bc4c3e9SNicolas Vasilache   case CombiningKind::MAXSI:
572bc4c3e9SNicolas Vasilache   case CombiningKind::AND:
582bc4c3e9SNicolas Vasilache   case CombiningKind::OR:
592bc4c3e9SNicolas Vasilache   case CombiningKind::XOR:
602bc4c3e9SNicolas Vasilache     type = KindType::INT;
612bc4c3e9SNicolas Vasilache     break;
622bc4c3e9SNicolas Vasilache   case CombiningKind::ADD:
632bc4c3e9SNicolas Vasilache   case CombiningKind::MUL:
642bc4c3e9SNicolas Vasilache     type = isInt ? KindType::INT : KindType::FLOAT;
652bc4c3e9SNicolas Vasilache     break;
662bc4c3e9SNicolas Vasilache   }
672bc4c3e9SNicolas Vasilache   bool isValidIntKind = (type == KindType::INT) && isInt;
682bc4c3e9SNicolas Vasilache   bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
692bc4c3e9SNicolas Vasilache   return (isValidIntKind || isValidFloatKind);
702bc4c3e9SNicolas Vasilache }
712bc4c3e9SNicolas Vasilache 
722bc4c3e9SNicolas Vasilache namespace {
732bc4c3e9SNicolas Vasilache /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
742bc4c3e9SNicolas Vasilache /// vector.extract_strided_slice.
752bc4c3e9SNicolas Vasilache ///
762bc4c3e9SNicolas Vasilache /// Example:
772bc4c3e9SNicolas Vasilache ///
782bc4c3e9SNicolas Vasilache /// ```
792bc4c3e9SNicolas Vasilache ///   %0:2 = vector.scan <add>, %arg0, %arg1
802bc4c3e9SNicolas Vasilache ///     {inclusive = true, reduction_dim = 1} :
812bc4c3e9SNicolas Vasilache ///     (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
822bc4c3e9SNicolas Vasilache /// ```
832bc4c3e9SNicolas Vasilache ///
842bc4c3e9SNicolas Vasilache /// is converted to:
852bc4c3e9SNicolas Vasilache ///
862bc4c3e9SNicolas Vasilache /// ```
872bc4c3e9SNicolas Vasilache ///   %cst = arith.constant dense<0> : vector<2x3xi32>
882bc4c3e9SNicolas Vasilache ///   %0 = vector.extract_strided_slice %arg0
892bc4c3e9SNicolas Vasilache ///     {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
902bc4c3e9SNicolas Vasilache ///       : vector<2x3xi32> to vector<2x1xi32>
912bc4c3e9SNicolas Vasilache ///   %1 = vector.insert_strided_slice %0, %cst
922bc4c3e9SNicolas Vasilache ///     {offsets = [0, 0], strides = [1, 1]}
932bc4c3e9SNicolas Vasilache ///       : vector<2x1xi32> into vector<2x3xi32>
942bc4c3e9SNicolas Vasilache ///   %2 = vector.extract_strided_slice %arg0
952bc4c3e9SNicolas Vasilache ///     {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
962bc4c3e9SNicolas Vasilache ///       : vector<2x3xi32> to vector<2x1xi32>
972bc4c3e9SNicolas Vasilache ///   %3 = arith.muli %0, %2 : vector<2x1xi32>
982bc4c3e9SNicolas Vasilache ///   %4 = vector.insert_strided_slice %3, %1
992bc4c3e9SNicolas Vasilache ///     {offsets = [0, 1], strides = [1, 1]}
1002bc4c3e9SNicolas Vasilache ///       : vector<2x1xi32> into vector<2x3xi32>
1012bc4c3e9SNicolas Vasilache ///   %5 = vector.extract_strided_slice %arg0
1022bc4c3e9SNicolas Vasilache ///     {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
1032bc4c3e9SNicolas Vasilache ///       : vector<2x3xi32> to vector<2x1xi32>
1042bc4c3e9SNicolas Vasilache ///   %6 = arith.muli %3, %5 : vector<2x1xi32>
1052bc4c3e9SNicolas Vasilache ///   %7 = vector.insert_strided_slice %6, %4
1062bc4c3e9SNicolas Vasilache ///     {offsets = [0, 2], strides = [1, 1]}
1072bc4c3e9SNicolas Vasilache ///       : vector<2x1xi32> into vector<2x3xi32>
1082bc4c3e9SNicolas Vasilache ///   %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
1092bc4c3e9SNicolas Vasilache ///   return %7, %8 : vector<2x3xi32>, vector<2xi32>
1102bc4c3e9SNicolas Vasilache /// ```
1112bc4c3e9SNicolas Vasilache struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
1122bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
1132bc4c3e9SNicolas Vasilache 
1142bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::ScanOp scanOp,
1152bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
1162bc4c3e9SNicolas Vasilache     auto loc = scanOp.getLoc();
1172bc4c3e9SNicolas Vasilache     VectorType destType = scanOp.getDestType();
1182bc4c3e9SNicolas Vasilache     ArrayRef<int64_t> destShape = destType.getShape();
1192bc4c3e9SNicolas Vasilache     auto elType = destType.getElementType();
1202bc4c3e9SNicolas Vasilache     bool isInt = elType.isIntOrIndex();
1212bc4c3e9SNicolas Vasilache     if (!isValidKind(isInt, scanOp.getKind()))
1222bc4c3e9SNicolas Vasilache       return failure();
1232bc4c3e9SNicolas Vasilache 
1242bc4c3e9SNicolas Vasilache     VectorType resType = VectorType::get(destShape, elType);
1252bc4c3e9SNicolas Vasilache     Value result = rewriter.create<arith::ConstantOp>(
1262bc4c3e9SNicolas Vasilache         loc, resType, rewriter.getZeroAttr(resType));
1272bc4c3e9SNicolas Vasilache     int64_t reductionDim = scanOp.getReductionDim();
1282bc4c3e9SNicolas Vasilache     bool inclusive = scanOp.getInclusive();
1292bc4c3e9SNicolas Vasilache     int64_t destRank = destType.getRank();
1302bc4c3e9SNicolas Vasilache     VectorType initialValueType = scanOp.getInitialValueType();
1312bc4c3e9SNicolas Vasilache     int64_t initialValueRank = initialValueType.getRank();
1322bc4c3e9SNicolas Vasilache 
133*5262865aSKazu Hirata     SmallVector<int64_t> reductionShape(destShape);
1342bc4c3e9SNicolas Vasilache     reductionShape[reductionDim] = 1;
1352bc4c3e9SNicolas Vasilache     VectorType reductionType = VectorType::get(reductionShape, elType);
1362bc4c3e9SNicolas Vasilache     SmallVector<int64_t> offsets(destRank, 0);
1372bc4c3e9SNicolas Vasilache     SmallVector<int64_t> strides(destRank, 1);
138*5262865aSKazu Hirata     SmallVector<int64_t> sizes(destShape);
1392bc4c3e9SNicolas Vasilache     sizes[reductionDim] = 1;
1402bc4c3e9SNicolas Vasilache     ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
1412bc4c3e9SNicolas Vasilache     ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
1422bc4c3e9SNicolas Vasilache 
1432bc4c3e9SNicolas Vasilache     Value lastOutput, lastInput;
1442bc4c3e9SNicolas Vasilache     for (int i = 0; i < destShape[reductionDim]; i++) {
1452bc4c3e9SNicolas Vasilache       offsets[reductionDim] = i;
1462bc4c3e9SNicolas Vasilache       ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
1472bc4c3e9SNicolas Vasilache       Value input = rewriter.create<vector::ExtractStridedSliceOp>(
1482bc4c3e9SNicolas Vasilache           loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
1492bc4c3e9SNicolas Vasilache           scanStrides);
1502bc4c3e9SNicolas Vasilache       Value output;
1512bc4c3e9SNicolas Vasilache       if (i == 0) {
1522bc4c3e9SNicolas Vasilache         if (inclusive) {
1532bc4c3e9SNicolas Vasilache           output = input;
1542bc4c3e9SNicolas Vasilache         } else {
1552bc4c3e9SNicolas Vasilache           if (initialValueRank == 0) {
1562bc4c3e9SNicolas Vasilache             // ShapeCastOp cannot handle 0-D vectors
1572bc4c3e9SNicolas Vasilache             output = rewriter.create<vector::BroadcastOp>(
1582bc4c3e9SNicolas Vasilache                 loc, input.getType(), scanOp.getInitialValue());
1592bc4c3e9SNicolas Vasilache           } else {
1602bc4c3e9SNicolas Vasilache             output = rewriter.create<vector::ShapeCastOp>(
1612bc4c3e9SNicolas Vasilache                 loc, input.getType(), scanOp.getInitialValue());
1622bc4c3e9SNicolas Vasilache           }
1632bc4c3e9SNicolas Vasilache         }
1642bc4c3e9SNicolas Vasilache       } else {
1652bc4c3e9SNicolas Vasilache         Value y = inclusive ? input : lastInput;
1669f74e6e6SJakub Kuderski         output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
1679f74e6e6SJakub Kuderski                                             lastOutput, y);
1682bc4c3e9SNicolas Vasilache       }
1692bc4c3e9SNicolas Vasilache       result = rewriter.create<vector::InsertStridedSliceOp>(
1702bc4c3e9SNicolas Vasilache           loc, output, result, offsets, strides);
1712bc4c3e9SNicolas Vasilache       lastOutput = output;
1722bc4c3e9SNicolas Vasilache       lastInput = input;
1732bc4c3e9SNicolas Vasilache     }
1742bc4c3e9SNicolas Vasilache 
1752bc4c3e9SNicolas Vasilache     Value reduction;
1762bc4c3e9SNicolas Vasilache     if (initialValueRank == 0) {
1772bc4c3e9SNicolas Vasilache       Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
1782bc4c3e9SNicolas Vasilache       reduction =
1792bc4c3e9SNicolas Vasilache           rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
1802bc4c3e9SNicolas Vasilache     } else {
1812bc4c3e9SNicolas Vasilache       reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
1822bc4c3e9SNicolas Vasilache                                                        lastOutput);
1832bc4c3e9SNicolas Vasilache     }
1842bc4c3e9SNicolas Vasilache 
1852bc4c3e9SNicolas Vasilache     rewriter.replaceOp(scanOp, {result, reduction});
1862bc4c3e9SNicolas Vasilache     return success();
1872bc4c3e9SNicolas Vasilache   }
1882bc4c3e9SNicolas Vasilache };
1892bc4c3e9SNicolas Vasilache } // namespace
1902bc4c3e9SNicolas Vasilache 
1912bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorScanLoweringPatterns(
1922bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
1932bc4c3e9SNicolas Vasilache   patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
1942bc4c3e9SNicolas Vasilache }
195