1d18ffd61SMatthias Springer //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2d18ffd61SMatthias Springer // 3d18ffd61SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d18ffd61SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5d18ffd61SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d18ffd61SMatthias Springer // 7d18ffd61SMatthias Springer //===----------------------------------------------------------------------===// 8d18ffd61SMatthias Springer // 9d18ffd61SMatthias Springer // This file contains cross-dialect canonicalization patterns that cannot be 10d18ffd61SMatthias Springer // actual canonicalization patterns due to undesired additional dependencies. 11d18ffd61SMatthias Springer // 12d18ffd61SMatthias Springer //===----------------------------------------------------------------------===// 13d18ffd61SMatthias Springer 1467d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 1567d0d7acSMichele Scuttari 16d18ffd61SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 17d18ffd61SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 194a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h" 20f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 21d18ffd61SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 22d18ffd61SMatthias Springer #include "mlir/IR/PatternMatch.h" 23d18ffd61SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 244fa6c273SMatthias Springer #include "llvm/ADT/TypeSwitch.h" 25d18ffd61SMatthias Springer 2667d0d7acSMichele Scuttari namespace mlir { 2767d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION 2867d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 2967d0d7acSMichele Scuttari } // namespace mlir 3067d0d7acSMichele Scuttari 31d18ffd61SMatthias Springer using namespace mlir; 32d18ffd61SMatthias Springer using namespace mlir::scf; 33d18ffd61SMatthias Springer 344fa6c273SMatthias Springer /// A simple, conservative analysis to determine if the loop is shape 354fa6c273SMatthias Springer /// conserving. I.e., the type of the arg-th yielded value is the same as the 364fa6c273SMatthias Springer /// type of the corresponding basic block argument of the loop. 374fa6c273SMatthias Springer /// Note: This function handles only simple cases. Expand as needed. 384fa6c273SMatthias Springer static bool isShapePreserving(ForOp forOp, int64_t arg) { 39ab737a86SMatthias Springer assert(arg < static_cast<int64_t>(forOp.getNumResults()) && 404fa6c273SMatthias Springer "arg is out of bounds"); 41ab737a86SMatthias Springer Value value = forOp.getYieldedValues()[arg]; 424fa6c273SMatthias Springer while (value) { 434fa6c273SMatthias Springer if (value == forOp.getRegionIterArgs()[arg]) 444fa6c273SMatthias Springer return true; 455550c821STres Popp OpResult opResult = dyn_cast<OpResult>(value); 464fa6c273SMatthias Springer if (!opResult) 474fa6c273SMatthias Springer return false; 484fa6c273SMatthias Springer 494fa6c273SMatthias Springer using tensor::InsertSliceOp; 505cf714bbSMatthias Springer value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 514fa6c273SMatthias Springer .template Case<InsertSliceOp>( 5204235d07SJacques Pienaar [&](InsertSliceOp op) { return op.getDest(); }) 534fa6c273SMatthias Springer .template Case<ForOp>([&](ForOp forOp) { 544fa6c273SMatthias Springer return isShapePreserving(forOp, opResult.getResultNumber()) 555cf714bbSMatthias Springer ? forOp.getInitArgs()[opResult.getResultNumber()] 564fa6c273SMatthias Springer : Value(); 574fa6c273SMatthias Springer }) 584fa6c273SMatthias Springer .Default([&](auto op) { return Value(); }); 594fa6c273SMatthias Springer } 604fa6c273SMatthias Springer return false; 614fa6c273SMatthias Springer } 624fa6c273SMatthias Springer 63c7d569b8SMatthias Springer namespace { 64c7d569b8SMatthias Springer /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 65c7d569b8SMatthias Springer /// 66c7d569b8SMatthias Springer /// ``` 67c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 68c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 69c7d569b8SMatthias Springer /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 70c7d569b8SMatthias Springer /// ... 71c7d569b8SMatthias Springer /// } 72c7d569b8SMatthias Springer /// ``` 73c7d569b8SMatthias Springer /// 74c7d569b8SMatthias Springer /// is folded to: 75c7d569b8SMatthias Springer /// 76c7d569b8SMatthias Springer /// ``` 77c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 78c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 79c7d569b8SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 80c7d569b8SMatthias Springer /// ... 81c7d569b8SMatthias Springer /// } 82c7d569b8SMatthias Springer /// ``` 83c7d569b8SMatthias Springer /// 84c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of 85c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations. 86c7d569b8SMatthias Springer template <typename OpTy> 87c7d569b8SMatthias Springer struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 88c7d569b8SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 89c7d569b8SMatthias Springer 90d18ffd61SMatthias Springer LogicalResult matchAndRewrite(OpTy dimOp, 91d18ffd61SMatthias Springer PatternRewriter &rewriter) const override { 925550c821STres Popp auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource()); 93d18ffd61SMatthias Springer if (!blockArg) 94d18ffd61SMatthias Springer return failure(); 95d18ffd61SMatthias Springer auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 96d18ffd61SMatthias Springer if (!forOp) 97d18ffd61SMatthias Springer return failure(); 984fa6c273SMatthias Springer if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) 994fa6c273SMatthias Springer return failure(); 100d18ffd61SMatthias Springer 1013cd2a0bcSMatthias Springer Value initArg = forOp.getTiedLoopInit(blockArg)->get(); 1025fcf907bSMatthias Springer rewriter.modifyOpInPlace( 10304235d07SJacques Pienaar dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); 104d18ffd61SMatthias Springer 105d18ffd61SMatthias Springer return success(); 106d18ffd61SMatthias Springer }; 107d18ffd61SMatthias Springer }; 108d18ffd61SMatthias Springer 109c7d569b8SMatthias Springer /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: 110c7d569b8SMatthias Springer /// 111c7d569b8SMatthias Springer /// ``` 112c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 113c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 114c7d569b8SMatthias Springer /// ... 115c7d569b8SMatthias Springer /// } 116c7d569b8SMatthias Springer /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32> 117c7d569b8SMatthias Springer /// ``` 118c7d569b8SMatthias Springer /// 119c7d569b8SMatthias Springer /// is folded to: 120c7d569b8SMatthias Springer /// 121c7d569b8SMatthias Springer /// ``` 122c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 123c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 124c7d569b8SMatthias Springer /// ... 125c7d569b8SMatthias Springer /// } 126c7d569b8SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 127c7d569b8SMatthias Springer /// ``` 128c7d569b8SMatthias Springer /// 129c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of 130c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations. 131c7d569b8SMatthias Springer template <typename OpTy> 132c7d569b8SMatthias Springer struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> { 133c7d569b8SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 134c7d569b8SMatthias Springer 135c7d569b8SMatthias Springer LogicalResult matchAndRewrite(OpTy dimOp, 136c7d569b8SMatthias Springer PatternRewriter &rewriter) const override { 13704235d07SJacques Pienaar auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>(); 138c7d569b8SMatthias Springer if (!forOp) 139c7d569b8SMatthias Springer return failure(); 1405550c821STres Popp auto opResult = cast<OpResult>(dimOp.getSource()); 141c7d569b8SMatthias Springer unsigned resultNumber = opResult.getResultNumber(); 142c7d569b8SMatthias Springer if (!isShapePreserving(forOp, resultNumber)) 143c7d569b8SMatthias Springer return failure(); 1445fcf907bSMatthias Springer rewriter.modifyOpInPlace(dimOp, [&]() { 1455cf714bbSMatthias Springer dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]); 146c7d569b8SMatthias Springer }); 147c7d569b8SMatthias Springer return success(); 148c7d569b8SMatthias Springer } 149c7d569b8SMatthias Springer }; 150c7d569b8SMatthias Springer 151d18ffd61SMatthias Springer /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 152d18ffd61SMatthias Springer /// and scf.parallel loops with a known range. 1533a5811a3SMatthias Springer template <typename OpTy> 154d18ffd61SMatthias Springer struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 155d18ffd61SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 156d18ffd61SMatthias Springer 157d18ffd61SMatthias Springer LogicalResult matchAndRewrite(OpTy op, 158d18ffd61SMatthias Springer PatternRewriter &rewriter) const override { 15993d640f3SMatthias Springer return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop); 160d18ffd61SMatthias Springer } 161d18ffd61SMatthias Springer }; 162d18ffd61SMatthias Springer 163039b969bSMichele Scuttari struct SCFForLoopCanonicalization 16467d0d7acSMichele Scuttari : public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 16541574554SRiver Riddle void runOnOperation() override { 16654998986SStella Laurenzo auto *parentOp = getOperation(); 16754998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext(); 168d18ffd61SMatthias Springer RewritePatternSet patterns(ctx); 169d18ffd61SMatthias Springer scf::populateSCFForLoopCanonicalizationPatterns(patterns); 170*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(parentOp, std::move(patterns)))) 171d18ffd61SMatthias Springer signalPassFailure(); 172d18ffd61SMatthias Springer } 173d18ffd61SMatthias Springer }; 174d18ffd61SMatthias Springer } // namespace 175d18ffd61SMatthias Springer 176d18ffd61SMatthias Springer void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 177d18ffd61SMatthias Springer RewritePatternSet &patterns) { 178d18ffd61SMatthias Springer MLIRContext *ctx = patterns.getContext(); 179d18ffd61SMatthias Springer patterns 1804c48f016SMatthias Springer .add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>, 1814c48f016SMatthias Springer AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>, 182b4e0507cSTres Popp DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>, 183c7d569b8SMatthias Springer DimOfLoopResultFolder<tensor::DimOp>, 184c7d569b8SMatthias Springer DimOfLoopResultFolder<memref::DimOp>>(ctx); 185d18ffd61SMatthias Springer } 186039b969bSMichele Scuttari 187039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 188039b969bSMichele Scuttari return std::make_unique<SCFForLoopCanonicalization>(); 189039b969bSMichele Scuttari } 190