xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass resolves `memref.dim` operations of result values in terms of
10 // shapes of their operands using the `InferShapedTypeOpInterface`.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Arith/Utils/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/Interfaces/InferTypeOpInterface.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 
26 namespace mlir {
27 namespace memref {
28 #define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
29 #define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
30 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
31 } // namespace memref
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
38 template <typename OpTy>
39 struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {
40   using OpRewritePattern<OpTy>::OpRewritePattern;
41 
42   LogicalResult matchAndRewrite(OpTy dimOp,
43                                 PatternRewriter &rewriter) const override {
44     OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
45     if (!dimValue)
46       return failure();
47     auto shapedTypeOp =
48         dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
49     if (!shapedTypeOp)
50       return failure();
51 
52     std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
53     if (!dimIndex)
54       return failure();
55 
56     SmallVector<Value> reifiedResultShapes;
57     if (failed(shapedTypeOp.reifyReturnTypeShapes(
58             rewriter, shapedTypeOp->getOperands(), reifiedResultShapes)))
59       return failure();
60 
61     if (reifiedResultShapes.size() != shapedTypeOp->getNumResults())
62       return failure();
63 
64     Value resultShape = reifiedResultShapes[dimValue.getResultNumber()];
65     auto resultShapeType = dyn_cast<RankedTensorType>(resultShape.getType());
66     if (!resultShapeType || !isa<IndexType>(resultShapeType.getElementType()))
67       return failure();
68 
69     Location loc = dimOp->getLoc();
70     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
71         dimOp, resultShape,
72         rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult());
73     return success();
74   }
75 };
76 
77 /// Fold dim of an operation that implements the InferShapedTypeOpInterface
78 template <typename OpTy>
79 struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
80   using OpRewritePattern<OpTy>::OpRewritePattern;
81 
82   void initialize() { OpRewritePattern<OpTy>::setHasBoundedRewriteRecursion(); }
83 
84   LogicalResult matchAndRewrite(OpTy dimOp,
85                                 PatternRewriter &rewriter) const override {
86     OpResult dimValue = dyn_cast<OpResult>(dimOp.getSource());
87     if (!dimValue)
88       return failure();
89     std::optional<int64_t> dimIndex = dimOp.getConstantIndex();
90     if (!dimIndex)
91       return failure();
92 
93     ReifiedRankedShapedTypeDims reifiedResultShapes;
94     if (failed(reifyResultShapes(rewriter, dimValue.getOwner(),
95                                  reifiedResultShapes)))
96       return failure();
97     unsigned resultNumber = dimValue.getResultNumber();
98     // Do not apply pattern if the IR is invalid (dim out of bounds).
99     if ((size_t)(*dimIndex) >= reifiedResultShapes[resultNumber].size())
100       return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
101     Value replacement = getValueOrCreateConstantIndexOp(
102         rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
103     rewriter.replaceOp(dimOp, replacement);
104     return success();
105   }
106 };
107 
108 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
109 ///
110 /// ```
111 /// %0 = ... : tensor<?x?xf32>
112 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
113 ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
114 ///   ...
115 /// }
116 /// ```
117 ///
118 /// is folded to:
119 ///
120 /// ```
121 /// %0 = ... : tensor<?x?xf32>
122 /// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
123 ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124 ///   ...
125 /// }
126 /// ```
127 struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
128   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
129 
130   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
131                                 PatternRewriter &rewriter) const final {
132     auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
133     if (!blockArg)
134       return failure();
135     // TODO: Enable this for loopLikeInterface. Restricting for scf.for
136     // because the init args shape might change in the loop body.
137     // For e.g.:
138     // ```
139     //  %0 = tensor.empty(%c1) : tensor<?xf32>
140     //  %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
141     //  tensor<?xf32> {
142     //    %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
143     //    %2 = arith.addi %c1, %1 : index
144     //    %3 = tensor.empty(%2) : tensor<?xf32>
145     //    scf.yield %3 : tensor<?xf32>
146     //  }
147     //
148     // ```
149     auto forAllOp =
150         dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
151     if (!forAllOp)
152       return failure();
153     Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
154     rewriter.modifyOpInPlace(
155         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
156     return success();
157   }
158 };
159 } // namespace
160 
161 //===----------------------------------------------------------------------===//
162 // Pass registration
163 //===----------------------------------------------------------------------===//
164 
165 namespace {
166 struct ResolveRankedShapeTypeResultDimsPass final
167     : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
168           ResolveRankedShapeTypeResultDimsPass> {
169   void runOnOperation() override;
170 };
171 
172 struct ResolveShapedTypeResultDimsPass final
173     : public memref::impl::ResolveShapedTypeResultDimsBase<
174           ResolveShapedTypeResultDimsPass> {
175   void runOnOperation() override;
176 };
177 
178 } // namespace
179 
180 void memref::populateResolveRankedShapedTypeResultDimsPatterns(
181     RewritePatternSet &patterns) {
182   patterns.add<DimOfReifyRankedShapedTypeOpInterface<memref::DimOp>,
183                DimOfReifyRankedShapedTypeOpInterface<tensor::DimOp>,
184                IterArgsToInitArgs>(patterns.getContext());
185 }
186 
187 void memref::populateResolveShapedTypeResultDimsPatterns(
188     RewritePatternSet &patterns) {
189   // TODO: Move tensor::DimOp pattern to the Tensor dialect.
190   patterns.add<DimOfShapedTypeOpInterface<memref::DimOp>,
191                DimOfShapedTypeOpInterface<tensor::DimOp>>(
192       patterns.getContext());
193 }
194 
195 void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
196   RewritePatternSet patterns(&getContext());
197   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
198   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
199     return signalPassFailure();
200 }
201 
202 void ResolveShapedTypeResultDimsPass::runOnOperation() {
203   RewritePatternSet patterns(&getContext());
204   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
205   memref::populateResolveShapedTypeResultDimsPatterns(patterns);
206   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
207     return signalPassFailure();
208 }
209 
210 std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
211   return std::make_unique<ResolveShapedTypeResultDimsPass>();
212 }
213 
214 std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
215   return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
216 }
217