xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision c61991ef01c34aa2d09fe6d16aead943b7fba2fa)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // The purpose of this file is to provide negative deserialization tests.
19 // For positive deserialization tests, please use serialization and
20 // deserialization for roundtripping.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
25 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
26 #include "mlir/Dialect/SPIRV/Serialization.h"
27 #include "mlir/IR/Diagnostics.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "gmock/gmock.h"
30 
31 #include <memory>
32 
33 using namespace mlir;
34 
35 using ::testing::StrEq;
36 
37 //===----------------------------------------------------------------------===//
38 // Test Fixture
39 //===----------------------------------------------------------------------===//
40 
41 /// A deserialization test fixture providing minimal SPIR-V building and
42 /// diagnostic checking utilities.
43 class DeserializationTest : public ::testing::Test {
44 protected:
45   DeserializationTest() {
46     // Register a diagnostic handler to capture the diagnostic so that we can
47     // check it later.
48     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
49       diagnostic.reset(new Diagnostic(std::move(diag)));
50     });
51   }
52 
53   /// Performs deserialization and returns the constructed spv.module op.
54   Optional<spirv::ModuleOp> deserialize() {
55     return spirv::deserialize(binary, &context);
56   }
57 
58   /// Checks there is a diagnostic generated with the given `errorMessage`.
59   void expectDiagnostic(StringRef errorMessage) {
60     ASSERT_NE(nullptr, diagnostic.get());
61 
62     // TODO(antiagainst): check error location too.
63     EXPECT_THAT(diagnostic->str(), StrEq(errorMessage));
64   }
65 
66   //===--------------------------------------------------------------------===//
67   // SPIR-V builder methods
68   //===--------------------------------------------------------------------===//
69 
70   /// Adds the SPIR-V module header to `binary`.
71   void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); }
72 
73   /// Adds the SPIR-V instruction into `binary`.
74   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
75     uint32_t wordCount = 1 + operands.size();
76     assert(((wordCount >> 16) == 0) && "word count out of range!");
77 
78     uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op);
79     binary.push_back(prefixedOpcode);
80     binary.append(operands.begin(), operands.end());
81   }
82 
83   uint32_t addVoidType() {
84     auto id = nextID++;
85     addInstruction(spirv::Opcode::OpTypeVoid, {id});
86     return id;
87   }
88 
89   uint32_t addIntType(uint32_t bitwidth) {
90     auto id = nextID++;
91     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
92     return id;
93   }
94 
95   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
96     auto id = nextID++;
97     SmallVector<uint32_t, 4> operands;
98     operands.push_back(id);
99     operands.push_back(retType);
100     operands.append(paramTypes.begin(), paramTypes.end());
101     addInstruction(spirv::Opcode::OpTypeFunction, operands);
102     return id;
103   }
104 
105   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
106     auto id = nextID++;
107     addInstruction(spirv::Opcode::OpFunction,
108                    {retType, id,
109                     static_cast<uint32_t>(spirv::FunctionControl::None),
110                     fnType});
111     return id;
112   }
113 
114   void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }
115 
116   void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }
117 
118 protected:
119   SmallVector<uint32_t, 5> binary;
120   uint32_t nextID = 1;
121   MLIRContext context;
122   std::unique_ptr<Diagnostic> diagnostic;
123 };
124 
125 //===----------------------------------------------------------------------===//
126 // Basics
127 //===----------------------------------------------------------------------===//
128 
129 TEST_F(DeserializationTest, EmptyModuleFailure) {
130   ASSERT_EQ(llvm::None, deserialize());
131   expectDiagnostic("SPIR-V binary module must have a 5-word header");
132 }
133 
134 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
135   addHeader();
136   binary.front() = 0xdeadbeef; // Change to a wrong magic number
137   ASSERT_EQ(llvm::None, deserialize());
138   expectDiagnostic("incorrect magic number");
139 }
140 
141 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
142   addHeader();
143   EXPECT_NE(llvm::None, deserialize());
144 }
145 
146 TEST_F(DeserializationTest, ZeroWordCountFailure) {
147   addHeader();
148   binary.push_back(0); // OpNop with zero word count
149 
150   ASSERT_EQ(llvm::None, deserialize());
151   expectDiagnostic("word count cannot be zero");
152 }
153 
154 TEST_F(DeserializationTest, InsufficientWordFailure) {
155   addHeader();
156   binary.push_back((2u << 16) |
157                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
158   // Missing word for type <id>
159 
160   ASSERT_EQ(llvm::None, deserialize());
161   expectDiagnostic("insufficient words for the last instruction");
162 }
163 
164 //===----------------------------------------------------------------------===//
165 // Types
166 //===----------------------------------------------------------------------===//
167 
168 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
169   addHeader();
170   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
171 
172   ASSERT_EQ(llvm::None, deserialize());
173   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Functions
178 //===----------------------------------------------------------------------===//
179 
180 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
181   addHeader();
182   auto voidType = addVoidType();
183   auto fnType = addFunctionType(voidType, {});
184   addFunction(voidType, fnType);
185   // Missing OpFunctionEnd
186 
187   ASSERT_EQ(llvm::None, deserialize());
188   expectDiagnostic("expected OpFunctionEnd instruction");
189 }
190 
191 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
192   addHeader();
193   auto voidType = addVoidType();
194   auto i32Type = addIntType(32);
195   auto fnType = addFunctionType(voidType, {i32Type});
196   addFunction(voidType, fnType);
197   // Missing OpFunctionParameter
198 
199   ASSERT_EQ(llvm::None, deserialize());
200   expectDiagnostic("expected OpFunctionParameter instruction");
201 }
202 
203 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
204   addHeader();
205   auto voidType = addVoidType();
206   auto fnType = addFunctionType(voidType, {});
207   addFunction(voidType, fnType);
208   // Missing OpLabel
209   addReturn();
210   addFunctionEnd();
211 
212   ASSERT_EQ(llvm::None, deserialize());
213   expectDiagnostic("a basic block must start with OpLabel");
214 }
215 
216 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
217   addHeader();
218   auto voidType = addVoidType();
219   auto fnType = addFunctionType(voidType, {});
220   addFunction(voidType, fnType);
221   addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
222   addReturn();
223   addFunctionEnd();
224 
225   ASSERT_EQ(llvm::None, deserialize());
226   expectDiagnostic("OpLabel should only have result <id>");
227 }
228