1 //===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 13 #include "mlir/Dialect/Tensor/IR/Tensor.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 17 namespace mlir { 18 namespace bufferization { 19 #define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSOR 20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 21 } // namespace bufferization 22 } // namespace mlir 23 24 using namespace mlir; 25 using namespace mlir::bufferization; 26 using namespace mlir::tensor; 27 28 namespace { 29 struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> { 30 using OpRewritePattern<tensor::EmptyOp>::OpRewritePattern; 31 32 LogicalResult matchAndRewrite(tensor::EmptyOp op, 33 PatternRewriter &rewriter) const override { 34 rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>( 35 op, op.getType(), op.getDynamicSizes()); 36 return success(); 37 } 38 }; 39 40 struct EmptyTensorToAllocTensor 41 : public bufferization::impl::EmptyTensorToAllocTensorBase< 42 EmptyTensorToAllocTensor> { 43 EmptyTensorToAllocTensor() = default; 44 45 void runOnOperation() override; 46 47 void getDependentDialects(DialectRegistry ®istry) const override { 48 registry 49 .insert<tensor::TensorDialect, bufferization::BufferizationDialect>(); 50 } 51 }; 52 } // namespace 53 54 void bufferization::populateEmptyTensorToAllocTensorPattern( 55 RewritePatternSet &patterns) { 56 patterns.insert<EmptyTensorLoweringPattern>(patterns.getContext()); 57 } 58 59 void EmptyTensorToAllocTensor::runOnOperation() { 60 Operation *op = getOperation(); 61 RewritePatternSet patterns(op->getContext()); 62 populateEmptyTensorToAllocTensorPattern(patterns); 63 if (failed(applyPatternsGreedily(op, std::move(patterns)))) 64 signalPassFailure(); 65 } 66 67 std::unique_ptr<Pass> 68 mlir::bufferization::createEmptyTensorToAllocTensorPass() { 69 return std::make_unique<EmptyTensorToAllocTensor>(); 70 } 71