1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The purpose of this file is to provide negative deserialization tests. 10 // For positive deserialization tests, please use serialization and 11 // deserialization for roundtripping. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 16 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Serialization.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/MLIRContext.h" 20 #include "gmock/gmock.h" 21 22 #include <memory> 23 24 using namespace mlir; 25 26 using ::testing::StrEq; 27 28 //===----------------------------------------------------------------------===// 29 // Test Fixture 30 //===----------------------------------------------------------------------===// 31 32 /// A deserialization test fixture providing minimal SPIR-V building and 33 /// diagnostic checking utilities. 34 class DeserializationTest : public ::testing::Test { 35 protected: 36 DeserializationTest() { 37 // Register a diagnostic handler to capture the diagnostic so that we can 38 // check it later. 39 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 40 diagnostic.reset(new Diagnostic(std::move(diag))); 41 }); 42 } 43 44 /// Performs deserialization and returns the constructed spv.module op. 45 Optional<spirv::ModuleOp> deserialize() { 46 return spirv::deserialize(binary, &context); 47 } 48 49 /// Checks there is a diagnostic generated with the given `errorMessage`. 50 void expectDiagnostic(StringRef errorMessage) { 51 ASSERT_NE(nullptr, diagnostic.get()); 52 53 // TODO(antiagainst): check error location too. 54 EXPECT_THAT(diagnostic->str(), StrEq(errorMessage)); 55 } 56 57 //===--------------------------------------------------------------------===// 58 // SPIR-V builder methods 59 //===--------------------------------------------------------------------===// 60 61 /// Adds the SPIR-V module header to `binary`. 62 void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); } 63 64 /// Adds the SPIR-V instruction into `binary`. 65 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 66 uint32_t wordCount = 1 + operands.size(); 67 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 68 binary.append(operands.begin(), operands.end()); 69 } 70 71 uint32_t addVoidType() { 72 auto id = nextID++; 73 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 74 return id; 75 } 76 77 uint32_t addIntType(uint32_t bitwidth) { 78 auto id = nextID++; 79 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 80 return id; 81 } 82 83 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { 84 auto id = nextID++; 85 SmallVector<uint32_t, 2> words; 86 words.push_back(id); 87 words.append(memberTypes.begin(), memberTypes.end()); 88 addInstruction(spirv::Opcode::OpTypeStruct, words); 89 return id; 90 } 91 92 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 93 auto id = nextID++; 94 SmallVector<uint32_t, 4> operands; 95 operands.push_back(id); 96 operands.push_back(retType); 97 operands.append(paramTypes.begin(), paramTypes.end()); 98 addInstruction(spirv::Opcode::OpTypeFunction, operands); 99 return id; 100 } 101 102 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 103 auto id = nextID++; 104 addInstruction(spirv::Opcode::OpFunction, 105 {retType, id, 106 static_cast<uint32_t>(spirv::FunctionControl::None), 107 fnType}); 108 return id; 109 } 110 111 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 112 113 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 114 115 protected: 116 SmallVector<uint32_t, 5> binary; 117 uint32_t nextID = 1; 118 MLIRContext context; 119 std::unique_ptr<Diagnostic> diagnostic; 120 }; 121 122 //===----------------------------------------------------------------------===// 123 // Basics 124 //===----------------------------------------------------------------------===// 125 126 TEST_F(DeserializationTest, EmptyModuleFailure) { 127 ASSERT_EQ(llvm::None, deserialize()); 128 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 129 } 130 131 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 132 addHeader(); 133 binary.front() = 0xdeadbeef; // Change to a wrong magic number 134 ASSERT_EQ(llvm::None, deserialize()); 135 expectDiagnostic("incorrect magic number"); 136 } 137 138 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 139 addHeader(); 140 EXPECT_NE(llvm::None, deserialize()); 141 } 142 143 TEST_F(DeserializationTest, ZeroWordCountFailure) { 144 addHeader(); 145 binary.push_back(0); // OpNop with zero word count 146 147 ASSERT_EQ(llvm::None, deserialize()); 148 expectDiagnostic("word count cannot be zero"); 149 } 150 151 TEST_F(DeserializationTest, InsufficientWordFailure) { 152 addHeader(); 153 binary.push_back((2u << 16) | 154 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 155 // Missing word for type <id> 156 157 ASSERT_EQ(llvm::None, deserialize()); 158 expectDiagnostic("insufficient words for the last instruction"); 159 } 160 161 //===----------------------------------------------------------------------===// 162 // Types 163 //===----------------------------------------------------------------------===// 164 165 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 166 addHeader(); 167 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 168 169 ASSERT_EQ(llvm::None, deserialize()); 170 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // StructType 175 //===----------------------------------------------------------------------===// 176 177 TEST_F(DeserializationTest, OpMemberNameSuccess) { 178 addHeader(); 179 SmallVector<uint32_t, 5> typeDecl; 180 std::swap(typeDecl, binary); 181 182 auto int32Type = addIntType(32); 183 auto structType = addStructType({int32Type, int32Type}); 184 std::swap(typeDecl, binary); 185 186 SmallVector<uint32_t, 5> operands1 = {structType, 0}; 187 spirv::encodeStringLiteralInto(operands1, "i1"); 188 addInstruction(spirv::Opcode::OpMemberName, operands1); 189 190 SmallVector<uint32_t, 5> operands2 = {structType, 1}; 191 spirv::encodeStringLiteralInto(operands2, "i2"); 192 addInstruction(spirv::Opcode::OpMemberName, operands2); 193 194 binary.append(typeDecl.begin(), typeDecl.end()); 195 EXPECT_NE(llvm::None, deserialize()); 196 } 197 198 TEST_F(DeserializationTest, OpMemberNameMissingOperands) { 199 addHeader(); 200 SmallVector<uint32_t, 5> typeDecl; 201 std::swap(typeDecl, binary); 202 203 auto int32Type = addIntType(32); 204 auto int64Type = addIntType(64); 205 auto structType = addStructType({int32Type, int64Type}); 206 std::swap(typeDecl, binary); 207 208 SmallVector<uint32_t, 5> operands1 = {structType}; 209 addInstruction(spirv::Opcode::OpMemberName, operands1); 210 211 binary.append(typeDecl.begin(), typeDecl.end()); 212 ASSERT_EQ(llvm::None, deserialize()); 213 expectDiagnostic("OpMemberName must have at least 3 operands"); 214 } 215 216 TEST_F(DeserializationTest, OpMemberNameExcessOperands) { 217 addHeader(); 218 SmallVector<uint32_t, 5> typeDecl; 219 std::swap(typeDecl, binary); 220 221 auto int32Type = addIntType(32); 222 auto structType = addStructType({int32Type}); 223 std::swap(typeDecl, binary); 224 225 SmallVector<uint32_t, 5> operands = {structType, 0}; 226 spirv::encodeStringLiteralInto(operands, "int32"); 227 operands.push_back(42); 228 addInstruction(spirv::Opcode::OpMemberName, operands); 229 230 binary.append(typeDecl.begin(), typeDecl.end()); 231 ASSERT_EQ(llvm::None, deserialize()); 232 expectDiagnostic("unexpected trailing words in OpMemberName instruction"); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // Functions 237 //===----------------------------------------------------------------------===// 238 239 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 240 addHeader(); 241 auto voidType = addVoidType(); 242 auto fnType = addFunctionType(voidType, {}); 243 addFunction(voidType, fnType); 244 // Missing OpFunctionEnd 245 246 ASSERT_EQ(llvm::None, deserialize()); 247 expectDiagnostic("expected OpFunctionEnd instruction"); 248 } 249 250 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 251 addHeader(); 252 auto voidType = addVoidType(); 253 auto i32Type = addIntType(32); 254 auto fnType = addFunctionType(voidType, {i32Type}); 255 addFunction(voidType, fnType); 256 // Missing OpFunctionParameter 257 258 ASSERT_EQ(llvm::None, deserialize()); 259 expectDiagnostic("expected OpFunctionParameter instruction"); 260 } 261 262 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 263 addHeader(); 264 auto voidType = addVoidType(); 265 auto fnType = addFunctionType(voidType, {}); 266 addFunction(voidType, fnType); 267 // Missing OpLabel 268 addReturn(); 269 addFunctionEnd(); 270 271 ASSERT_EQ(llvm::None, deserialize()); 272 expectDiagnostic("a basic block must start with OpLabel"); 273 } 274 275 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 276 addHeader(); 277 auto voidType = addVoidType(); 278 auto fnType = addFunctionType(voidType, {}); 279 addFunction(voidType, fnType); 280 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 281 addReturn(); 282 addFunctionEnd(); 283 284 ASSERT_EQ(llvm::None, deserialize()); 285 expectDiagnostic("OpLabel should only have result <id>"); 286 } 287