xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp (revision 5d4603a02d0c3e0106b10d245322b1d2072c0c3d)
1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Tensor/IR/Tensor.h"
11 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
12 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 /// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a
21 /// AffineSymbolExpr and appends it to `symbols`; otherwise creates a
22 /// AffineConstantExpr.
23 static AffineExpr getAffineExpr(OpFoldResult ofr,
24                                 SmallVector<OpFoldResult> &symbols) {
25   if (auto attr = ofr.dyn_cast<Attribute>()) {
26     return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(),
27                                  attr.getContext());
28   }
29   Value v = ofr.get<Value>();
30   AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext());
31   symbols.push_back(v);
32   return expr;
33 }
34 
35 /// Builds the AffineExpr incrementally for arithmetic operations.
36 static AffineExpr add(AffineExpr expr, OpFoldResult ofr,
37                       SmallVector<OpFoldResult> &symbols) {
38   return expr + getAffineExpr(ofr, symbols);
39 }
40 static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs,
41                       SmallVector<OpFoldResult> &symbols) {
42   return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols);
43 }
44 
45 /// Converts an AffineExpr to OpFoldResult by generating an `affine.apply`
46 /// op and fold it.
47 static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc,
48                                     AffineExpr expr,
49                                     SmallVector<OpFoldResult> &symbols) {
50   AffineMap m = AffineMap::get(0, symbols.size(), expr);
51   return makeComposedFoldedAffineApply(builder, loc, m, symbols);
52 }
53 
54 LogicalResult tensor::mergeOffsetsSizesAndStrides(
55     OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets,
56     ArrayRef<OpFoldResult> producerSizes,
57     ArrayRef<OpFoldResult> producerStrides,
58     const llvm::SmallBitVector &droppedProducerDims,
59     ArrayRef<OpFoldResult> consumerOffsets,
60     ArrayRef<OpFoldResult> consumerSizes,
61     ArrayRef<OpFoldResult> consumerStrides,
62     SmallVector<OpFoldResult> &combinedOffsets,
63     SmallVector<OpFoldResult> &combinedSizes,
64     SmallVector<OpFoldResult> &combinedStrides) {
65   combinedOffsets.resize(producerOffsets.size());
66   combinedSizes.resize(producerOffsets.size());
67   combinedStrides.resize(producerOffsets.size());
68   unsigned consumerPos = 0;
69   for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) {
70     if (droppedProducerDims.test(i)) {
71       // For dropped dims, get the values from the producer.
72       combinedOffsets[i] = producerOffsets[i];
73       combinedSizes[i] = producerSizes[i];
74       combinedStrides[i] = producerStrides[i];
75       continue;
76     }
77     SmallVector<OpFoldResult> offsetSymbols, strideSymbols;
78     // The combined offset is computed as
79     //    producer_offset + consumer_offset * producer_strides.
80     combinedOffsets[i] =
81         getOpFoldResult(builder, loc,
82                         add(mul(consumerOffsets[consumerPos],
83                                 producerStrides[i], offsetSymbols),
84                             producerOffsets[i], offsetSymbols),
85                         offsetSymbols);
86     combinedSizes[i] = consumerSizes[consumerPos];
87     // The combined stride is computed as
88     //    consumer_stride * producer_stride.
89     combinedStrides[i] = getOpFoldResult(
90         builder, loc,
91         mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols),
92         strideSymbols);
93     consumerPos++;
94   }
95   return success();
96 }
97 
98 LogicalResult tensor::mergeOffsetsSizesAndStrides(
99     OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer,
100     OffsetSizeAndStrideOpInterface consumer,
101     const llvm::SmallBitVector &droppedProducerDims,
102     SmallVector<OpFoldResult> &combinedOffsets,
103     SmallVector<OpFoldResult> &combinedSizes,
104     SmallVector<OpFoldResult> &combinedStrides) {
105   SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets();
106   SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes();
107   SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides();
108   SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets();
109   SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes();
110   SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides();
111   return tensor::mergeOffsetsSizesAndStrides(
112       builder, loc, producerOffsets, producerSizes, producerStrides,
113       droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides,
114       combinedOffsets, combinedSizes, combinedStrides);
115 }
116 
117 namespace {
118 /// Merges consecutive tensor.extract_slice ops into one.
119 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
120   using OpRewritePattern::OpRewritePattern;
121 
122   LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
123                                 PatternRewriter &rewriter) const override {
124     auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
125     if (!prevOp)
126       return failure();
127 
128     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
129     if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp,
130                                            nextOp, prevOp.getDroppedDims(),
131                                            newOffsets, newSizes, newStrides)))
132       return failure();
133 
134     rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(),
135                                                 prevOp.getSource(), newOffsets,
136                                                 newSizes, newStrides);
137     return success();
138   }
139 };
140 
141 /// Merges consecutive tensor.insert_slice ops into one.
142 struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
143   using OpRewritePattern::OpRewritePattern;
144 
145   LogicalResult matchAndRewrite(InsertSliceOp nextOp,
146                                 PatternRewriter &rewriter) const override {
147     auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
148     if (!prevOp)
149       return failure();
150 
151     if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
152       return failure();
153 
154     // The first insert_slice op should be rank reducing to make sure we cover
155     // the full source tensor to be inserted in the second insert_slice op.
156     SliceVerificationResult result =
157         isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
158     if (result != SliceVerificationResult::Success)
159       return failure();
160 
161     // Dynamic dimensions can pass rank reducing check in the above, e.g,
162     // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
163     // the dynamic size covers the full tensor.
164     if (!prevOp.getSourceType().hasStaticShape() ||
165         !prevOp.getDestType().hasStaticShape())
166       return failure();
167 
168     rewriter.replaceOpWithNewOp<InsertSliceOp>(
169         nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
170         nextOp.getMixedSizes(), nextOp.getMixedStrides());
171     return success();
172   }
173 };
174 } // namespace
175 
176 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
177     RewritePatternSet &patterns) {
178   patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
179       patterns.getContext());
180 }
181