xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 namespace mlir {
18 namespace bufferization {
19 #define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSOR
20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21 } // namespace bufferization
22 } // namespace mlir
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::tensor;
27 
28 namespace {
29 struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {
30   using OpRewritePattern<tensor::EmptyOp>::OpRewritePattern;
31 
32   LogicalResult matchAndRewrite(tensor::EmptyOp op,
33                                 PatternRewriter &rewriter) const override {
34     rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
35         op, op.getType(), op.getDynamicSizes());
36     return success();
37   }
38 };
39 
40 struct EmptyTensorToAllocTensor
41     : public bufferization::impl::EmptyTensorToAllocTensorBase<
42           EmptyTensorToAllocTensor> {
43   EmptyTensorToAllocTensor() = default;
44 
45   void runOnOperation() override;
46 
47   void getDependentDialects(DialectRegistry &registry) const override {
48     registry
49         .insert<tensor::TensorDialect, bufferization::BufferizationDialect>();
50   }
51 };
52 } // namespace
53 
54 void bufferization::populateEmptyTensorToAllocTensorPattern(
55     RewritePatternSet &patterns) {
56   patterns.insert<EmptyTensorLoweringPattern>(patterns.getContext());
57 }
58 
59 void EmptyTensorToAllocTensor::runOnOperation() {
60   Operation *op = getOperation();
61   RewritePatternSet patterns(op->getContext());
62   populateEmptyTensorToAllocTensorPattern(patterns);
63   if (failed(applyPatternsGreedily(op, std::move(patterns))))
64     signalPassFailure();
65 }
66 
67 std::unique_ptr<Pass>
68 mlir::bufferization::createEmptyTensorToAllocTensorPass() {
69   return std::make_unique<EmptyTensorToAllocTensor>();
70 }
71