1*d18ffd61SMatthias Springer //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2*d18ffd61SMatthias Springer // 3*d18ffd61SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*d18ffd61SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5*d18ffd61SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*d18ffd61SMatthias Springer // 7*d18ffd61SMatthias Springer //===----------------------------------------------------------------------===// 8*d18ffd61SMatthias Springer // 9*d18ffd61SMatthias Springer // This file contains cross-dialect canonicalization patterns that cannot be 10*d18ffd61SMatthias Springer // actual canonicalization patterns due to undesired additional dependencies. 11*d18ffd61SMatthias Springer // 12*d18ffd61SMatthias Springer //===----------------------------------------------------------------------===// 13*d18ffd61SMatthias Springer 14*d18ffd61SMatthias Springer #include "PassDetail.h" 15*d18ffd61SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 16*d18ffd61SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 17*d18ffd61SMatthias Springer #include "mlir/Dialect/SCF/Passes.h" 18*d18ffd61SMatthias Springer #include "mlir/Dialect/SCF/SCF.h" 19*d18ffd61SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h" 20*d18ffd61SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 21*d18ffd61SMatthias Springer #include "mlir/IR/PatternMatch.h" 22*d18ffd61SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23*d18ffd61SMatthias Springer 24*d18ffd61SMatthias Springer using namespace mlir; 25*d18ffd61SMatthias Springer using namespace mlir::scf; 26*d18ffd61SMatthias Springer 27*d18ffd61SMatthias Springer namespace { 28*d18ffd61SMatthias Springer /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 29*d18ffd61SMatthias Springer /// 30*d18ffd61SMatthias Springer /// ``` 31*d18ffd61SMatthias Springer /// %0 = ... : tensor<?x?xf32> 32*d18ffd61SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 33*d18ffd61SMatthias Springer /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 34*d18ffd61SMatthias Springer /// ... 35*d18ffd61SMatthias Springer /// } 36*d18ffd61SMatthias Springer /// ``` 37*d18ffd61SMatthias Springer /// 38*d18ffd61SMatthias Springer /// is folded to: 39*d18ffd61SMatthias Springer /// 40*d18ffd61SMatthias Springer /// ``` 41*d18ffd61SMatthias Springer /// %0 = ... : tensor<?x?xf32> 42*d18ffd61SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 43*d18ffd61SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 44*d18ffd61SMatthias Springer /// ... 45*d18ffd61SMatthias Springer /// } 46*d18ffd61SMatthias Springer /// ``` 47*d18ffd61SMatthias Springer template <typename OpTy> 48*d18ffd61SMatthias Springer struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 49*d18ffd61SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 50*d18ffd61SMatthias Springer 51*d18ffd61SMatthias Springer LogicalResult matchAndRewrite(OpTy dimOp, 52*d18ffd61SMatthias Springer PatternRewriter &rewriter) const override { 53*d18ffd61SMatthias Springer auto blockArg = dimOp.source().template dyn_cast<BlockArgument>(); 54*d18ffd61SMatthias Springer if (!blockArg) 55*d18ffd61SMatthias Springer return failure(); 56*d18ffd61SMatthias Springer auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 57*d18ffd61SMatthias Springer if (!forOp) 58*d18ffd61SMatthias Springer return failure(); 59*d18ffd61SMatthias Springer 60*d18ffd61SMatthias Springer Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 61*d18ffd61SMatthias Springer rewriter.updateRootInPlace( 62*d18ffd61SMatthias Springer dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); 63*d18ffd61SMatthias Springer 64*d18ffd61SMatthias Springer return success(); 65*d18ffd61SMatthias Springer }; 66*d18ffd61SMatthias Springer }; 67*d18ffd61SMatthias Springer 68*d18ffd61SMatthias Springer /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 69*d18ffd61SMatthias Springer /// and scf.parallel loops with a known range. 70*d18ffd61SMatthias Springer template <typename OpTy, bool IsMin> 71*d18ffd61SMatthias Springer struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 72*d18ffd61SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 73*d18ffd61SMatthias Springer 74*d18ffd61SMatthias Springer LogicalResult matchAndRewrite(OpTy op, 75*d18ffd61SMatthias Springer PatternRewriter &rewriter) const override { 76*d18ffd61SMatthias Springer auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { 77*d18ffd61SMatthias Springer if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 78*d18ffd61SMatthias Springer lb = forOp.lowerBound(); 79*d18ffd61SMatthias Springer ub = forOp.upperBound(); 80*d18ffd61SMatthias Springer step = forOp.step(); 81*d18ffd61SMatthias Springer return success(); 82*d18ffd61SMatthias Springer } 83*d18ffd61SMatthias Springer if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 84*d18ffd61SMatthias Springer for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 85*d18ffd61SMatthias Springer if (parOp.getInductionVars()[idx] == iv) { 86*d18ffd61SMatthias Springer lb = parOp.lowerBound()[idx]; 87*d18ffd61SMatthias Springer ub = parOp.upperBound()[idx]; 88*d18ffd61SMatthias Springer step = parOp.step()[idx]; 89*d18ffd61SMatthias Springer return success(); 90*d18ffd61SMatthias Springer } 91*d18ffd61SMatthias Springer } 92*d18ffd61SMatthias Springer return failure(); 93*d18ffd61SMatthias Springer } 94*d18ffd61SMatthias Springer return failure(); 95*d18ffd61SMatthias Springer }; 96*d18ffd61SMatthias Springer 97*d18ffd61SMatthias Springer return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 98*d18ffd61SMatthias Springer op.operands(), IsMin, loopMatcher); 99*d18ffd61SMatthias Springer } 100*d18ffd61SMatthias Springer }; 101*d18ffd61SMatthias Springer 102*d18ffd61SMatthias Springer struct SCFForLoopCanonicalization 103*d18ffd61SMatthias Springer : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 104*d18ffd61SMatthias Springer void runOnFunction() override { 105*d18ffd61SMatthias Springer FuncOp funcOp = getFunction(); 106*d18ffd61SMatthias Springer MLIRContext *ctx = funcOp.getContext(); 107*d18ffd61SMatthias Springer RewritePatternSet patterns(ctx); 108*d18ffd61SMatthias Springer scf::populateSCFForLoopCanonicalizationPatterns(patterns); 109*d18ffd61SMatthias Springer if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) 110*d18ffd61SMatthias Springer signalPassFailure(); 111*d18ffd61SMatthias Springer } 112*d18ffd61SMatthias Springer }; 113*d18ffd61SMatthias Springer } // namespace 114*d18ffd61SMatthias Springer 115*d18ffd61SMatthias Springer void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 116*d18ffd61SMatthias Springer RewritePatternSet &patterns) { 117*d18ffd61SMatthias Springer MLIRContext *ctx = patterns.getContext(); 118*d18ffd61SMatthias Springer patterns 119*d18ffd61SMatthias Springer .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 120*d18ffd61SMatthias Springer AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 121*d18ffd61SMatthias Springer DimOfIterArgFolder<tensor::DimOp>, 122*d18ffd61SMatthias Springer DimOfIterArgFolder<memref::DimOp>>(ctx); 123*d18ffd61SMatthias Springer } 124*d18ffd61SMatthias Springer 125*d18ffd61SMatthias Springer std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 126*d18ffd61SMatthias Springer return std::make_unique<SCFForLoopCanonicalization>(); 127*d18ffd61SMatthias Springer } 128