xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp (revision 65341b09b0d54ce5318e26a63b84138695d2ac35)
1 //===- AllocationOpInterfaceImpl.cpp - Impl. of AllocationOpInterface -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/Operation.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 struct DefaultAllocationInterface
21     : public bufferization::AllocationOpInterface::ExternalModel<
22           DefaultAllocationInterface, memref::AllocOp> {
buildDealloc__anoned7c1d120111::DefaultAllocationInterface23   static std::optional<Operation *> buildDealloc(OpBuilder &builder,
24                                                  Value alloc) {
25     return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
26         .getOperation();
27   }
buildClone__anoned7c1d120111::DefaultAllocationInterface28   static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
29     return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
30         .getResult();
31   }
getHoistingKind__anoned7c1d120111::DefaultAllocationInterface32   static ::mlir::HoistingKind getHoistingKind() {
33     return HoistingKind::Loop | HoistingKind::Block;
34   }
35   static ::std::optional<::mlir::Operation *>
buildPromotedAlloc__anoned7c1d120111::DefaultAllocationInterface36   buildPromotedAlloc(OpBuilder &builder, Value alloc) {
37     Operation *definingOp = alloc.getDefiningOp();
38     return builder.create<memref::AllocaOp>(
39         definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
40         definingOp->getOperands(), definingOp->getAttrs());
41   }
42 };
43 
44 struct DefaultAutomaticAllocationHoistingInterface
45     : public bufferization::AllocationOpInterface::ExternalModel<
46           DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
getHoistingKind__anoned7c1d120111::DefaultAutomaticAllocationHoistingInterface47   static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
48 };
49 
50 struct DefaultReallocationInterface
51     : public bufferization::AllocationOpInterface::ExternalModel<
52           DefaultAllocationInterface, memref::ReallocOp> {
buildDealloc__anoned7c1d120111::DefaultReallocationInterface53   static std::optional<Operation *> buildDealloc(OpBuilder &builder,
54                                                  Value realloc) {
55     return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
56         .getOperation();
57   }
58 };
59 } // namespace
60 
registerAllocationOpInterfaceExternalModels(DialectRegistry & registry)61 void mlir::memref::registerAllocationOpInterfaceExternalModels(
62     DialectRegistry &registry) {
63   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
64     memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
65     memref::AllocaOp::attachInterface<
66         DefaultAutomaticAllocationHoistingInterface>(*ctx);
67     memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
68   });
69 }
70