xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp (revision bb4c53b7ba113b274ad0fd8d881313509947c896)
1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 /// Adds each corresponding pair of offsets in `offsets1` and `offsets2` and
21 /// returns the results.
22 static SmallVector<OpFoldResult> mergeOffsets(Location loc,
23                                               ArrayRef<OpFoldResult> offsets1,
24                                               ArrayRef<OpFoldResult> offsets2,
25                                               OpBuilder &builder) {
26   SmallVector<OpFoldResult> foldedOffsets;
27   assert(offsets1.size() == offsets2.size());
28   foldedOffsets.reserve(offsets1.size());
29 
30   AffineExpr dim1, dim2;
31   bindDims(builder.getContext(), dim1, dim2);
32 
33   for (const auto &pair : llvm::zip(offsets1, offsets2)) {
34     auto offset0 =
35         getValueOrCreateConstantIndexOp(builder, loc, std::get<0>(pair));
36     auto offset1 =
37         getValueOrCreateConstantIndexOp(builder, loc, std::get<1>(pair));
38     auto foldedOffset =
39         makeComposedAffineApply(builder, loc, dim1 + dim2, {offset0, offset1});
40     foldedOffsets.push_back(foldedOffset.getResult());
41   }
42   return foldedOffsets;
43 }
44 
45 namespace {
46 /// Merges consecutive tensor.extract_slice ops into one.
47 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
48   using OpRewritePattern::OpRewritePattern;
49 
50   LogicalResult matchAndRewrite(ExtractSliceOp nextOp,
51                                 PatternRewriter &rewriter) const override {
52     auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>();
53     if (!prevOp)
54       return failure();
55 
56     if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
57       return failure();
58 
59     auto prevResultType = prevOp.getType().cast<ShapedType>();
60     if (prevOp.getSourceType().getRank() != prevResultType.getRank())
61       return rewriter.notifyMatchFailure(
62           prevOp, "rank-reducing producder case unimplemented");
63 
64     Location loc = nextOp.getLoc();
65 
66     SmallVector<OpFoldResult> prevOffsets = prevOp.getMixedOffsets();
67     SmallVector<OpFoldResult> nextOffsets = nextOp.getMixedOffsets();
68     SmallVector<OpFoldResult> foldedOffsets =
69         mergeOffsets(loc, prevOffsets, nextOffsets, rewriter);
70 
71     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
72         nextOp, nextOp.getType(), prevOp.getSource(), foldedOffsets,
73         nextOp.getMixedSizes(), nextOp.getMixedStrides());
74     return success();
75   }
76 };
77 
78 /// Merges consecutive tensor.insert_slice ops into one.
79 struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
80   using OpRewritePattern::OpRewritePattern;
81 
82   LogicalResult matchAndRewrite(InsertSliceOp nextOp,
83                                 PatternRewriter &rewriter) const override {
84     auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
85     if (!prevOp)
86       return failure();
87 
88     if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride())
89       return failure();
90 
91     // The first insert_slice op should be rank reducing to make sure we cover
92     // the full source tensor to be inserted in the second insert_slice op.
93     SliceVerificationResult result =
94         isRankReducedType(prevOp.getDestType(), prevOp.getSourceType());
95     if (result != SliceVerificationResult::Success)
96       return failure();
97 
98     // Dynamic dimensions can pass rank reducing check in the above, e.g,
99     // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain
100     // the dynamic size covers the full tensor.
101     if (!prevOp.getSourceType().hasStaticShape() ||
102         !prevOp.getDestType().hasStaticShape())
103       return failure();
104 
105     rewriter.replaceOpWithNewOp<InsertSliceOp>(
106         nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
107         nextOp.getMixedSizes(), nextOp.getMixedStrides());
108     return success();
109   }
110 };
111 } // namespace
112 
113 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
114     RewritePatternSet &patterns) {
115   patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
116       patterns.getContext());
117 }
118