xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (revision c077a4f305aa7faf92a1438b239078c1da1563a9)
1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 template <typename ReshapeOp>
20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21   FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
22                                bool foldSingleUseOnly = false)
23       : OpRewritePattern<ReshapeOp>(ctx, benefit),
24         foldSingleUseOnly(foldSingleUseOnly) {}
25 
26   LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27                                 PatternRewriter &rewriter) const override {
28     // Check for tensor.empty source.
29     auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
30     if (!emptyOp)
31       return failure();
32 
33     // Check for single use.
34     if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
35       return failure();
36 
37     // Reify result shape.
38     Location loc = reshapeOp.getLoc();
39     ReifiedRankedShapedTypeDims resultShapes;
40     if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41         !llvm::hasSingleElement(resultShapes))
42       return failure();
43 
44     // Create new tensor.empty op.
45     // TODO: Do not drop tensor type encoding.
46     Value emptyTensor = rewriter.create<EmptyOp>(
47         loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48     if (emptyTensor.getType() != reshapeOp.getResultType()) {
49       rewriter.replaceOpWithNewOp<tensor::CastOp>(
50           reshapeOp, reshapeOp.getResultType(), emptyTensor);
51     } else {
52       rewriter.replaceOp(reshapeOp, emptyTensor);
53     }
54     return success();
55   }
56 
57 private:
58   bool foldSingleUseOnly = false;
59 };
60 
61 /// tensor.empty does not define any tensor contents, so a slice of a
62 /// tensor.empty can be folded to a smaller tensor.empty.
63 struct FoldEmptyTensorWithExtractSliceOp
64     : public OpRewritePattern<ExtractSliceOp> {
65   FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
66                                     PatternBenefit benefit = 1,
67                                     bool foldSingleUseOnly = false)
68       : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
69         foldSingleUseOnly(foldSingleUseOnly) {}
70 
71   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72                                 PatternRewriter &rewriter) const override {
73     // Check for tensor.empty source.
74     auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
75     if (!emptyOp)
76       return failure();
77 
78     // Check for single use.
79     if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
80       return failure();
81 
82     // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
83     // its dynamic sizes must be preserved as well as its result type.
84     auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85                                             sliceOp.getType().getElementType(),
86                                             sliceOp.getType().getEncoding());
87     rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88                                          sliceOp.getSizes());
89     return success();
90   }
91 
92 private:
93   bool foldSingleUseOnly = false;
94 };
95 
96 /// tensor.empty does not define any tensor contents, so an unpadded pack
97 /// can be folded away.
98 struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
99   using OpRewritePattern<PackOp>::OpRewritePattern;
100 
101   LogicalResult matchAndRewrite(PackOp packOp,
102                                 PatternRewriter &rewriter) const override {
103     // Check for tensor.empty source.
104     auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
105     if (!emptyOp)
106       return failure();
107 
108     // Check for padding.
109     // Packing with padding cannot be simply removed.
110     if (packOp.getPaddingValue())
111       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
112 
113     // Replace the pack directly with its destination.
114     rewriter.replaceOp(packOp, packOp.getDest());
115 
116     return success();
117   }
118 };
119 
120 /// tensor.empty does not define any tensor contents, so an unpack
121 /// can be folded away.
122 struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
123   using OpRewritePattern<UnPackOp>::OpRewritePattern;
124 
125   LogicalResult matchAndRewrite(UnPackOp unPackOp,
126                                 PatternRewriter &rewriter) const override {
127     // Check for tensor.empty source.
128     auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
129     if (!emptyOp)
130       return failure();
131 
132     // Replace the unpack directly with its destination.
133     rewriter.replaceOp(unPackOp, unPackOp.getDest());
134 
135     return success();
136   }
137 };
138 
139 // Fold concat operation where all the operands are empty.
140 struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
141   using OpRewritePattern<ConcatOp>::OpRewritePattern;
142 
143   LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
144                                 PatternRewriter &rewriter) const override {
145     auto concatOperands = concatOp.getInputs();
146     if (concatOperands.empty()) {
147       return failure();
148     }
149     auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
150     if (!firstEmptyOp) {
151       return failure();
152     }
153     auto isDefinedByEmptyOp = [](Value v) -> bool {
154       return v.getDefiningOp<tensor::EmptyOp>();
155     };
156     if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
157       return rewriter.notifyMatchFailure(
158           concatOp, "not all operands are defined by an empty op");
159     }
160     SmallVector<SmallVector<OpFoldResult>> resultShape;
161     if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
162       return rewriter.notifyMatchFailure(concatOp,
163                                          "failed to get result shape");
164     }
165     rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
166         concatOp, resultShape[0], concatOp.getResultType().getElementType());
167     return success();
168   }
169 };
170 
171 } // namespace
172 
173 void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
174                                                    bool foldSingleUseOnly) {
175   patterns.add<FoldEmptyTensorWithExtractSliceOp,
176                FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
177                FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
178       patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
179   patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
180                FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
181                                             /*benefit=*/1);
182 }
183