1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "llvm/Support/Debug.h" 13 14 using namespace mlir; 15 using namespace mlir::tensor; 16 17 namespace { 18 19 template <typename ReshapeOp> 20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> { 21 using OpRewritePattern<ReshapeOp>::OpRewritePattern; 22 23 LogicalResult matchAndRewrite(ReshapeOp reshapeOp, 24 PatternRewriter &rewriter) const override { 25 if (!reshapeOp.getSrc().template getDefiningOp<EmptyOp>()) 26 return failure(); 27 Location loc = reshapeOp.getLoc(); 28 ReifiedRankedShapedTypeDims resultShapes; 29 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || 30 !llvm::hasSingleElement(resultShapes)) 31 return failure(); 32 // TODO: Do not drop tensor type encoding. 33 Value emptyTensor = rewriter.create<EmptyOp>( 34 loc, resultShapes[0], reshapeOp.getResultType().getElementType()); 35 if (emptyTensor.getType() != reshapeOp.getResultType()) { 36 rewriter.replaceOpWithNewOp<tensor::CastOp>( 37 reshapeOp, reshapeOp.getResultType(), emptyTensor); 38 } else { 39 rewriter.replaceOp(reshapeOp, emptyTensor); 40 } 41 return success(); 42 } 43 }; 44 45 /// `tensor.empty` does not define any tensor contents, so a slice of a 46 /// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. 47 struct FoldEmptyTensorWithExtractSliceOp 48 : public OpRewritePattern<ExtractSliceOp> { 49 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern; 50 51 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 52 PatternRewriter &rewriter) const override { 53 if (!sliceOp.getSource().getDefiningOp<EmptyOp>()) 54 return failure(); 55 56 // ExtractSliceOp may be rank-reducing; its dynamic sizes must be 57 // preserved as well as its result type. 58 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), 59 sliceOp.getType().getElementType(), 60 sliceOp.getType().getEncoding()); 61 rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType, 62 sliceOp.getSizes()); 63 return success(); 64 } 65 }; 66 67 } // namespace 68 69 void mlir::tensor::populateFoldTensorEmptyPatterns( 70 RewritePatternSet &patterns) { 71 patterns.add<FoldEmptyTensorWithExtractSliceOp, 72 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>, 73 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>( 74 patterns.getContext()); 75 } 76