1f6fb0a4fSAlexander Belyaev //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// 2f6fb0a4fSAlexander Belyaev // 3f6fb0a4fSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f6fb0a4fSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 5f6fb0a4fSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f6fb0a4fSAlexander Belyaev // 7f6fb0a4fSAlexander Belyaev //===----------------------------------------------------------------------===// 8f6fb0a4fSAlexander Belyaev // 9f6fb0a4fSAlexander Belyaev #include "mlir/Dialect/Tensor/IR/Tensor.h" 10f6fb0a4fSAlexander Belyaev #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 11f6fb0a4fSAlexander Belyaev #include "mlir/IR/PatternMatch.h" 12f6fb0a4fSAlexander Belyaev #include "llvm/Support/Debug.h" 13f6fb0a4fSAlexander Belyaev 14f6fb0a4fSAlexander Belyaev using namespace mlir; 15f6fb0a4fSAlexander Belyaev using namespace mlir::tensor; 16f6fb0a4fSAlexander Belyaev 17f6fb0a4fSAlexander Belyaev namespace { 18f6fb0a4fSAlexander Belyaev 19f6fb0a4fSAlexander Belyaev template <typename ReshapeOp> 20f6fb0a4fSAlexander Belyaev struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> { 2140052b08SMatthias Springer FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1, 2240052b08SMatthias Springer bool foldSingleUseOnly = false) 2340052b08SMatthias Springer : OpRewritePattern<ReshapeOp>(ctx, benefit), 2440052b08SMatthias Springer foldSingleUseOnly(foldSingleUseOnly) {} 25f6fb0a4fSAlexander Belyaev 26f6fb0a4fSAlexander Belyaev LogicalResult matchAndRewrite(ReshapeOp reshapeOp, 27f6fb0a4fSAlexander Belyaev PatternRewriter &rewriter) const override { 2840052b08SMatthias Springer // Check for tensor.empty source. 2940052b08SMatthias Springer auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>(); 3040052b08SMatthias Springer if (!emptyOp) 31f6fb0a4fSAlexander Belyaev return failure(); 3240052b08SMatthias Springer 3340052b08SMatthias Springer // Check for single use. 3440052b08SMatthias Springer if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) 3540052b08SMatthias Springer return failure(); 3640052b08SMatthias Springer 3740052b08SMatthias Springer // Reify result shape. 38f6fb0a4fSAlexander Belyaev Location loc = reshapeOp.getLoc(); 39f6fb0a4fSAlexander Belyaev ReifiedRankedShapedTypeDims resultShapes; 40758329dcSMatthias Springer if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || 41f6fb0a4fSAlexander Belyaev !llvm::hasSingleElement(resultShapes)) 42f6fb0a4fSAlexander Belyaev return failure(); 4340052b08SMatthias Springer 4440052b08SMatthias Springer // Create new tensor.empty op. 45f6fb0a4fSAlexander Belyaev // TODO: Do not drop tensor type encoding. 462a5b13e7SMatthias Springer Value emptyTensor = rewriter.create<EmptyOp>( 472a5b13e7SMatthias Springer loc, resultShapes[0], reshapeOp.getResultType().getElementType()); 48f6fb0a4fSAlexander Belyaev if (emptyTensor.getType() != reshapeOp.getResultType()) { 49f6fb0a4fSAlexander Belyaev rewriter.replaceOpWithNewOp<tensor::CastOp>( 50f6fb0a4fSAlexander Belyaev reshapeOp, reshapeOp.getResultType(), emptyTensor); 51f6fb0a4fSAlexander Belyaev } else { 52f6fb0a4fSAlexander Belyaev rewriter.replaceOp(reshapeOp, emptyTensor); 53f6fb0a4fSAlexander Belyaev } 54f6fb0a4fSAlexander Belyaev return success(); 55f6fb0a4fSAlexander Belyaev } 5640052b08SMatthias Springer 5740052b08SMatthias Springer private: 5840052b08SMatthias Springer bool foldSingleUseOnly = false; 59f6fb0a4fSAlexander Belyaev }; 60f6fb0a4fSAlexander Belyaev 6140052b08SMatthias Springer /// tensor.empty does not define any tensor contents, so a slice of a 6240052b08SMatthias Springer /// tensor.empty can be folded to a smaller tensor.empty. 63f6fb0a4fSAlexander Belyaev struct FoldEmptyTensorWithExtractSliceOp 64f6fb0a4fSAlexander Belyaev : public OpRewritePattern<ExtractSliceOp> { 6540052b08SMatthias Springer FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx, 6640052b08SMatthias Springer PatternBenefit benefit = 1, 6740052b08SMatthias Springer bool foldSingleUseOnly = false) 6840052b08SMatthias Springer : OpRewritePattern<ExtractSliceOp>(ctx, benefit), 6940052b08SMatthias Springer foldSingleUseOnly(foldSingleUseOnly) {} 70f6fb0a4fSAlexander Belyaev 71f6fb0a4fSAlexander Belyaev LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 72f6fb0a4fSAlexander Belyaev PatternRewriter &rewriter) const override { 7340052b08SMatthias Springer // Check for tensor.empty source. 7440052b08SMatthias Springer auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>(); 7540052b08SMatthias Springer if (!emptyOp) 76f6fb0a4fSAlexander Belyaev return failure(); 77f6fb0a4fSAlexander Belyaev 7840052b08SMatthias Springer // Check for single use. 7940052b08SMatthias Springer if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) 8040052b08SMatthias Springer return failure(); 8140052b08SMatthias Springer 8240052b08SMatthias Springer // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; 8340052b08SMatthias Springer // its dynamic sizes must be preserved as well as its result type. 84f6fb0a4fSAlexander Belyaev auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), 85f6fb0a4fSAlexander Belyaev sliceOp.getType().getElementType(), 86f6fb0a4fSAlexander Belyaev sliceOp.getType().getEncoding()); 87f6fb0a4fSAlexander Belyaev rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType, 88f6fb0a4fSAlexander Belyaev sliceOp.getSizes()); 89f6fb0a4fSAlexander Belyaev return success(); 90f6fb0a4fSAlexander Belyaev } 9140052b08SMatthias Springer 9240052b08SMatthias Springer private: 9340052b08SMatthias Springer bool foldSingleUseOnly = false; 94f6fb0a4fSAlexander Belyaev }; 95f6fb0a4fSAlexander Belyaev 96b5861494SAdam Siemieniuk /// tensor.empty does not define any tensor contents, so an unpadded pack 97b5861494SAdam Siemieniuk /// can be folded away. 98b5861494SAdam Siemieniuk struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> { 99b5861494SAdam Siemieniuk using OpRewritePattern<PackOp>::OpRewritePattern; 100b5861494SAdam Siemieniuk 101b5861494SAdam Siemieniuk LogicalResult matchAndRewrite(PackOp packOp, 102b5861494SAdam Siemieniuk PatternRewriter &rewriter) const override { 103b5861494SAdam Siemieniuk // Check for tensor.empty source. 104b5861494SAdam Siemieniuk auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>(); 105b5861494SAdam Siemieniuk if (!emptyOp) 106b5861494SAdam Siemieniuk return failure(); 107b5861494SAdam Siemieniuk 108b5861494SAdam Siemieniuk // Check for padding. 109b5861494SAdam Siemieniuk // Packing with padding cannot be simply removed. 110b5861494SAdam Siemieniuk if (packOp.getPaddingValue()) 111b5861494SAdam Siemieniuk return rewriter.notifyMatchFailure(packOp, "expects no padding value"); 112b5861494SAdam Siemieniuk 113b5861494SAdam Siemieniuk // Replace the pack directly with its destination. 114b5861494SAdam Siemieniuk rewriter.replaceOp(packOp, packOp.getDest()); 115b5861494SAdam Siemieniuk 116b5861494SAdam Siemieniuk return success(); 117b5861494SAdam Siemieniuk } 118b5861494SAdam Siemieniuk }; 119b5861494SAdam Siemieniuk 120b5861494SAdam Siemieniuk /// tensor.empty does not define any tensor contents, so an unpack 121b5861494SAdam Siemieniuk /// can be folded away. 122b5861494SAdam Siemieniuk struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> { 123b5861494SAdam Siemieniuk using OpRewritePattern<UnPackOp>::OpRewritePattern; 124b5861494SAdam Siemieniuk 125b5861494SAdam Siemieniuk LogicalResult matchAndRewrite(UnPackOp unPackOp, 126b5861494SAdam Siemieniuk PatternRewriter &rewriter) const override { 127b5861494SAdam Siemieniuk // Check for tensor.empty source. 128b5861494SAdam Siemieniuk auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>(); 129b5861494SAdam Siemieniuk if (!emptyOp) 130b5861494SAdam Siemieniuk return failure(); 131b5861494SAdam Siemieniuk 132b5861494SAdam Siemieniuk // Replace the unpack directly with its destination. 133b5861494SAdam Siemieniuk rewriter.replaceOp(unPackOp, unPackOp.getDest()); 134b5861494SAdam Siemieniuk 135b5861494SAdam Siemieniuk return success(); 136b5861494SAdam Siemieniuk } 137b5861494SAdam Siemieniuk }; 138b5861494SAdam Siemieniuk 139*c077a4f3SMaheshRavishankar // Fold concat operation where all the operands are empty. 140*c077a4f3SMaheshRavishankar struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> { 141*c077a4f3SMaheshRavishankar using OpRewritePattern<ConcatOp>::OpRewritePattern; 142*c077a4f3SMaheshRavishankar 143*c077a4f3SMaheshRavishankar LogicalResult matchAndRewrite(tensor::ConcatOp concatOp, 144*c077a4f3SMaheshRavishankar PatternRewriter &rewriter) const override { 145*c077a4f3SMaheshRavishankar auto concatOperands = concatOp.getInputs(); 146*c077a4f3SMaheshRavishankar if (concatOperands.empty()) { 147*c077a4f3SMaheshRavishankar return failure(); 148*c077a4f3SMaheshRavishankar } 149*c077a4f3SMaheshRavishankar auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>(); 150*c077a4f3SMaheshRavishankar if (!firstEmptyOp) { 151*c077a4f3SMaheshRavishankar return failure(); 152*c077a4f3SMaheshRavishankar } 153*c077a4f3SMaheshRavishankar auto isDefinedByEmptyOp = [](Value v) -> bool { 154*c077a4f3SMaheshRavishankar return v.getDefiningOp<tensor::EmptyOp>(); 155*c077a4f3SMaheshRavishankar }; 156*c077a4f3SMaheshRavishankar if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) { 157*c077a4f3SMaheshRavishankar return rewriter.notifyMatchFailure( 158*c077a4f3SMaheshRavishankar concatOp, "not all operands are defined by an empty op"); 159*c077a4f3SMaheshRavishankar } 160*c077a4f3SMaheshRavishankar SmallVector<SmallVector<OpFoldResult>> resultShape; 161*c077a4f3SMaheshRavishankar if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) { 162*c077a4f3SMaheshRavishankar return rewriter.notifyMatchFailure(concatOp, 163*c077a4f3SMaheshRavishankar "failed to get result shape"); 164*c077a4f3SMaheshRavishankar } 165*c077a4f3SMaheshRavishankar rewriter.replaceOpWithNewOp<tensor::EmptyOp>( 166*c077a4f3SMaheshRavishankar concatOp, resultShape[0], concatOp.getResultType().getElementType()); 167*c077a4f3SMaheshRavishankar return success(); 168*c077a4f3SMaheshRavishankar } 169*c077a4f3SMaheshRavishankar }; 170*c077a4f3SMaheshRavishankar 171f6fb0a4fSAlexander Belyaev } // namespace 172f6fb0a4fSAlexander Belyaev 17340052b08SMatthias Springer void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, 17440052b08SMatthias Springer bool foldSingleUseOnly) { 175f6fb0a4fSAlexander Belyaev patterns.add<FoldEmptyTensorWithExtractSliceOp, 176f6fb0a4fSAlexander Belyaev FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>, 177f6fb0a4fSAlexander Belyaev FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>( 17840052b08SMatthias Springer patterns.getContext(), /*benefit=*/1, foldSingleUseOnly); 179*c077a4f3SMaheshRavishankar patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp, 180*c077a4f3SMaheshRavishankar FoldEmptyTensorWithUnPackOp>(patterns.getContext(), 181*c077a4f3SMaheshRavishankar /*benefit=*/1); 182f6fb0a4fSAlexander Belyaev } 183