xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp (revision d18ffd61d4f2500dc4ae267f4705102abb2cf02f)
1*d18ffd61SMatthias Springer //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2*d18ffd61SMatthias Springer //
3*d18ffd61SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*d18ffd61SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5*d18ffd61SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*d18ffd61SMatthias Springer //
7*d18ffd61SMatthias Springer //===----------------------------------------------------------------------===//
8*d18ffd61SMatthias Springer //
9*d18ffd61SMatthias Springer // This file contains cross-dialect canonicalization patterns that cannot be
10*d18ffd61SMatthias Springer // actual canonicalization patterns due to undesired additional dependencies.
11*d18ffd61SMatthias Springer //
12*d18ffd61SMatthias Springer //===----------------------------------------------------------------------===//
13*d18ffd61SMatthias Springer 
14*d18ffd61SMatthias Springer #include "PassDetail.h"
15*d18ffd61SMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
16*d18ffd61SMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
17*d18ffd61SMatthias Springer #include "mlir/Dialect/SCF/Passes.h"
18*d18ffd61SMatthias Springer #include "mlir/Dialect/SCF/SCF.h"
19*d18ffd61SMatthias Springer #include "mlir/Dialect/SCF/Transforms.h"
20*d18ffd61SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
21*d18ffd61SMatthias Springer #include "mlir/IR/PatternMatch.h"
22*d18ffd61SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23*d18ffd61SMatthias Springer 
24*d18ffd61SMatthias Springer using namespace mlir;
25*d18ffd61SMatthias Springer using namespace mlir::scf;
26*d18ffd61SMatthias Springer 
27*d18ffd61SMatthias Springer namespace {
28*d18ffd61SMatthias Springer /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
29*d18ffd61SMatthias Springer ///
30*d18ffd61SMatthias Springer /// ```
31*d18ffd61SMatthias Springer /// %0 = ... : tensor<?x?xf32>
32*d18ffd61SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
33*d18ffd61SMatthias Springer ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
34*d18ffd61SMatthias Springer ///   ...
35*d18ffd61SMatthias Springer /// }
36*d18ffd61SMatthias Springer /// ```
37*d18ffd61SMatthias Springer ///
38*d18ffd61SMatthias Springer /// is folded to:
39*d18ffd61SMatthias Springer ///
40*d18ffd61SMatthias Springer /// ```
41*d18ffd61SMatthias Springer /// %0 = ... : tensor<?x?xf32>
42*d18ffd61SMatthias Springer /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
43*d18ffd61SMatthias Springer ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
44*d18ffd61SMatthias Springer ///   ...
45*d18ffd61SMatthias Springer /// }
46*d18ffd61SMatthias Springer /// ```
47*d18ffd61SMatthias Springer template <typename OpTy>
48*d18ffd61SMatthias Springer struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
49*d18ffd61SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
50*d18ffd61SMatthias Springer 
51*d18ffd61SMatthias Springer   LogicalResult matchAndRewrite(OpTy dimOp,
52*d18ffd61SMatthias Springer                                 PatternRewriter &rewriter) const override {
53*d18ffd61SMatthias Springer     auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
54*d18ffd61SMatthias Springer     if (!blockArg)
55*d18ffd61SMatthias Springer       return failure();
56*d18ffd61SMatthias Springer     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
57*d18ffd61SMatthias Springer     if (!forOp)
58*d18ffd61SMatthias Springer       return failure();
59*d18ffd61SMatthias Springer 
60*d18ffd61SMatthias Springer     Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
61*d18ffd61SMatthias Springer     rewriter.updateRootInPlace(
62*d18ffd61SMatthias Springer         dimOp, [&]() { dimOp.sourceMutable().assign(initArg); });
63*d18ffd61SMatthias Springer 
64*d18ffd61SMatthias Springer     return success();
65*d18ffd61SMatthias Springer   };
66*d18ffd61SMatthias Springer };
67*d18ffd61SMatthias Springer 
68*d18ffd61SMatthias Springer /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
69*d18ffd61SMatthias Springer /// and scf.parallel loops with a known range.
70*d18ffd61SMatthias Springer template <typename OpTy, bool IsMin>
71*d18ffd61SMatthias Springer struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
72*d18ffd61SMatthias Springer   using OpRewritePattern<OpTy>::OpRewritePattern;
73*d18ffd61SMatthias Springer 
74*d18ffd61SMatthias Springer   LogicalResult matchAndRewrite(OpTy op,
75*d18ffd61SMatthias Springer                                 PatternRewriter &rewriter) const override {
76*d18ffd61SMatthias Springer     auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
77*d18ffd61SMatthias Springer       if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
78*d18ffd61SMatthias Springer         lb = forOp.lowerBound();
79*d18ffd61SMatthias Springer         ub = forOp.upperBound();
80*d18ffd61SMatthias Springer         step = forOp.step();
81*d18ffd61SMatthias Springer         return success();
82*d18ffd61SMatthias Springer       }
83*d18ffd61SMatthias Springer       if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
84*d18ffd61SMatthias Springer         for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
85*d18ffd61SMatthias Springer           if (parOp.getInductionVars()[idx] == iv) {
86*d18ffd61SMatthias Springer             lb = parOp.lowerBound()[idx];
87*d18ffd61SMatthias Springer             ub = parOp.upperBound()[idx];
88*d18ffd61SMatthias Springer             step = parOp.step()[idx];
89*d18ffd61SMatthias Springer             return success();
90*d18ffd61SMatthias Springer           }
91*d18ffd61SMatthias Springer         }
92*d18ffd61SMatthias Springer         return failure();
93*d18ffd61SMatthias Springer       }
94*d18ffd61SMatthias Springer       return failure();
95*d18ffd61SMatthias Springer     };
96*d18ffd61SMatthias Springer 
97*d18ffd61SMatthias Springer     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
98*d18ffd61SMatthias Springer                                            op.operands(), IsMin, loopMatcher);
99*d18ffd61SMatthias Springer   }
100*d18ffd61SMatthias Springer };
101*d18ffd61SMatthias Springer 
102*d18ffd61SMatthias Springer struct SCFForLoopCanonicalization
103*d18ffd61SMatthias Springer     : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
104*d18ffd61SMatthias Springer   void runOnFunction() override {
105*d18ffd61SMatthias Springer     FuncOp funcOp = getFunction();
106*d18ffd61SMatthias Springer     MLIRContext *ctx = funcOp.getContext();
107*d18ffd61SMatthias Springer     RewritePatternSet patterns(ctx);
108*d18ffd61SMatthias Springer     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
109*d18ffd61SMatthias Springer     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
110*d18ffd61SMatthias Springer       signalPassFailure();
111*d18ffd61SMatthias Springer   }
112*d18ffd61SMatthias Springer };
113*d18ffd61SMatthias Springer } // namespace
114*d18ffd61SMatthias Springer 
115*d18ffd61SMatthias Springer void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
116*d18ffd61SMatthias Springer     RewritePatternSet &patterns) {
117*d18ffd61SMatthias Springer   MLIRContext *ctx = patterns.getContext();
118*d18ffd61SMatthias Springer   patterns
119*d18ffd61SMatthias Springer       .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
120*d18ffd61SMatthias Springer               AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
121*d18ffd61SMatthias Springer               DimOfIterArgFolder<tensor::DimOp>,
122*d18ffd61SMatthias Springer               DimOfIterArgFolder<memref::DimOp>>(ctx);
123*d18ffd61SMatthias Springer }
124*d18ffd61SMatthias Springer 
125*d18ffd61SMatthias Springer std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
126*d18ffd61SMatthias Springer   return std::make_unique<SCFForLoopCanonicalization>();
127*d18ffd61SMatthias Springer }
128