xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold tensor subset ops with producer / consumers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Utils/IndexingUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include <type_traits>
28 
29 namespace mlir {
30 namespace tensor {
31 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
32 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
33 } // namespace tensor
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 static Value getTensorOperand(vector::TransferReadOp op) {
39   return op.getSource();
40 }
41 
42 static Value getTensorOperand(tensor::InsertSliceOp op) {
43   return op.getSource();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // Patterns
48 //===----------------------------------------------------------------------===//
49 
50 namespace {
51 /// Merge extract_slice operation with load/transferRead operation.
52 class TransferReadOfExtractSliceOpFolder final
53     : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
54 public:
55   using MaskableOpRewritePattern::MaskableOpRewritePattern;
56 
57   FailureOr<mlir::Value>
58   matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
59                             vector::MaskingOpInterface maskOp,
60                             PatternRewriter &rewriter) const override;
61 };
62 
63 /// Merge insert_slice operation with store/transferWriteOp operation.
64 class InsertSliceOfTransferWriteOpFolder final
65     : public OpRewritePattern<tensor::InsertSliceOp> {
66 public:
67   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
68 
69   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
70                                 PatternRewriter &rewriter) const override;
71 
72 private:
73   static bool
74   doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
75 };
76 } // namespace
77 
78 template <typename XferOp, typename ExtractOrInsertOp>
79 static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
80     RewriterBase &rewriter, XferOp xferOp,
81     ExtractOrInsertOp extractOrInsertSliceOp) {
82   if (xferOp.hasOutOfBoundsDim())
83     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
84   if (xferOp.getMask())
85     return rewriter.notifyMatchFailure(xferOp, "masked transfer");
86   if (!extractOrInsertSliceOp.hasUnitStride()) {
87     return rewriter.notifyMatchFailure(
88         xferOp, "non-1 stride insert/extract, requires keeping track of "
89                 "strides, this may result in needing to insert "
90                 "vector.insert_strided_slice/extract_strided_slice ops");
91   }
92   return success();
93 }
94 
95 FailureOr<mlir::Value>
96 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
97     vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
98     PatternRewriter &rewriter) const {
99   auto extractSliceOp =
100       getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
101   if (!extractSliceOp)
102     return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
103 
104   LogicalResult preconditionResult =
105       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
106                                                      extractSliceOp);
107   if (failed(preconditionResult))
108     return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
109 
110   SmallVector<Value> indices(readOp.getIndices().begin(),
111                              readOp.getIndices().end());
112   SmallVector<Value> sourceIndices;
113   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
114       rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
115       extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
116       indices, sourceIndices);
117 
118   Operation *newOp = rewriter.create<vector::TransferReadOp>(
119       readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
120       sourceIndices,
121       AffineMapAttr::get(expandDimsToRank(
122           readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
123           extractSliceOp.getDroppedDims())),
124       readOp.getPadding(),
125       /*mask=*/Value(), readOp.getInBoundsAttr());
126   if (maskOp)
127     newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
128   return newOp->getResults()[0];
129 }
130 
131 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
132     tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
133   auto writeOp = getTensorOperand(insertSliceOp)
134                      .template getDefiningOp<vector::TransferWriteOp>();
135   if (!writeOp)
136     return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
137 
138   LogicalResult preconditionResult =
139       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
140                                                      insertSliceOp);
141   if (failed(preconditionResult))
142     return preconditionResult;
143 
144   if (!doesTransferWriteCoverInsertSlice(writeOp))
145     return rewriter.notifyMatchFailure(
146         insertSliceOp, "transfer_write does not cover insert_slice");
147 
148   SmallVector<Value> indices(writeOp.getIndices().begin(),
149                              writeOp.getIndices().end());
150   SmallVector<Value> sourceIndices;
151   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
152       rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
153       insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
154       sourceIndices);
155 
156   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
157       insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
158       AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
159                                           insertSliceOp.getDestType().getRank(),
160                                           insertSliceOp.getDroppedDims())),
161       writeOp.getInBoundsAttr());
162 
163   return success();
164 }
165 
166 bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
167     vector::TransferWriteOp writeOp) {
168   if (writeOp.getShapedType().hasStaticShape())
169     return llvm::equal(writeOp.getVectorType().getShape(),
170                        writeOp.getShapedType().getShape());
171 
172   // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
173 
174   return false;
175 }
176 
177 template <typename OpTy>
178 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
179   using OpRewritePattern<OpTy>::OpRewritePattern;
180 
181   LogicalResult matchAndRewrite(OpTy insertSliceOp,
182                                 PatternRewriter &rewriter) const override {
183     auto sourceInsertSliceOp =
184         insertSliceOp.getSource()
185             .template getDefiningOp<tensor::InsertSliceOp>();
186     if (!sourceInsertSliceOp)
187       return failure();
188 
189     // TODO: relax unit stride assumption where possible.
190     if (!insertSliceOp.hasUnitStride()) {
191       return rewriter.notifyMatchFailure(insertSliceOp,
192                                          "requires unit strides");
193     }
194     if (!sourceInsertSliceOp.hasUnitStride()) {
195       return rewriter.notifyMatchFailure(sourceInsertSliceOp,
196                                          "requires unit strides");
197     }
198 
199     int64_t srcDim = 0;
200     llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
201     for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
202       if (droppedDims[d])
203         continue;
204       if (insertSliceOp.getMixedSizes()[d] !=
205           sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
206         return rewriter.notifyMatchFailure(
207             sourceInsertSliceOp,
208             "requires matching sizes to fold, otherwise a copy is needed");
209       }
210     }
211 
212     // Resolve sizes according to dropped dims.
213     SmallVector<OpFoldResult> resolvedSizes;
214     // Note: the "insertSlice" case is symmetrical to the extract/subview case:
215     // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
216     // passed as the destination to the helper function.
217     affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
218                                         sourceInsertSliceOp.getMixedSizes(),
219                                         droppedDims, resolvedSizes);
220 
221     // If we are inside an InParallel region, temporarily set the insertion
222     // point outside: only tensor.parallel_insert_slice ops are allowed in
223     // there.
224     if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
225       rewriter.setInsertionPoint(
226           insertSliceOp->template getParentOfType<scf::InParallelOp>());
227     }
228 
229     // Resolve offsets according to source offsets and strides.
230     SmallVector<Value> resolvedOffsets;
231     // Note: the "insertSlice" case is symmetrical to the extract/subview case:
232     // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
233     // passed as the destination to the helper function.
234     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
235         rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
236         insertSliceOp.getMixedStrides(), droppedDims,
237         sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
238 
239     // Reset the insertion point.
240     rewriter.setInsertionPoint(insertSliceOp);
241     // Replace original op.
242     rewriter.replaceOpWithNewOp<OpTy>(
243         insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
244         getAsOpFoldResult(resolvedOffsets), resolvedSizes,
245         insertSliceOp.getMixedStrides());
246 
247     return success();
248   }
249 };
250 
251 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
252   populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
253   patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
254                InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
255       patterns.getContext());
256 }
257 
258 void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
259     RewritePatternSet &patterns) {
260   patterns.add<TransferReadOfExtractSliceOpFolder,
261                InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // Pass registration
266 //===----------------------------------------------------------------------===//
267 
268 namespace {
269 
270 struct FoldTensorSubsetOpsPass final
271     : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
272   void runOnOperation() override;
273 };
274 
275 } // namespace
276 
277 void FoldTensorSubsetOpsPass::runOnOperation() {
278   RewritePatternSet patterns(&getContext());
279   tensor::populateFoldTensorSubsetOpPatterns(patterns);
280   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
281 }
282 
283 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
284   return std::make_unique<FoldTensorSubsetOpsPass>();
285 }
286