xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (revision a95ad2da36b6a996b05c79df6b385cd98bac286d)
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 /// Fold expand_shape(extract_slice) ops that cancel itself out.
19 struct FoldExpandOfRankReducingExtract
20     : public OpRewritePattern<ExpandShapeOp> {
21   using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
22 
23   LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
24                                 PatternRewriter &rewriter) const override {
25     RankedTensorType resultType = expandShapeOp.getResultType();
26     auto extractSliceOp =
27         expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
28     if (!extractSliceOp)
29       return failure();
30     RankedTensorType srcType = extractSliceOp.getSourceType();
31 
32     // Only cases where the ExpandShapeOp can be folded away entirely are
33     // supported. Moreover, only simple cases where the resulting ExtractSliceOp
34     // has no rank-reduction anymore are supported at the moment.
35     RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
36         srcType, extractSliceOp.getStaticOffsets(),
37         extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
38     if (nonReducingExtractType != resultType)
39       return failure();
40 
41     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
42     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
43     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
44     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
45         expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
46         mixedStrides);
47     return success();
48   }
49 };
50 
51 /// Fold collapse_shape which only removes static dimensions of size `1`
52 /// into extract_slice.
53 struct FoldUnPaddingCollapseIntoExtract
54     : public OpRewritePattern<tensor::CollapseShapeOp> {
55   using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
56 
57   LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
58                                 PatternRewriter &rewriter) const override {
59     auto extractSliceOp =
60         collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
61     // Collapse cannot be folded away with multiple users of the extract slice
62     // and it is not necessarily beneficial to only convert the collapse into
63     // another extract slice.
64     if (!extractSliceOp || !extractSliceOp->hasOneUse())
65       return failure();
66 
67     // Only fold away simple collapse where all removed dimensions have static
68     // size `1`.
69     SliceVerificationResult res = isRankReducedType(
70         collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
71     if (res != SliceVerificationResult::Success)
72       return rewriter.notifyMatchFailure(collapseShapeOp,
73                                          "expected unpadding collapse");
74 
75     Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
76         extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
77         extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
78         extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
79     rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
80     return success();
81   }
82 };
83 
84 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
85 template <typename OpTy>
86 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
87   using OpRewritePattern<OpTy>::OpRewritePattern;
88 
89   LogicalResult matchAndRewrite(OpTy insertSliceOp,
90                                 PatternRewriter &rewriter) const override {
91     auto collapseShapeOp =
92         insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
93     if (!collapseShapeOp)
94       return failure();
95     RankedTensorType srcType = collapseShapeOp.getSrcType();
96 
97     // Only cases where the CollapseShapeOp can be folded away entirely are
98     // supported. Moreover, only simple cases where the resulting InsertSliceOp
99     // has no rank-reduction anymore are supported at the moment.
100     RankedTensorType nonReducingInsertType =
101         RankedTensorType::get(insertSliceOp.getStaticSizes(),
102                               insertSliceOp.getDestType().getElementType());
103     if (nonReducingInsertType != srcType)
104       return failure();
105 
106     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
107     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
108     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
109     rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
110                                       insertSliceOp.getDest(), mixedOffsets,
111                                       mixedSizes, mixedStrides);
112     return success();
113   }
114 };
115 
116 /// Fold expand_shape which only adds static dimensions of size `1`
117 /// into insert_slice.
118 template <typename OpTy>
119 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
120   using OpRewritePattern<OpTy>::OpRewritePattern;
121 
122   LogicalResult matchAndRewrite(OpTy insertSliceOp,
123                                 PatternRewriter &rewriter) const override {
124     auto expandShapeOp = insertSliceOp.getSource()
125                              .template getDefiningOp<tensor::ExpandShapeOp>();
126     if (!expandShapeOp)
127       return failure();
128 
129     // Only fold away simple expansion where all added dimensions have static
130     // size `1`.
131     SliceVerificationResult res = isRankReducedType(
132         expandShapeOp.getResultType(), expandShapeOp.getSrcType());
133     if (res != SliceVerificationResult::Success)
134       return rewriter.notifyMatchFailure(insertSliceOp,
135                                          "expected rank increasing expansion");
136 
137     rewriter.modifyOpInPlace(insertSliceOp, [&]() {
138       insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
139     });
140     return success();
141   }
142 };
143 
144 /// Pattern to bubble up a tensor.expand_shape op through a producer
145 /// tensor.collapse_shape op that has non intersecting reassociations.
146 struct BubbleUpExpandThroughParallelCollapse
147     : public OpRewritePattern<tensor::ExpandShapeOp> {
148   using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
149 
150   LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
151                                 PatternRewriter &rewriter) const override {
152     auto collapseOp =
153         expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
154     if (!collapseOp)
155       return failure();
156     auto expandReInds = expandOp.getReassociationIndices();
157     auto collapseReInds = collapseOp.getReassociationIndices();
158 
159     // Reshapes are parallel to each other if none of the reassociation indices
160     // have greater than 1 index for both reshapes.
161     for (auto [expandReassociation, collapseReassociation] :
162          llvm::zip_equal(expandReInds, collapseReInds)) {
163       if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
164         return failure();
165     }
166 
167     // Compute new reassociation indices and expanded/collaped shapes.
168     SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
169     Location loc = expandOp->getLoc();
170     SmallVector<OpFoldResult> collapseSizes =
171         tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
172     SmallVector<OpFoldResult> expandSizes(getMixedValues(
173         expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
174     SmallVector<OpFoldResult> newExpandSizes;
175     int64_t index = 0, expandIndex = 0, collapseIndex = 0;
176     for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
177       if (collapseReassociation.size() != 1) {
178         ReassociationIndices newCollapseReassociation;
179         for (size_t i = 0; i < collapseReassociation.size(); ++i) {
180           newCollapseReassociation.push_back(index);
181           newExpandReInds.push_back({index++});
182           newExpandSizes.push_back(collapseSizes[collapseIndex++]);
183         }
184         newCollapseReInds.push_back(newCollapseReassociation);
185         expandIndex++;
186         continue;
187       }
188       ReassociationIndices newExpandReassociation;
189       auto expandReassociation = expandReInds[idx];
190       for (size_t i = 0; i < expandReassociation.size(); ++i) {
191         newExpandReassociation.push_back(index);
192         newCollapseReInds.push_back({index++});
193         newExpandSizes.push_back(expandSizes[expandIndex++]);
194       }
195       newExpandReInds.push_back(newExpandReassociation);
196       collapseIndex++;
197     }
198 
199     // Swap reshape order.
200     SmallVector<Value> dynamicSizes;
201     SmallVector<int64_t> staticSizes;
202     dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
203     auto expandResultType = expandOp.getResultType().clone(staticSizes);
204     auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
205         loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
206         newExpandSizes);
207     rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
208         expandOp, newExpand.getResult(), newCollapseReInds);
209     return success();
210   }
211 };
212 
213 } // namespace
214 
215 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
216     RewritePatternSet &patterns) {
217   patterns
218       .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
219            FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
220            FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
221            FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
222            FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
223           patterns.getContext());
224 }
225 
226 void mlir::tensor::populateBubbleUpExpandShapePatterns(
227     RewritePatternSet &patterns) {
228   patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229 }
230