xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (revision 4c48f016effde67d500fc95290096aec9f3bdb70)
1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Fold tensor subset ops with producer / consumers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19 #include "mlir/Dialect/Utils/IndexingUtils.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/BuiltinAttributes.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include <type_traits>
26 
27 namespace mlir {
28 namespace tensor {
29 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
30 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
31 } // namespace tensor
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 static Value getTensorOperand(vector::TransferReadOp op) {
37   return op.getSource();
38 }
39 
40 static Value getTensorOperand(tensor::InsertSliceOp op) {
41   return op.getSource();
42 }
43 
44 //===----------------------------------------------------------------------===//
45 // Patterns
46 //===----------------------------------------------------------------------===//
47 
48 namespace {
49 /// Merge extract_slice operation with load/transferRead operation.
50 class TransferReadOfExtractSliceOpFolder final
51     : public OpRewritePattern<vector::TransferReadOp> {
52 public:
53   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
54 
55   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
56                                 PatternRewriter &rewriter) const override;
57 };
58 
59 /// Merge insert_slice operation with store/transferWriteOp operation.
60 class InsertSliceOfTransferWriteOpFolder final
61     : public OpRewritePattern<tensor::InsertSliceOp> {
62 public:
63   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
64 
65   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
66                                 PatternRewriter &rewriter) const override;
67 };
68 } // namespace
69 
70 template <typename XferOp, typename ExtractOrInsertOp>
71 static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
72     RewriterBase &rewriter, XferOp xferOp,
73     ExtractOrInsertOp extractOrInsertSliceOp) {
74   if (xferOp.hasOutOfBoundsDim())
75     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
76   if (xferOp.getMask())
77     return rewriter.notifyMatchFailure(xferOp, "masked transfer");
78   if (!extractOrInsertSliceOp.hasUnitStride()) {
79     return rewriter.notifyMatchFailure(
80         xferOp, "non-1 stride insert/extract, requires keeping track of "
81                 "strides, this may result in needing to insert "
82                 "vector.insert_strided_slice/extract_strided_slice ops");
83   }
84   return success();
85 }
86 
87 LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite(
88     vector::TransferReadOp readOp, PatternRewriter &rewriter) const {
89   auto extractSliceOp =
90       getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
91   if (!extractSliceOp)
92     return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
93 
94   LogicalResult preconditionResult =
95       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
96                                                      extractSliceOp);
97   if (failed(preconditionResult))
98     return preconditionResult;
99 
100   SmallVector<Value> indices(readOp.getIndices().begin(),
101                              readOp.getIndices().end());
102   SmallVector<Value> sourceIndices;
103   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
104       rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
105       extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
106       indices, sourceIndices);
107 
108   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
109       readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices,
110       AffineMapAttr::get(expandDimsToRank(
111           readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
112           extractSliceOp.getDroppedDims())),
113       readOp.getPadding(),
114       /*mask=*/Value(), readOp.getInBoundsAttr());
115 
116   return success();
117 }
118 
119 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
120     tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
121   auto writeOp = getTensorOperand(insertSliceOp)
122                      .template getDefiningOp<vector::TransferWriteOp>();
123   if (!writeOp)
124     return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
125 
126   LogicalResult preconditionResult =
127       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
128                                                      insertSliceOp);
129   if (failed(preconditionResult))
130     return preconditionResult;
131 
132   SmallVector<Value> indices(writeOp.getIndices().begin(),
133                              writeOp.getIndices().end());
134   SmallVector<Value> sourceIndices;
135   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
136       rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
137       insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
138       sourceIndices);
139 
140   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
141       insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
142       AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
143                                           insertSliceOp.getDestType().getRank(),
144                                           insertSliceOp.getDroppedDims())),
145       writeOp.getInBoundsAttr());
146 
147   return success();
148 }
149 
150 template <typename OpTy>
151 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
152   using OpRewritePattern<OpTy>::OpRewritePattern;
153 
154   LogicalResult matchAndRewrite(OpTy insertSliceOp,
155                                 PatternRewriter &rewriter) const override {
156     auto sourceInsertSliceOp =
157         insertSliceOp.getSource()
158             .template getDefiningOp<tensor::InsertSliceOp>();
159     if (!sourceInsertSliceOp)
160       return failure();
161 
162     // TODO: relax unit stride assumption where possible.
163     if (!insertSliceOp.hasUnitStride()) {
164       return rewriter.notifyMatchFailure(insertSliceOp,
165                                          "requires unit strides");
166     }
167     if (!sourceInsertSliceOp.hasUnitStride()) {
168       return rewriter.notifyMatchFailure(sourceInsertSliceOp,
169                                          "requires unit strides");
170     }
171 
172     int64_t srcDim = 0;
173     llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
174     for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
175       if (droppedDims[d])
176         continue;
177       if (insertSliceOp.getMixedSizes()[d] !=
178           sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
179         return rewriter.notifyMatchFailure(
180             sourceInsertSliceOp,
181             "requires matching sizes to fold, otherwise a copy is needed");
182       }
183     }
184 
185     // Resolve sizes according to dropped dims.
186     SmallVector<OpFoldResult> resolvedSizes;
187     // Note: the "insertSlice" case is symmetrical to the extract/subview case:
188     // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
189     // passed as the destination to the helper function.
190     affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
191                                         sourceInsertSliceOp.getMixedSizes(),
192                                         droppedDims, resolvedSizes);
193 
194     // If we are inside an InParallel region, temporarily set the insertion
195     // point outside: only tensor.parallel_insert_slice ops are allowed in
196     // there.
197     if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
198       rewriter.setInsertionPoint(
199           insertSliceOp->template getParentOfType<scf::InParallelOp>());
200     }
201 
202     // Resolve offsets according to source offsets and strides.
203     SmallVector<Value> resolvedOffsets;
204     // Note: the "insertSlice" case is symmetrical to the extract/subview case:
205     // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
206     // passed as the destination to the helper function.
207     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
208         rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
209         insertSliceOp.getMixedStrides(), droppedDims,
210         sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
211 
212     // Reset the insertion point.
213     rewriter.setInsertionPoint(insertSliceOp);
214     // Replace original op.
215     rewriter.replaceOpWithNewOp<OpTy>(
216         insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
217         getAsOpFoldResult(resolvedOffsets), resolvedSizes,
218         insertSliceOp.getMixedStrides());
219 
220     return success();
221   }
222 };
223 
224 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
225   patterns.add<TransferReadOfExtractSliceOpFolder,
226                InsertSliceOfTransferWriteOpFolder,
227                InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
228                InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
229       patterns.getContext());
230 }
231 //===----------------------------------------------------------------------===//
232 // Pass registration
233 //===----------------------------------------------------------------------===//
234 
235 namespace {
236 
237 struct FoldTensorSubsetOpsPass final
238     : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
239   void runOnOperation() override;
240 };
241 
242 } // namespace
243 
244 void FoldTensorSubsetOpsPass::runOnOperation() {
245   RewritePatternSet patterns(&getContext());
246   tensor::populateFoldTensorSubsetOpPatterns(patterns);
247   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
248 }
249 
250 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
251   return std::make_unique<FoldTensorSubsetOpsPass>();
252 }
253