1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The purpose of this file is to provide negative deserialization tests. 10 // For positive deserialization tests, please use serialization and 11 // deserialization for roundtripping. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/Serialization.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "gmock/gmock.h" 22 23 #include <memory> 24 25 using namespace mlir; 26 27 // Load the SPIRV dialect 28 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration; 29 30 using ::testing::StrEq; 31 32 //===----------------------------------------------------------------------===// 33 // Test Fixture 34 //===----------------------------------------------------------------------===// 35 36 /// A deserialization test fixture providing minimal SPIR-V building and 37 /// diagnostic checking utilities. 38 class DeserializationTest : public ::testing::Test { 39 protected: 40 DeserializationTest() { 41 // Register a diagnostic handler to capture the diagnostic so that we can 42 // check it later. 43 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 44 diagnostic.reset(new Diagnostic(std::move(diag))); 45 }); 46 } 47 48 /// Performs deserialization and returns the constructed spv.module op. 49 Optional<spirv::ModuleOp> deserialize() { 50 return spirv::deserialize(binary, &context); 51 } 52 53 /// Checks there is a diagnostic generated with the given `errorMessage`. 54 void expectDiagnostic(StringRef errorMessage) { 55 ASSERT_NE(nullptr, diagnostic.get()); 56 57 // TODO: check error location too. 58 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); 59 } 60 61 //===--------------------------------------------------------------------===// 62 // SPIR-V builder methods 63 //===--------------------------------------------------------------------===// 64 65 /// Adds the SPIR-V module header to `binary`. 66 void addHeader() { 67 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); 68 } 69 70 /// Adds the SPIR-V instruction into `binary`. 71 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 72 uint32_t wordCount = 1 + operands.size(); 73 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 74 binary.append(operands.begin(), operands.end()); 75 } 76 77 uint32_t addVoidType() { 78 auto id = nextID++; 79 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 80 return id; 81 } 82 83 uint32_t addIntType(uint32_t bitwidth) { 84 auto id = nextID++; 85 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 86 return id; 87 } 88 89 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { 90 auto id = nextID++; 91 SmallVector<uint32_t, 2> words; 92 words.push_back(id); 93 words.append(memberTypes.begin(), memberTypes.end()); 94 addInstruction(spirv::Opcode::OpTypeStruct, words); 95 return id; 96 } 97 98 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 99 auto id = nextID++; 100 SmallVector<uint32_t, 4> operands; 101 operands.push_back(id); 102 operands.push_back(retType); 103 operands.append(paramTypes.begin(), paramTypes.end()); 104 addInstruction(spirv::Opcode::OpTypeFunction, operands); 105 return id; 106 } 107 108 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 109 auto id = nextID++; 110 addInstruction(spirv::Opcode::OpFunction, 111 {retType, id, 112 static_cast<uint32_t>(spirv::FunctionControl::None), 113 fnType}); 114 return id; 115 } 116 117 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 118 119 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 120 121 protected: 122 SmallVector<uint32_t, 5> binary; 123 uint32_t nextID = 1; 124 MLIRContext context; 125 std::unique_ptr<Diagnostic> diagnostic; 126 }; 127 128 //===----------------------------------------------------------------------===// 129 // Basics 130 //===----------------------------------------------------------------------===// 131 132 TEST_F(DeserializationTest, EmptyModuleFailure) { 133 ASSERT_EQ(llvm::None, deserialize()); 134 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 135 } 136 137 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 138 addHeader(); 139 binary.front() = 0xdeadbeef; // Change to a wrong magic number 140 ASSERT_EQ(llvm::None, deserialize()); 141 expectDiagnostic("incorrect magic number"); 142 } 143 144 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 145 addHeader(); 146 EXPECT_NE(llvm::None, deserialize()); 147 } 148 149 TEST_F(DeserializationTest, ZeroWordCountFailure) { 150 addHeader(); 151 binary.push_back(0); // OpNop with zero word count 152 153 ASSERT_EQ(llvm::None, deserialize()); 154 expectDiagnostic("word count cannot be zero"); 155 } 156 157 TEST_F(DeserializationTest, InsufficientWordFailure) { 158 addHeader(); 159 binary.push_back((2u << 16) | 160 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 161 // Missing word for type <id> 162 163 ASSERT_EQ(llvm::None, deserialize()); 164 expectDiagnostic("insufficient words for the last instruction"); 165 } 166 167 //===----------------------------------------------------------------------===// 168 // Types 169 //===----------------------------------------------------------------------===// 170 171 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 172 addHeader(); 173 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 174 175 ASSERT_EQ(llvm::None, deserialize()); 176 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 177 } 178 179 //===----------------------------------------------------------------------===// 180 // StructType 181 //===----------------------------------------------------------------------===// 182 183 TEST_F(DeserializationTest, OpMemberNameSuccess) { 184 addHeader(); 185 SmallVector<uint32_t, 5> typeDecl; 186 std::swap(typeDecl, binary); 187 188 auto int32Type = addIntType(32); 189 auto structType = addStructType({int32Type, int32Type}); 190 std::swap(typeDecl, binary); 191 192 SmallVector<uint32_t, 5> operands1 = {structType, 0}; 193 spirv::encodeStringLiteralInto(operands1, "i1"); 194 addInstruction(spirv::Opcode::OpMemberName, operands1); 195 196 SmallVector<uint32_t, 5> operands2 = {structType, 1}; 197 spirv::encodeStringLiteralInto(operands2, "i2"); 198 addInstruction(spirv::Opcode::OpMemberName, operands2); 199 200 binary.append(typeDecl.begin(), typeDecl.end()); 201 EXPECT_NE(llvm::None, deserialize()); 202 } 203 204 TEST_F(DeserializationTest, OpMemberNameMissingOperands) { 205 addHeader(); 206 SmallVector<uint32_t, 5> typeDecl; 207 std::swap(typeDecl, binary); 208 209 auto int32Type = addIntType(32); 210 auto int64Type = addIntType(64); 211 auto structType = addStructType({int32Type, int64Type}); 212 std::swap(typeDecl, binary); 213 214 SmallVector<uint32_t, 5> operands1 = {structType}; 215 addInstruction(spirv::Opcode::OpMemberName, operands1); 216 217 binary.append(typeDecl.begin(), typeDecl.end()); 218 ASSERT_EQ(llvm::None, deserialize()); 219 expectDiagnostic("OpMemberName must have at least 3 operands"); 220 } 221 222 TEST_F(DeserializationTest, OpMemberNameExcessOperands) { 223 addHeader(); 224 SmallVector<uint32_t, 5> typeDecl; 225 std::swap(typeDecl, binary); 226 227 auto int32Type = addIntType(32); 228 auto structType = addStructType({int32Type}); 229 std::swap(typeDecl, binary); 230 231 SmallVector<uint32_t, 5> operands = {structType, 0}; 232 spirv::encodeStringLiteralInto(operands, "int32"); 233 operands.push_back(42); 234 addInstruction(spirv::Opcode::OpMemberName, operands); 235 236 binary.append(typeDecl.begin(), typeDecl.end()); 237 ASSERT_EQ(llvm::None, deserialize()); 238 expectDiagnostic("unexpected trailing words in OpMemberName instruction"); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // Functions 243 //===----------------------------------------------------------------------===// 244 245 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 246 addHeader(); 247 auto voidType = addVoidType(); 248 auto fnType = addFunctionType(voidType, {}); 249 addFunction(voidType, fnType); 250 // Missing OpFunctionEnd 251 252 ASSERT_EQ(llvm::None, deserialize()); 253 expectDiagnostic("expected OpFunctionEnd instruction"); 254 } 255 256 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 257 addHeader(); 258 auto voidType = addVoidType(); 259 auto i32Type = addIntType(32); 260 auto fnType = addFunctionType(voidType, {i32Type}); 261 addFunction(voidType, fnType); 262 // Missing OpFunctionParameter 263 264 ASSERT_EQ(llvm::None, deserialize()); 265 expectDiagnostic("expected OpFunctionParameter instruction"); 266 } 267 268 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 269 addHeader(); 270 auto voidType = addVoidType(); 271 auto fnType = addFunctionType(voidType, {}); 272 addFunction(voidType, fnType); 273 // Missing OpLabel 274 addReturn(); 275 addFunctionEnd(); 276 277 ASSERT_EQ(llvm::None, deserialize()); 278 expectDiagnostic("a basic block must start with OpLabel"); 279 } 280 281 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 282 addHeader(); 283 auto voidType = addVoidType(); 284 auto fnType = addFunctionType(voidType, {}); 285 addFunction(voidType, fnType); 286 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 287 addReturn(); 288 addFunctionEnd(); 289 290 ASSERT_EQ(llvm::None, deserialize()); 291 expectDiagnostic("OpLabel should only have result <id>"); 292 } 293