xref: /llvm-project/mlir/lib/Target/SPIRV/TranslateRegistration.cpp (revision 984b800a036fc61ccb129a8da7592af9cadc94dd)
1 //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation from SPIR-V binary module to MLIR SPIR-V
10 // ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Verifier.h"
20 #include "mlir/Parser/Parser.h"
21 #include "mlir/Support/FileUtilities.h"
22 #include "mlir/Target/SPIRV/Deserialization.h"
23 #include "mlir/Target/SPIRV/Serialization.h"
24 #include "mlir/Tools/mlir-translate/Translation.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SMLoc.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include "llvm/Support/ToolOutputFile.h"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Deserialization registration
35 //===----------------------------------------------------------------------===//
36 
37 // Deserializes the SPIR-V binary module stored in the file named as
38 // `inputFilename` and returns a module containing the SPIR-V module.
39 static OwningOpRef<Operation *>
deserializeModule(const llvm::MemoryBuffer * input,MLIRContext * context)40 deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) {
41   context->loadDialect<spirv::SPIRVDialect>();
42 
43   // Make sure the input stream can be treated as a stream of SPIR-V words
44   auto *start = input->getBufferStart();
45   auto size = input->getBufferSize();
46   if (size % sizeof(uint32_t) != 0) {
47     emitError(UnknownLoc::get(context))
48         << "SPIR-V binary module must contain integral number of 32-bit words";
49     return {};
50   }
51 
52   auto binary = llvm::ArrayRef(reinterpret_cast<const uint32_t *>(start),
53                                size / sizeof(uint32_t));
54   return spirv::deserialize(binary, context);
55 }
56 
57 namespace mlir {
registerFromSPIRVTranslation()58 void registerFromSPIRVTranslation() {
59   TranslateToMLIRRegistration fromBinary(
60       "deserialize-spirv", "deserializes the SPIR-V module",
61       [](llvm::SourceMgr &sourceMgr, MLIRContext *context) {
62         assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer");
63         return deserializeModule(
64             sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context);
65       });
66 }
67 } // namespace mlir
68 
69 //===----------------------------------------------------------------------===//
70 // Serialization registration
71 //===----------------------------------------------------------------------===//
72 
serializeModule(spirv::ModuleOp module,raw_ostream & output)73 static LogicalResult serializeModule(spirv::ModuleOp module,
74                                      raw_ostream &output) {
75   SmallVector<uint32_t, 0> binary;
76   if (failed(spirv::serialize(module, binary)))
77     return failure();
78 
79   output.write(reinterpret_cast<char *>(binary.data()),
80                binary.size() * sizeof(uint32_t));
81 
82   return mlir::success();
83 }
84 
85 namespace mlir {
registerToSPIRVTranslation()86 void registerToSPIRVTranslation() {
87   TranslateFromMLIRRegistration toBinary(
88       "serialize-spirv", "serialize SPIR-V dialect",
89       [](spirv::ModuleOp module, raw_ostream &output) {
90         return serializeModule(module, output);
91       },
92       [](DialectRegistry &registry) {
93         registry.insert<spirv::SPIRVDialect>();
94       });
95 }
96 } // namespace mlir
97 
98 //===----------------------------------------------------------------------===//
99 // Round-trip registration
100 //===----------------------------------------------------------------------===//
101 
roundTripModule(spirv::ModuleOp module,bool emitDebugInfo,raw_ostream & output)102 static LogicalResult roundTripModule(spirv::ModuleOp module, bool emitDebugInfo,
103                                      raw_ostream &output) {
104   SmallVector<uint32_t, 0> binary;
105   MLIRContext *context = module->getContext();
106 
107   spirv::SerializationOptions options;
108   options.emitDebugInfo = emitDebugInfo;
109   if (failed(spirv::serialize(module, binary, options)))
110     return failure();
111 
112   MLIRContext deserializationContext(context->getDialectRegistry());
113   // TODO: we should only load the required dialects instead of all dialects.
114   deserializationContext.loadAllAvailableDialects();
115   // Then deserialize to get back a SPIR-V module.
116   OwningOpRef<spirv::ModuleOp> spirvModule =
117       spirv::deserialize(binary, &deserializationContext);
118   if (!spirvModule)
119     return failure();
120   spirvModule->print(output);
121 
122   return mlir::success();
123 }
124 
125 namespace mlir {
registerTestRoundtripSPIRV()126 void registerTestRoundtripSPIRV() {
127   TranslateFromMLIRRegistration roundtrip(
128       "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect",
129       [](spirv::ModuleOp module, raw_ostream &output) {
130         return roundTripModule(module, /*emitDebugInfo=*/false, output);
131       },
132       [](DialectRegistry &registry) {
133         registry.insert<spirv::SPIRVDialect>();
134       });
135 }
136 
registerTestRoundtripDebugSPIRV()137 void registerTestRoundtripDebugSPIRV() {
138   TranslateFromMLIRRegistration roundtrip(
139       "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V",
140       [](spirv::ModuleOp module, raw_ostream &output) {
141         return roundTripModule(module, /*emitDebugInfo=*/true, output);
142       },
143       [](DialectRegistry &registry) {
144         registry.insert<spirv::SPIRVDialect>();
145       });
146 }
147 } // namespace mlir
148