1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
10
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14
15 using namespace mlir;
16 using namespace mlir::bufferization;
17 using namespace mlir::ml_program;
18
19 namespace mlir {
20 namespace ml_program {
21 namespace {
22
23 template <typename Interface, typename Op>
24 struct ExternalModelBase
25 : public BufferizableOpInterface::ExternalModel<Interface, Op> {
26
getAliasingValuesmlir::ml_program::__anonc21192f20111::ExternalModelBase27 AliasingValueList getAliasingValues(Operation *, OpOperand &,
28 const AnalysisState &) const {
29 return {};
30 }
31
bufferRelationmlir::ml_program::__anonc21192f20111::ExternalModelBase32 BufferRelation bufferRelation(Operation *, OpResult,
33 const AnalysisState &) const {
34 return BufferRelation::Unknown;
35 }
36 };
37
38 /// Bufferization of ml_program.global into a memref.global
39 struct GlobalOpInterface
40 : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
41
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalOpInterface42 bool bufferizesToMemoryRead(Operation *, OpOperand &,
43 const AnalysisState &) const {
44 return false;
45 }
46
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalOpInterface47 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
48 const AnalysisState &) const {
49 return false;
50 }
51
hasTensorSemanticsmlir::ml_program::__anonc21192f20111::GlobalOpInterface52 bool hasTensorSemantics(Operation *) const { return true; }
53
bufferizemlir::ml_program::__anonc21192f20111::GlobalOpInterface54 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
55 const BufferizationOptions &) const {
56 auto globalOp = cast<GlobalOp>(op);
57 if (!globalOp.getValue().has_value())
58 return globalOp.emitError("global op must have a value");
59
60 auto tensorType = cast<TensorType>(globalOp.getType());
61 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
62
63 replaceOpWithNewBufferizedOp<memref::GlobalOp>(
64 rewriter, globalOp, globalOp.getSymName(),
65 /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
66 /*type=*/cast<MemRefType>(memrefType),
67 /*initial_value=*/globalOp.getValue().value(),
68 /*constant=*/!globalOp.getIsMutable(),
69 /*alignment=*/nullptr);
70
71 return success();
72 }
73 };
74
75 /// Bufferization of ml_program.global_load into a memref.get_global
76 struct GlobalLoadOpInterface
77 : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
78
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface79 bool bufferizesToMemoryRead(Operation *, OpOperand &,
80 const AnalysisState &) const {
81 return false;
82 }
83
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface84 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
85 const AnalysisState &) const {
86 return false;
87 }
88
isWritablemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface89 bool isWritable(Operation *, Value, const AnalysisState &) const {
90 return false;
91 }
92
bufferizemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface93 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
94 const BufferizationOptions &) const {
95 auto globalLoadOp = cast<GlobalLoadOp>(op);
96
97 auto tensorType = cast<TensorType>(globalLoadOp.getType());
98 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
99
100 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
101 rewriter, globalLoadOp, memrefType,
102 globalLoadOp.getGlobalAttr().getLeafReference());
103
104 return success();
105 }
106 };
107
108 /// Bufferization of ml_program.global_store into a memref.get_global and
109 /// memcpy
110 struct GlobalStoreOpInterface
111 : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
112
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface113 bool bufferizesToMemoryRead(Operation *, OpOperand &,
114 const AnalysisState &) const {
115 return false;
116 }
117
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface118 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
119 const AnalysisState &) const {
120 return true;
121 }
122
bufferizemlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface123 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
124 const BufferizationOptions &options) const {
125 auto globalStoreOp = cast<GlobalStoreOp>(op);
126
127 auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
128 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
129
130 auto loc = globalStoreOp.getLoc();
131 auto targetMemref = rewriter.create<memref::GetGlobalOp>(
132 loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
133
134 auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
135 if (failed(sourceMemref)) {
136 return failure();
137 }
138
139 auto memcpy =
140 options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
141 if (failed(memcpy)) {
142 return failure();
143 }
144 rewriter.eraseOp(globalStoreOp);
145
146 return success();
147 }
148 };
149 } // namespace
150
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)151 void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
152 registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
153 GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
154 GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
155 GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
156 });
157 }
158 } // namespace ml_program
159 } // namespace mlir
160