xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to expand affine index ops into one or more more
10 // fundamental operations.
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Passes.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
17 #include "mlir/Dialect/Affine/Utils.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 namespace mlir {
22 namespace affine {
23 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
24 #include "mlir/Dialect/Affine/Passes.h.inc"
25 } // namespace affine
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::affine;
30 
31 namespace {
32 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
33 /// operations.
34 struct LowerDelinearizeIndexOps
35     : public OpRewritePattern<AffineDelinearizeIndexOp> {
36   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
37   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
38                                 PatternRewriter &rewriter) const override {
39     FailureOr<SmallVector<Value>> multiIndex =
40         delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
41                          op.getEffectiveBasis(), /*hasOuterBound=*/false);
42     if (failed(multiIndex))
43       return failure();
44     rewriter.replaceOp(op, *multiIndex);
45     return success();
46   }
47 };
48 
49 /// Lowers `affine.linearize_index` into a sequence of multiplications and
50 /// additions.
51 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
52   using OpRewritePattern::OpRewritePattern;
53   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
54                                 PatternRewriter &rewriter) const override {
55     // Should be folded away, included here for safety.
56     if (op.getMultiIndex().empty()) {
57       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
58       return success();
59     }
60 
61     SmallVector<OpFoldResult> multiIndex =
62         getAsOpFoldResult(op.getMultiIndex());
63     OpFoldResult linearIndex =
64         linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
65     Value linearIndexValue =
66         getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
67     rewriter.replaceOp(op, linearIndexValue);
68     return success();
69   }
70 };
71 
72 class ExpandAffineIndexOpsAsAffinePass
73     : public affine::impl::AffineExpandIndexOpsAsAffineBase<
74           ExpandAffineIndexOpsAsAffinePass> {
75 public:
76   ExpandAffineIndexOpsAsAffinePass() = default;
77 
78   void runOnOperation() override {
79     MLIRContext *context = &getContext();
80     RewritePatternSet patterns(context);
81     populateAffineExpandIndexOpsAsAffinePatterns(patterns);
82     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
83       return signalPassFailure();
84   }
85 };
86 
87 } // namespace
88 
89 void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
90     RewritePatternSet &patterns) {
91   patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
92       patterns.getContext());
93 }
94 
95 std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
96   return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
97 }
98