1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The purpose of this file is to provide negative deserialization tests. 10 // For positive deserialization tests, please use serialization and 11 // deserialization for roundtripping. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/SPIRVModule.h" 18 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/Serialization.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "gmock/gmock.h" 23 24 #include <memory> 25 26 using namespace mlir; 27 28 /// Load the SPIRV dialect. 29 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration; 30 31 using ::testing::StrEq; 32 33 //===----------------------------------------------------------------------===// 34 // Test Fixture 35 //===----------------------------------------------------------------------===// 36 37 /// A deserialization test fixture providing minimal SPIR-V building and 38 /// diagnostic checking utilities. 39 class DeserializationTest : public ::testing::Test { 40 protected: 41 DeserializationTest() { 42 // Register a diagnostic handler to capture the diagnostic so that we can 43 // check it later. 44 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 45 diagnostic.reset(new Diagnostic(std::move(diag))); 46 }); 47 } 48 49 /// Performs deserialization and returns the constructed spv.module op. 50 spirv::OwningSPIRVModuleRef deserialize() { 51 return spirv::deserialize(binary, &context); 52 } 53 54 /// Checks there is a diagnostic generated with the given `errorMessage`. 55 void expectDiagnostic(StringRef errorMessage) { 56 ASSERT_NE(nullptr, diagnostic.get()); 57 58 // TODO: check error location too. 59 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); 60 } 61 62 //===--------------------------------------------------------------------===// 63 // SPIR-V builder methods 64 //===--------------------------------------------------------------------===// 65 66 /// Adds the SPIR-V module header to `binary`. 67 void addHeader() { 68 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); 69 } 70 71 /// Adds the SPIR-V instruction into `binary`. 72 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 73 uint32_t wordCount = 1 + operands.size(); 74 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 75 binary.append(operands.begin(), operands.end()); 76 } 77 78 uint32_t addVoidType() { 79 auto id = nextID++; 80 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 81 return id; 82 } 83 84 uint32_t addIntType(uint32_t bitwidth) { 85 auto id = nextID++; 86 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 87 return id; 88 } 89 90 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { 91 auto id = nextID++; 92 SmallVector<uint32_t, 2> words; 93 words.push_back(id); 94 words.append(memberTypes.begin(), memberTypes.end()); 95 addInstruction(spirv::Opcode::OpTypeStruct, words); 96 return id; 97 } 98 99 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 100 auto id = nextID++; 101 SmallVector<uint32_t, 4> operands; 102 operands.push_back(id); 103 operands.push_back(retType); 104 operands.append(paramTypes.begin(), paramTypes.end()); 105 addInstruction(spirv::Opcode::OpTypeFunction, operands); 106 return id; 107 } 108 109 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 110 auto id = nextID++; 111 addInstruction(spirv::Opcode::OpFunction, 112 {retType, id, 113 static_cast<uint32_t>(spirv::FunctionControl::None), 114 fnType}); 115 return id; 116 } 117 118 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 119 120 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 121 122 protected: 123 SmallVector<uint32_t, 5> binary; 124 uint32_t nextID = 1; 125 MLIRContext context; 126 std::unique_ptr<Diagnostic> diagnostic; 127 }; 128 129 //===----------------------------------------------------------------------===// 130 // Basics 131 //===----------------------------------------------------------------------===// 132 133 TEST_F(DeserializationTest, EmptyModuleFailure) { 134 ASSERT_FALSE(deserialize()); 135 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 136 } 137 138 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 139 addHeader(); 140 binary.front() = 0xdeadbeef; // Change to a wrong magic number 141 ASSERT_FALSE(deserialize()); 142 expectDiagnostic("incorrect magic number"); 143 } 144 145 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 146 addHeader(); 147 EXPECT_TRUE(deserialize()); 148 } 149 150 TEST_F(DeserializationTest, ZeroWordCountFailure) { 151 addHeader(); 152 binary.push_back(0); // OpNop with zero word count 153 154 ASSERT_FALSE(deserialize()); 155 expectDiagnostic("word count cannot be zero"); 156 } 157 158 TEST_F(DeserializationTest, InsufficientWordFailure) { 159 addHeader(); 160 binary.push_back((2u << 16) | 161 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 162 // Missing word for type <id>. 163 164 ASSERT_FALSE(deserialize()); 165 expectDiagnostic("insufficient words for the last instruction"); 166 } 167 168 //===----------------------------------------------------------------------===// 169 // Types 170 //===----------------------------------------------------------------------===// 171 172 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 173 addHeader(); 174 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 175 176 ASSERT_FALSE(deserialize()); 177 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // StructType 182 //===----------------------------------------------------------------------===// 183 184 TEST_F(DeserializationTest, OpMemberNameSuccess) { 185 addHeader(); 186 SmallVector<uint32_t, 5> typeDecl; 187 std::swap(typeDecl, binary); 188 189 auto int32Type = addIntType(32); 190 auto structType = addStructType({int32Type, int32Type}); 191 std::swap(typeDecl, binary); 192 193 SmallVector<uint32_t, 5> operands1 = {structType, 0}; 194 spirv::encodeStringLiteralInto(operands1, "i1"); 195 addInstruction(spirv::Opcode::OpMemberName, operands1); 196 197 SmallVector<uint32_t, 5> operands2 = {structType, 1}; 198 spirv::encodeStringLiteralInto(operands2, "i2"); 199 addInstruction(spirv::Opcode::OpMemberName, operands2); 200 201 binary.append(typeDecl.begin(), typeDecl.end()); 202 EXPECT_TRUE(deserialize()); 203 } 204 205 TEST_F(DeserializationTest, OpMemberNameMissingOperands) { 206 addHeader(); 207 SmallVector<uint32_t, 5> typeDecl; 208 std::swap(typeDecl, binary); 209 210 auto int32Type = addIntType(32); 211 auto int64Type = addIntType(64); 212 auto structType = addStructType({int32Type, int64Type}); 213 std::swap(typeDecl, binary); 214 215 SmallVector<uint32_t, 5> operands1 = {structType}; 216 addInstruction(spirv::Opcode::OpMemberName, operands1); 217 218 binary.append(typeDecl.begin(), typeDecl.end()); 219 ASSERT_FALSE(deserialize()); 220 expectDiagnostic("OpMemberName must have at least 3 operands"); 221 } 222 223 TEST_F(DeserializationTest, OpMemberNameExcessOperands) { 224 addHeader(); 225 SmallVector<uint32_t, 5> typeDecl; 226 std::swap(typeDecl, binary); 227 228 auto int32Type = addIntType(32); 229 auto structType = addStructType({int32Type}); 230 std::swap(typeDecl, binary); 231 232 SmallVector<uint32_t, 5> operands = {structType, 0}; 233 spirv::encodeStringLiteralInto(operands, "int32"); 234 operands.push_back(42); 235 addInstruction(spirv::Opcode::OpMemberName, operands); 236 237 binary.append(typeDecl.begin(), typeDecl.end()); 238 ASSERT_FALSE(deserialize()); 239 expectDiagnostic("unexpected trailing words in OpMemberName instruction"); 240 } 241 242 //===----------------------------------------------------------------------===// 243 // Functions 244 //===----------------------------------------------------------------------===// 245 246 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 247 addHeader(); 248 auto voidType = addVoidType(); 249 auto fnType = addFunctionType(voidType, {}); 250 addFunction(voidType, fnType); 251 // Missing OpFunctionEnd. 252 253 ASSERT_FALSE(deserialize()); 254 expectDiagnostic("expected OpFunctionEnd instruction"); 255 } 256 257 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 258 addHeader(); 259 auto voidType = addVoidType(); 260 auto i32Type = addIntType(32); 261 auto fnType = addFunctionType(voidType, {i32Type}); 262 addFunction(voidType, fnType); 263 // Missing OpFunctionParameter. 264 265 ASSERT_FALSE(deserialize()); 266 expectDiagnostic("expected OpFunctionParameter instruction"); 267 } 268 269 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 270 addHeader(); 271 auto voidType = addVoidType(); 272 auto fnType = addFunctionType(voidType, {}); 273 addFunction(voidType, fnType); 274 // Missing OpLabel. 275 addReturn(); 276 addFunctionEnd(); 277 278 ASSERT_FALSE(deserialize()); 279 expectDiagnostic("a basic block must start with OpLabel"); 280 } 281 282 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 283 addHeader(); 284 auto voidType = addVoidType(); 285 auto fnType = addFunctionType(voidType, {}); 286 addFunction(voidType, fnType); 287 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 288 addReturn(); 289 addFunctionEnd(); 290 291 ASSERT_FALSE(deserialize()); 292 expectDiagnostic("OpLabel should only have result <id>"); 293 } 294