//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains cross-dialect canonicalization patterns that cannot be // actual canonicalization patterns due to undesired additional dependencies. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::scf; /// A simple, conservative analysis to determine if the loop is shape /// conserving. I.e., the type of the arg-th yielded value is the same as the /// type of the corresponding basic block argument of the loop. /// Note: This function handles only simple cases. Expand as needed. static bool isShapePreserving(ForOp forOp, int64_t arg) { assert(arg < static_cast(forOp.getNumResults()) && "arg is out of bounds"); Value value = forOp.getYieldedValues()[arg]; while (value) { if (value == forOp.getRegionIterArgs()[arg]) return true; OpResult opResult = dyn_cast(value); if (!opResult) return false; using tensor::InsertSliceOp; value = llvm::TypeSwitch(opResult.getOwner()) .template Case( [&](InsertSliceOp op) { return op.getDest(); }) .template Case([&](ForOp forOp) { return isShapePreserving(forOp, opResult.getResultNumber()) ? forOp.getInitArgs()[opResult.getResultNumber()] : Value(); }) .Default([&](auto op) { return Value(); }); } return false; } namespace { /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: /// /// ``` /// %0 = ... : tensor /// scf.for ... iter_args(%arg0 = %0) -> (tensor) { /// %1 = tensor.dim %arg0, %c0 : tensor /// ... /// } /// ``` /// /// is folded to: /// /// ``` /// %0 = ... : tensor /// scf.for ... iter_args(%arg0 = %0) -> (tensor) { /// %1 = tensor.dim %0, %c0 : tensor /// ... /// } /// ``` /// /// Note: Dim ops are folded only if it can be proven that the runtime type of /// the iter arg does not change with loop iterations. template struct DimOfIterArgFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { auto blockArg = dyn_cast(dimOp.getSource()); if (!blockArg) return failure(); auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); if (!forOp) return failure(); if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1)) return failure(); Value initArg = forOp.getTiedLoopInit(blockArg)->get(); rewriter.modifyOpInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); return success(); }; }; /// Fold dim ops of loop results to dim ops of their respective init args. E.g.: /// /// ``` /// %0 = ... : tensor /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor) { /// ... /// } /// %1 = tensor.dim %r, %c0 : tensor /// ``` /// /// is folded to: /// /// ``` /// %0 = ... : tensor /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor) { /// ... /// } /// %1 = tensor.dim %0, %c0 : tensor /// ``` /// /// Note: Dim ops are folded only if it can be proven that the runtime type of /// the iter arg does not change with loop iterations. template struct DimOfLoopResultFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy dimOp, PatternRewriter &rewriter) const override { auto forOp = dimOp.getSource().template getDefiningOp(); if (!forOp) return failure(); auto opResult = cast(dimOp.getSource()); unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(forOp, resultNumber)) return failure(); rewriter.modifyOpInPlace(dimOp, [&]() { dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]); }); return success(); } }; /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for /// and scf.parallel loops with a known range. template struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop); } }; struct SCFForLoopCanonicalization : public impl::SCFForLoopCanonicalizationBase { void runOnOperation() override { auto *parentOp = getOperation(); MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); scf::populateSCFForLoopCanonicalizationPatterns(patterns); if (failed(applyPatternsGreedily(parentOp, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::scf::populateSCFForLoopCanonicalizationPatterns( RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); patterns .add, AffineOpSCFCanonicalizationPattern, DimOfIterArgFolder, DimOfIterArgFolder, DimOfLoopResultFolder, DimOfLoopResultFolder>(ctx); } std::unique_ptr mlir::createSCFForLoopCanonicalizationPass() { return std::make_unique(); }