xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision c64770506b89a2376fe13080bc3b72789e6c752d)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/Serialization.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "gmock/gmock.h"
22 
23 #include <memory>
24 
25 using namespace mlir;
26 
27 // Load the SPIRV dialect
28 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration;
29 
30 using ::testing::StrEq;
31 
32 //===----------------------------------------------------------------------===//
33 // Test Fixture
34 //===----------------------------------------------------------------------===//
35 
36 /// A deserialization test fixture providing minimal SPIR-V building and
37 /// diagnostic checking utilities.
38 class DeserializationTest : public ::testing::Test {
39 protected:
40   DeserializationTest() {
41     // Register a diagnostic handler to capture the diagnostic so that we can
42     // check it later.
43     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
44       diagnostic.reset(new Diagnostic(std::move(diag)));
45     });
46   }
47 
48   /// Performs deserialization and returns the constructed spv.module op.
49   Optional<spirv::ModuleOp> deserialize() {
50     return spirv::deserialize(binary, &context);
51   }
52 
53   /// Checks there is a diagnostic generated with the given `errorMessage`.
54   void expectDiagnostic(StringRef errorMessage) {
55     ASSERT_NE(nullptr, diagnostic.get());
56 
57     // TODO(antiagainst): check error location too.
58     EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
59   }
60 
61   //===--------------------------------------------------------------------===//
62   // SPIR-V builder methods
63   //===--------------------------------------------------------------------===//
64 
65   /// Adds the SPIR-V module header to `binary`.
66   void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); }
67 
68   /// Adds the SPIR-V instruction into `binary`.
69   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
70     uint32_t wordCount = 1 + operands.size();
71     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
72     binary.append(operands.begin(), operands.end());
73   }
74 
75   uint32_t addVoidType() {
76     auto id = nextID++;
77     addInstruction(spirv::Opcode::OpTypeVoid, {id});
78     return id;
79   }
80 
81   uint32_t addIntType(uint32_t bitwidth) {
82     auto id = nextID++;
83     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
84     return id;
85   }
86 
87   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
88     auto id = nextID++;
89     SmallVector<uint32_t, 2> words;
90     words.push_back(id);
91     words.append(memberTypes.begin(), memberTypes.end());
92     addInstruction(spirv::Opcode::OpTypeStruct, words);
93     return id;
94   }
95 
96   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
97     auto id = nextID++;
98     SmallVector<uint32_t, 4> operands;
99     operands.push_back(id);
100     operands.push_back(retType);
101     operands.append(paramTypes.begin(), paramTypes.end());
102     addInstruction(spirv::Opcode::OpTypeFunction, operands);
103     return id;
104   }
105 
106   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
107     auto id = nextID++;
108     addInstruction(spirv::Opcode::OpFunction,
109                    {retType, id,
110                     static_cast<uint32_t>(spirv::FunctionControl::None),
111                     fnType});
112     return id;
113   }
114 
115   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
116 
117   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
118 
119 protected:
120   SmallVector<uint32_t, 5> binary;
121   uint32_t nextID = 1;
122   MLIRContext context;
123   std::unique_ptr<Diagnostic> diagnostic;
124 };
125 
126 //===----------------------------------------------------------------------===//
127 // Basics
128 //===----------------------------------------------------------------------===//
129 
130 TEST_F(DeserializationTest, EmptyModuleFailure) {
131   ASSERT_EQ(llvm::None, deserialize());
132   expectDiagnostic("SPIR-V binary module must have a 5-word header");
133 }
134 
135 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
136   addHeader();
137   binary.front() = 0xdeadbeef; // Change to a wrong magic number
138   ASSERT_EQ(llvm::None, deserialize());
139   expectDiagnostic("incorrect magic number");
140 }
141 
142 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
143   addHeader();
144   EXPECT_NE(llvm::None, deserialize());
145 }
146 
147 TEST_F(DeserializationTest, ZeroWordCountFailure) {
148   addHeader();
149   binary.push_back(0); // OpNop with zero word count
150 
151   ASSERT_EQ(llvm::None, deserialize());
152   expectDiagnostic("word count cannot be zero");
153 }
154 
155 TEST_F(DeserializationTest, InsufficientWordFailure) {
156   addHeader();
157   binary.push_back((2u << 16) |
158                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
159   // Missing word for type <id>
160 
161   ASSERT_EQ(llvm::None, deserialize());
162   expectDiagnostic("insufficient words for the last instruction");
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // Types
167 //===----------------------------------------------------------------------===//
168 
169 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
170   addHeader();
171   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
172 
173   ASSERT_EQ(llvm::None, deserialize());
174   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // StructType
179 //===----------------------------------------------------------------------===//
180 
181 TEST_F(DeserializationTest, OpMemberNameSuccess) {
182   addHeader();
183   SmallVector<uint32_t, 5> typeDecl;
184   std::swap(typeDecl, binary);
185 
186   auto int32Type = addIntType(32);
187   auto structType = addStructType({int32Type, int32Type});
188   std::swap(typeDecl, binary);
189 
190   SmallVector<uint32_t, 5> operands1 = {structType, 0};
191   spirv::encodeStringLiteralInto(operands1, "i1");
192   addInstruction(spirv::Opcode::OpMemberName, operands1);
193 
194   SmallVector<uint32_t, 5> operands2 = {structType, 1};
195   spirv::encodeStringLiteralInto(operands2, "i2");
196   addInstruction(spirv::Opcode::OpMemberName, operands2);
197 
198   binary.append(typeDecl.begin(), typeDecl.end());
199   EXPECT_NE(llvm::None, deserialize());
200 }
201 
202 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
203   addHeader();
204   SmallVector<uint32_t, 5> typeDecl;
205   std::swap(typeDecl, binary);
206 
207   auto int32Type = addIntType(32);
208   auto int64Type = addIntType(64);
209   auto structType = addStructType({int32Type, int64Type});
210   std::swap(typeDecl, binary);
211 
212   SmallVector<uint32_t, 5> operands1 = {structType};
213   addInstruction(spirv::Opcode::OpMemberName, operands1);
214 
215   binary.append(typeDecl.begin(), typeDecl.end());
216   ASSERT_EQ(llvm::None, deserialize());
217   expectDiagnostic("OpMemberName must have at least 3 operands");
218 }
219 
220 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
221   addHeader();
222   SmallVector<uint32_t, 5> typeDecl;
223   std::swap(typeDecl, binary);
224 
225   auto int32Type = addIntType(32);
226   auto structType = addStructType({int32Type});
227   std::swap(typeDecl, binary);
228 
229   SmallVector<uint32_t, 5> operands = {structType, 0};
230   spirv::encodeStringLiteralInto(operands, "int32");
231   operands.push_back(42);
232   addInstruction(spirv::Opcode::OpMemberName, operands);
233 
234   binary.append(typeDecl.begin(), typeDecl.end());
235   ASSERT_EQ(llvm::None, deserialize());
236   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Functions
241 //===----------------------------------------------------------------------===//
242 
243 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
244   addHeader();
245   auto voidType = addVoidType();
246   auto fnType = addFunctionType(voidType, {});
247   addFunction(voidType, fnType);
248   // Missing OpFunctionEnd
249 
250   ASSERT_EQ(llvm::None, deserialize());
251   expectDiagnostic("expected OpFunctionEnd instruction");
252 }
253 
254 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
255   addHeader();
256   auto voidType = addVoidType();
257   auto i32Type = addIntType(32);
258   auto fnType = addFunctionType(voidType, {i32Type});
259   addFunction(voidType, fnType);
260   // Missing OpFunctionParameter
261 
262   ASSERT_EQ(llvm::None, deserialize());
263   expectDiagnostic("expected OpFunctionParameter instruction");
264 }
265 
266 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
267   addHeader();
268   auto voidType = addVoidType();
269   auto fnType = addFunctionType(voidType, {});
270   addFunction(voidType, fnType);
271   // Missing OpLabel
272   addReturn();
273   addFunctionEnd();
274 
275   ASSERT_EQ(llvm::None, deserialize());
276   expectDiagnostic("a basic block must start with OpLabel");
277 }
278 
279 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
280   addHeader();
281   auto voidType = addVoidType();
282   auto fnType = addFunctionType(voidType, {});
283   addFunction(voidType, fnType);
284   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
285   addReturn();
286   addFunctionEnd();
287 
288   ASSERT_EQ(llvm::None, deserialize());
289   expectDiagnostic("OpLabel should only have result <id>");
290 }
291