1 //===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to expand affine index ops into one or more more 10 // fundamental operations. 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/Passes.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/Transforms/Transforms.h" 17 #include "mlir/Dialect/Affine/Utils.h" 18 #include "mlir/Dialect/Arith/Utils/Utils.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 namespace mlir { 22 namespace affine { 23 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE 24 #include "mlir/Dialect/Affine/Passes.h.inc" 25 } // namespace affine 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::affine; 30 31 namespace { 32 /// Lowers `affine.delinearize_index` into a sequence of division and remainder 33 /// operations. 34 struct LowerDelinearizeIndexOps 35 : public OpRewritePattern<AffineDelinearizeIndexOp> { 36 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; 37 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, 38 PatternRewriter &rewriter) const override { 39 FailureOr<SmallVector<Value>> multiIndex = 40 delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), 41 op.getEffectiveBasis(), /*hasOuterBound=*/false); 42 if (failed(multiIndex)) 43 return failure(); 44 rewriter.replaceOp(op, *multiIndex); 45 return success(); 46 } 47 }; 48 49 /// Lowers `affine.linearize_index` into a sequence of multiplications and 50 /// additions. 51 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { 52 using OpRewritePattern::OpRewritePattern; 53 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, 54 PatternRewriter &rewriter) const override { 55 // Should be folded away, included here for safety. 56 if (op.getMultiIndex().empty()) { 57 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); 58 return success(); 59 } 60 61 SmallVector<OpFoldResult> multiIndex = 62 getAsOpFoldResult(op.getMultiIndex()); 63 OpFoldResult linearIndex = 64 linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis()); 65 Value linearIndexValue = 66 getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex); 67 rewriter.replaceOp(op, linearIndexValue); 68 return success(); 69 } 70 }; 71 72 class ExpandAffineIndexOpsAsAffinePass 73 : public affine::impl::AffineExpandIndexOpsAsAffineBase< 74 ExpandAffineIndexOpsAsAffinePass> { 75 public: 76 ExpandAffineIndexOpsAsAffinePass() = default; 77 78 void runOnOperation() override { 79 MLIRContext *context = &getContext(); 80 RewritePatternSet patterns(context); 81 populateAffineExpandIndexOpsAsAffinePatterns(patterns); 82 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 83 return signalPassFailure(); 84 } 85 }; 86 87 } // namespace 88 89 void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns( 90 RewritePatternSet &patterns) { 91 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( 92 patterns.getContext()); 93 } 94 95 std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsAsAffinePass() { 96 return std::make_unique<ExpandAffineIndexOpsAsAffinePass>(); 97 } 98