xref: /llvm-project/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains cross-dialect canonicalization patterns that cannot be
10 // actual canonicalization patterns due to undesired additional dependencies.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SCF/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/SCF/IR/SCF.h"
19 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
20 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_SCFFORLOOPCANONICALIZATION
28 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 using namespace mlir::scf;
33 
34 /// A simple, conservative analysis to determine if the loop is shape
35 /// conserving. I.e., the type of the arg-th yielded value is the same as the
36 /// type of the corresponding basic block argument of the loop.
37 /// Note: This function handles only simple cases. Expand as needed.
38 static bool isShapePreserving(ForOp forOp, int64_t arg) {
39   assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
40          "arg is out of bounds");
41   Value value = forOp.getYieldedValues()[arg];
42   while (value) {
43     if (value == forOp.getRegionIterArgs()[arg])
44       return true;
45     OpResult opResult = dyn_cast<OpResult>(value);
46     if (!opResult)
47       return false;
48 
49     using tensor::InsertSliceOp;
50     value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
51                 .template Case<InsertSliceOp>(
52                     [&](InsertSliceOp op) { return op.getDest(); })
53                 .template Case<ForOp>([&](ForOp forOp) {
54                   return isShapePreserving(forOp, opResult.getResultNumber())
55                              ? forOp.getInitArgs()[opResult.getResultNumber()]
56                              : Value();
57                 })
58                 .Default([&](auto op) { return Value(); });
59   }
60   return false;
61 }
62 
63 namespace {
64 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
65 ///
66 /// ```
67 /// %0 = ... : tensor<?x?xf32>
68 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
69 ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
70 ///   ...
71 /// }
72 /// ```
73 ///
74 /// is folded to:
75 ///
76 /// ```
77 /// %0 = ... : tensor<?x?xf32>
78 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
79 ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
80 ///   ...
81 /// }
82 /// ```
83 ///
84 /// Note: Dim ops are folded only if it can be proven that the runtime type of
85 /// the iter arg does not change with loop iterations.
86 template <typename OpTy>
87 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
88   using OpRewritePattern<OpTy>::OpRewritePattern;
89 
90   LogicalResult matchAndRewrite(OpTy dimOp,
91                                 PatternRewriter &rewriter) const override {
92     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
93     if (!blockArg)
94       return failure();
95     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
96     if (!forOp)
97       return failure();
98     if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
99       return failure();
100 
101     Value initArg = forOp.getTiedLoopInit(blockArg)->get();
102     rewriter.modifyOpInPlace(
103         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
104 
105     return success();
106   };
107 };
108 
109 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
110 ///
111 /// ```
112 /// %0 = ... : tensor<?x?xf32>
113 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
114 ///   ...
115 /// }
116 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
117 /// ```
118 ///
119 /// is folded to:
120 ///
121 /// ```
122 /// %0 = ... : tensor<?x?xf32>
123 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
124 ///   ...
125 /// }
126 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
127 /// ```
128 ///
129 /// Note: Dim ops are folded only if it can be proven that the runtime type of
130 /// the iter arg does not change with loop iterations.
131 template <typename OpTy>
132 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
133   using OpRewritePattern<OpTy>::OpRewritePattern;
134 
135   LogicalResult matchAndRewrite(OpTy dimOp,
136                                 PatternRewriter &rewriter) const override {
137     auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
138     if (!forOp)
139       return failure();
140     auto opResult = cast<OpResult>(dimOp.getSource());
141     unsigned resultNumber = opResult.getResultNumber();
142     if (!isShapePreserving(forOp, resultNumber))
143       return failure();
144     rewriter.modifyOpInPlace(dimOp, [&]() {
145       dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
146     });
147     return success();
148   }
149 };
150 
151 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
152 /// and scf.parallel loops with a known range.
153 template <typename OpTy>
154 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
155   using OpRewritePattern<OpTy>::OpRewritePattern;
156 
157   LogicalResult matchAndRewrite(OpTy op,
158                                 PatternRewriter &rewriter) const override {
159     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, scf::matchForLikeLoop);
160   }
161 };
162 
163 struct SCFForLoopCanonicalization
164     : public impl::SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
165   void runOnOperation() override {
166     auto *parentOp = getOperation();
167     MLIRContext *ctx = parentOp->getContext();
168     RewritePatternSet patterns(ctx);
169     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
170     if (failed(applyPatternsGreedily(parentOp, std::move(patterns))))
171       signalPassFailure();
172   }
173 };
174 } // namespace
175 
176 void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
177     RewritePatternSet &patterns) {
178   MLIRContext *ctx = patterns.getContext();
179   patterns
180       .add<AffineOpSCFCanonicalizationPattern<affine::AffineMinOp>,
181            AffineOpSCFCanonicalizationPattern<affine::AffineMaxOp>,
182            DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
183            DimOfLoopResultFolder<tensor::DimOp>,
184            DimOfLoopResultFolder<memref::DimOp>>(ctx);
185 }
186 
187 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
188   return std::make_unique<SCFForLoopCanonicalization>();
189 }
190