xref: /llvm-project/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (revision fe0d16ff60fa2ef7fdbda3574493534979cce742)
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 
15 using namespace mlir;
16 using namespace mlir::bufferization;
17 using namespace mlir::ml_program;
18 
19 namespace mlir {
20 namespace ml_program {
21 namespace {
22 
23 template <typename Interface, typename Op>
24 struct ExternalModelBase
25     : public BufferizableOpInterface::ExternalModel<Interface, Op> {
26 
getAliasingValuesmlir::ml_program::__anonc21192f20111::ExternalModelBase27   AliasingValueList getAliasingValues(Operation *, OpOperand &,
28                                       const AnalysisState &) const {
29     return {};
30   }
31 
bufferRelationmlir::ml_program::__anonc21192f20111::ExternalModelBase32   BufferRelation bufferRelation(Operation *, OpResult,
33                                 const AnalysisState &) const {
34     return BufferRelation::Unknown;
35   }
36 };
37 
38 /// Bufferization of ml_program.global into a memref.global
39 struct GlobalOpInterface
40     : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
41 
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalOpInterface42   bool bufferizesToMemoryRead(Operation *, OpOperand &,
43                               const AnalysisState &) const {
44     return false;
45   }
46 
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalOpInterface47   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
48                                const AnalysisState &) const {
49     return false;
50   }
51 
hasTensorSemanticsmlir::ml_program::__anonc21192f20111::GlobalOpInterface52   bool hasTensorSemantics(Operation *) const { return true; }
53 
bufferizemlir::ml_program::__anonc21192f20111::GlobalOpInterface54   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
55                           const BufferizationOptions &) const {
56     auto globalOp = cast<GlobalOp>(op);
57     if (!globalOp.getValue().has_value())
58       return globalOp.emitError("global op must have a value");
59 
60     auto tensorType = cast<TensorType>(globalOp.getType());
61     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
62 
63     replaceOpWithNewBufferizedOp<memref::GlobalOp>(
64         rewriter, globalOp, globalOp.getSymName(),
65         /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
66         /*type=*/cast<MemRefType>(memrefType),
67         /*initial_value=*/globalOp.getValue().value(),
68         /*constant=*/!globalOp.getIsMutable(),
69         /*alignment=*/nullptr);
70 
71     return success();
72   }
73 };
74 
75 /// Bufferization of ml_program.global_load into a memref.get_global
76 struct GlobalLoadOpInterface
77     : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
78 
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface79   bool bufferizesToMemoryRead(Operation *, OpOperand &,
80                               const AnalysisState &) const {
81     return false;
82   }
83 
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface84   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
85                                const AnalysisState &) const {
86     return false;
87   }
88 
isWritablemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface89   bool isWritable(Operation *, Value, const AnalysisState &) const {
90     return false;
91   }
92 
bufferizemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface93   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
94                           const BufferizationOptions &) const {
95     auto globalLoadOp = cast<GlobalLoadOp>(op);
96 
97     auto tensorType = cast<TensorType>(globalLoadOp.getType());
98     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
99 
100     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
101         rewriter, globalLoadOp, memrefType,
102         globalLoadOp.getGlobalAttr().getLeafReference());
103 
104     return success();
105   }
106 };
107 
108 /// Bufferization of ml_program.global_store into a memref.get_global and
109 /// memcpy
110 struct GlobalStoreOpInterface
111     : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
112 
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface113   bool bufferizesToMemoryRead(Operation *, OpOperand &,
114                               const AnalysisState &) const {
115     return false;
116   }
117 
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface118   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
119                                const AnalysisState &) const {
120     return true;
121   }
122 
bufferizemlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface123   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
124                           const BufferizationOptions &options) const {
125     auto globalStoreOp = cast<GlobalStoreOp>(op);
126 
127     auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
128     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
129 
130     auto loc = globalStoreOp.getLoc();
131     auto targetMemref = rewriter.create<memref::GetGlobalOp>(
132         loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
133 
134     auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
135     if (failed(sourceMemref)) {
136       return failure();
137     }
138 
139     auto memcpy =
140         options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
141     if (failed(memcpy)) {
142       return failure();
143     }
144     rewriter.eraseOp(globalStoreOp);
145 
146     return success();
147   }
148 };
149 } // namespace
150 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)151 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
152   registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
153     GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
154     GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
155     GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
156   });
157 }
158 } // namespace ml_program
159 } // namespace mlir
160