xref: /llvm-project/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (revision fe0d16ff60fa2ef7fdbda3574493534979cce742)
1fa101214SRyan Holt //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2fa101214SRyan Holt //
3fa101214SRyan Holt // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fa101214SRyan Holt // See https://llvm.org/LICENSE.txt for license information.
5fa101214SRyan Holt // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fa101214SRyan Holt //
7fa101214SRyan Holt //===----------------------------------------------------------------------===//
8fa101214SRyan Holt 
9*fe0d16ffSJoel Wee #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
10fa101214SRyan Holt 
11fa101214SRyan Holt #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12fa101214SRyan Holt #include "mlir/Dialect/MLProgram/IR/MLProgram.h"
13fa101214SRyan Holt #include "mlir/Dialect/MemRef/IR/MemRef.h"
14fa101214SRyan Holt 
15fa101214SRyan Holt using namespace mlir;
16fa101214SRyan Holt using namespace mlir::bufferization;
17fa101214SRyan Holt using namespace mlir::ml_program;
18fa101214SRyan Holt 
19fa101214SRyan Holt namespace mlir {
20fa101214SRyan Holt namespace ml_program {
21fa101214SRyan Holt namespace {
22fa101214SRyan Holt 
23fa101214SRyan Holt template <typename Interface, typename Op>
24fa101214SRyan Holt struct ExternalModelBase
25fa101214SRyan Holt     : public BufferizableOpInterface::ExternalModel<Interface, Op> {
26fa101214SRyan Holt 
getAliasingValuesmlir::ml_program::__anonc21192f20111::ExternalModelBase27fa101214SRyan Holt   AliasingValueList getAliasingValues(Operation *, OpOperand &,
28fa101214SRyan Holt                                       const AnalysisState &) const {
29fa101214SRyan Holt     return {};
30fa101214SRyan Holt   }
31fa101214SRyan Holt 
bufferRelationmlir::ml_program::__anonc21192f20111::ExternalModelBase32fa101214SRyan Holt   BufferRelation bufferRelation(Operation *, OpResult,
33fa101214SRyan Holt                                 const AnalysisState &) const {
34fa101214SRyan Holt     return BufferRelation::Unknown;
35fa101214SRyan Holt   }
36fa101214SRyan Holt };
37fa101214SRyan Holt 
38fa101214SRyan Holt /// Bufferization of ml_program.global into a memref.global
39fa101214SRyan Holt struct GlobalOpInterface
40fa101214SRyan Holt     : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
41fa101214SRyan Holt 
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalOpInterface42fa101214SRyan Holt   bool bufferizesToMemoryRead(Operation *, OpOperand &,
43fa101214SRyan Holt                               const AnalysisState &) const {
44fa101214SRyan Holt     return false;
45fa101214SRyan Holt   }
46fa101214SRyan Holt 
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalOpInterface47fa101214SRyan Holt   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
48fa101214SRyan Holt                                const AnalysisState &) const {
49fa101214SRyan Holt     return false;
50fa101214SRyan Holt   }
51fa101214SRyan Holt 
hasTensorSemanticsmlir::ml_program::__anonc21192f20111::GlobalOpInterface52fa101214SRyan Holt   bool hasTensorSemantics(Operation *) const { return true; }
53fa101214SRyan Holt 
bufferizemlir::ml_program::__anonc21192f20111::GlobalOpInterface54fa101214SRyan Holt   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
55fa101214SRyan Holt                           const BufferizationOptions &) const {
56fa101214SRyan Holt     auto globalOp = cast<GlobalOp>(op);
57fa101214SRyan Holt     if (!globalOp.getValue().has_value())
58fa101214SRyan Holt       return globalOp.emitError("global op must have a value");
59fa101214SRyan Holt 
60fa101214SRyan Holt     auto tensorType = cast<TensorType>(globalOp.getType());
61fa101214SRyan Holt     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
62fa101214SRyan Holt 
63fa101214SRyan Holt     replaceOpWithNewBufferizedOp<memref::GlobalOp>(
64fa101214SRyan Holt         rewriter, globalOp, globalOp.getSymName(),
65fa101214SRyan Holt         /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
66fa101214SRyan Holt         /*type=*/cast<MemRefType>(memrefType),
67fa101214SRyan Holt         /*initial_value=*/globalOp.getValue().value(),
68fa101214SRyan Holt         /*constant=*/!globalOp.getIsMutable(),
69fa101214SRyan Holt         /*alignment=*/nullptr);
70fa101214SRyan Holt 
71fa101214SRyan Holt     return success();
72fa101214SRyan Holt   }
73fa101214SRyan Holt };
74fa101214SRyan Holt 
75fa101214SRyan Holt /// Bufferization of ml_program.global_load into a memref.get_global
76fa101214SRyan Holt struct GlobalLoadOpInterface
77fa101214SRyan Holt     : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
78fa101214SRyan Holt 
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface79fa101214SRyan Holt   bool bufferizesToMemoryRead(Operation *, OpOperand &,
80fa101214SRyan Holt                               const AnalysisState &) const {
81fa101214SRyan Holt     return false;
82fa101214SRyan Holt   }
83fa101214SRyan Holt 
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface84fa101214SRyan Holt   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
85fa101214SRyan Holt                                const AnalysisState &) const {
86fa101214SRyan Holt     return false;
87fa101214SRyan Holt   }
88fa101214SRyan Holt 
isWritablemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface89fa101214SRyan Holt   bool isWritable(Operation *, Value, const AnalysisState &) const {
90fa101214SRyan Holt     return false;
91fa101214SRyan Holt   }
92fa101214SRyan Holt 
bufferizemlir::ml_program::__anonc21192f20111::GlobalLoadOpInterface93fa101214SRyan Holt   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
94fa101214SRyan Holt                           const BufferizationOptions &) const {
95fa101214SRyan Holt     auto globalLoadOp = cast<GlobalLoadOp>(op);
96fa101214SRyan Holt 
97fa101214SRyan Holt     auto tensorType = cast<TensorType>(globalLoadOp.getType());
98fa101214SRyan Holt     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
99fa101214SRyan Holt 
100fa101214SRyan Holt     replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
101fa101214SRyan Holt         rewriter, globalLoadOp, memrefType,
102fa101214SRyan Holt         globalLoadOp.getGlobalAttr().getLeafReference());
103fa101214SRyan Holt 
104fa101214SRyan Holt     return success();
105fa101214SRyan Holt   }
106fa101214SRyan Holt };
107fa101214SRyan Holt 
108fa101214SRyan Holt /// Bufferization of ml_program.global_store into a memref.get_global and
109fa101214SRyan Holt /// memcpy
110fa101214SRyan Holt struct GlobalStoreOpInterface
111fa101214SRyan Holt     : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
112fa101214SRyan Holt 
bufferizesToMemoryReadmlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface113fa101214SRyan Holt   bool bufferizesToMemoryRead(Operation *, OpOperand &,
114fa101214SRyan Holt                               const AnalysisState &) const {
115fa101214SRyan Holt     return false;
116fa101214SRyan Holt   }
117fa101214SRyan Holt 
bufferizesToMemoryWritemlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface118fa101214SRyan Holt   bool bufferizesToMemoryWrite(Operation *, OpOperand &,
119fa101214SRyan Holt                                const AnalysisState &) const {
120fa101214SRyan Holt     return true;
121fa101214SRyan Holt   }
122fa101214SRyan Holt 
bufferizemlir::ml_program::__anonc21192f20111::GlobalStoreOpInterface123fa101214SRyan Holt   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
124fa101214SRyan Holt                           const BufferizationOptions &options) const {
125fa101214SRyan Holt     auto globalStoreOp = cast<GlobalStoreOp>(op);
126fa101214SRyan Holt 
127fa101214SRyan Holt     auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
128fa101214SRyan Holt     auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
129fa101214SRyan Holt 
130fa101214SRyan Holt     auto loc = globalStoreOp.getLoc();
131fa101214SRyan Holt     auto targetMemref = rewriter.create<memref::GetGlobalOp>(
132fa101214SRyan Holt         loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
133fa101214SRyan Holt 
134fa101214SRyan Holt     auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
135fa101214SRyan Holt     if (failed(sourceMemref)) {
136fa101214SRyan Holt       return failure();
137fa101214SRyan Holt     }
138fa101214SRyan Holt 
139fa101214SRyan Holt     auto memcpy =
140fa101214SRyan Holt         options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
141fa101214SRyan Holt     if (failed(memcpy)) {
142fa101214SRyan Holt       return failure();
143fa101214SRyan Holt     }
144fa101214SRyan Holt     rewriter.eraseOp(globalStoreOp);
145fa101214SRyan Holt 
146fa101214SRyan Holt     return success();
147fa101214SRyan Holt   }
148fa101214SRyan Holt };
149fa101214SRyan Holt } // namespace
150fa101214SRyan Holt 
registerBufferizableOpInterfaceExternalModels(DialectRegistry & registry)151fa101214SRyan Holt void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
152fa101214SRyan Holt   registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
153fa101214SRyan Holt     GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
154fa101214SRyan Holt     GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
155fa101214SRyan Holt     GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
156fa101214SRyan Holt   });
157fa101214SRyan Holt }
158fa101214SRyan Holt } // namespace ml_program
159fa101214SRyan Holt } // namespace mlir
160