xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision 308571074c13ea2a0758aa085aa02f72150f891e)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Serialization.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "gmock/gmock.h"
21 
22 #include <memory>
23 
24 using namespace mlir;
25 
26 using ::testing::StrEq;
27 
28 //===----------------------------------------------------------------------===//
29 // Test Fixture
30 //===----------------------------------------------------------------------===//
31 
32 /// A deserialization test fixture providing minimal SPIR-V building and
33 /// diagnostic checking utilities.
34 class DeserializationTest : public ::testing::Test {
35 protected:
36   DeserializationTest() {
37     // Register a diagnostic handler to capture the diagnostic so that we can
38     // check it later.
39     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
40       diagnostic.reset(new Diagnostic(std::move(diag)));
41     });
42   }
43 
44   /// Performs deserialization and returns the constructed spv.module op.
45   Optional<spirv::ModuleOp> deserialize() {
46     return spirv::deserialize(binary, &context);
47   }
48 
49   /// Checks there is a diagnostic generated with the given `errorMessage`.
50   void expectDiagnostic(StringRef errorMessage) {
51     ASSERT_NE(nullptr, diagnostic.get());
52 
53     // TODO(antiagainst): check error location too.
54     EXPECT_THAT(diagnostic->str(), StrEq(errorMessage));
55   }
56 
57   //===--------------------------------------------------------------------===//
58   // SPIR-V builder methods
59   //===--------------------------------------------------------------------===//
60 
61   /// Adds the SPIR-V module header to `binary`.
62   void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); }
63 
64   /// Adds the SPIR-V instruction into `binary`.
65   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
66     uint32_t wordCount = 1 + operands.size();
67     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
68     binary.append(operands.begin(), operands.end());
69   }
70 
71   uint32_t addVoidType() {
72     auto id = nextID++;
73     addInstruction(spirv::Opcode::OpTypeVoid, {id});
74     return id;
75   }
76 
77   uint32_t addIntType(uint32_t bitwidth) {
78     auto id = nextID++;
79     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
80     return id;
81   }
82 
83   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
84     auto id = nextID++;
85     SmallVector<uint32_t, 2> words;
86     words.push_back(id);
87     words.append(memberTypes.begin(), memberTypes.end());
88     addInstruction(spirv::Opcode::OpTypeStruct, words);
89     return id;
90   }
91 
92   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
93     auto id = nextID++;
94     SmallVector<uint32_t, 4> operands;
95     operands.push_back(id);
96     operands.push_back(retType);
97     operands.append(paramTypes.begin(), paramTypes.end());
98     addInstruction(spirv::Opcode::OpTypeFunction, operands);
99     return id;
100   }
101 
102   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
103     auto id = nextID++;
104     addInstruction(spirv::Opcode::OpFunction,
105                    {retType, id,
106                     static_cast<uint32_t>(spirv::FunctionControl::None),
107                     fnType});
108     return id;
109   }
110 
111   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
112 
113   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
114 
115 protected:
116   SmallVector<uint32_t, 5> binary;
117   uint32_t nextID = 1;
118   MLIRContext context;
119   std::unique_ptr<Diagnostic> diagnostic;
120 };
121 
122 //===----------------------------------------------------------------------===//
123 // Basics
124 //===----------------------------------------------------------------------===//
125 
126 TEST_F(DeserializationTest, EmptyModuleFailure) {
127   ASSERT_EQ(llvm::None, deserialize());
128   expectDiagnostic("SPIR-V binary module must have a 5-word header");
129 }
130 
131 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
132   addHeader();
133   binary.front() = 0xdeadbeef; // Change to a wrong magic number
134   ASSERT_EQ(llvm::None, deserialize());
135   expectDiagnostic("incorrect magic number");
136 }
137 
138 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
139   addHeader();
140   EXPECT_NE(llvm::None, deserialize());
141 }
142 
143 TEST_F(DeserializationTest, ZeroWordCountFailure) {
144   addHeader();
145   binary.push_back(0); // OpNop with zero word count
146 
147   ASSERT_EQ(llvm::None, deserialize());
148   expectDiagnostic("word count cannot be zero");
149 }
150 
151 TEST_F(DeserializationTest, InsufficientWordFailure) {
152   addHeader();
153   binary.push_back((2u << 16) |
154                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
155   // Missing word for type <id>
156 
157   ASSERT_EQ(llvm::None, deserialize());
158   expectDiagnostic("insufficient words for the last instruction");
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // Types
163 //===----------------------------------------------------------------------===//
164 
165 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
166   addHeader();
167   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
168 
169   ASSERT_EQ(llvm::None, deserialize());
170   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // StructType
175 //===----------------------------------------------------------------------===//
176 
177 TEST_F(DeserializationTest, OpMemberNameSuccess) {
178   addHeader();
179   SmallVector<uint32_t, 5> typeDecl;
180   std::swap(typeDecl, binary);
181 
182   auto int32Type = addIntType(32);
183   auto structType = addStructType({int32Type, int32Type});
184   std::swap(typeDecl, binary);
185 
186   SmallVector<uint32_t, 5> operands1 = {structType, 0};
187   spirv::encodeStringLiteralInto(operands1, "i1");
188   addInstruction(spirv::Opcode::OpMemberName, operands1);
189 
190   SmallVector<uint32_t, 5> operands2 = {structType, 1};
191   spirv::encodeStringLiteralInto(operands2, "i2");
192   addInstruction(spirv::Opcode::OpMemberName, operands2);
193 
194   binary.append(typeDecl.begin(), typeDecl.end());
195   EXPECT_NE(llvm::None, deserialize());
196 }
197 
198 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
199   addHeader();
200   SmallVector<uint32_t, 5> typeDecl;
201   std::swap(typeDecl, binary);
202 
203   auto int32Type = addIntType(32);
204   auto int64Type = addIntType(64);
205   auto structType = addStructType({int32Type, int64Type});
206   std::swap(typeDecl, binary);
207 
208   SmallVector<uint32_t, 5> operands1 = {structType};
209   addInstruction(spirv::Opcode::OpMemberName, operands1);
210 
211   binary.append(typeDecl.begin(), typeDecl.end());
212   ASSERT_EQ(llvm::None, deserialize());
213   expectDiagnostic("OpMemberName must have at least 3 operands");
214 }
215 
216 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
217   addHeader();
218   SmallVector<uint32_t, 5> typeDecl;
219   std::swap(typeDecl, binary);
220 
221   auto int32Type = addIntType(32);
222   auto structType = addStructType({int32Type});
223   std::swap(typeDecl, binary);
224 
225   SmallVector<uint32_t, 5> operands = {structType, 0};
226   spirv::encodeStringLiteralInto(operands, "int32");
227   operands.push_back(42);
228   addInstruction(spirv::Opcode::OpMemberName, operands);
229 
230   binary.append(typeDecl.begin(), typeDecl.end());
231   ASSERT_EQ(llvm::None, deserialize());
232   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Functions
237 //===----------------------------------------------------------------------===//
238 
239 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
240   addHeader();
241   auto voidType = addVoidType();
242   auto fnType = addFunctionType(voidType, {});
243   addFunction(voidType, fnType);
244   // Missing OpFunctionEnd
245 
246   ASSERT_EQ(llvm::None, deserialize());
247   expectDiagnostic("expected OpFunctionEnd instruction");
248 }
249 
250 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
251   addHeader();
252   auto voidType = addVoidType();
253   auto i32Type = addIntType(32);
254   auto fnType = addFunctionType(voidType, {i32Type});
255   addFunction(voidType, fnType);
256   // Missing OpFunctionParameter
257 
258   ASSERT_EQ(llvm::None, deserialize());
259   expectDiagnostic("expected OpFunctionParameter instruction");
260 }
261 
262 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
263   addHeader();
264   auto voidType = addVoidType();
265   auto fnType = addFunctionType(voidType, {});
266   addFunction(voidType, fnType);
267   // Missing OpLabel
268   addReturn();
269   addFunctionEnd();
270 
271   ASSERT_EQ(llvm::None, deserialize());
272   expectDiagnostic("a basic block must start with OpLabel");
273 }
274 
275 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
276   addHeader();
277   auto voidType = addVoidType();
278   auto fnType = addFunctionType(voidType, {});
279   addFunction(voidType, fnType);
280   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
281   addReturn();
282   addFunctionEnd();
283 
284   ASSERT_EQ(llvm::None, deserialize());
285   expectDiagnostic("OpLabel should only have result <id>");
286 }
287