xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (revision c077a4f305aa7faf92a1438b239078c1da1563a9)
1f6fb0a4fSAlexander Belyaev //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2f6fb0a4fSAlexander Belyaev //
3f6fb0a4fSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f6fb0a4fSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5f6fb0a4fSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f6fb0a4fSAlexander Belyaev //
7f6fb0a4fSAlexander Belyaev //===----------------------------------------------------------------------===//
8f6fb0a4fSAlexander Belyaev //
9f6fb0a4fSAlexander Belyaev #include "mlir/Dialect/Tensor/IR/Tensor.h"
10f6fb0a4fSAlexander Belyaev #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11f6fb0a4fSAlexander Belyaev #include "mlir/IR/PatternMatch.h"
12f6fb0a4fSAlexander Belyaev #include "llvm/Support/Debug.h"
13f6fb0a4fSAlexander Belyaev 
14f6fb0a4fSAlexander Belyaev using namespace mlir;
15f6fb0a4fSAlexander Belyaev using namespace mlir::tensor;
16f6fb0a4fSAlexander Belyaev 
17f6fb0a4fSAlexander Belyaev namespace {
18f6fb0a4fSAlexander Belyaev 
19f6fb0a4fSAlexander Belyaev template <typename ReshapeOp>
20f6fb0a4fSAlexander Belyaev struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
2140052b08SMatthias Springer   FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1,
2240052b08SMatthias Springer                                bool foldSingleUseOnly = false)
2340052b08SMatthias Springer       : OpRewritePattern<ReshapeOp>(ctx, benefit),
2440052b08SMatthias Springer         foldSingleUseOnly(foldSingleUseOnly) {}
25f6fb0a4fSAlexander Belyaev 
26f6fb0a4fSAlexander Belyaev   LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
27f6fb0a4fSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
2840052b08SMatthias Springer     // Check for tensor.empty source.
2940052b08SMatthias Springer     auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>();
3040052b08SMatthias Springer     if (!emptyOp)
31f6fb0a4fSAlexander Belyaev       return failure();
3240052b08SMatthias Springer 
3340052b08SMatthias Springer     // Check for single use.
3440052b08SMatthias Springer     if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
3540052b08SMatthias Springer       return failure();
3640052b08SMatthias Springer 
3740052b08SMatthias Springer     // Reify result shape.
38f6fb0a4fSAlexander Belyaev     Location loc = reshapeOp.getLoc();
39f6fb0a4fSAlexander Belyaev     ReifiedRankedShapedTypeDims resultShapes;
40758329dcSMatthias Springer     if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
41f6fb0a4fSAlexander Belyaev         !llvm::hasSingleElement(resultShapes))
42f6fb0a4fSAlexander Belyaev       return failure();
4340052b08SMatthias Springer 
4440052b08SMatthias Springer     // Create new tensor.empty op.
45f6fb0a4fSAlexander Belyaev     // TODO: Do not drop tensor type encoding.
462a5b13e7SMatthias Springer     Value emptyTensor = rewriter.create<EmptyOp>(
472a5b13e7SMatthias Springer         loc, resultShapes[0], reshapeOp.getResultType().getElementType());
48f6fb0a4fSAlexander Belyaev     if (emptyTensor.getType() != reshapeOp.getResultType()) {
49f6fb0a4fSAlexander Belyaev       rewriter.replaceOpWithNewOp<tensor::CastOp>(
50f6fb0a4fSAlexander Belyaev           reshapeOp, reshapeOp.getResultType(), emptyTensor);
51f6fb0a4fSAlexander Belyaev     } else {
52f6fb0a4fSAlexander Belyaev       rewriter.replaceOp(reshapeOp, emptyTensor);
53f6fb0a4fSAlexander Belyaev     }
54f6fb0a4fSAlexander Belyaev     return success();
55f6fb0a4fSAlexander Belyaev   }
5640052b08SMatthias Springer 
5740052b08SMatthias Springer private:
5840052b08SMatthias Springer   bool foldSingleUseOnly = false;
59f6fb0a4fSAlexander Belyaev };
60f6fb0a4fSAlexander Belyaev 
6140052b08SMatthias Springer /// tensor.empty does not define any tensor contents, so a slice of a
6240052b08SMatthias Springer /// tensor.empty can be folded to a smaller tensor.empty.
63f6fb0a4fSAlexander Belyaev struct FoldEmptyTensorWithExtractSliceOp
64f6fb0a4fSAlexander Belyaev     : public OpRewritePattern<ExtractSliceOp> {
6540052b08SMatthias Springer   FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx,
6640052b08SMatthias Springer                                     PatternBenefit benefit = 1,
6740052b08SMatthias Springer                                     bool foldSingleUseOnly = false)
6840052b08SMatthias Springer       : OpRewritePattern<ExtractSliceOp>(ctx, benefit),
6940052b08SMatthias Springer         foldSingleUseOnly(foldSingleUseOnly) {}
70f6fb0a4fSAlexander Belyaev 
71f6fb0a4fSAlexander Belyaev   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
72f6fb0a4fSAlexander Belyaev                                 PatternRewriter &rewriter) const override {
7340052b08SMatthias Springer     // Check for tensor.empty source.
7440052b08SMatthias Springer     auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>();
7540052b08SMatthias Springer     if (!emptyOp)
76f6fb0a4fSAlexander Belyaev       return failure();
77f6fb0a4fSAlexander Belyaev 
7840052b08SMatthias Springer     // Check for single use.
7940052b08SMatthias Springer     if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses()))
8040052b08SMatthias Springer       return failure();
8140052b08SMatthias Springer 
8240052b08SMatthias Springer     // Create new tensor.empty op. tensor.extract_slice may be rank-reducing;
8340052b08SMatthias Springer     // its dynamic sizes must be preserved as well as its result type.
84f6fb0a4fSAlexander Belyaev     auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
85f6fb0a4fSAlexander Belyaev                                             sliceOp.getType().getElementType(),
86f6fb0a4fSAlexander Belyaev                                             sliceOp.getType().getEncoding());
87f6fb0a4fSAlexander Belyaev     rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
88f6fb0a4fSAlexander Belyaev                                          sliceOp.getSizes());
89f6fb0a4fSAlexander Belyaev     return success();
90f6fb0a4fSAlexander Belyaev   }
9140052b08SMatthias Springer 
9240052b08SMatthias Springer private:
9340052b08SMatthias Springer   bool foldSingleUseOnly = false;
94f6fb0a4fSAlexander Belyaev };
95f6fb0a4fSAlexander Belyaev 
96b5861494SAdam Siemieniuk /// tensor.empty does not define any tensor contents, so an unpadded pack
97b5861494SAdam Siemieniuk /// can be folded away.
98b5861494SAdam Siemieniuk struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
99b5861494SAdam Siemieniuk   using OpRewritePattern<PackOp>::OpRewritePattern;
100b5861494SAdam Siemieniuk 
101b5861494SAdam Siemieniuk   LogicalResult matchAndRewrite(PackOp packOp,
102b5861494SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
103b5861494SAdam Siemieniuk     // Check for tensor.empty source.
104b5861494SAdam Siemieniuk     auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>();
105b5861494SAdam Siemieniuk     if (!emptyOp)
106b5861494SAdam Siemieniuk       return failure();
107b5861494SAdam Siemieniuk 
108b5861494SAdam Siemieniuk     // Check for padding.
109b5861494SAdam Siemieniuk     // Packing with padding cannot be simply removed.
110b5861494SAdam Siemieniuk     if (packOp.getPaddingValue())
111b5861494SAdam Siemieniuk       return rewriter.notifyMatchFailure(packOp, "expects no padding value");
112b5861494SAdam Siemieniuk 
113b5861494SAdam Siemieniuk     // Replace the pack directly with its destination.
114b5861494SAdam Siemieniuk     rewriter.replaceOp(packOp, packOp.getDest());
115b5861494SAdam Siemieniuk 
116b5861494SAdam Siemieniuk     return success();
117b5861494SAdam Siemieniuk   }
118b5861494SAdam Siemieniuk };
119b5861494SAdam Siemieniuk 
120b5861494SAdam Siemieniuk /// tensor.empty does not define any tensor contents, so an unpack
121b5861494SAdam Siemieniuk /// can be folded away.
122b5861494SAdam Siemieniuk struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
123b5861494SAdam Siemieniuk   using OpRewritePattern<UnPackOp>::OpRewritePattern;
124b5861494SAdam Siemieniuk 
125b5861494SAdam Siemieniuk   LogicalResult matchAndRewrite(UnPackOp unPackOp,
126b5861494SAdam Siemieniuk                                 PatternRewriter &rewriter) const override {
127b5861494SAdam Siemieniuk     // Check for tensor.empty source.
128b5861494SAdam Siemieniuk     auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>();
129b5861494SAdam Siemieniuk     if (!emptyOp)
130b5861494SAdam Siemieniuk       return failure();
131b5861494SAdam Siemieniuk 
132b5861494SAdam Siemieniuk     // Replace the unpack directly with its destination.
133b5861494SAdam Siemieniuk     rewriter.replaceOp(unPackOp, unPackOp.getDest());
134b5861494SAdam Siemieniuk 
135b5861494SAdam Siemieniuk     return success();
136b5861494SAdam Siemieniuk   }
137b5861494SAdam Siemieniuk };
138b5861494SAdam Siemieniuk 
139*c077a4f3SMaheshRavishankar // Fold concat operation where all the operands are empty.
140*c077a4f3SMaheshRavishankar struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
141*c077a4f3SMaheshRavishankar   using OpRewritePattern<ConcatOp>::OpRewritePattern;
142*c077a4f3SMaheshRavishankar 
143*c077a4f3SMaheshRavishankar   LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
144*c077a4f3SMaheshRavishankar                                 PatternRewriter &rewriter) const override {
145*c077a4f3SMaheshRavishankar     auto concatOperands = concatOp.getInputs();
146*c077a4f3SMaheshRavishankar     if (concatOperands.empty()) {
147*c077a4f3SMaheshRavishankar       return failure();
148*c077a4f3SMaheshRavishankar     }
149*c077a4f3SMaheshRavishankar     auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
150*c077a4f3SMaheshRavishankar     if (!firstEmptyOp) {
151*c077a4f3SMaheshRavishankar       return failure();
152*c077a4f3SMaheshRavishankar     }
153*c077a4f3SMaheshRavishankar     auto isDefinedByEmptyOp = [](Value v) -> bool {
154*c077a4f3SMaheshRavishankar       return v.getDefiningOp<tensor::EmptyOp>();
155*c077a4f3SMaheshRavishankar     };
156*c077a4f3SMaheshRavishankar     if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
157*c077a4f3SMaheshRavishankar       return rewriter.notifyMatchFailure(
158*c077a4f3SMaheshRavishankar           concatOp, "not all operands are defined by an empty op");
159*c077a4f3SMaheshRavishankar     }
160*c077a4f3SMaheshRavishankar     SmallVector<SmallVector<OpFoldResult>> resultShape;
161*c077a4f3SMaheshRavishankar     if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
162*c077a4f3SMaheshRavishankar       return rewriter.notifyMatchFailure(concatOp,
163*c077a4f3SMaheshRavishankar                                          "failed to get result shape");
164*c077a4f3SMaheshRavishankar     }
165*c077a4f3SMaheshRavishankar     rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
166*c077a4f3SMaheshRavishankar         concatOp, resultShape[0], concatOp.getResultType().getElementType());
167*c077a4f3SMaheshRavishankar     return success();
168*c077a4f3SMaheshRavishankar   }
169*c077a4f3SMaheshRavishankar };
170*c077a4f3SMaheshRavishankar 
171f6fb0a4fSAlexander Belyaev } // namespace
172f6fb0a4fSAlexander Belyaev 
17340052b08SMatthias Springer void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
17440052b08SMatthias Springer                                                    bool foldSingleUseOnly) {
175f6fb0a4fSAlexander Belyaev   patterns.add<FoldEmptyTensorWithExtractSliceOp,
176f6fb0a4fSAlexander Belyaev                FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
177f6fb0a4fSAlexander Belyaev                FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
17840052b08SMatthias Springer       patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
179*c077a4f3SMaheshRavishankar   patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
180*c077a4f3SMaheshRavishankar                FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
181*c077a4f3SMaheshRavishankar                                             /*benefit=*/1);
182f6fb0a4fSAlexander Belyaev }
183