1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // The purpose of this file is to provide negative deserialization tests. 10 // For positive deserialization tests, please use serialization and 11 // deserialization for roundtripping. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Target/SPIRV/Deserialization.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/IR/Diagnostics.h" 19 #include "mlir/IR/MLIRContext.h" 20 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 21 #include "gmock/gmock.h" 22 23 #include <memory> 24 25 using namespace mlir; 26 27 using ::testing::StrEq; 28 29 //===----------------------------------------------------------------------===// 30 // Test Fixture 31 //===----------------------------------------------------------------------===// 32 33 /// A deserialization test fixture providing minimal SPIR-V building and 34 /// diagnostic checking utilities. 35 class DeserializationTest : public ::testing::Test { 36 protected: 37 DeserializationTest() { 38 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 39 // Register a diagnostic handler to capture the diagnostic so that we can 40 // check it later. 41 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 42 diagnostic.reset(new Diagnostic(std::move(diag))); 43 }); 44 } 45 46 /// Performs deserialization and returns the constructed spv.module op. 47 OwningOpRef<spirv::ModuleOp> deserialize() { 48 return spirv::deserialize(binary, &context); 49 } 50 51 /// Checks there is a diagnostic generated with the given `errorMessage`. 52 void expectDiagnostic(StringRef errorMessage) { 53 ASSERT_NE(nullptr, diagnostic.get()); 54 55 // TODO: check error location too. 56 EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage))); 57 } 58 59 //===--------------------------------------------------------------------===// 60 // SPIR-V builder methods 61 //===--------------------------------------------------------------------===// 62 63 /// Adds the SPIR-V module header to `binary`. 64 void addHeader() { 65 spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0); 66 } 67 68 /// Adds the SPIR-V instruction into `binary`. 69 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 70 uint32_t wordCount = 1 + operands.size(); 71 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 72 binary.append(operands.begin(), operands.end()); 73 } 74 75 uint32_t addVoidType() { 76 auto id = nextID++; 77 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 78 return id; 79 } 80 81 uint32_t addIntType(uint32_t bitwidth) { 82 auto id = nextID++; 83 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 84 return id; 85 } 86 87 uint32_t addStructType(ArrayRef<uint32_t> memberTypes) { 88 auto id = nextID++; 89 SmallVector<uint32_t, 2> words; 90 words.push_back(id); 91 words.append(memberTypes.begin(), memberTypes.end()); 92 addInstruction(spirv::Opcode::OpTypeStruct, words); 93 return id; 94 } 95 96 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 97 auto id = nextID++; 98 SmallVector<uint32_t, 4> operands; 99 operands.push_back(id); 100 operands.push_back(retType); 101 operands.append(paramTypes.begin(), paramTypes.end()); 102 addInstruction(spirv::Opcode::OpTypeFunction, operands); 103 return id; 104 } 105 106 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 107 auto id = nextID++; 108 addInstruction(spirv::Opcode::OpFunction, 109 {retType, id, 110 static_cast<uint32_t>(spirv::FunctionControl::None), 111 fnType}); 112 return id; 113 } 114 115 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 116 117 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 118 119 protected: 120 SmallVector<uint32_t, 5> binary; 121 uint32_t nextID = 1; 122 MLIRContext context; 123 std::unique_ptr<Diagnostic> diagnostic; 124 }; 125 126 //===----------------------------------------------------------------------===// 127 // Basics 128 //===----------------------------------------------------------------------===// 129 130 TEST_F(DeserializationTest, EmptyModuleFailure) { 131 ASSERT_FALSE(deserialize()); 132 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 133 } 134 135 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 136 addHeader(); 137 binary.front() = 0xdeadbeef; // Change to a wrong magic number 138 ASSERT_FALSE(deserialize()); 139 expectDiagnostic("incorrect magic number"); 140 } 141 142 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 143 addHeader(); 144 EXPECT_TRUE(deserialize()); 145 } 146 147 TEST_F(DeserializationTest, ZeroWordCountFailure) { 148 addHeader(); 149 binary.push_back(0); // OpNop with zero word count 150 151 ASSERT_FALSE(deserialize()); 152 expectDiagnostic("word count cannot be zero"); 153 } 154 155 TEST_F(DeserializationTest, InsufficientWordFailure) { 156 addHeader(); 157 binary.push_back((2u << 16) | 158 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 159 // Missing word for type <id>. 160 161 ASSERT_FALSE(deserialize()); 162 expectDiagnostic("insufficient words for the last instruction"); 163 } 164 165 //===----------------------------------------------------------------------===// 166 // Types 167 //===----------------------------------------------------------------------===// 168 169 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 170 addHeader(); 171 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 172 173 ASSERT_FALSE(deserialize()); 174 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 175 } 176 177 //===----------------------------------------------------------------------===// 178 // StructType 179 //===----------------------------------------------------------------------===// 180 181 TEST_F(DeserializationTest, OpMemberNameSuccess) { 182 addHeader(); 183 SmallVector<uint32_t, 5> typeDecl; 184 std::swap(typeDecl, binary); 185 186 auto int32Type = addIntType(32); 187 auto structType = addStructType({int32Type, int32Type}); 188 std::swap(typeDecl, binary); 189 190 SmallVector<uint32_t, 5> operands1 = {structType, 0}; 191 (void)spirv::encodeStringLiteralInto(operands1, "i1"); 192 addInstruction(spirv::Opcode::OpMemberName, operands1); 193 194 SmallVector<uint32_t, 5> operands2 = {structType, 1}; 195 (void)spirv::encodeStringLiteralInto(operands2, "i2"); 196 addInstruction(spirv::Opcode::OpMemberName, operands2); 197 198 binary.append(typeDecl.begin(), typeDecl.end()); 199 EXPECT_TRUE(deserialize()); 200 } 201 202 TEST_F(DeserializationTest, OpMemberNameMissingOperands) { 203 addHeader(); 204 SmallVector<uint32_t, 5> typeDecl; 205 std::swap(typeDecl, binary); 206 207 auto int32Type = addIntType(32); 208 auto int64Type = addIntType(64); 209 auto structType = addStructType({int32Type, int64Type}); 210 std::swap(typeDecl, binary); 211 212 SmallVector<uint32_t, 5> operands1 = {structType}; 213 addInstruction(spirv::Opcode::OpMemberName, operands1); 214 215 binary.append(typeDecl.begin(), typeDecl.end()); 216 ASSERT_FALSE(deserialize()); 217 expectDiagnostic("OpMemberName must have at least 3 operands"); 218 } 219 220 TEST_F(DeserializationTest, OpMemberNameExcessOperands) { 221 addHeader(); 222 SmallVector<uint32_t, 5> typeDecl; 223 std::swap(typeDecl, binary); 224 225 auto int32Type = addIntType(32); 226 auto structType = addStructType({int32Type}); 227 std::swap(typeDecl, binary); 228 229 SmallVector<uint32_t, 5> operands = {structType, 0}; 230 (void)spirv::encodeStringLiteralInto(operands, "int32"); 231 operands.push_back(42); 232 addInstruction(spirv::Opcode::OpMemberName, operands); 233 234 binary.append(typeDecl.begin(), typeDecl.end()); 235 ASSERT_FALSE(deserialize()); 236 expectDiagnostic("unexpected trailing words in OpMemberName instruction"); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // Functions 241 //===----------------------------------------------------------------------===// 242 243 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 244 addHeader(); 245 auto voidType = addVoidType(); 246 auto fnType = addFunctionType(voidType, {}); 247 addFunction(voidType, fnType); 248 // Missing OpFunctionEnd. 249 250 ASSERT_FALSE(deserialize()); 251 expectDiagnostic("expected OpFunctionEnd instruction"); 252 } 253 254 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 255 addHeader(); 256 auto voidType = addVoidType(); 257 auto i32Type = addIntType(32); 258 auto fnType = addFunctionType(voidType, {i32Type}); 259 addFunction(voidType, fnType); 260 // Missing OpFunctionParameter. 261 262 ASSERT_FALSE(deserialize()); 263 expectDiagnostic("expected OpFunctionParameter instruction"); 264 } 265 266 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 267 addHeader(); 268 auto voidType = addVoidType(); 269 auto fnType = addFunctionType(voidType, {}); 270 addFunction(voidType, fnType); 271 // Missing OpLabel. 272 addReturn(); 273 addFunctionEnd(); 274 275 ASSERT_FALSE(deserialize()); 276 expectDiagnostic("a basic block must start with OpLabel"); 277 } 278 279 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 280 addHeader(); 281 auto voidType = addVoidType(); 282 auto fnType = addFunctionType(voidType, {}); 283 addFunction(voidType, fnType); 284 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 285 addReturn(); 286 addFunctionEnd(); 287 288 ASSERT_FALSE(deserialize()); 289 expectDiagnostic("OpLabel should only have result <id>"); 290 } 291