xref: /llvm-project/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp (revision 4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6)
1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // The purpose of this file is to provide negative deserialization tests.
19 // For positive deserialization tests, please use serialization and
20 // deserialization for roundtripping.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
25 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
26 #include "mlir/Dialect/SPIRV/Serialization.h"
27 #include "mlir/IR/Diagnostics.h"
28 #include "mlir/IR/MLIRContext.h"
29 #include "gmock/gmock.h"
30 
31 #include <memory>
32 
33 using namespace mlir;
34 
35 using ::testing::StrEq;
36 
37 //===----------------------------------------------------------------------===//
38 // Test Fixture
39 //===----------------------------------------------------------------------===//
40 
41 /// A deserialization test fixture providing minimal SPIR-V building and
42 /// diagnostic checking utilities.
43 class DeserializationTest : public ::testing::Test {
44 protected:
45   DeserializationTest() {
46     // Register a diagnostic handler to capture the diagnostic so that we can
47     // check it later.
48     context.getDiagEngine().setHandler([&](Diagnostic diag) {
49       diagnostic.reset(new Diagnostic(std::move(diag)));
50     });
51   }
52 
53   /// Performs deserialization and returns the constructed spv.module op.
54   Optional<spirv::ModuleOp> deserialize() {
55     return spirv::deserialize(binary, &context);
56   }
57 
58   /// Checks there is a diagnostic generated with the given `errorMessage`.
59   void expectDiagnostic(StringRef errorMessage) {
60     ASSERT_NE(nullptr, diagnostic.get());
61 
62     // TODO(antiagainst): check error location too.
63     EXPECT_THAT(diagnostic->str(), StrEq(errorMessage));
64   }
65 
66   //===--------------------------------------------------------------------===//
67   // SPIR-V builder methods
68   //===--------------------------------------------------------------------===//
69 
70   /// Adds the SPIR-V module header to `binary`.
71   void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); }
72 
73   /// Adds the SPIR-V instruction into `binary`.
74   void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
75     uint32_t wordCount = 1 + operands.size();
76     assert(((wordCount >> 16) == 0) && "word count out of range!");
77 
78     uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op);
79     binary.push_back(prefixedOpcode);
80     binary.append(operands.begin(), operands.end());
81   }
82 
83   uint32_t addVoidType() {
84     auto id = nextID++;
85     addInstruction(spirv::Opcode::OpTypeVoid, {id});
86     return id;
87   }
88 
89   uint32_t addIntType(uint32_t bitwidth) {
90     auto id = nextID++;
91     addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
92     return id;
93   }
94 
95   uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
96     auto id = nextID++;
97     SmallVector<uint32_t, 4> operands;
98     operands.push_back(id);
99     operands.push_back(retType);
100     operands.append(paramTypes.begin(), paramTypes.end());
101     addInstruction(spirv::Opcode::OpTypeFunction, operands);
102     return id;
103   }
104 
105   uint32_t addFunction(uint32_t retType, uint32_t fnType) {
106     auto id = nextID++;
107     addInstruction(spirv::Opcode::OpFunction,
108                    {retType, id,
109                     static_cast<uint32_t>(spirv::FunctionControl::None),
110                     fnType});
111     return id;
112   }
113 
114   uint32_t addFunctionEnd() {
115     auto id = nextID++;
116     addInstruction(spirv::Opcode::OpFunctionEnd, {id});
117     return id;
118   }
119 
120 protected:
121   SmallVector<uint32_t, 5> binary;
122   uint32_t nextID = 1;
123   MLIRContext context;
124   std::unique_ptr<Diagnostic> diagnostic;
125 };
126 
127 //===----------------------------------------------------------------------===//
128 // Basics
129 //===----------------------------------------------------------------------===//
130 
131 TEST_F(DeserializationTest, EmptyModuleFailure) {
132   ASSERT_EQ(llvm::None, deserialize());
133   expectDiagnostic("SPIR-V binary module must have a 5-word header");
134 }
135 
136 TEST_F(DeserializationTest, WrongMagicNumberFailure) {
137   addHeader();
138   binary.front() = 0xdeadbeef; // Change to a wrong magic number
139   ASSERT_EQ(llvm::None, deserialize());
140   expectDiagnostic("incorrect magic number");
141 }
142 
143 TEST_F(DeserializationTest, OnlyHeaderSuccess) {
144   addHeader();
145   EXPECT_NE(llvm::None, deserialize());
146 }
147 
148 TEST_F(DeserializationTest, ZeroWordCountFailure) {
149   addHeader();
150   binary.push_back(0); // OpNop with zero word count
151 
152   ASSERT_EQ(llvm::None, deserialize());
153   expectDiagnostic("word count cannot be zero");
154 }
155 
156 TEST_F(DeserializationTest, InsufficientWordFailure) {
157   addHeader();
158   binary.push_back((2u << 16) |
159                    static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
160   // Missing word for type <id>
161 
162   ASSERT_EQ(llvm::None, deserialize());
163   expectDiagnostic("insufficient words for the last instruction");
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Types
168 //===----------------------------------------------------------------------===//
169 
170 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
171   addHeader();
172   addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
173 
174   ASSERT_EQ(llvm::None, deserialize());
175   expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // Functions
180 //===----------------------------------------------------------------------===//
181 
182 TEST_F(DeserializationTest, FunctionMissingEndFailure) {
183   addHeader();
184   auto voidType = addVoidType();
185   auto fnType = addFunctionType(voidType, {});
186   addFunction(voidType, fnType);
187   // Missing OpFunctionEnd
188 
189   ASSERT_EQ(llvm::None, deserialize());
190   expectDiagnostic("expected OpFunctionEnd instruction");
191 }
192 
193 TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
194   addHeader();
195   auto voidType = addVoidType();
196   auto i32Type = addIntType(32);
197   auto fnType = addFunctionType(voidType, {i32Type});
198   addFunction(voidType, fnType);
199   // Missing OpFunctionParameter
200 
201   ASSERT_EQ(llvm::None, deserialize());
202   expectDiagnostic("expected OpFunctionParameter instruction");
203 }
204