1d18ffd61SMatthias Springer //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// 2d18ffd61SMatthias Springer // 3d18ffd61SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d18ffd61SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5d18ffd61SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d18ffd61SMatthias Springer // 7d18ffd61SMatthias Springer //===----------------------------------------------------------------------===// 8d18ffd61SMatthias Springer // 9d18ffd61SMatthias Springer // This file contains cross-dialect canonicalization patterns that cannot be 10d18ffd61SMatthias Springer // actual canonicalization patterns due to undesired additional dependencies. 11d18ffd61SMatthias Springer // 12d18ffd61SMatthias Springer //===----------------------------------------------------------------------===// 13d18ffd61SMatthias Springer 14*039b969bSMichele Scuttari #include "PassDetail.h" 15d18ffd61SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 16d18ffd61SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 18*039b969bSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h" 198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 20f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 21d18ffd61SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 22d18ffd61SMatthias Springer #include "mlir/IR/PatternMatch.h" 23d18ffd61SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 244fa6c273SMatthias Springer #include "llvm/ADT/TypeSwitch.h" 25d18ffd61SMatthias Springer 26d18ffd61SMatthias Springer using namespace mlir; 27d18ffd61SMatthias Springer using namespace mlir::scf; 28d18ffd61SMatthias Springer 294fa6c273SMatthias Springer /// A simple, conservative analysis to determine if the loop is shape 304fa6c273SMatthias Springer /// conserving. I.e., the type of the arg-th yielded value is the same as the 314fa6c273SMatthias Springer /// type of the corresponding basic block argument of the loop. 324fa6c273SMatthias Springer /// Note: This function handles only simple cases. Expand as needed. 334fa6c273SMatthias Springer static bool isShapePreserving(ForOp forOp, int64_t arg) { 344fa6c273SMatthias Springer auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator()); 35c0342a2dSJacques Pienaar assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) && 364fa6c273SMatthias Springer "arg is out of bounds"); 37c0342a2dSJacques Pienaar Value value = yieldOp.getResults()[arg]; 384fa6c273SMatthias Springer while (value) { 394fa6c273SMatthias Springer if (value == forOp.getRegionIterArgs()[arg]) 404fa6c273SMatthias Springer return true; 414fa6c273SMatthias Springer OpResult opResult = value.dyn_cast<OpResult>(); 424fa6c273SMatthias Springer if (!opResult) 434fa6c273SMatthias Springer return false; 444fa6c273SMatthias Springer 454fa6c273SMatthias Springer using tensor::InsertSliceOp; 464fa6c273SMatthias Springer value = 474fa6c273SMatthias Springer llvm::TypeSwitch<Operation *, Value>(opResult.getOwner()) 484fa6c273SMatthias Springer .template Case<InsertSliceOp>( 4904235d07SJacques Pienaar [&](InsertSliceOp op) { return op.getDest(); }) 504fa6c273SMatthias Springer .template Case<ForOp>([&](ForOp forOp) { 514fa6c273SMatthias Springer return isShapePreserving(forOp, opResult.getResultNumber()) 524fa6c273SMatthias Springer ? forOp.getIterOperands()[opResult.getResultNumber()] 534fa6c273SMatthias Springer : Value(); 544fa6c273SMatthias Springer }) 554fa6c273SMatthias Springer .Default([&](auto op) { return Value(); }); 564fa6c273SMatthias Springer } 574fa6c273SMatthias Springer return false; 584fa6c273SMatthias Springer } 594fa6c273SMatthias Springer 60c7d569b8SMatthias Springer namespace { 61c7d569b8SMatthias Springer /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: 62c7d569b8SMatthias Springer /// 63c7d569b8SMatthias Springer /// ``` 64c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 65c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 66c7d569b8SMatthias Springer /// %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 67c7d569b8SMatthias Springer /// ... 68c7d569b8SMatthias Springer /// } 69c7d569b8SMatthias Springer /// ``` 70c7d569b8SMatthias Springer /// 71c7d569b8SMatthias Springer /// is folded to: 72c7d569b8SMatthias Springer /// 73c7d569b8SMatthias Springer /// ``` 74c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 75c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 76c7d569b8SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 77c7d569b8SMatthias Springer /// ... 78c7d569b8SMatthias Springer /// } 79c7d569b8SMatthias Springer /// ``` 80c7d569b8SMatthias Springer /// 81c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of 82c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations. 83c7d569b8SMatthias Springer template <typename OpTy> 84c7d569b8SMatthias Springer struct DimOfIterArgFolder : public OpRewritePattern<OpTy> { 85c7d569b8SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 86c7d569b8SMatthias Springer 87d18ffd61SMatthias Springer LogicalResult matchAndRewrite(OpTy dimOp, 88d18ffd61SMatthias Springer PatternRewriter &rewriter) const override { 8904235d07SJacques Pienaar auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>(); 90d18ffd61SMatthias Springer if (!blockArg) 91d18ffd61SMatthias Springer return failure(); 92d18ffd61SMatthias Springer auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp()); 93d18ffd61SMatthias Springer if (!forOp) 94d18ffd61SMatthias Springer return failure(); 954fa6c273SMatthias Springer if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) 964fa6c273SMatthias Springer return failure(); 97d18ffd61SMatthias Springer 98d18ffd61SMatthias Springer Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); 99d18ffd61SMatthias Springer rewriter.updateRootInPlace( 10004235d07SJacques Pienaar dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); 101d18ffd61SMatthias Springer 102d18ffd61SMatthias Springer return success(); 103d18ffd61SMatthias Springer }; 104d18ffd61SMatthias Springer }; 105d18ffd61SMatthias Springer 106c7d569b8SMatthias Springer /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: 107c7d569b8SMatthias Springer /// 108c7d569b8SMatthias Springer /// ``` 109c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 110c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 111c7d569b8SMatthias Springer /// ... 112c7d569b8SMatthias Springer /// } 113c7d569b8SMatthias Springer /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32> 114c7d569b8SMatthias Springer /// ``` 115c7d569b8SMatthias Springer /// 116c7d569b8SMatthias Springer /// is folded to: 117c7d569b8SMatthias Springer /// 118c7d569b8SMatthias Springer /// ``` 119c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32> 120c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) { 121c7d569b8SMatthias Springer /// ... 122c7d569b8SMatthias Springer /// } 123c7d569b8SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32> 124c7d569b8SMatthias Springer /// ``` 125c7d569b8SMatthias Springer /// 126c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of 127c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations. 128c7d569b8SMatthias Springer template <typename OpTy> 129c7d569b8SMatthias Springer struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> { 130c7d569b8SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 131c7d569b8SMatthias Springer 132c7d569b8SMatthias Springer LogicalResult matchAndRewrite(OpTy dimOp, 133c7d569b8SMatthias Springer PatternRewriter &rewriter) const override { 13404235d07SJacques Pienaar auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>(); 135c7d569b8SMatthias Springer if (!forOp) 136c7d569b8SMatthias Springer return failure(); 13704235d07SJacques Pienaar auto opResult = dimOp.getSource().template cast<OpResult>(); 138c7d569b8SMatthias Springer unsigned resultNumber = opResult.getResultNumber(); 139c7d569b8SMatthias Springer if (!isShapePreserving(forOp, resultNumber)) 140c7d569b8SMatthias Springer return failure(); 141c7d569b8SMatthias Springer rewriter.updateRootInPlace(dimOp, [&]() { 14204235d07SJacques Pienaar dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]); 143c7d569b8SMatthias Springer }); 144c7d569b8SMatthias Springer return success(); 145c7d569b8SMatthias Springer } 146c7d569b8SMatthias Springer }; 147c7d569b8SMatthias Springer 148d18ffd61SMatthias Springer /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for 149d18ffd61SMatthias Springer /// and scf.parallel loops with a known range. 150d18ffd61SMatthias Springer template <typename OpTy, bool IsMin> 151d18ffd61SMatthias Springer struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> { 152d18ffd61SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 153d18ffd61SMatthias Springer 154d18ffd61SMatthias Springer LogicalResult matchAndRewrite(OpTy op, 155d18ffd61SMatthias Springer PatternRewriter &rewriter) const override { 156a489aa74SNicolas Vasilache auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub, 157a489aa74SNicolas Vasilache OpFoldResult &step) { 158d18ffd61SMatthias Springer if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { 159c0342a2dSJacques Pienaar lb = forOp.getLowerBound(); 160c0342a2dSJacques Pienaar ub = forOp.getUpperBound(); 161c0342a2dSJacques Pienaar step = forOp.getStep(); 162d18ffd61SMatthias Springer return success(); 163d18ffd61SMatthias Springer } 164d18ffd61SMatthias Springer if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { 165d18ffd61SMatthias Springer for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { 166d18ffd61SMatthias Springer if (parOp.getInductionVars()[idx] == iv) { 167c0342a2dSJacques Pienaar lb = parOp.getLowerBound()[idx]; 168c0342a2dSJacques Pienaar ub = parOp.getUpperBound()[idx]; 169c0342a2dSJacques Pienaar step = parOp.getStep()[idx]; 170d18ffd61SMatthias Springer return success(); 171d18ffd61SMatthias Springer } 172d18ffd61SMatthias Springer } 173d18ffd61SMatthias Springer return failure(); 174d18ffd61SMatthias Springer } 175a489aa74SNicolas Vasilache if (scf::ForeachThreadOp foreachThreadOp = 176a489aa74SNicolas Vasilache scf::getForeachThreadOpThreadIndexOwner(iv)) { 177a489aa74SNicolas Vasilache for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) { 178a489aa74SNicolas Vasilache if (foreachThreadOp.getThreadIndices()[idx] == iv) { 179a489aa74SNicolas Vasilache lb = OpBuilder(iv.getContext()).getIndexAttr(0); 180a489aa74SNicolas Vasilache ub = foreachThreadOp.getNumThreads()[idx]; 181a489aa74SNicolas Vasilache step = OpBuilder(iv.getContext()).getIndexAttr(1); 182a489aa74SNicolas Vasilache return success(); 183a489aa74SNicolas Vasilache } 184a489aa74SNicolas Vasilache } 185a489aa74SNicolas Vasilache return failure(); 186a489aa74SNicolas Vasilache } 187d18ffd61SMatthias Springer return failure(); 188d18ffd61SMatthias Springer }; 189d18ffd61SMatthias Springer 190d18ffd61SMatthias Springer return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), 191d18ffd61SMatthias Springer op.operands(), IsMin, loopMatcher); 192d18ffd61SMatthias Springer } 193d18ffd61SMatthias Springer }; 194d18ffd61SMatthias Springer 195*039b969bSMichele Scuttari struct SCFForLoopCanonicalization 196*039b969bSMichele Scuttari : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> { 19741574554SRiver Riddle void runOnOperation() override { 19854998986SStella Laurenzo auto *parentOp = getOperation(); 19954998986SStella Laurenzo MLIRContext *ctx = parentOp->getContext(); 200d18ffd61SMatthias Springer RewritePatternSet patterns(ctx); 201d18ffd61SMatthias Springer scf::populateSCFForLoopCanonicalizationPatterns(patterns); 20254998986SStella Laurenzo if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) 203d18ffd61SMatthias Springer signalPassFailure(); 204d18ffd61SMatthias Springer } 205d18ffd61SMatthias Springer }; 206d18ffd61SMatthias Springer } // namespace 207d18ffd61SMatthias Springer 208d18ffd61SMatthias Springer void mlir::scf::populateSCFForLoopCanonicalizationPatterns( 209d18ffd61SMatthias Springer RewritePatternSet &patterns) { 210d18ffd61SMatthias Springer MLIRContext *ctx = patterns.getContext(); 211d18ffd61SMatthias Springer patterns 212b4e0507cSTres Popp .add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>, 213d18ffd61SMatthias Springer AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>, 214b4e0507cSTres Popp DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>, 215c7d569b8SMatthias Springer DimOfLoopResultFolder<tensor::DimOp>, 216c7d569b8SMatthias Springer DimOfLoopResultFolder<memref::DimOp>>(ctx); 217d18ffd61SMatthias Springer } 218*039b969bSMichele Scuttari 219*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() { 220*039b969bSMichele Scuttari return std::make_unique<SCFForLoopCanonicalization>(); 221*039b969bSMichele Scuttari } 222