//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::tensor; namespace { template struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1, bool foldSingleUseOnly = false) : OpRewritePattern(ctx, benefit), foldSingleUseOnly(foldSingleUseOnly) {} LogicalResult matchAndRewrite(ReshapeOp reshapeOp, PatternRewriter &rewriter) const override { // Check for tensor.empty source. auto emptyOp = reshapeOp.getSrc().template getDefiningOp(); if (!emptyOp) return failure(); // Check for single use. if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) return failure(); // Reify result shape. Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); // Create new tensor.empty op. // TODO: Do not drop tensor type encoding. Value emptyTensor = rewriter.create( loc, resultShapes[0], reshapeOp.getResultType().getElementType()); if (emptyTensor.getType() != reshapeOp.getResultType()) { rewriter.replaceOpWithNewOp( reshapeOp, reshapeOp.getResultType(), emptyTensor); } else { rewriter.replaceOp(reshapeOp, emptyTensor); } return success(); } private: bool foldSingleUseOnly = false; }; /// tensor.empty does not define any tensor contents, so a slice of a /// tensor.empty can be folded to a smaller tensor.empty. struct FoldEmptyTensorWithExtractSliceOp : public OpRewritePattern { FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx, PatternBenefit benefit = 1, bool foldSingleUseOnly = false) : OpRewritePattern(ctx, benefit), foldSingleUseOnly(foldSingleUseOnly) {} LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { // Check for tensor.empty source. auto emptyOp = sliceOp.getSource().template getDefiningOp(); if (!emptyOp) return failure(); // Check for single use. if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) return failure(); // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; // its dynamic sizes must be preserved as well as its result type. auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), sliceOp.getType().getElementType(), sliceOp.getType().getEncoding()); rewriter.replaceOpWithNewOp(sliceOp, tensorType, sliceOp.getSizes()); return success(); } private: bool foldSingleUseOnly = false; }; /// tensor.empty does not define any tensor contents, so an unpadded pack /// can be folded away. struct FoldEmptyTensorWithPackOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PackOp packOp, PatternRewriter &rewriter) const override { // Check for tensor.empty source. auto emptyOp = packOp.getSource().getDefiningOp(); if (!emptyOp) return failure(); // Check for padding. // Packing with padding cannot be simply removed. if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); // Replace the pack directly with its destination. rewriter.replaceOp(packOp, packOp.getDest()); return success(); } }; /// tensor.empty does not define any tensor contents, so an unpack /// can be folded away. struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { // Check for tensor.empty source. auto emptyOp = unPackOp.getSource().getDefiningOp(); if (!emptyOp) return failure(); // Replace the unpack directly with its destination. rewriter.replaceOp(unPackOp, unPackOp.getDest()); return success(); } }; // Fold concat operation where all the operands are empty. struct FoldConcatsOfEmpty : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ConcatOp concatOp, PatternRewriter &rewriter) const override { auto concatOperands = concatOp.getInputs(); if (concatOperands.empty()) { return failure(); } auto firstEmptyOp = concatOperands.front().getDefiningOp(); if (!firstEmptyOp) { return failure(); } auto isDefinedByEmptyOp = [](Value v) -> bool { return v.getDefiningOp(); }; if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) { return rewriter.notifyMatchFailure( concatOp, "not all operands are defined by an empty op"); } SmallVector> resultShape; if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) { return rewriter.notifyMatchFailure(concatOp, "failed to get result shape"); } rewriter.replaceOpWithNewOp( concatOp, resultShape[0], concatOp.getResultType().getElementType()); return success(); } }; } // namespace void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly) { patterns.add, FoldEmptyTensorWithReshapeOp>( patterns.getContext(), /*benefit=*/1, foldSingleUseOnly); patterns.add(patterns.getContext(), /*benefit=*/1); }