1 //===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "llvm/Support/Debug.h" 13 14 using namespace mlir; 15 using namespace mlir::tensor; 16 17 namespace { 18 19 template <typename ReshapeOp> 20 struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> { 21 FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1, 22 bool foldSingleUseOnly = false) 23 : OpRewritePattern<ReshapeOp>(ctx, benefit), 24 foldSingleUseOnly(foldSingleUseOnly) {} 25 26 LogicalResult matchAndRewrite(ReshapeOp reshapeOp, 27 PatternRewriter &rewriter) const override { 28 // Check for tensor.empty source. 29 auto emptyOp = reshapeOp.getSrc().template getDefiningOp<EmptyOp>(); 30 if (!emptyOp) 31 return failure(); 32 33 // Check for single use. 34 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) 35 return failure(); 36 37 // Reify result shape. 38 Location loc = reshapeOp.getLoc(); 39 ReifiedRankedShapedTypeDims resultShapes; 40 if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || 41 !llvm::hasSingleElement(resultShapes)) 42 return failure(); 43 44 // Create new tensor.empty op. 45 // TODO: Do not drop tensor type encoding. 46 Value emptyTensor = rewriter.create<EmptyOp>( 47 loc, resultShapes[0], reshapeOp.getResultType().getElementType()); 48 if (emptyTensor.getType() != reshapeOp.getResultType()) { 49 rewriter.replaceOpWithNewOp<tensor::CastOp>( 50 reshapeOp, reshapeOp.getResultType(), emptyTensor); 51 } else { 52 rewriter.replaceOp(reshapeOp, emptyTensor); 53 } 54 return success(); 55 } 56 57 private: 58 bool foldSingleUseOnly = false; 59 }; 60 61 /// tensor.empty does not define any tensor contents, so a slice of a 62 /// tensor.empty can be folded to a smaller tensor.empty. 63 struct FoldEmptyTensorWithExtractSliceOp 64 : public OpRewritePattern<ExtractSliceOp> { 65 FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx, 66 PatternBenefit benefit = 1, 67 bool foldSingleUseOnly = false) 68 : OpRewritePattern<ExtractSliceOp>(ctx, benefit), 69 foldSingleUseOnly(foldSingleUseOnly) {} 70 71 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, 72 PatternRewriter &rewriter) const override { 73 // Check for tensor.empty source. 74 auto emptyOp = sliceOp.getSource().template getDefiningOp<EmptyOp>(); 75 if (!emptyOp) 76 return failure(); 77 78 // Check for single use. 79 if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) 80 return failure(); 81 82 // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; 83 // its dynamic sizes must be preserved as well as its result type. 84 auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), 85 sliceOp.getType().getElementType(), 86 sliceOp.getType().getEncoding()); 87 rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType, 88 sliceOp.getSizes()); 89 return success(); 90 } 91 92 private: 93 bool foldSingleUseOnly = false; 94 }; 95 96 /// tensor.empty does not define any tensor contents, so an unpadded pack 97 /// can be folded away. 98 struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> { 99 using OpRewritePattern<PackOp>::OpRewritePattern; 100 101 LogicalResult matchAndRewrite(PackOp packOp, 102 PatternRewriter &rewriter) const override { 103 // Check for tensor.empty source. 104 auto emptyOp = packOp.getSource().getDefiningOp<EmptyOp>(); 105 if (!emptyOp) 106 return failure(); 107 108 // Check for padding. 109 // Packing with padding cannot be simply removed. 110 if (packOp.getPaddingValue()) 111 return rewriter.notifyMatchFailure(packOp, "expects no padding value"); 112 113 // Replace the pack directly with its destination. 114 rewriter.replaceOp(packOp, packOp.getDest()); 115 116 return success(); 117 } 118 }; 119 120 /// tensor.empty does not define any tensor contents, so an unpack 121 /// can be folded away. 122 struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> { 123 using OpRewritePattern<UnPackOp>::OpRewritePattern; 124 125 LogicalResult matchAndRewrite(UnPackOp unPackOp, 126 PatternRewriter &rewriter) const override { 127 // Check for tensor.empty source. 128 auto emptyOp = unPackOp.getSource().getDefiningOp<EmptyOp>(); 129 if (!emptyOp) 130 return failure(); 131 132 // Replace the unpack directly with its destination. 133 rewriter.replaceOp(unPackOp, unPackOp.getDest()); 134 135 return success(); 136 } 137 }; 138 139 // Fold concat operation where all the operands are empty. 140 struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> { 141 using OpRewritePattern<ConcatOp>::OpRewritePattern; 142 143 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp, 144 PatternRewriter &rewriter) const override { 145 auto concatOperands = concatOp.getInputs(); 146 if (concatOperands.empty()) { 147 return failure(); 148 } 149 auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>(); 150 if (!firstEmptyOp) { 151 return failure(); 152 } 153 auto isDefinedByEmptyOp = [](Value v) -> bool { 154 return v.getDefiningOp<tensor::EmptyOp>(); 155 }; 156 if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) { 157 return rewriter.notifyMatchFailure( 158 concatOp, "not all operands are defined by an empty op"); 159 } 160 SmallVector<SmallVector<OpFoldResult>> resultShape; 161 if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) { 162 return rewriter.notifyMatchFailure(concatOp, 163 "failed to get result shape"); 164 } 165 rewriter.replaceOpWithNewOp<tensor::EmptyOp>( 166 concatOp, resultShape[0], concatOp.getResultType().getElementType()); 167 return success(); 168 } 169 }; 170 171 } // namespace 172 173 void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, 174 bool foldSingleUseOnly) { 175 patterns.add<FoldEmptyTensorWithExtractSliceOp, 176 FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>, 177 FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>( 178 patterns.getContext(), /*benefit=*/1, foldSingleUseOnly); 179 patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp, 180 FoldEmptyTensorWithUnPackOp>(patterns.getContext(), 181 /*benefit=*/1); 182 } 183