xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision d84fe55e0d4ddc2d2d65a6ff988368281a01b385)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
18 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Serialization.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "gmock/gmock.h"
23 
24 #include <memory>
25 
26 using namespace mlir;
27 
28 /// Load the SPIRV dialect.
29 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration;
30 
31 using ::testing::StrEq;
32 
33 //===----------------------------------------------------------------------===//
34 // Test Fixture
35 //===----------------------------------------------------------------------===//
36 
37 /// A deserialization test fixture providing minimal SPIR-V building and
38 /// diagnostic checking utilities.
39 class DeserializationTest : public ::testing::Test {
40 protected:
41   DeserializationTest() {
42     // Register a diagnostic handler to capture the diagnostic so that we can
43     // check it later.
44     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
45       diagnostic.reset(new Diagnostic(std::move(diag)));
46     });
47   }
48 
49   /// Performs deserialization and returns the constructed spv.module op.
50   spirv::OwningSPIRVModuleRef deserialize() {
51     return spirv::deserialize(binary, &context);
52   }
53 
54   /// Checks there is a diagnostic generated with the given `errorMessage`.
55   void expectDiagnostic(StringRef errorMessage) {
56     ASSERT_NE(nullptr, diagnostic.get());
57 
58     // TODO: check error location too.
59     EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
60   }
61 
62   //===--------------------------------------------------------------------===//
63   // SPIR-V builder methods
64   //===--------------------------------------------------------------------===//
65 
66   /// Adds the SPIR-V module header to `binary`.
67   void addHeader() {
68     spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
69   }
70 
71   /// Adds the SPIR-V instruction into `binary`.
72   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
73     uint32_t wordCount = 1 + operands.size();
74     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
75     binary.append(operands.begin(), operands.end());
76   }
77 
78   uint32_t addVoidType() {
79     auto id = nextID++;
80     addInstruction(spirv::Opcode::OpTypeVoid, {id});
81     return id;
82   }
83 
84   uint32_t addIntType(uint32_t bitwidth) {
85     auto id = nextID++;
86     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
87     return id;
88   }
89 
90   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
91     auto id = nextID++;
92     SmallVector<uint32_t, 2> words;
93     words.push_back(id);
94     words.append(memberTypes.begin(), memberTypes.end());
95     addInstruction(spirv::Opcode::OpTypeStruct, words);
96     return id;
97   }
98 
99   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
100     auto id = nextID++;
101     SmallVector<uint32_t, 4> operands;
102     operands.push_back(id);
103     operands.push_back(retType);
104     operands.append(paramTypes.begin(), paramTypes.end());
105     addInstruction(spirv::Opcode::OpTypeFunction, operands);
106     return id;
107   }
108 
109   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
110     auto id = nextID++;
111     addInstruction(spirv::Opcode::OpFunction,
112                    {retType, id,
113                     static_cast<uint32_t>(spirv::FunctionControl::None),
114                     fnType});
115     return id;
116   }
117 
118   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
119 
120   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
121 
122 protected:
123   SmallVector<uint32_t, 5> binary;
124   uint32_t nextID = 1;
125   MLIRContext context;
126   std::unique_ptr<Diagnostic> diagnostic;
127 };
128 
129 //===----------------------------------------------------------------------===//
130 // Basics
131 //===----------------------------------------------------------------------===//
132 
133 TEST_F(DeserializationTest, EmptyModuleFailure) {
134   ASSERT_FALSE(deserialize());
135   expectDiagnostic("SPIR-V binary module must have a 5-word header");
136 }
137 
138 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
139   addHeader();
140   binary.front() = 0xdeadbeef; // Change to a wrong magic number
141   ASSERT_FALSE(deserialize());
142   expectDiagnostic("incorrect magic number");
143 }
144 
145 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
146   addHeader();
147   EXPECT_TRUE(deserialize());
148 }
149 
150 TEST_F(DeserializationTest, ZeroWordCountFailure) {
151   addHeader();
152   binary.push_back(0); // OpNop with zero word count
153 
154   ASSERT_FALSE(deserialize());
155   expectDiagnostic("word count cannot be zero");
156 }
157 
158 TEST_F(DeserializationTest, InsufficientWordFailure) {
159   addHeader();
160   binary.push_back((2u << 16) |
161                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
162   // Missing word for type <id>.
163 
164   ASSERT_FALSE(deserialize());
165   expectDiagnostic("insufficient words for the last instruction");
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // Types
170 //===----------------------------------------------------------------------===//
171 
172 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
173   addHeader();
174   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
175 
176   ASSERT_FALSE(deserialize());
177   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // StructType
182 //===----------------------------------------------------------------------===//
183 
184 TEST_F(DeserializationTest, OpMemberNameSuccess) {
185   addHeader();
186   SmallVector<uint32_t, 5> typeDecl;
187   std::swap(typeDecl, binary);
188 
189   auto int32Type = addIntType(32);
190   auto structType = addStructType({int32Type, int32Type});
191   std::swap(typeDecl, binary);
192 
193   SmallVector<uint32_t, 5> operands1 = {structType, 0};
194   spirv::encodeStringLiteralInto(operands1, "i1");
195   addInstruction(spirv::Opcode::OpMemberName, operands1);
196 
197   SmallVector<uint32_t, 5> operands2 = {structType, 1};
198   spirv::encodeStringLiteralInto(operands2, "i2");
199   addInstruction(spirv::Opcode::OpMemberName, operands2);
200 
201   binary.append(typeDecl.begin(), typeDecl.end());
202   EXPECT_TRUE(deserialize());
203 }
204 
205 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
206   addHeader();
207   SmallVector<uint32_t, 5> typeDecl;
208   std::swap(typeDecl, binary);
209 
210   auto int32Type = addIntType(32);
211   auto int64Type = addIntType(64);
212   auto structType = addStructType({int32Type, int64Type});
213   std::swap(typeDecl, binary);
214 
215   SmallVector<uint32_t, 5> operands1 = {structType};
216   addInstruction(spirv::Opcode::OpMemberName, operands1);
217 
218   binary.append(typeDecl.begin(), typeDecl.end());
219   ASSERT_FALSE(deserialize());
220   expectDiagnostic("OpMemberName must have at least 3 operands");
221 }
222 
223 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
224   addHeader();
225   SmallVector<uint32_t, 5> typeDecl;
226   std::swap(typeDecl, binary);
227 
228   auto int32Type = addIntType(32);
229   auto structType = addStructType({int32Type});
230   std::swap(typeDecl, binary);
231 
232   SmallVector<uint32_t, 5> operands = {structType, 0};
233   spirv::encodeStringLiteralInto(operands, "int32");
234   operands.push_back(42);
235   addInstruction(spirv::Opcode::OpMemberName, operands);
236 
237   binary.append(typeDecl.begin(), typeDecl.end());
238   ASSERT_FALSE(deserialize());
239   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Functions
244 //===----------------------------------------------------------------------===//
245 
246 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
247   addHeader();
248   auto voidType = addVoidType();
249   auto fnType = addFunctionType(voidType, {});
250   addFunction(voidType, fnType);
251   // Missing OpFunctionEnd.
252 
253   ASSERT_FALSE(deserialize());
254   expectDiagnostic("expected OpFunctionEnd instruction");
255 }
256 
257 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
258   addHeader();
259   auto voidType = addVoidType();
260   auto i32Type = addIntType(32);
261   auto fnType = addFunctionType(voidType, {i32Type});
262   addFunction(voidType, fnType);
263   // Missing OpFunctionParameter.
264 
265   ASSERT_FALSE(deserialize());
266   expectDiagnostic("expected OpFunctionParameter instruction");
267 }
268 
269 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
270   addHeader();
271   auto voidType = addVoidType();
272   auto fnType = addFunctionType(voidType, {});
273   addFunction(voidType, fnType);
274   // Missing OpLabel.
275   addReturn();
276   addFunctionEnd();
277 
278   ASSERT_FALSE(deserialize());
279   expectDiagnostic("a basic block must start with OpLabel");
280 }
281 
282 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
283   addHeader();
284   auto voidType = addVoidType();
285   auto fnType = addFunctionType(voidType, {});
286   addFunction(voidType, fnType);
287   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
288   addReturn();
289   addFunctionEnd();
290 
291   ASSERT_FALSE(deserialize());
292   expectDiagnostic("OpLabel should only have result <id>");
293 }
294