1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The purpose of this file is to provide negative deserialization tests. 10 // For positive deserialization tests, please use serialization and 11 // deserialization for roundtripping. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/SPIRVModule.h" 18 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/Serialization.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "gmock/gmock.h" 23 24 #include <memory> 25 26 using namespace mlir; 27 28 /// Load the SPIRV dialect. 29 static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration; 30 31 using ::testing::StrEq; 32 33 //===----------------------------------------------------------------------===// 34 // Test Fixture 35 //===----------------------------------------------------------------------===// 36 37 /// A deserialization test fixture providing minimal SPIR-V building and 38 /// diagnostic checking utilities. 39 class DeserializationTest : public ::testing::Test { 40 protected: 41 DeserializationTest() : context(/*loadAllDialects=*/false) { 42 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 43 // Register a diagnostic handler to capture the diagnostic so that we can 44 // check it later. 45 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 46 diagnostic.reset(new Diagnostic(std::move(diag))); 47 }); 48 } 49 50 /// Performs deserialization and returns the constructed spv.module op. 51 spirv::OwningSPIRVModuleRef deserialize() { 52 return spirv::deserialize(binary, &context); 53 } 54 55 /// Checks there is a diagnostic generated with the given `errorMessage`. 56 void expectDiagnostic(StringRef errorMessage) { 57 ASSERT_NE(nullptr, diagnostic.get()); 58 59 // TODO: check error location too. 60 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); 61 } 62 63 //===--------------------------------------------------------------------===// 64 // SPIR-V builder methods 65 //===--------------------------------------------------------------------===// 66 67 /// Adds the SPIR-V module header to `binary`. 68 void addHeader() { 69 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); 70 } 71 72 /// Adds the SPIR-V instruction into `binary`. 73 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 74 uint32_t wordCount = 1 + operands.size(); 75 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 76 binary.append(operands.begin(), operands.end()); 77 } 78 79 uint32_t addVoidType() { 80 auto id = nextID++; 81 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 82 return id; 83 } 84 85 uint32_t addIntType(uint32_t bitwidth) { 86 auto id = nextID++; 87 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 88 return id; 89 } 90 91 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { 92 auto id = nextID++; 93 SmallVector<uint32_t, 2> words; 94 words.push_back(id); 95 words.append(memberTypes.begin(), memberTypes.end()); 96 addInstruction(spirv::Opcode::OpTypeStruct, words); 97 return id; 98 } 99 100 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 101 auto id = nextID++; 102 SmallVector<uint32_t, 4> operands; 103 operands.push_back(id); 104 operands.push_back(retType); 105 operands.append(paramTypes.begin(), paramTypes.end()); 106 addInstruction(spirv::Opcode::OpTypeFunction, operands); 107 return id; 108 } 109 110 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 111 auto id = nextID++; 112 addInstruction(spirv::Opcode::OpFunction, 113 {retType, id, 114 static_cast<uint32_t>(spirv::FunctionControl::None), 115 fnType}); 116 return id; 117 } 118 119 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 120 121 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 122 123 protected: 124 SmallVector<uint32_t, 5> binary; 125 uint32_t nextID = 1; 126 MLIRContext context; 127 std::unique_ptr<Diagnostic> diagnostic; 128 }; 129 130 //===----------------------------------------------------------------------===// 131 // Basics 132 //===----------------------------------------------------------------------===// 133 134 TEST_F(DeserializationTest, EmptyModuleFailure) { 135 ASSERT_FALSE(deserialize()); 136 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 137 } 138 139 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 140 addHeader(); 141 binary.front() = 0xdeadbeef; // Change to a wrong magic number 142 ASSERT_FALSE(deserialize()); 143 expectDiagnostic("incorrect magic number"); 144 } 145 146 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 147 addHeader(); 148 EXPECT_TRUE(deserialize()); 149 } 150 151 TEST_F(DeserializationTest, ZeroWordCountFailure) { 152 addHeader(); 153 binary.push_back(0); // OpNop with zero word count 154 155 ASSERT_FALSE(deserialize()); 156 expectDiagnostic("word count cannot be zero"); 157 } 158 159 TEST_F(DeserializationTest, InsufficientWordFailure) { 160 addHeader(); 161 binary.push_back((2u << 16) | 162 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 163 // Missing word for type <id>. 164 165 ASSERT_FALSE(deserialize()); 166 expectDiagnostic("insufficient words for the last instruction"); 167 } 168 169 //===----------------------------------------------------------------------===// 170 // Types 171 //===----------------------------------------------------------------------===// 172 173 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 174 addHeader(); 175 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 176 177 ASSERT_FALSE(deserialize()); 178 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 179 } 180 181 //===----------------------------------------------------------------------===// 182 // StructType 183 //===----------------------------------------------------------------------===// 184 185 TEST_F(DeserializationTest, OpMemberNameSuccess) { 186 addHeader(); 187 SmallVector<uint32_t, 5> typeDecl; 188 std::swap(typeDecl, binary); 189 190 auto int32Type = addIntType(32); 191 auto structType = addStructType({int32Type, int32Type}); 192 std::swap(typeDecl, binary); 193 194 SmallVector<uint32_t, 5> operands1 = {structType, 0}; 195 spirv::encodeStringLiteralInto(operands1, "i1"); 196 addInstruction(spirv::Opcode::OpMemberName, operands1); 197 198 SmallVector<uint32_t, 5> operands2 = {structType, 1}; 199 spirv::encodeStringLiteralInto(operands2, "i2"); 200 addInstruction(spirv::Opcode::OpMemberName, operands2); 201 202 binary.append(typeDecl.begin(), typeDecl.end()); 203 EXPECT_TRUE(deserialize()); 204 } 205 206 TEST_F(DeserializationTest, OpMemberNameMissingOperands) { 207 addHeader(); 208 SmallVector<uint32_t, 5> typeDecl; 209 std::swap(typeDecl, binary); 210 211 auto int32Type = addIntType(32); 212 auto int64Type = addIntType(64); 213 auto structType = addStructType({int32Type, int64Type}); 214 std::swap(typeDecl, binary); 215 216 SmallVector<uint32_t, 5> operands1 = {structType}; 217 addInstruction(spirv::Opcode::OpMemberName, operands1); 218 219 binary.append(typeDecl.begin(), typeDecl.end()); 220 ASSERT_FALSE(deserialize()); 221 expectDiagnostic("OpMemberName must have at least 3 operands"); 222 } 223 224 TEST_F(DeserializationTest, OpMemberNameExcessOperands) { 225 addHeader(); 226 SmallVector<uint32_t, 5> typeDecl; 227 std::swap(typeDecl, binary); 228 229 auto int32Type = addIntType(32); 230 auto structType = addStructType({int32Type}); 231 std::swap(typeDecl, binary); 232 233 SmallVector<uint32_t, 5> operands = {structType, 0}; 234 spirv::encodeStringLiteralInto(operands, "int32"); 235 operands.push_back(42); 236 addInstruction(spirv::Opcode::OpMemberName, operands); 237 238 binary.append(typeDecl.begin(), typeDecl.end()); 239 ASSERT_FALSE(deserialize()); 240 expectDiagnostic("unexpected trailing words in OpMemberName instruction"); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // Functions 245 //===----------------------------------------------------------------------===// 246 247 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 248 addHeader(); 249 auto voidType = addVoidType(); 250 auto fnType = addFunctionType(voidType, {}); 251 addFunction(voidType, fnType); 252 // Missing OpFunctionEnd. 253 254 ASSERT_FALSE(deserialize()); 255 expectDiagnostic("expected OpFunctionEnd instruction"); 256 } 257 258 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 259 addHeader(); 260 auto voidType = addVoidType(); 261 auto i32Type = addIntType(32); 262 auto fnType = addFunctionType(voidType, {i32Type}); 263 addFunction(voidType, fnType); 264 // Missing OpFunctionParameter. 265 266 ASSERT_FALSE(deserialize()); 267 expectDiagnostic("expected OpFunctionParameter instruction"); 268 } 269 270 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 271 addHeader(); 272 auto voidType = addVoidType(); 273 auto fnType = addFunctionType(voidType, {}); 274 addFunction(voidType, fnType); 275 // Missing OpLabel. 276 addReturn(); 277 addFunctionEnd(); 278 279 ASSERT_FALSE(deserialize()); 280 expectDiagnostic("a basic block must start with OpLabel"); 281 } 282 283 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 284 addHeader(); 285 auto voidType = addVoidType(); 286 auto fnType = addFunctionType(voidType, {}); 287 addFunction(voidType, fnType); 288 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 289 addReturn(); 290 addFunctionEnd(); 291 292 ASSERT_FALSE(deserialize()); 293 expectDiagnostic("OpLabel should only have result <id>"); 294 } 295