xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (revision 758329dc7cd3b0da835a4f865b89003263050080)
1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11 #include "mlir/IR/PatternMatch.h"
12 #include "llvm/Support/Debug.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 template <typename ReshapeOp>
20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
21   using OpRewritePattern<ReshapeOp>::OpRewritePattern;
22 
23   LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
24                                 PatternRewriter &rewriter) const override {
25     if (!reshapeOp.getSrc().template getDefiningOp<EmptyOp>())
26       return failure();
27     Location loc = reshapeOp.getLoc();
28     ReifiedRankedShapedTypeDims resultShapes;
29     if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) ||
30         !llvm::hasSingleElement(resultShapes))
31       return failure();
32     // TODO: Do not drop tensor type encoding.
33     Value emptyTensor = rewriter.create<EmptyOp>(
34         loc, resultShapes[0], reshapeOp.getResultType().getElementType());
35     if (emptyTensor.getType() != reshapeOp.getResultType()) {
36       rewriter.replaceOpWithNewOp<tensor::CastOp>(
37           reshapeOp, reshapeOp.getResultType(), emptyTensor);
38     } else {
39       rewriter.replaceOp(reshapeOp, emptyTensor);
40     }
41     return success();
42   }
43 };
44 
45 /// `tensor.empty` does not define any tensor contents, so a slice of a
46 /// `tensor.empty` can be canonicalized to a smaller `tensor.empty`.
47 struct FoldEmptyTensorWithExtractSliceOp
48     : public OpRewritePattern<ExtractSliceOp> {
49   using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
50 
51   LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
52                                 PatternRewriter &rewriter) const override {
53     if (!sliceOp.getSource().getDefiningOp<EmptyOp>())
54       return failure();
55 
56     // ExtractSliceOp may be rank-reducing; its dynamic sizes must be
57     // preserved as well as its result type.
58     auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
59                                             sliceOp.getType().getElementType(),
60                                             sliceOp.getType().getEncoding());
61     rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
62                                          sliceOp.getSizes());
63     return success();
64   }
65 };
66 
67 } // namespace
68 
69 void mlir::tensor::populateFoldTensorEmptyPatterns(
70     RewritePatternSet &patterns) {
71   patterns.add<FoldEmptyTensorWithExtractSliceOp,
72                FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
73                FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
74       patterns.getContext());
75 }
76