xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
19e8200c7SKrzysztof Drewniak //===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
29e8200c7SKrzysztof Drewniak //
39e8200c7SKrzysztof Drewniak // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49e8200c7SKrzysztof Drewniak // See https://llvm.org/LICENSE.txt for license information.
59e8200c7SKrzysztof Drewniak // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69e8200c7SKrzysztof Drewniak //
79e8200c7SKrzysztof Drewniak //===----------------------------------------------------------------------===//
89e8200c7SKrzysztof Drewniak //
99e8200c7SKrzysztof Drewniak // This file implements a pass to expand affine index ops into one or more more
109e8200c7SKrzysztof Drewniak // fundamental operations.
119e8200c7SKrzysztof Drewniak //===----------------------------------------------------------------------===//
129e8200c7SKrzysztof Drewniak 
139e8200c7SKrzysztof Drewniak #include "mlir/Dialect/Affine/Passes.h"
149e8200c7SKrzysztof Drewniak 
159e8200c7SKrzysztof Drewniak #include "mlir/Dialect/Affine/IR/AffineOps.h"
169e8200c7SKrzysztof Drewniak #include "mlir/Dialect/Affine/Transforms/Transforms.h"
179e8200c7SKrzysztof Drewniak #include "mlir/Dialect/Affine/Utils.h"
189e8200c7SKrzysztof Drewniak #include "mlir/Dialect/Arith/Utils/Utils.h"
199e8200c7SKrzysztof Drewniak #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
209e8200c7SKrzysztof Drewniak 
219e8200c7SKrzysztof Drewniak namespace mlir {
229e8200c7SKrzysztof Drewniak namespace affine {
239e8200c7SKrzysztof Drewniak #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
249e8200c7SKrzysztof Drewniak #include "mlir/Dialect/Affine/Passes.h.inc"
259e8200c7SKrzysztof Drewniak } // namespace affine
269e8200c7SKrzysztof Drewniak } // namespace mlir
279e8200c7SKrzysztof Drewniak 
289e8200c7SKrzysztof Drewniak using namespace mlir;
299e8200c7SKrzysztof Drewniak using namespace mlir::affine;
309e8200c7SKrzysztof Drewniak 
319e8200c7SKrzysztof Drewniak namespace {
329e8200c7SKrzysztof Drewniak /// Lowers `affine.delinearize_index` into a sequence of division and remainder
339e8200c7SKrzysztof Drewniak /// operations.
349e8200c7SKrzysztof Drewniak struct LowerDelinearizeIndexOps
359e8200c7SKrzysztof Drewniak     : public OpRewritePattern<AffineDelinearizeIndexOp> {
369e8200c7SKrzysztof Drewniak   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
379e8200c7SKrzysztof Drewniak   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
389e8200c7SKrzysztof Drewniak                                 PatternRewriter &rewriter) const override {
399e8200c7SKrzysztof Drewniak     FailureOr<SmallVector<Value>> multiIndex =
409e8200c7SKrzysztof Drewniak         delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
419e8200c7SKrzysztof Drewniak                          op.getEffectiveBasis(), /*hasOuterBound=*/false);
429e8200c7SKrzysztof Drewniak     if (failed(multiIndex))
439e8200c7SKrzysztof Drewniak       return failure();
449e8200c7SKrzysztof Drewniak     rewriter.replaceOp(op, *multiIndex);
459e8200c7SKrzysztof Drewniak     return success();
469e8200c7SKrzysztof Drewniak   }
479e8200c7SKrzysztof Drewniak };
489e8200c7SKrzysztof Drewniak 
499e8200c7SKrzysztof Drewniak /// Lowers `affine.linearize_index` into a sequence of multiplications and
509e8200c7SKrzysztof Drewniak /// additions.
519e8200c7SKrzysztof Drewniak struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
529e8200c7SKrzysztof Drewniak   using OpRewritePattern::OpRewritePattern;
539e8200c7SKrzysztof Drewniak   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
549e8200c7SKrzysztof Drewniak                                 PatternRewriter &rewriter) const override {
559e8200c7SKrzysztof Drewniak     // Should be folded away, included here for safety.
569e8200c7SKrzysztof Drewniak     if (op.getMultiIndex().empty()) {
579e8200c7SKrzysztof Drewniak       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
589e8200c7SKrzysztof Drewniak       return success();
599e8200c7SKrzysztof Drewniak     }
609e8200c7SKrzysztof Drewniak 
619e8200c7SKrzysztof Drewniak     SmallVector<OpFoldResult> multiIndex =
629e8200c7SKrzysztof Drewniak         getAsOpFoldResult(op.getMultiIndex());
639e8200c7SKrzysztof Drewniak     OpFoldResult linearIndex =
649e8200c7SKrzysztof Drewniak         linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
659e8200c7SKrzysztof Drewniak     Value linearIndexValue =
669e8200c7SKrzysztof Drewniak         getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
679e8200c7SKrzysztof Drewniak     rewriter.replaceOp(op, linearIndexValue);
689e8200c7SKrzysztof Drewniak     return success();
699e8200c7SKrzysztof Drewniak   }
709e8200c7SKrzysztof Drewniak };
719e8200c7SKrzysztof Drewniak 
729e8200c7SKrzysztof Drewniak class ExpandAffineIndexOpsAsAffinePass
739e8200c7SKrzysztof Drewniak     : public affine::impl::AffineExpandIndexOpsAsAffineBase<
749e8200c7SKrzysztof Drewniak           ExpandAffineIndexOpsAsAffinePass> {
759e8200c7SKrzysztof Drewniak public:
769e8200c7SKrzysztof Drewniak   ExpandAffineIndexOpsAsAffinePass() = default;
779e8200c7SKrzysztof Drewniak 
789e8200c7SKrzysztof Drewniak   void runOnOperation() override {
799e8200c7SKrzysztof Drewniak     MLIRContext *context = &getContext();
809e8200c7SKrzysztof Drewniak     RewritePatternSet patterns(context);
819e8200c7SKrzysztof Drewniak     populateAffineExpandIndexOpsAsAffinePatterns(patterns);
82*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
839e8200c7SKrzysztof Drewniak       return signalPassFailure();
849e8200c7SKrzysztof Drewniak   }
859e8200c7SKrzysztof Drewniak };
869e8200c7SKrzysztof Drewniak 
879e8200c7SKrzysztof Drewniak } // namespace
889e8200c7SKrzysztof Drewniak 
899e8200c7SKrzysztof Drewniak void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns(
909e8200c7SKrzysztof Drewniak     RewritePatternSet &patterns) {
919e8200c7SKrzysztof Drewniak   patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
929e8200c7SKrzysztof Drewniak       patterns.getContext());
939e8200c7SKrzysztof Drewniak }
949e8200c7SKrzysztof Drewniak 
959e8200c7SKrzysztof Drewniak std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() {
969e8200c7SKrzysztof Drewniak   return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
979e8200c7SKrzysztof Drewniak }
98