xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision ebf521e78483693e5cc9ba4ad3298f0397923624)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
18 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Serialization.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "gmock/gmock.h"
23 
24 #include <memory>
25 
26 using namespace mlir;
27 
28 /// Load the SPIRV dialect.
29 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration;
30 
31 using ::testing::StrEq;
32 
33 //===----------------------------------------------------------------------===//
34 // Test Fixture
35 //===----------------------------------------------------------------------===//
36 
37 /// A deserialization test fixture providing minimal SPIR-V building and
38 /// diagnostic checking utilities.
39 class DeserializationTest : public ::testing::Test {
40 protected:
41   DeserializationTest() : context(/*loadAllDialects=*/false) {
42     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
43     // Register a diagnostic handler to capture the diagnostic so that we can
44     // check it later.
45     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
46       diagnostic.reset(new Diagnostic(std::move(diag)));
47     });
48   }
49 
50   /// Performs deserialization and returns the constructed spv.module op.
51   spirv::OwningSPIRVModuleRef deserialize() {
52     return spirv::deserialize(binary, &context);
53   }
54 
55   /// Checks there is a diagnostic generated with the given `errorMessage`.
56   void expectDiagnostic(StringRef errorMessage) {
57     ASSERT_NE(nullptr, diagnostic.get());
58 
59     // TODO: check error location too.
60     EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
61   }
62 
63   //===--------------------------------------------------------------------===//
64   // SPIR-V builder methods
65   //===--------------------------------------------------------------------===//
66 
67   /// Adds the SPIR-V module header to `binary`.
68   void addHeader() {
69     spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
70   }
71 
72   /// Adds the SPIR-V instruction into `binary`.
73   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
74     uint32_t wordCount = 1 + operands.size();
75     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
76     binary.append(operands.begin(), operands.end());
77   }
78 
79   uint32_t addVoidType() {
80     auto id = nextID++;
81     addInstruction(spirv::Opcode::OpTypeVoid, {id});
82     return id;
83   }
84 
85   uint32_t addIntType(uint32_t bitwidth) {
86     auto id = nextID++;
87     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
88     return id;
89   }
90 
91   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
92     auto id = nextID++;
93     SmallVector<uint32_t, 2> words;
94     words.push_back(id);
95     words.append(memberTypes.begin(), memberTypes.end());
96     addInstruction(spirv::Opcode::OpTypeStruct, words);
97     return id;
98   }
99 
100   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
101     auto id = nextID++;
102     SmallVector<uint32_t, 4> operands;
103     operands.push_back(id);
104     operands.push_back(retType);
105     operands.append(paramTypes.begin(), paramTypes.end());
106     addInstruction(spirv::Opcode::OpTypeFunction, operands);
107     return id;
108   }
109 
110   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
111     auto id = nextID++;
112     addInstruction(spirv::Opcode::OpFunction,
113                    {retType, id,
114                     static_cast<uint32_t>(spirv::FunctionControl::None),
115                     fnType});
116     return id;
117   }
118 
119   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
120 
121   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
122 
123 protected:
124   SmallVector<uint32_t, 5> binary;
125   uint32_t nextID = 1;
126   MLIRContext context;
127   std::unique_ptr<Diagnostic> diagnostic;
128 };
129 
130 //===----------------------------------------------------------------------===//
131 // Basics
132 //===----------------------------------------------------------------------===//
133 
134 TEST_F(DeserializationTest, EmptyModuleFailure) {
135   ASSERT_FALSE(deserialize());
136   expectDiagnostic("SPIR-V binary module must have a 5-word header");
137 }
138 
139 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
140   addHeader();
141   binary.front() = 0xdeadbeef; // Change to a wrong magic number
142   ASSERT_FALSE(deserialize());
143   expectDiagnostic("incorrect magic number");
144 }
145 
146 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
147   addHeader();
148   EXPECT_TRUE(deserialize());
149 }
150 
151 TEST_F(DeserializationTest, ZeroWordCountFailure) {
152   addHeader();
153   binary.push_back(0); // OpNop with zero word count
154 
155   ASSERT_FALSE(deserialize());
156   expectDiagnostic("word count cannot be zero");
157 }
158 
159 TEST_F(DeserializationTest, InsufficientWordFailure) {
160   addHeader();
161   binary.push_back((2u << 16) |
162                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
163   // Missing word for type <id>.
164 
165   ASSERT_FALSE(deserialize());
166   expectDiagnostic("insufficient words for the last instruction");
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // Types
171 //===----------------------------------------------------------------------===//
172 
173 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
174   addHeader();
175   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
176 
177   ASSERT_FALSE(deserialize());
178   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // StructType
183 //===----------------------------------------------------------------------===//
184 
185 TEST_F(DeserializationTest, OpMemberNameSuccess) {
186   addHeader();
187   SmallVector<uint32_t, 5> typeDecl;
188   std::swap(typeDecl, binary);
189 
190   auto int32Type = addIntType(32);
191   auto structType = addStructType({int32Type, int32Type});
192   std::swap(typeDecl, binary);
193 
194   SmallVector<uint32_t, 5> operands1 = {structType, 0};
195   spirv::encodeStringLiteralInto(operands1, "i1");
196   addInstruction(spirv::Opcode::OpMemberName, operands1);
197 
198   SmallVector<uint32_t, 5> operands2 = {structType, 1};
199   spirv::encodeStringLiteralInto(operands2, "i2");
200   addInstruction(spirv::Opcode::OpMemberName, operands2);
201 
202   binary.append(typeDecl.begin(), typeDecl.end());
203   EXPECT_TRUE(deserialize());
204 }
205 
206 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
207   addHeader();
208   SmallVector<uint32_t, 5> typeDecl;
209   std::swap(typeDecl, binary);
210 
211   auto int32Type = addIntType(32);
212   auto int64Type = addIntType(64);
213   auto structType = addStructType({int32Type, int64Type});
214   std::swap(typeDecl, binary);
215 
216   SmallVector<uint32_t, 5> operands1 = {structType};
217   addInstruction(spirv::Opcode::OpMemberName, operands1);
218 
219   binary.append(typeDecl.begin(), typeDecl.end());
220   ASSERT_FALSE(deserialize());
221   expectDiagnostic("OpMemberName must have at least 3 operands");
222 }
223 
224 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
225   addHeader();
226   SmallVector<uint32_t, 5> typeDecl;
227   std::swap(typeDecl, binary);
228 
229   auto int32Type = addIntType(32);
230   auto structType = addStructType({int32Type});
231   std::swap(typeDecl, binary);
232 
233   SmallVector<uint32_t, 5> operands = {structType, 0};
234   spirv::encodeStringLiteralInto(operands, "int32");
235   operands.push_back(42);
236   addInstruction(spirv::Opcode::OpMemberName, operands);
237 
238   binary.append(typeDecl.begin(), typeDecl.end());
239   ASSERT_FALSE(deserialize());
240   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Functions
245 //===----------------------------------------------------------------------===//
246 
247 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
248   addHeader();
249   auto voidType = addVoidType();
250   auto fnType = addFunctionType(voidType, {});
251   addFunction(voidType, fnType);
252   // Missing OpFunctionEnd.
253 
254   ASSERT_FALSE(deserialize());
255   expectDiagnostic("expected OpFunctionEnd instruction");
256 }
257 
258 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
259   addHeader();
260   auto voidType = addVoidType();
261   auto i32Type = addIntType(32);
262   auto fnType = addFunctionType(voidType, {i32Type});
263   addFunction(voidType, fnType);
264   // Missing OpFunctionParameter.
265 
266   ASSERT_FALSE(deserialize());
267   expectDiagnostic("expected OpFunctionParameter instruction");
268 }
269 
270 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
271   addHeader();
272   auto voidType = addVoidType();
273   auto fnType = addFunctionType(voidType, {});
274   addFunction(voidType, fnType);
275   // Missing OpLabel.
276   addReturn();
277   addFunctionEnd();
278 
279   ASSERT_FALSE(deserialize());
280   expectDiagnostic("a basic block must start with OpLabel");
281 }
282 
283 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
284   addHeader();
285   auto voidType = addVoidType();
286   auto fnType = addFunctionType(voidType, {});
287   addFunction(voidType, fnType);
288   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
289   addReturn();
290   addFunctionEnd();
291 
292   ASSERT_FALSE(deserialize());
293   expectDiagnostic("OpLabel should only have result <id>");
294 }
295