xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp (revision a6e72f93923378bffe13367f6dedd526ad39b184)
1*a6e72f93SManupa Karunaratne //===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
2*a6e72f93SManupa Karunaratne //
3*a6e72f93SManupa Karunaratne // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*a6e72f93SManupa Karunaratne // See https://llvm.org/LICENSE.txt for license information.
5*a6e72f93SManupa Karunaratne // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*a6e72f93SManupa Karunaratne //
7*a6e72f93SManupa Karunaratne //===----------------------------------------------------------------------===//
8*a6e72f93SManupa Karunaratne //
9*a6e72f93SManupa Karunaratne // This file implements target-independent rewrites and utilities to lower the
10*a6e72f93SManupa Karunaratne // 'vector.step' operation.
11*a6e72f93SManupa Karunaratne //
12*a6e72f93SManupa Karunaratne //===----------------------------------------------------------------------===//
13*a6e72f93SManupa Karunaratne 
14*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Arith/IR/Arith.h"
15*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Vector/IR/VectorOps.h"
16*a6e72f93SManupa Karunaratne #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17*a6e72f93SManupa Karunaratne #include "mlir/IR/PatternMatch.h"
18*a6e72f93SManupa Karunaratne 
19*a6e72f93SManupa Karunaratne #define DEBUG_TYPE "vector-step-lowering"
20*a6e72f93SManupa Karunaratne 
21*a6e72f93SManupa Karunaratne using namespace mlir;
22*a6e72f93SManupa Karunaratne using namespace mlir::vector;
23*a6e72f93SManupa Karunaratne 
24*a6e72f93SManupa Karunaratne namespace {
25*a6e72f93SManupa Karunaratne 
26*a6e72f93SManupa Karunaratne struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
27*a6e72f93SManupa Karunaratne   using OpRewritePattern::OpRewritePattern;
28*a6e72f93SManupa Karunaratne 
29*a6e72f93SManupa Karunaratne   LogicalResult matchAndRewrite(vector::StepOp stepOp,
30*a6e72f93SManupa Karunaratne                                 PatternRewriter &rewriter) const override {
31*a6e72f93SManupa Karunaratne     auto resultType = cast<VectorType>(stepOp.getType());
32*a6e72f93SManupa Karunaratne     if (resultType.isScalable()) {
33*a6e72f93SManupa Karunaratne       return failure();
34*a6e72f93SManupa Karunaratne     }
35*a6e72f93SManupa Karunaratne     int64_t elementCount = resultType.getNumElements();
36*a6e72f93SManupa Karunaratne     SmallVector<APInt> indices =
37*a6e72f93SManupa Karunaratne         llvm::map_to_vector(llvm::seq(elementCount),
38*a6e72f93SManupa Karunaratne                             [](int64_t i) { return APInt(/*width=*/64, i); });
39*a6e72f93SManupa Karunaratne     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
40*a6e72f93SManupa Karunaratne         stepOp, DenseElementsAttr::get(resultType, indices));
41*a6e72f93SManupa Karunaratne     return success();
42*a6e72f93SManupa Karunaratne   }
43*a6e72f93SManupa Karunaratne };
44*a6e72f93SManupa Karunaratne } // namespace
45*a6e72f93SManupa Karunaratne 
46*a6e72f93SManupa Karunaratne void mlir::vector::populateVectorStepLoweringPatterns(
47*a6e72f93SManupa Karunaratne     RewritePatternSet &patterns, PatternBenefit benefit) {
48*a6e72f93SManupa Karunaratne   patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit);
49*a6e72f93SManupa Karunaratne }
50