1*a6e72f93SManupa Karunaratne //===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===// 2*a6e72f93SManupa Karunaratne // 3*a6e72f93SManupa Karunaratne // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*a6e72f93SManupa Karunaratne // See https://llvm.org/LICENSE.txt for license information. 5*a6e72f93SManupa Karunaratne // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*a6e72f93SManupa Karunaratne // 7*a6e72f93SManupa Karunaratne //===----------------------------------------------------------------------===// 8*a6e72f93SManupa Karunaratne // 9*a6e72f93SManupa Karunaratne // This file implements target-independent rewrites and utilities to lower the 10*a6e72f93SManupa Karunaratne // 'vector.step' operation. 11*a6e72f93SManupa Karunaratne // 12*a6e72f93SManupa Karunaratne //===----------------------------------------------------------------------===// 13*a6e72f93SManupa Karunaratne 14*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Arith/IR/Arith.h" 15*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Vector/IR/VectorOps.h" 16*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 17*a6e72f93SManupa Karunaratne #include "mlir/IR/PatternMatch.h" 18*a6e72f93SManupa Karunaratne 19*a6e72f93SManupa Karunaratne #define DEBUG_TYPE "vector-step-lowering" 20*a6e72f93SManupa Karunaratne 21*a6e72f93SManupa Karunaratne using namespace mlir; 22*a6e72f93SManupa Karunaratne using namespace mlir::vector; 23*a6e72f93SManupa Karunaratne 24*a6e72f93SManupa Karunaratne namespace { 25*a6e72f93SManupa Karunaratne 26*a6e72f93SManupa Karunaratne struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> { 27*a6e72f93SManupa Karunaratne using OpRewritePattern::OpRewritePattern; 28*a6e72f93SManupa Karunaratne 29*a6e72f93SManupa Karunaratne LogicalResult matchAndRewrite(vector::StepOp stepOp, 30*a6e72f93SManupa Karunaratne PatternRewriter &rewriter) const override { 31*a6e72f93SManupa Karunaratne auto resultType = cast<VectorType>(stepOp.getType()); 32*a6e72f93SManupa Karunaratne if (resultType.isScalable()) { 33*a6e72f93SManupa Karunaratne return failure(); 34*a6e72f93SManupa Karunaratne } 35*a6e72f93SManupa Karunaratne int64_t elementCount = resultType.getNumElements(); 36*a6e72f93SManupa Karunaratne SmallVector<APInt> indices = 37*a6e72f93SManupa Karunaratne llvm::map_to_vector(llvm::seq(elementCount), 38*a6e72f93SManupa Karunaratne [](int64_t i) { return APInt(/*width=*/64, i); }); 39*a6e72f93SManupa Karunaratne rewriter.replaceOpWithNewOp<arith::ConstantOp>( 40*a6e72f93SManupa Karunaratne stepOp, DenseElementsAttr::get(resultType, indices)); 41*a6e72f93SManupa Karunaratne return success(); 42*a6e72f93SManupa Karunaratne } 43*a6e72f93SManupa Karunaratne }; 44*a6e72f93SManupa Karunaratne } // namespace 45*a6e72f93SManupa Karunaratne 46*a6e72f93SManupa Karunaratne void mlir::vector::populateVectorStepLoweringPatterns( 47*a6e72f93SManupa Karunaratne RewritePatternSet &patterns, PatternBenefit benefit) { 48*a6e72f93SManupa Karunaratne patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit); 49*a6e72f93SManupa Karunaratne } 50