xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision 9db53a182705ac1f652c6ee375735bea5539272c)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/Serialization.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "gmock/gmock.h"
22 
23 #include <memory>
24 
25 using namespace mlir;
26 
27 // Load the SPIRV dialect
28 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration;
29 
30 using ::testing::StrEq;
31 
32 //===----------------------------------------------------------------------===//
33 // Test Fixture
34 //===----------------------------------------------------------------------===//
35 
36 /// A deserialization test fixture providing minimal SPIR-V building and
37 /// diagnostic checking utilities.
38 class DeserializationTest : public ::testing::Test {
39 protected:
40   DeserializationTest() {
41     // Register a diagnostic handler to capture the diagnostic so that we can
42     // check it later.
43     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
44       diagnostic.reset(new Diagnostic(std::move(diag)));
45     });
46   }
47 
48   /// Performs deserialization and returns the constructed spv.module op.
49   Optional<spirv::ModuleOp> deserialize() {
50     return spirv::deserialize(binary, &context);
51   }
52 
53   /// Checks there is a diagnostic generated with the given `errorMessage`.
54   void expectDiagnostic(StringRef errorMessage) {
55     ASSERT_NE(nullptr, diagnostic.get());
56 
57     // TODO: check error location too.
58     EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
59   }
60 
61   //===--------------------------------------------------------------------===//
62   // SPIR-V builder methods
63   //===--------------------------------------------------------------------===//
64 
65   /// Adds the SPIR-V module header to `binary`.
66   void addHeader() {
67     spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
68   }
69 
70   /// Adds the SPIR-V instruction into `binary`.
71   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
72     uint32_t wordCount = 1 + operands.size();
73     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
74     binary.append(operands.begin(), operands.end());
75   }
76 
77   uint32_t addVoidType() {
78     auto id = nextID++;
79     addInstruction(spirv::Opcode::OpTypeVoid, {id});
80     return id;
81   }
82 
83   uint32_t addIntType(uint32_t bitwidth) {
84     auto id = nextID++;
85     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
86     return id;
87   }
88 
89   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
90     auto id = nextID++;
91     SmallVector<uint32_t, 2> words;
92     words.push_back(id);
93     words.append(memberTypes.begin(), memberTypes.end());
94     addInstruction(spirv::Opcode::OpTypeStruct, words);
95     return id;
96   }
97 
98   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
99     auto id = nextID++;
100     SmallVector<uint32_t, 4> operands;
101     operands.push_back(id);
102     operands.push_back(retType);
103     operands.append(paramTypes.begin(), paramTypes.end());
104     addInstruction(spirv::Opcode::OpTypeFunction, operands);
105     return id;
106   }
107 
108   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
109     auto id = nextID++;
110     addInstruction(spirv::Opcode::OpFunction,
111                    {retType, id,
112                     static_cast<uint32_t>(spirv::FunctionControl::None),
113                     fnType});
114     return id;
115   }
116 
117   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
118 
119   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
120 
121 protected:
122   SmallVector<uint32_t, 5> binary;
123   uint32_t nextID = 1;
124   MLIRContext context;
125   std::unique_ptr<Diagnostic> diagnostic;
126 };
127 
128 //===----------------------------------------------------------------------===//
129 // Basics
130 //===----------------------------------------------------------------------===//
131 
132 TEST_F(DeserializationTest, EmptyModuleFailure) {
133   ASSERT_EQ(llvm::None, deserialize());
134   expectDiagnostic("SPIR-V binary module must have a 5-word header");
135 }
136 
137 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
138   addHeader();
139   binary.front() = 0xdeadbeef; // Change to a wrong magic number
140   ASSERT_EQ(llvm::None, deserialize());
141   expectDiagnostic("incorrect magic number");
142 }
143 
144 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
145   addHeader();
146   EXPECT_NE(llvm::None, deserialize());
147 }
148 
149 TEST_F(DeserializationTest, ZeroWordCountFailure) {
150   addHeader();
151   binary.push_back(0); // OpNop with zero word count
152 
153   ASSERT_EQ(llvm::None, deserialize());
154   expectDiagnostic("word count cannot be zero");
155 }
156 
157 TEST_F(DeserializationTest, InsufficientWordFailure) {
158   addHeader();
159   binary.push_back((2u << 16) |
160                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
161   // Missing word for type <id>
162 
163   ASSERT_EQ(llvm::None, deserialize());
164   expectDiagnostic("insufficient words for the last instruction");
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // Types
169 //===----------------------------------------------------------------------===//
170 
171 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
172   addHeader();
173   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
174 
175   ASSERT_EQ(llvm::None, deserialize());
176   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // StructType
181 //===----------------------------------------------------------------------===//
182 
183 TEST_F(DeserializationTest, OpMemberNameSuccess) {
184   addHeader();
185   SmallVector<uint32_t, 5> typeDecl;
186   std::swap(typeDecl, binary);
187 
188   auto int32Type = addIntType(32);
189   auto structType = addStructType({int32Type, int32Type});
190   std::swap(typeDecl, binary);
191 
192   SmallVector<uint32_t, 5> operands1 = {structType, 0};
193   spirv::encodeStringLiteralInto(operands1, "i1");
194   addInstruction(spirv::Opcode::OpMemberName, operands1);
195 
196   SmallVector<uint32_t, 5> operands2 = {structType, 1};
197   spirv::encodeStringLiteralInto(operands2, "i2");
198   addInstruction(spirv::Opcode::OpMemberName, operands2);
199 
200   binary.append(typeDecl.begin(), typeDecl.end());
201   EXPECT_NE(llvm::None, deserialize());
202 }
203 
204 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
205   addHeader();
206   SmallVector<uint32_t, 5> typeDecl;
207   std::swap(typeDecl, binary);
208 
209   auto int32Type = addIntType(32);
210   auto int64Type = addIntType(64);
211   auto structType = addStructType({int32Type, int64Type});
212   std::swap(typeDecl, binary);
213 
214   SmallVector<uint32_t, 5> operands1 = {structType};
215   addInstruction(spirv::Opcode::OpMemberName, operands1);
216 
217   binary.append(typeDecl.begin(), typeDecl.end());
218   ASSERT_EQ(llvm::None, deserialize());
219   expectDiagnostic("OpMemberName must have at least 3 operands");
220 }
221 
222 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
223   addHeader();
224   SmallVector<uint32_t, 5> typeDecl;
225   std::swap(typeDecl, binary);
226 
227   auto int32Type = addIntType(32);
228   auto structType = addStructType({int32Type});
229   std::swap(typeDecl, binary);
230 
231   SmallVector<uint32_t, 5> operands = {structType, 0};
232   spirv::encodeStringLiteralInto(operands, "int32");
233   operands.push_back(42);
234   addInstruction(spirv::Opcode::OpMemberName, operands);
235 
236   binary.append(typeDecl.begin(), typeDecl.end());
237   ASSERT_EQ(llvm::None, deserialize());
238   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Functions
243 //===----------------------------------------------------------------------===//
244 
245 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
246   addHeader();
247   auto voidType = addVoidType();
248   auto fnType = addFunctionType(voidType, {});
249   addFunction(voidType, fnType);
250   // Missing OpFunctionEnd
251 
252   ASSERT_EQ(llvm::None, deserialize());
253   expectDiagnostic("expected OpFunctionEnd instruction");
254 }
255 
256 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
257   addHeader();
258   auto voidType = addVoidType();
259   auto i32Type = addIntType(32);
260   auto fnType = addFunctionType(voidType, {i32Type});
261   addFunction(voidType, fnType);
262   // Missing OpFunctionParameter
263 
264   ASSERT_EQ(llvm::None, deserialize());
265   expectDiagnostic("expected OpFunctionParameter instruction");
266 }
267 
268 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
269   addHeader();
270   auto voidType = addVoidType();
271   auto fnType = addFunctionType(voidType, {});
272   addFunction(voidType, fnType);
273   // Missing OpLabel
274   addReturn();
275   addFunctionEnd();
276 
277   ASSERT_EQ(llvm::None, deserialize());
278   expectDiagnostic("a basic block must start with OpLabel");
279 }
280 
281 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
282   addHeader();
283   auto voidType = addVoidType();
284   auto fnType = addFunctionType(voidType, {});
285   addFunction(voidType, fnType);
286   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
287   addReturn();
288   addFunctionEnd();
289 
290   ASSERT_EQ(llvm::None, deserialize());
291   expectDiagnostic("OpLabel should only have result <id>");
292 }
293