xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp (revision bd81524e7f52c38ad0bc689934343a476e545265)
1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
12 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 LogicalResult tensor::mergeOffsetsSizesAndStrides(
21     OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
22     ArrayRef<OpFoldResult> producerSizes,
23     ArrayRef<OpFoldResult> producerStrides,
24     const llvm::SmallBitVector &droppedProducerDims,
25     ArrayRef<OpFoldResult> consumerOffsets,
26     ArrayRef<OpFoldResult> consumerSizes,
27     ArrayRef<OpFoldResult> consumerStrides,
28     SmallVector<OpFoldResult> &combinedOffsets,
29     SmallVector<OpFoldResult> &combinedSizes,
30     SmallVector<OpFoldResult> &combinedStrides) {
31   combinedOffsets.resize(producerOffsets.size());
32   combinedSizes.resize(producerOffsets.size());
33   combinedStrides.resize(producerOffsets.size());
34 
35   AffineExpr s0, s1, s2;
36   bindSymbols(builder.getContext(), s0, s1, s2);
37 
38   unsigned consumerPos = 0;
39   for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
40     if (droppedProducerDims.test(i)) {
41       // For dropped dims, get the values from the producer.
42       combinedOffsets[i] = producerOffsets[i];
43       combinedSizes[i] = producerSizes[i];
44       combinedStrides[i] = producerStrides[i];
45       continue;
46     }
47     SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
48     // The combined offset is computed as
49     //    producer_offset + consumer_offset * producer_strides.
50     combinedOffsets[i] = makeComposedFoldedAffineApply(
51         builder, loc, s0 * s1 + s2,
52         {consumerOffsets[consumerPos], producerStrides[i], producerOffsets[i]});
53     combinedSizes[i] = consumerSizes[consumerPos];
54     // The combined stride is computed as
55     //    consumer_stride * producer_stride.
56     combinedStrides[i] = makeComposedFoldedAffineApply(
57         builder, loc, s0 * s1,
58         {consumerStrides[consumerPos], producerStrides[i]});
59 
60     consumerPos++;
61   }
62   return success();
63 }
64 
65 LogicalResult tensor::mergeOffsetsSizesAndStrides(
66     OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
67     OffsetSizeAndStrideOpInterface consumer,
68     const llvm::SmallBitVector &droppedProducerDims,
69     SmallVector<OpFoldResult> &combinedOffsets,
70     SmallVector<OpFoldResult> &combinedSizes,
71     SmallVector<OpFoldResult> &combinedStrides) {
72   SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
73   SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
74   SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
75   SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
76   SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
77   SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
78   return tensor::mergeOffsetsSizesAndStrides(
79       builder, loc, producerOffsets, producerSizes, producerStrides,
80       droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
81       combinedOffsets, combinedSizes, combinedStrides);
82 }
83 
84 namespace {
85 /// Merges consecutive tensor.extract_slice ops into one.
86 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
87   using OpRewritePattern::OpRewritePattern;
88 
89   LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
90                                 PatternRewriter &rewriter) const override {
91     auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
92     if (!prevOp)
93       return failure();
94 
95     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
96     if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp,
97                                            nextOp, prevOp.getDroppedDims(),
98                                            newOffsets, newSizes, newStrides)))
99       return failure();
100 
101     rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
102                                                 prevOp.getSource(), newOffsets,
103                                                 newSizes, newStrides);
104     return success();
105   }
106 };
107 
108 /// Merges consecutive tensor.insert_slice ops into one.
109 struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
110   using OpRewritePattern::OpRewritePattern;
111 
112   LogicalResult matchAndRewrite(InsertSliceOp nextOp,
113                                 PatternRewriter &rewriter) const override {
114     auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
115     if (!prevOp)
116       return failure();
117 
118     if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
119       return failure();
120 
121     // The first insert_slice op should be rank reducing to make sure we cover
122     // the full source tensor to be inserted in the second insert_slice op.
123     SliceVerificationResult result =
124         isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
125     if (result != SliceVerificationResult::Success)
126       return failure();
127 
128     // Dynamic dimensions can pass rank reducing check in the above, e.g,
129     // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
130     // the dynamic size covers the full tensor.
131     if (!prevOp.getSourceType().hasStaticShape() ||
132         !prevOp.getDestType().hasStaticShape())
133       return failure();
134 
135     rewriter.replaceOpWithNewOp<InsertSliceOp>(
136         nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
137         nextOp.getMixedSizes(), nextOp.getMixedStrides());
138     return success();
139   }
140 };
141 } // namespace
142 
143 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
144     RewritePatternSet &patterns) {
145   patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
146       patterns.getContext());
147 }
148