1 //===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.step' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Vector/IR/VectorOps.h" 16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17 #include "mlir/IR/PatternMatch.h" 18 19 #define DEBUG_TYPE "vector-step-lowering" 20 21 using namespace mlir; 22 using namespace mlir::vector; 23 24 namespace { 25 26 struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> { 27 using OpRewritePattern::OpRewritePattern; 28 29 LogicalResult matchAndRewrite(vector::StepOp stepOp, 30 PatternRewriter &rewriter) const override { 31 auto resultType = cast<VectorType>(stepOp.getType()); 32 if (resultType.isScalable()) { 33 return failure(); 34 } 35 int64_t elementCount = resultType.getNumElements(); 36 SmallVector<APInt> indices = 37 llvm::map_to_vector(llvm::seq(elementCount), 38 [](int64_t i) { return APInt(/*width=*/64, i); }); 39 rewriter.replaceOpWithNewOp<arith::ConstantOp>( 40 stepOp, DenseElementsAttr::get(resultType, indices)); 41 return success(); 42 } 43 }; 44 } // namespace 45 46 void mlir::vector::populateVectorStepLoweringPatterns( 47 RewritePatternSet &patterns, PatternBenefit benefit) { 48 patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit); 49 } 50