xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1d18ffd61SMatthias Springer //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2d18ffd61SMatthias Springer //
3d18ffd61SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d18ffd61SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5d18ffd61SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d18ffd61SMatthias Springer //
7d18ffd61SMatthias Springer //===----------------------------------------------------------------------===//
8d18ffd61SMatthias Springer //
9d18ffd61SMatthias Springer // This file contains cross-dialect canonicalization patterns that cannot be
10d18ffd61SMatthias Springer // actual canonicalization patterns due to undesired additional dependencies.
11d18ffd61SMatthias Springer //
12d18ffd61SMatthias Springer //===----------------------------------------------------------------------===//
13d18ffd61SMatthias Springer 
14*039b969bSMichele Scuttari #include "PassDetail.h"
15d18ffd61SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
16d18ffd61SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
18*039b969bSMichele Scuttari #include "mlir/Dialect/SCF/Transforms/Passes.h"
198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h"
20f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
21d18ffd61SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
22d18ffd61SMatthias Springer #include "mlir/IR/PatternMatch.h"
23d18ffd61SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
244fa6c273SMatthias Springer #include "llvm/ADT/TypeSwitch.h"
25d18ffd61SMatthias Springer 
26d18ffd61SMatthias Springer using namespace mlir;
27d18ffd61SMatthias Springer using namespace mlir::scf;
28d18ffd61SMatthias Springer 
294fa6c273SMatthias Springer /// A simple, conservative analysis to determine if the loop is shape
304fa6c273SMatthias Springer /// conserving. I.e., the type of the arg-th yielded value is the same as the
314fa6c273SMatthias Springer /// type of the corresponding basic block argument of the loop.
324fa6c273SMatthias Springer /// Note: This function handles only simple cases. Expand as needed.
334fa6c273SMatthias Springer static bool isShapePreserving(ForOp forOp, int64_t arg) {
344fa6c273SMatthias Springer   auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
35c0342a2dSJacques Pienaar   assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
364fa6c273SMatthias Springer          "arg is out of bounds");
37c0342a2dSJacques Pienaar   Value value = yieldOp.getResults()[arg];
384fa6c273SMatthias Springer   while (value) {
394fa6c273SMatthias Springer     if (value == forOp.getRegionIterArgs()[arg])
404fa6c273SMatthias Springer       return true;
414fa6c273SMatthias Springer     OpResult opResult = value.dyn_cast<OpResult>();
424fa6c273SMatthias Springer     if (!opResult)
434fa6c273SMatthias Springer       return false;
444fa6c273SMatthias Springer 
454fa6c273SMatthias Springer     using tensor::InsertSliceOp;
464fa6c273SMatthias Springer     value =
474fa6c273SMatthias Springer         llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
484fa6c273SMatthias Springer             .template Case<InsertSliceOp>(
4904235d07SJacques Pienaar                 [&](InsertSliceOp op) { return op.getDest(); })
504fa6c273SMatthias Springer             .template Case<ForOp>([&](ForOp forOp) {
514fa6c273SMatthias Springer               return isShapePreserving(forOp, opResult.getResultNumber())
524fa6c273SMatthias Springer                          ? forOp.getIterOperands()[opResult.getResultNumber()]
534fa6c273SMatthias Springer                          : Value();
544fa6c273SMatthias Springer             })
554fa6c273SMatthias Springer             .Default([&](auto op) { return Value(); });
564fa6c273SMatthias Springer   }
574fa6c273SMatthias Springer   return false;
584fa6c273SMatthias Springer }
594fa6c273SMatthias Springer 
60c7d569b8SMatthias Springer namespace {
61c7d569b8SMatthias Springer /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
62c7d569b8SMatthias Springer ///
63c7d569b8SMatthias Springer /// ```
64c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
65c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
66c7d569b8SMatthias Springer ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
67c7d569b8SMatthias Springer ///   ...
68c7d569b8SMatthias Springer /// }
69c7d569b8SMatthias Springer /// ```
70c7d569b8SMatthias Springer ///
71c7d569b8SMatthias Springer /// is folded to:
72c7d569b8SMatthias Springer ///
73c7d569b8SMatthias Springer /// ```
74c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
75c7d569b8SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
76c7d569b8SMatthias Springer ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
77c7d569b8SMatthias Springer ///   ...
78c7d569b8SMatthias Springer /// }
79c7d569b8SMatthias Springer /// ```
80c7d569b8SMatthias Springer ///
81c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of
82c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations.
83c7d569b8SMatthias Springer template <typename OpTy>
84c7d569b8SMatthias Springer struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
85c7d569b8SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
86c7d569b8SMatthias Springer 
87d18ffd61SMatthias Springer   LogicalResult matchAndRewrite(OpTy dimOp,
88d18ffd61SMatthias Springer                                 PatternRewriter &rewriter) const override {
8904235d07SJacques Pienaar     auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>();
90d18ffd61SMatthias Springer     if (!blockArg)
91d18ffd61SMatthias Springer       return failure();
92d18ffd61SMatthias Springer     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
93d18ffd61SMatthias Springer     if (!forOp)
94d18ffd61SMatthias Springer       return failure();
954fa6c273SMatthias Springer     if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
964fa6c273SMatthias Springer       return failure();
97d18ffd61SMatthias Springer 
98d18ffd61SMatthias Springer     Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
99d18ffd61SMatthias Springer     rewriter.updateRootInPlace(
10004235d07SJacques Pienaar         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
101d18ffd61SMatthias Springer 
102d18ffd61SMatthias Springer     return success();
103d18ffd61SMatthias Springer   };
104d18ffd61SMatthias Springer };
105d18ffd61SMatthias Springer 
106c7d569b8SMatthias Springer /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
107c7d569b8SMatthias Springer ///
108c7d569b8SMatthias Springer /// ```
109c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
110c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
111c7d569b8SMatthias Springer ///   ...
112c7d569b8SMatthias Springer /// }
113c7d569b8SMatthias Springer /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
114c7d569b8SMatthias Springer /// ```
115c7d569b8SMatthias Springer ///
116c7d569b8SMatthias Springer /// is folded to:
117c7d569b8SMatthias Springer ///
118c7d569b8SMatthias Springer /// ```
119c7d569b8SMatthias Springer /// %0 = ... : tensor<?x?xf32>
120c7d569b8SMatthias Springer /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
121c7d569b8SMatthias Springer ///   ...
122c7d569b8SMatthias Springer /// }
123c7d569b8SMatthias Springer /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124c7d569b8SMatthias Springer /// ```
125c7d569b8SMatthias Springer ///
126c7d569b8SMatthias Springer /// Note: Dim ops are folded only if it can be proven that the runtime type of
127c7d569b8SMatthias Springer /// the iter arg does not change with loop iterations.
128c7d569b8SMatthias Springer template <typename OpTy>
129c7d569b8SMatthias Springer struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
130c7d569b8SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
131c7d569b8SMatthias Springer 
132c7d569b8SMatthias Springer   LogicalResult matchAndRewrite(OpTy dimOp,
133c7d569b8SMatthias Springer                                 PatternRewriter &rewriter) const override {
13404235d07SJacques Pienaar     auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
135c7d569b8SMatthias Springer     if (!forOp)
136c7d569b8SMatthias Springer       return failure();
13704235d07SJacques Pienaar     auto opResult = dimOp.getSource().template cast<OpResult>();
138c7d569b8SMatthias Springer     unsigned resultNumber = opResult.getResultNumber();
139c7d569b8SMatthias Springer     if (!isShapePreserving(forOp, resultNumber))
140c7d569b8SMatthias Springer       return failure();
141c7d569b8SMatthias Springer     rewriter.updateRootInPlace(dimOp, [&]() {
14204235d07SJacques Pienaar       dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
143c7d569b8SMatthias Springer     });
144c7d569b8SMatthias Springer     return success();
145c7d569b8SMatthias Springer   }
146c7d569b8SMatthias Springer };
147c7d569b8SMatthias Springer 
148d18ffd61SMatthias Springer /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
149d18ffd61SMatthias Springer /// and scf.parallel loops with a known range.
150d18ffd61SMatthias Springer template <typename OpTy, bool IsMin>
151d18ffd61SMatthias Springer struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
152d18ffd61SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
153d18ffd61SMatthias Springer 
154d18ffd61SMatthias Springer   LogicalResult matchAndRewrite(OpTy op,
155d18ffd61SMatthias Springer                                 PatternRewriter &rewriter) const override {
156a489aa74SNicolas Vasilache     auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub,
157a489aa74SNicolas Vasilache                           OpFoldResult &step) {
158d18ffd61SMatthias Springer       if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
159c0342a2dSJacques Pienaar         lb = forOp.getLowerBound();
160c0342a2dSJacques Pienaar         ub = forOp.getUpperBound();
161c0342a2dSJacques Pienaar         step = forOp.getStep();
162d18ffd61SMatthias Springer         return success();
163d18ffd61SMatthias Springer       }
164d18ffd61SMatthias Springer       if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
165d18ffd61SMatthias Springer         for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
166d18ffd61SMatthias Springer           if (parOp.getInductionVars()[idx] == iv) {
167c0342a2dSJacques Pienaar             lb = parOp.getLowerBound()[idx];
168c0342a2dSJacques Pienaar             ub = parOp.getUpperBound()[idx];
169c0342a2dSJacques Pienaar             step = parOp.getStep()[idx];
170d18ffd61SMatthias Springer             return success();
171d18ffd61SMatthias Springer           }
172d18ffd61SMatthias Springer         }
173d18ffd61SMatthias Springer         return failure();
174d18ffd61SMatthias Springer       }
175a489aa74SNicolas Vasilache       if (scf::ForeachThreadOp foreachThreadOp =
176a489aa74SNicolas Vasilache               scf::getForeachThreadOpThreadIndexOwner(iv)) {
177a489aa74SNicolas Vasilache         for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
178a489aa74SNicolas Vasilache           if (foreachThreadOp.getThreadIndices()[idx] == iv) {
179a489aa74SNicolas Vasilache             lb = OpBuilder(iv.getContext()).getIndexAttr(0);
180a489aa74SNicolas Vasilache             ub = foreachThreadOp.getNumThreads()[idx];
181a489aa74SNicolas Vasilache             step = OpBuilder(iv.getContext()).getIndexAttr(1);
182a489aa74SNicolas Vasilache             return success();
183a489aa74SNicolas Vasilache           }
184a489aa74SNicolas Vasilache         }
185a489aa74SNicolas Vasilache         return failure();
186a489aa74SNicolas Vasilache       }
187d18ffd61SMatthias Springer       return failure();
188d18ffd61SMatthias Springer     };
189d18ffd61SMatthias Springer 
190d18ffd61SMatthias Springer     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
191d18ffd61SMatthias Springer                                            op.operands(), IsMin, loopMatcher);
192d18ffd61SMatthias Springer   }
193d18ffd61SMatthias Springer };
194d18ffd61SMatthias Springer 
195*039b969bSMichele Scuttari struct SCFForLoopCanonicalization
196*039b969bSMichele Scuttari     : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
19741574554SRiver Riddle   void runOnOperation() override {
19854998986SStella Laurenzo     auto *parentOp = getOperation();
19954998986SStella Laurenzo     MLIRContext *ctx = parentOp->getContext();
200d18ffd61SMatthias Springer     RewritePatternSet patterns(ctx);
201d18ffd61SMatthias Springer     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
20254998986SStella Laurenzo     if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
203d18ffd61SMatthias Springer       signalPassFailure();
204d18ffd61SMatthias Springer   }
205d18ffd61SMatthias Springer };
206d18ffd61SMatthias Springer } // namespace
207d18ffd61SMatthias Springer 
208d18ffd61SMatthias Springer void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
209d18ffd61SMatthias Springer     RewritePatternSet &patterns) {
210d18ffd61SMatthias Springer   MLIRContext *ctx = patterns.getContext();
211d18ffd61SMatthias Springer   patterns
212b4e0507cSTres Popp       .add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
213d18ffd61SMatthias Springer            AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
214b4e0507cSTres Popp            DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
215c7d569b8SMatthias Springer            DimOfLoopResultFolder<tensor::DimOp>,
216c7d569b8SMatthias Springer            DimOfLoopResultFolder<memref::DimOp>>(ctx);
217d18ffd61SMatthias Springer }
218*039b969bSMichele Scuttari 
219*039b969bSMichele Scuttari std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
220*039b969bSMichele Scuttari   return std::make_unique<SCFForLoopCanonicalization>();
221*039b969bSMichele Scuttari }
222