xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision 5ab6ef758f0f549fb39bf9b34a6a743e989b212a)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The purpose of this file is to provide negative deserialization tests.
10 // For positive deserialization tests, please use serialization and
11 // deserialization for roundtripping.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Target/SPIRV/Deserialization.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/MLIRContext.h"
20 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
21 #include "gmock/gmock.h"
22 
23 #include <memory>
24 
25 using namespace mlir;
26 
27 using ::testing::StrEq;
28 
29 //===----------------------------------------------------------------------===//
30 // Test Fixture
31 //===----------------------------------------------------------------------===//
32 
33 /// A deserialization test fixture providing minimal SPIR-V building and
34 /// diagnostic checking utilities.
35 class DeserializationTest : public ::testing::Test {
36 protected:
DeserializationTest()37   DeserializationTest() {
38     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
39     // Register a diagnostic handler to capture the diagnostic so that we can
40     // check it later.
41     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
42       diagnostic = std::make_unique<Diagnostic>(std::move(diag));
43     });
44   }
45 
46   /// Performs deserialization and returns the constructed spirv.module op.
deserialize()47   OwningOpRef<spirv::ModuleOp> deserialize() {
48     return spirv::deserialize(binary, &context);
49   }
50 
51   /// Checks there is a diagnostic generated with the given `errorMessage`.
expectDiagnostic(StringRef errorMessage)52   void expectDiagnostic(StringRef errorMessage) {
53     ASSERT_NE(nullptr, diagnostic.get());
54 
55     // TODO: check error location too.
56     EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
57   }
58 
59   //===--------------------------------------------------------------------===//
60   // SPIR-V builder methods
61   //===--------------------------------------------------------------------===//
62 
63   /// Adds the SPIR-V module header to `binary`.
addHeader()64   void addHeader() {
65     spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
66   }
67 
68   /// Adds the SPIR-V instruction into `binary`.
addInstruction(spirv::Opcode op,ArrayRef<uint32_t> operands)69   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
70     uint32_t wordCount = 1 + operands.size();
71     binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
72     binary.append(operands.begin(), operands.end());
73   }
74 
addVoidType()75   uint32_t addVoidType() {
76     auto id = nextID++;
77     addInstruction(spirv::Opcode::OpTypeVoid, {id});
78     return id;
79   }
80 
addIntType(uint32_t bitwidth)81   uint32_t addIntType(uint32_t bitwidth) {
82     auto id = nextID++;
83     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
84     return id;
85   }
86 
addStructType(ArrayRef<uint32_t> memberTypes)87   uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
88     auto id = nextID++;
89     SmallVector<uint32_t, 2> words;
90     words.push_back(id);
91     words.append(memberTypes.begin(), memberTypes.end());
92     addInstruction(spirv::Opcode::OpTypeStruct, words);
93     return id;
94   }
95 
addFunctionType(uint32_t retType,ArrayRef<uint32_t> paramTypes)96   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
97     auto id = nextID++;
98     SmallVector<uint32_t, 4> operands;
99     operands.push_back(id);
100     operands.push_back(retType);
101     operands.append(paramTypes.begin(), paramTypes.end());
102     addInstruction(spirv::Opcode::OpTypeFunction, operands);
103     return id;
104   }
105 
addFunction(uint32_t retType,uint32_t fnType)106   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
107     auto id = nextID++;
108     addInstruction(spirv::Opcode::OpFunction,
109                    {retType, id,
110                     static_cast<uint32_t>(spirv::FunctionControl::None),
111                     fnType});
112     return id;
113   }
114 
addFunctionEnd()115   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
116 
addReturn()117   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
118 
119 protected:
120   SmallVector<uint32_t, 5> binary;
121   uint32_t nextID = 1;
122   MLIRContext context;
123   std::unique_ptr<Diagnostic> diagnostic;
124 };
125 
126 //===----------------------------------------------------------------------===//
127 // Basics
128 //===----------------------------------------------------------------------===//
129 
TEST_F(DeserializationTest,EmptyModuleFailure)130 TEST_F(DeserializationTest, EmptyModuleFailure) {
131   ASSERT_FALSE(deserialize());
132   expectDiagnostic("SPIR-V binary module must have a 5-word header");
133 }
134 
TEST_F(DeserializationTest,WrongMagicNumberFailure)135 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
136   addHeader();
137   binary.front() = 0xdeadbeef; // Change to a wrong magic number
138   ASSERT_FALSE(deserialize());
139   expectDiagnostic("incorrect magic number");
140 }
141 
TEST_F(DeserializationTest,OnlyHeaderSuccess)142 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
143   addHeader();
144   EXPECT_TRUE(deserialize());
145 }
146 
TEST_F(DeserializationTest,ZeroWordCountFailure)147 TEST_F(DeserializationTest, ZeroWordCountFailure) {
148   addHeader();
149   binary.push_back(0); // OpNop with zero word count
150 
151   ASSERT_FALSE(deserialize());
152   expectDiagnostic("word count cannot be zero");
153 }
154 
TEST_F(DeserializationTest,InsufficientWordFailure)155 TEST_F(DeserializationTest, InsufficientWordFailure) {
156   addHeader();
157   binary.push_back((2u << 16) |
158                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
159   // Missing word for type <id>.
160 
161   ASSERT_FALSE(deserialize());
162   expectDiagnostic("insufficient words for the last instruction");
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // Types
167 //===----------------------------------------------------------------------===//
168 
TEST_F(DeserializationTest,IntTypeMissingSignednessFailure)169 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
170   addHeader();
171   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
172 
173   ASSERT_FALSE(deserialize());
174   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // StructType
179 //===----------------------------------------------------------------------===//
180 
TEST_F(DeserializationTest,OpMemberNameSuccess)181 TEST_F(DeserializationTest, OpMemberNameSuccess) {
182   addHeader();
183   SmallVector<uint32_t, 5> typeDecl;
184   std::swap(typeDecl, binary);
185 
186   auto int32Type = addIntType(32);
187   auto structType = addStructType({int32Type, int32Type});
188   std::swap(typeDecl, binary);
189 
190   SmallVector<uint32_t, 5> operands1 = {structType, 0};
191   (void)spirv::encodeStringLiteralInto(operands1, "i1");
192   addInstruction(spirv::Opcode::OpMemberName, operands1);
193 
194   SmallVector<uint32_t, 5> operands2 = {structType, 1};
195   (void)spirv::encodeStringLiteralInto(operands2, "i2");
196   addInstruction(spirv::Opcode::OpMemberName, operands2);
197 
198   binary.append(typeDecl.begin(), typeDecl.end());
199   EXPECT_TRUE(deserialize());
200 }
201 
TEST_F(DeserializationTest,OpMemberNameMissingOperands)202 TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
203   addHeader();
204   SmallVector<uint32_t, 5> typeDecl;
205   std::swap(typeDecl, binary);
206 
207   auto int32Type = addIntType(32);
208   auto int64Type = addIntType(64);
209   auto structType = addStructType({int32Type, int64Type});
210   std::swap(typeDecl, binary);
211 
212   SmallVector<uint32_t, 5> operands1 = {structType};
213   addInstruction(spirv::Opcode::OpMemberName, operands1);
214 
215   binary.append(typeDecl.begin(), typeDecl.end());
216   ASSERT_FALSE(deserialize());
217   expectDiagnostic("OpMemberName must have at least 3 operands");
218 }
219 
TEST_F(DeserializationTest,OpMemberNameExcessOperands)220 TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
221   addHeader();
222   SmallVector<uint32_t, 5> typeDecl;
223   std::swap(typeDecl, binary);
224 
225   auto int32Type = addIntType(32);
226   auto structType = addStructType({int32Type});
227   std::swap(typeDecl, binary);
228 
229   SmallVector<uint32_t, 5> operands = {structType, 0};
230   (void)spirv::encodeStringLiteralInto(operands, "int32");
231   operands.push_back(42);
232   addInstruction(spirv::Opcode::OpMemberName, operands);
233 
234   binary.append(typeDecl.begin(), typeDecl.end());
235   ASSERT_FALSE(deserialize());
236   expectDiagnostic("unexpected trailing words in OpMemberName instruction");
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Functions
241 //===----------------------------------------------------------------------===//
242 
TEST_F(DeserializationTest,FunctionMissingEndFailure)243 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
244   addHeader();
245   auto voidType = addVoidType();
246   auto fnType = addFunctionType(voidType, {});
247   addFunction(voidType, fnType);
248   // Missing OpFunctionEnd.
249 
250   ASSERT_FALSE(deserialize());
251   expectDiagnostic("expected OpFunctionEnd instruction");
252 }
253 
TEST_F(DeserializationTest,FunctionMissingParameterFailure)254 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
255   addHeader();
256   auto voidType = addVoidType();
257   auto i32Type = addIntType(32);
258   auto fnType = addFunctionType(voidType, {i32Type});
259   addFunction(voidType, fnType);
260   // Missing OpFunctionParameter.
261 
262   ASSERT_FALSE(deserialize());
263   expectDiagnostic("expected OpFunctionParameter instruction");
264 }
265 
TEST_F(DeserializationTest,FunctionMissingLabelForFirstBlockFailure)266 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
267   addHeader();
268   auto voidType = addVoidType();
269   auto fnType = addFunctionType(voidType, {});
270   addFunction(voidType, fnType);
271   // Missing OpLabel.
272   addReturn();
273   addFunctionEnd();
274 
275   ASSERT_FALSE(deserialize());
276   expectDiagnostic("a basic block must start with OpLabel");
277 }
278 
TEST_F(DeserializationTest,FunctionMalformedLabelFailure)279 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
280   addHeader();
281   auto voidType = addVoidType();
282   auto fnType = addFunctionType(voidType, {});
283   addFunction(voidType, fnType);
284   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
285   addReturn();
286   addFunctionEnd();
287 
288   ASSERT_FALSE(deserialize());
289   expectDiagnostic("OpLabel should only have result <id>");
290 }
291