xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1d18ffd61SMatthias Springer //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2d18ffd61SMatthias Springer //
3d18ffd61SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d18ffd61SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5d18ffd61SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d18ffd61SMatthias Springer //
7d18ffd61SMatthias Springer //===----------------------------------------------------------------------===//
8d18ffd61SMatthias Springer //
9d18ffd61SMatthias Springer // This file contains cross-dialect canonicalization patterns that cannot be
10d18ffd61SMatthias Springer // actual canonicalization patterns due to undesired additional dependencies.
11d18ffd61SMatthias Springer //
12d18ffd61SMatthias Springer //===----------------------------------------------------------------------===//
13d18ffd61SMatthias Springer 
1467d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
1567d0d7acSMichele Scuttari 
16d18ffd61SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
17d18ffd61SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
194a6b31b8SAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h"
20f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
21d18ffd61SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
22d18ffd61SMatthias Springer #include "mlir/IR/PatternMatch.h"
23d18ffd61SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
244fa6c273SMatthias Springer #include "llvm/ADT/TypeSwitch.h"
25d18ffd61SMatthias Springer 
2667d0d7acSMichele Scuttari namespace mlir {
2767d0d7acSMichele Scuttari #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
2867d0d7acSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
2967d0d7acSMichele Scuttari } // namespace mlir
3067d0d7acSMichele Scuttari 
31d18ffd61SMatthias Springer using namespace mlir;
32d18ffd61SMatthias Springer using namespace mlir::scf;
33d18ffd61SMatthias Springer 
344fa6c273SMatthias Springer /// A simple, conservative analysis to determine if the loop is shape
354fa6c273SMatthias Springer /// conserving. I.e., the type of the arg-th yielded value is the same as the
364fa6c273SMatthias Springer /// type of the corresponding basic block argument of the loop.
374fa6c273SMatthias Springer /// Note: This function handles only simple cases. Expand as needed.
384fa6c273SMatthias Springer static bool isShapePreserving(ForOp forOp, int64_t arg) {
39ab737a86SMatthias Springer   assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
404fa6c273SMatthias Springer          "arg is out of bounds");
41ab737a86SMatthias Springer   Value value = forOp.getYieldedValues()[arg];
424fa6c273SMatthias Springer   while (value) {
434fa6c273SMatthias Springer     if (value == forOp.getRegionIterArgs()[arg])
444fa6c273SMatthias Springer       return true;
455550c821STres Popp     OpResult opResult = dyn_cast<OpResult>(value);
464fa6c273SMatthias Springer     if (!opResult)
474fa6c273SMatthias Springer       return false;
484fa6c273SMatthias Springer 
494fa6c273SMatthias Springer     using tensor::InsertSliceOp;
505cf714bbSMatthias Springer     value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
514fa6c273SMatthias Springer                 .template Case<InsertSliceOp>(
5204235d07SJacques Pienaar                     [&](InsertSliceOp op) { return op.getDest(); })
534fa6c273SMatthias Springer                 .template Case<ForOp>([&](ForOp forOp) {
544fa6c273SMatthias Springer                   return isShapePreserving(forOp, opResult.getResultNumber())
555cf714bbSMatthias Springer                              ? forOp.getInitArgs()[opResult.getResultNumber()]
564fa6c273SMatthias Springer                              : Value();
574fa6c273SMatthias Springer                 })
584fa6c273SMatthias Springer                 .Default([&](auto op) { return Value(); });
594fa6c273SMatthias Springer   }
604fa6c273SMatthias Springer   return false;
614fa6c273SMatthias Springer }
624fa6c273SMatthias Springer 
63c7d569b8SMatthias Springer namespace {
64c7d569b8SMatthias Springer /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
65c7d569b8SMatthias Springer ///
66c7d569b8SMatthias Springer /// ```
67c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
68c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
69c7d569b8SMatthias Springer ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
70c7d569b8SMatthias Springer ///   ...
71c7d569b8SMatthias Springer /// }
72c7d569b8SMatthias Springer /// ```
73c7d569b8SMatthias Springer ///
74c7d569b8SMatthias Springer /// is folded to:
75c7d569b8SMatthias Springer ///
76c7d569b8SMatthias Springer /// ```
77c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
78c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
79c7d569b8SMatthias Springer ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
80c7d569b8SMatthias Springer ///   ...
81c7d569b8SMatthias Springer /// }
82c7d569b8SMatthias Springer /// ```
83c7d569b8SMatthias Springer ///
84c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of
85c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations.
86c7d569b8SMatthias Springer template <typename OpTy>
87c7d569b8SMatthias Springer struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
88c7d569b8SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
89c7d569b8SMatthias Springer 
90d18ffd61SMatthias Springer   LogicalResult matchAndRewrite(OpTy dimOp,
91d18ffd61SMatthias Springer                                 PatternRewriter &rewriter) const override {
925550c821STres Popp     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
93d18ffd61SMatthias Springer     if (!blockArg)
94d18ffd61SMatthias Springer       return failure();
95d18ffd61SMatthias Springer     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
96d18ffd61SMatthias Springer     if (!forOp)
97d18ffd61SMatthias Springer       return failure();
984fa6c273SMatthias Springer     if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
994fa6c273SMatthias Springer       return failure();
100d18ffd61SMatthias Springer 
1013cd2a0bcSMatthias Springer     Value initArg = forOp.getTiedLoopInit(blockArg)->get();
1025fcf907bSMatthias Springer     rewriter.modifyOpInPlace(
10304235d07SJacques Pienaar         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
104d18ffd61SMatthias Springer 
105d18ffd61SMatthias Springer     return success();
106d18ffd61SMatthias Springer   };
107d18ffd61SMatthias Springer };
108d18ffd61SMatthias Springer 
109c7d569b8SMatthias Springer /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
110c7d569b8SMatthias Springer ///
111c7d569b8SMatthias Springer /// ```
112c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
113c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
114c7d569b8SMatthias Springer ///   ...
115c7d569b8SMatthias Springer /// }
116c7d569b8SMatthias Springer /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
117c7d569b8SMatthias Springer /// ```
118c7d569b8SMatthias Springer ///
119c7d569b8SMatthias Springer /// is folded to:
120c7d569b8SMatthias Springer ///
121c7d569b8SMatthias Springer /// ```
122c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
123c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
124c7d569b8SMatthias Springer ///   ...
125c7d569b8SMatthias Springer /// }
126c7d569b8SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
127c7d569b8SMatthias Springer /// ```
128c7d569b8SMatthias Springer ///
129c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of
130c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations.
131c7d569b8SMatthias Springer template <typename OpTy>
132c7d569b8SMatthias Springer struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
133c7d569b8SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
134c7d569b8SMatthias Springer 
135c7d569b8SMatthias Springer   LogicalResult matchAndRewrite(OpTy dimOp,
136c7d569b8SMatthias Springer                                 PatternRewriter &rewriter) const override {
13704235d07SJacques Pienaar     auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
138c7d569b8SMatthias Springer     if (!forOp)
139c7d569b8SMatthias Springer       return failure();
1405550c821STres Popp     auto opResult = cast<OpResult>(dimOp.getSource());
141c7d569b8SMatthias Springer     unsigned resultNumber = opResult.getResultNumber();
142c7d569b8SMatthias Springer     if (!isShapePreserving(forOp, resultNumber))
143c7d569b8SMatthias Springer       return failure();
1445fcf907bSMatthias Springer     rewriter.modifyOpInPlace(dimOp, [&]() {
1455cf714bbSMatthias Springer       dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
146c7d569b8SMatthias Springer     });
147c7d569b8SMatthias Springer     return success();
148c7d569b8SMatthias Springer   }
149c7d569b8SMatthias Springer };
150c7d569b8SMatthias Springer 
151d18ffd61SMatthias Springer /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
152d18ffd61SMatthias Springer /// and scf.parallel loops with a known range.
1533a5811a3SMatthias Springer template <typename OpTy>
154d18ffd61SMatthias Springer struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
155d18ffd61SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
156d18ffd61SMatthias Springer 
157d18ffd61SMatthias Springer   LogicalResult matchAndRewrite(OpTy op,
158d18ffd61SMatthias Springer                                 PatternRewriter &rewriter) const override {
15993d640f3SMatthias Springer     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop);
160d18ffd61SMatthias Springer   }
161d18ffd61SMatthias Springer };
162d18ffd61SMatthias Springer 
163039b969bSMichele Scuttari struct SCFForLoopCanonicalization
16467d0d7acSMichele Scuttari     : public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
16541574554SRiver Riddle   void runOnOperation() override {
16654998986SStella Laurenzo     auto *parentOp = getOperation();
16754998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
168d18ffd61SMatthias Springer     RewritePatternSet patterns(ctx);
169d18ffd61SMatthias Springer     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
170*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(parentOp, std::move(patterns))))
171d18ffd61SMatthias Springer       signalPassFailure();
172d18ffd61SMatthias Springer   }
173d18ffd61SMatthias Springer };
174d18ffd61SMatthias Springer } // namespace
175d18ffd61SMatthias Springer 
176d18ffd61SMatthias Springer void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
177d18ffd61SMatthias Springer     RewritePatternSet &patterns) {
178d18ffd61SMatthias Springer   MLIRContext *ctx = patterns.getContext();
179d18ffd61SMatthias Springer   patterns
1804c48f016SMatthias Springer       .add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
1814c48f016SMatthias Springer            AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
182b4e0507cSTres Popp            DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
183c7d569b8SMatthias Springer            DimOfLoopResultFolder<tensor::DimOp>,
184c7d569b8SMatthias Springer            DimOfLoopResultFolder<memref::DimOp>>(ctx);
185d18ffd61SMatthias Springer }
186039b969bSMichele Scuttari 
187039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
188039b969bSMichele Scuttari   return std::make_unique<SCFForLoopCanonicalization>();
189039b969bSMichele Scuttari }
190