1 //===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // The purpose of this file is to provide negative deserialization tests. 19 // For positive deserialization tests, please use serialization and 20 // deserialization for roundtripping. 21 // 22 //===----------------------------------------------------------------------===// 23 24 #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" 25 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 26 #include "mlir/Dialect/SPIRV/Serialization.h" 27 #include "mlir/IR/Diagnostics.h" 28 #include "mlir/IR/MLIRContext.h" 29 #include "gmock/gmock.h" 30 31 #include <memory> 32 33 using namespace mlir; 34 35 using ::testing::StrEq; 36 37 //===----------------------------------------------------------------------===// 38 // Test Fixture 39 //===----------------------------------------------------------------------===// 40 41 /// A deserialization test fixture providing minimal SPIR-V building and 42 /// diagnostic checking utilities. 43 class DeserializationTest : public ::testing::Test { 44 protected: 45 DeserializationTest() { 46 // Register a diagnostic handler to capture the diagnostic so that we can 47 // check it later. 48 context.getDiagEngine().setHandler([&](Diagnostic diag) { 49 diagnostic.reset(new Diagnostic(std::move(diag))); 50 }); 51 } 52 53 /// Performs deserialization and returns the constructed spv.module op. 54 Optional<spirv::ModuleOp> deserialize() { 55 return spirv::deserialize(binary, &context); 56 } 57 58 /// Checks there is a diagnostic generated with the given `errorMessage`. 59 void expectDiagnostic(StringRef errorMessage) { 60 ASSERT_NE(nullptr, diagnostic.get()); 61 62 // TODO(antiagainst): check error location too. 63 EXPECT_THAT(diagnostic->str(), StrEq(errorMessage)); 64 } 65 66 //===--------------------------------------------------------------------===// 67 // SPIR-V builder methods 68 //===--------------------------------------------------------------------===// 69 70 /// Adds the SPIR-V module header to `binary`. 71 void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); } 72 73 /// Adds the SPIR-V instruction into `binary`. 74 void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { 75 uint32_t wordCount = 1 + operands.size(); 76 assert(((wordCount >> 16) == 0) && "word count out of range!"); 77 78 uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op); 79 binary.push_back(prefixedOpcode); 80 binary.append(operands.begin(), operands.end()); 81 } 82 83 uint32_t addVoidType() { 84 auto id = nextID++; 85 addInstruction(spirv::Opcode::OpTypeVoid, {id}); 86 return id; 87 } 88 89 uint32_t addIntType(uint32_t bitwidth) { 90 auto id = nextID++; 91 addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); 92 return id; 93 } 94 95 uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { 96 auto id = nextID++; 97 SmallVector<uint32_t, 4> operands; 98 operands.push_back(id); 99 operands.push_back(retType); 100 operands.append(paramTypes.begin(), paramTypes.end()); 101 addInstruction(spirv::Opcode::OpTypeFunction, operands); 102 return id; 103 } 104 105 uint32_t addFunction(uint32_t retType, uint32_t fnType) { 106 auto id = nextID++; 107 addInstruction(spirv::Opcode::OpFunction, 108 {retType, id, 109 static_cast<uint32_t>(spirv::FunctionControl::None), 110 fnType}); 111 return id; 112 } 113 114 uint32_t addFunctionEnd() { 115 auto id = nextID++; 116 addInstruction(spirv::Opcode::OpFunctionEnd, {id}); 117 return id; 118 } 119 120 protected: 121 SmallVector<uint32_t, 5> binary; 122 uint32_t nextID = 1; 123 MLIRContext context; 124 std::unique_ptr<Diagnostic> diagnostic; 125 }; 126 127 //===----------------------------------------------------------------------===// 128 // Basics 129 //===----------------------------------------------------------------------===// 130 131 TEST_F(DeserializationTest, EmptyModuleFailure) { 132 ASSERT_EQ(llvm::None, deserialize()); 133 expectDiagnostic("SPIR-V binary module must have a 5-word header"); 134 } 135 136 TEST_F(DeserializationTest, WrongMagicNumberFailure) { 137 addHeader(); 138 binary.front() = 0xdeadbeef; // Change to a wrong magic number 139 ASSERT_EQ(llvm::None, deserialize()); 140 expectDiagnostic("incorrect magic number"); 141 } 142 143 TEST_F(DeserializationTest, OnlyHeaderSuccess) { 144 addHeader(); 145 EXPECT_NE(llvm::None, deserialize()); 146 } 147 148 TEST_F(DeserializationTest, ZeroWordCountFailure) { 149 addHeader(); 150 binary.push_back(0); // OpNop with zero word count 151 152 ASSERT_EQ(llvm::None, deserialize()); 153 expectDiagnostic("word count cannot be zero"); 154 } 155 156 TEST_F(DeserializationTest, InsufficientWordFailure) { 157 addHeader(); 158 binary.push_back((2u << 16) | 159 static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); 160 // Missing word for type <id> 161 162 ASSERT_EQ(llvm::None, deserialize()); 163 expectDiagnostic("insufficient words for the last instruction"); 164 } 165 166 //===----------------------------------------------------------------------===// 167 // Types 168 //===----------------------------------------------------------------------===// 169 170 TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { 171 addHeader(); 172 addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); 173 174 ASSERT_EQ(llvm::None, deserialize()); 175 expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); 176 } 177 178 //===----------------------------------------------------------------------===// 179 // Functions 180 //===----------------------------------------------------------------------===// 181 182 TEST_F(DeserializationTest, FunctionMissingEndFailure) { 183 addHeader(); 184 auto voidType = addVoidType(); 185 auto fnType = addFunctionType(voidType, {}); 186 addFunction(voidType, fnType); 187 // Missing OpFunctionEnd 188 189 ASSERT_EQ(llvm::None, deserialize()); 190 expectDiagnostic("expected OpFunctionEnd instruction"); 191 } 192 193 TEST_F(DeserializationTest, FunctionMissingParameterFailure) { 194 addHeader(); 195 auto voidType = addVoidType(); 196 auto i32Type = addIntType(32); 197 auto fnType = addFunctionType(voidType, {i32Type}); 198 addFunction(voidType, fnType); 199 // Missing OpFunctionParameter 200 201 ASSERT_EQ(llvm::None, deserialize()); 202 expectDiagnostic("expected OpFunctionParameter instruction"); 203 } 204