1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // The purpose of this file is to provide negative deserialization tests. 19 // For positive deserialization tests, please use serialization and 20 // deserialization for roundtripping. 21 // 22 //===----------------------------------------------------------------------===// 23 24 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 25 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 26 #include "mlir/Dialect/SPIRV/Serialization.h" 27 #include "mlir/IR/Diagnostics.h" 28 #include "mlir/IR/MLIRContext.h" 29 #include "gmock/gmock.h" 30 31 #include <memory> 32 33 using namespace mlir; 34 35 using ::testing::StrEq; 36 37 //===----------------------------------------------------------------------===// 38 // Test Fixture 39 //===----------------------------------------------------------------------===// 40 41 /// A deserialization test fixture providing minimal SPIR-V building and 42 /// diagnostic checking utilities. 43 class DeserializationTest : public ::testing::Test { 44 protected: 45 DeserializationTest() { 46 // Register a diagnostic handler to capture the diagnostic so that we can 47 // check it later. 48 context.getDiagEngine().registerHandler([&](Diagnostic &diag) { 49 diagnostic.reset(new Diagnostic(std::move(diag))); 50 }); 51 } 52 53 /// Performs deserialization and returns the constructed spv.module op. 54 Optional<spirv::ModuleOp> deserialize() { 55 return spirv::deserialize(binary, &context); 56 } 57 58 /// Checks there is a diagnostic generated with the given `errorMessage`. 59 void expectDiagnostic(StringRef errorMessage) { 60 ASSERT_NE(nullptr, diagnostic.get()); 61 62 // TODO(antiagainst): check error location too. 63 EXPECT_THAT(diagnostic->str(), StrEq(errorMessage)); 64 } 65 66 //===--------------------------------------------------------------------===// 67 // SPIR-V builder methods 68 //===--------------------------------------------------------------------===// 69 70 /// Adds the SPIR-V module header to `binary`. 71 void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); } 72 73 /// Adds the SPIR-V instruction into `binary`. 74 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 75 uint32_t wordCount = 1 + operands.size(); 76 assert(((wordCount >> 16) == 0) && "word count out of range!"); 77 78 uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op); 79 binary.push_back(prefixedOpcode); 80 binary.append(operands.begin(), operands.end()); 81 } 82 83 uint32_t addVoidType() { 84 auto id = nextID++; 85 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 86 return id; 87 } 88 89 uint32_t addIntType(uint32_t bitwidth) { 90 auto id = nextID++; 91 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 92 return id; 93 } 94 95 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 96 auto id = nextID++; 97 SmallVector<uint32_t, 4> operands; 98 operands.push_back(id); 99 operands.push_back(retType); 100 operands.append(paramTypes.begin(), paramTypes.end()); 101 addInstruction(spirv::Opcode::OpTypeFunction, operands); 102 return id; 103 } 104 105 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 106 auto id = nextID++; 107 addInstruction(spirv::Opcode::OpFunction, 108 {retType, id, 109 static_cast<uint32_t>(spirv::FunctionControl::None), 110 fnType}); 111 return id; 112 } 113 114 void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); } 115 116 void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); } 117 118 protected: 119 SmallVector<uint32_t, 5> binary; 120 uint32_t nextID = 1; 121 MLIRContext context; 122 std::unique_ptr<Diagnostic> diagnostic; 123 }; 124 125 //===----------------------------------------------------------------------===// 126 // Basics 127 //===----------------------------------------------------------------------===// 128 129 TEST_F(DeserializationTest, EmptyModuleFailure) { 130 ASSERT_EQ(llvm::None, deserialize()); 131 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 132 } 133 134 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 135 addHeader(); 136 binary.front() = 0xdeadbeef; // Change to a wrong magic number 137 ASSERT_EQ(llvm::None, deserialize()); 138 expectDiagnostic("incorrect magic number"); 139 } 140 141 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 142 addHeader(); 143 EXPECT_NE(llvm::None, deserialize()); 144 } 145 146 TEST_F(DeserializationTest, ZeroWordCountFailure) { 147 addHeader(); 148 binary.push_back(0); // OpNop with zero word count 149 150 ASSERT_EQ(llvm::None, deserialize()); 151 expectDiagnostic("word count cannot be zero"); 152 } 153 154 TEST_F(DeserializationTest, InsufficientWordFailure) { 155 addHeader(); 156 binary.push_back((2u << 16) | 157 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 158 // Missing word for type <id> 159 160 ASSERT_EQ(llvm::None, deserialize()); 161 expectDiagnostic("insufficient words for the last instruction"); 162 } 163 164 //===----------------------------------------------------------------------===// 165 // Types 166 //===----------------------------------------------------------------------===// 167 168 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 169 addHeader(); 170 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 171 172 ASSERT_EQ(llvm::None, deserialize()); 173 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // Functions 178 //===----------------------------------------------------------------------===// 179 180 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 181 addHeader(); 182 auto voidType = addVoidType(); 183 auto fnType = addFunctionType(voidType, {}); 184 addFunction(voidType, fnType); 185 // Missing OpFunctionEnd 186 187 ASSERT_EQ(llvm::None, deserialize()); 188 expectDiagnostic("expected OpFunctionEnd instruction"); 189 } 190 191 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 192 addHeader(); 193 auto voidType = addVoidType(); 194 auto i32Type = addIntType(32); 195 auto fnType = addFunctionType(voidType, {i32Type}); 196 addFunction(voidType, fnType); 197 // Missing OpFunctionParameter 198 199 ASSERT_EQ(llvm::None, deserialize()); 200 expectDiagnostic("expected OpFunctionParameter instruction"); 201 } 202 203 TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) { 204 addHeader(); 205 auto voidType = addVoidType(); 206 auto fnType = addFunctionType(voidType, {}); 207 addFunction(voidType, fnType); 208 // Missing OpLabel 209 addReturn(); 210 addFunctionEnd(); 211 212 ASSERT_EQ(llvm::None, deserialize()); 213 expectDiagnostic("a basic block must start with OpLabel"); 214 } 215 216 TEST_F(DeserializationTest, FunctionMalformedLabelFailure) { 217 addHeader(); 218 auto voidType = addVoidType(); 219 auto fnType = addFunctionType(voidType, {}); 220 addFunction(voidType, fnType); 221 addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel 222 addReturn(); 223 addFunctionEnd(); 224 225 ASSERT_EQ(llvm::None, deserialize()); 226 expectDiagnostic("OpLabel should only have result <id>"); 227 } 228