xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp (revision a6e72f93923378bffe13367f6dedd526ad39b184)
1 //===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.step' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/Vector/IR/VectorOps.h"
16 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 #define DEBUG_TYPE "vector-step-lowering"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 namespace {
25 
26 struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
27   using OpRewritePattern::OpRewritePattern;
28 
29   LogicalResult matchAndRewrite(vector::StepOp stepOp,
30                                 PatternRewriter &rewriter) const override {
31     auto resultType = cast<VectorType>(stepOp.getType());
32     if (resultType.isScalable()) {
33       return failure();
34     }
35     int64_t elementCount = resultType.getNumElements();
36     SmallVector<APInt> indices =
37         llvm::map_to_vector(llvm::seq(elementCount),
38                             [](int64_t i) { return APInt(/*width=*/64, i); });
39     rewriter.replaceOpWithNewOp<arith::ConstantOp>(
40         stepOp, DenseElementsAttr::get(resultType, indices));
41     return success();
42   }
43 };
44 } // namespace
45 
46 void mlir::vector::populateVectorStepLoweringPatterns(
47     RewritePatternSet &patterns, PatternBenefit benefit) {
48   patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit);
49 }
50