1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The purpose of this file is to provide negative deserialization tests. 10 // For positive deserialization tests, please use serialization and 11 // deserialization for roundtripping. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Target/SPIRV/Deserialization.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/IR/Diagnostics.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 22 #include "gmock/gmock.h" 23 24 #include <memory> 25 26 using namespace mlir; 27 28 using ::testing::StrEq; 29 30 //===----------------------------------------------------------------------===// 31 // Test Fixture 32 //===----------------------------------------------------------------------===// 33 34 /// A deserialization test fixture providing minimal SPIR-V building and 35 /// diagnostic checking utilities. 36 class DeserializationTest : public ::testing::Test { 37 protected: 38 DeserializationTest() { 39 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 40 // Register a diagnostic handler to capture the diagnostic so that we can 41 // check it later. 42 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 43 diagnostic.reset(new Diagnostic(std::move(diag))); 44 }); 45 } 46 47 /// Performs deserialization and returns the constructed spv.module op. 48 spirv::OwningSPIRVModuleRef deserialize() { 49 return spirv::deserialize(binary, &context); 50 } 51 52 /// Checks there is a diagnostic generated with the given `errorMessage`. 53 void expectDiagnostic(StringRef errorMessage) { 54 ASSERT_NE(nullptr, diagnostic.get()); 55 56 // TODO: check error location too. 57 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); 58 } 59 60 //===--------------------------------------------------------------------===// 61 // SPIR-V builder methods 62 //===--------------------------------------------------------------------===// 63 64 /// Adds the SPIR-V module header to `binary`. 65 void addHeader() { 66 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); 67 } 68 69 /// Adds the SPIR-V instruction into `binary`. 70 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 71 uint32_t wordCount = 1 + operands.size(); 72 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 73 binary.append(operands.begin(), operands.end()); 74 } 75 76 uint32_t addVoidType() { 77 auto id = nextID++; 78 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 79 return id; 80 } 81 82 uint32_t addIntType(uint32_t bitwidth) { 83 auto id = nextID++; 84 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 85 return id; 86 } 87 88 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { 89 auto id = nextID++; 90 SmallVector<uint32_t, 2> words; 91 words.push_back(id); 92 words.append(memberTypes.begin(), memberTypes.end()); 93 addInstruction(spirv::Opcode::OpTypeStruct, words); 94 return id; 95 } 96 97 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 98 auto id = nextID++; 99 SmallVector<uint32_t, 4> operands; 100 operands.push_back(id); 101 operands.push_back(retType); 102 operands.append(paramTypes.begin(), paramTypes.end()); 103 addInstruction(spirv::Opcode::OpTypeFunction, operands); 104 return id; 105 } 106 107 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 108 auto id = nextID++; 109 addInstruction(spirv::Opcode::OpFunction, 110 {retType, id, 111 static_cast<uint32_t>(spirv::FunctionControl::None), 112 fnType}); 113 return id; 114 } 115 116 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 117 118 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 119 120 protected: 121 SmallVector<uint32_t, 5> binary; 122 uint32_t nextID = 1; 123 MLIRContext context; 124 std::unique_ptr<Diagnostic> diagnostic; 125 }; 126 127 //===----------------------------------------------------------------------===// 128 // Basics 129 //===----------------------------------------------------------------------===// 130 131 TEST_F(DeserializationTest, EmptyModuleFailure) { 132 ASSERT_FALSE(deserialize()); 133 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 134 } 135 136 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 137 addHeader(); 138 binary.front() = 0xdeadbeef; // Change to a wrong magic number 139 ASSERT_FALSE(deserialize()); 140 expectDiagnostic("incorrect magic number"); 141 } 142 143 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 144 addHeader(); 145 EXPECT_TRUE(deserialize()); 146 } 147 148 TEST_F(DeserializationTest, ZeroWordCountFailure) { 149 addHeader(); 150 binary.push_back(0); // OpNop with zero word count 151 152 ASSERT_FALSE(deserialize()); 153 expectDiagnostic("word count cannot be zero"); 154 } 155 156 TEST_F(DeserializationTest, InsufficientWordFailure) { 157 addHeader(); 158 binary.push_back((2u << 16) | 159 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 160 // Missing word for type <id>. 161 162 ASSERT_FALSE(deserialize()); 163 expectDiagnostic("insufficient words for the last instruction"); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // Types 168 //===----------------------------------------------------------------------===// 169 170 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 171 addHeader(); 172 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 173 174 ASSERT_FALSE(deserialize()); 175 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 176 } 177 178 //===----------------------------------------------------------------------===// 179 // StructType 180 //===----------------------------------------------------------------------===// 181 182 TEST_F(DeserializationTest, OpMemberNameSuccess) { 183 addHeader(); 184 SmallVector<uint32_t, 5> typeDecl; 185 std::swap(typeDecl, binary); 186 187 auto int32Type = addIntType(32); 188 auto structType = addStructType({int32Type, int32Type}); 189 std::swap(typeDecl, binary); 190 191 SmallVector<uint32_t, 5> operands1 = {structType, 0}; 192 spirv::encodeStringLiteralInto(operands1, "i1"); 193 addInstruction(spirv::Opcode::OpMemberName, operands1); 194 195 SmallVector<uint32_t, 5> operands2 = {structType, 1}; 196 spirv::encodeStringLiteralInto(operands2, "i2"); 197 addInstruction(spirv::Opcode::OpMemberName, operands2); 198 199 binary.append(typeDecl.begin(), typeDecl.end()); 200 EXPECT_TRUE(deserialize()); 201 } 202 203 TEST_F(DeserializationTest, OpMemberNameMissingOperands) { 204 addHeader(); 205 SmallVector<uint32_t, 5> typeDecl; 206 std::swap(typeDecl, binary); 207 208 auto int32Type = addIntType(32); 209 auto int64Type = addIntType(64); 210 auto structType = addStructType({int32Type, int64Type}); 211 std::swap(typeDecl, binary); 212 213 SmallVector<uint32_t, 5> operands1 = {structType}; 214 addInstruction(spirv::Opcode::OpMemberName, operands1); 215 216 binary.append(typeDecl.begin(), typeDecl.end()); 217 ASSERT_FALSE(deserialize()); 218 expectDiagnostic("OpMemberName must have at least 3 operands"); 219 } 220 221 TEST_F(DeserializationTest, OpMemberNameExcessOperands) { 222 addHeader(); 223 SmallVector<uint32_t, 5> typeDecl; 224 std::swap(typeDecl, binary); 225 226 auto int32Type = addIntType(32); 227 auto structType = addStructType({int32Type}); 228 std::swap(typeDecl, binary); 229 230 SmallVector<uint32_t, 5> operands = {structType, 0}; 231 spirv::encodeStringLiteralInto(operands, "int32"); 232 operands.push_back(42); 233 addInstruction(spirv::Opcode::OpMemberName, operands); 234 235 binary.append(typeDecl.begin(), typeDecl.end()); 236 ASSERT_FALSE(deserialize()); 237 expectDiagnostic("unexpected trailing words in OpMemberName instruction"); 238 } 239 240 //===----------------------------------------------------------------------===// 241 // Functions 242 //===----------------------------------------------------------------------===// 243 244 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 245 addHeader(); 246 auto voidType = addVoidType(); 247 auto fnType = addFunctionType(voidType, {}); 248 addFunction(voidType, fnType); 249 // Missing OpFunctionEnd. 250 251 ASSERT_FALSE(deserialize()); 252 expectDiagnostic("expected OpFunctionEnd instruction"); 253 } 254 255 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 256 addHeader(); 257 auto voidType = addVoidType(); 258 auto i32Type = addIntType(32); 259 auto fnType = addFunctionType(voidType, {i32Type}); 260 addFunction(voidType, fnType); 261 // Missing OpFunctionParameter. 262 263 ASSERT_FALSE(deserialize()); 264 expectDiagnostic("expected OpFunctionParameter instruction"); 265 } 266 267 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 268 addHeader(); 269 auto voidType = addVoidType(); 270 auto fnType = addFunctionType(voidType, {}); 271 addFunction(voidType, fnType); 272 // Missing OpLabel. 273 addReturn(); 274 addFunctionEnd(); 275 276 ASSERT_FALSE(deserialize()); 277 expectDiagnostic("a basic block must start with OpLabel"); 278 } 279 280 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 281 addHeader(); 282 auto voidType = addVoidType(); 283 auto fnType = addFunctionType(voidType, {}); 284 addFunction(voidType, fnType); 285 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 286 addReturn(); 287 addFunctionEnd(); 288 289 ASSERT_FALSE(deserialize()); 290 expectDiagnostic("OpLabel should only have result <id>"); 291 } 292