1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains corner case tests for the SPIR-V serializer that are not 10 // covered by normal serialization and deserialization roundtripping. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/SPIRV/Serialization.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 24 #include "llvm/ADT/DenseSet.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/SmallVector.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "gmock/gmock.h" 30 31 using namespace mlir; 32 33 //===----------------------------------------------------------------------===// 34 // Test Fixture 35 //===----------------------------------------------------------------------===// 36 37 class SerializationTest : public ::testing::Test { 38 protected: 39 SerializationTest() { 40 context.getOrLoadDialect<mlir::spirv::SPIRVDialect>(); 41 initModuleOp(); 42 } 43 44 /// Initializes an empty SPIR-V module op. 45 void initModuleOp() { 46 OpBuilder builder(&context); 47 OperationState state(UnknownLoc::get(&context), 48 spirv::ModuleOp::getOperationName()); 49 state.addAttribute("addressing_model", 50 builder.getAttr<spirv::AddressingModelAttr>( 51 spirv::AddressingModel::Logical)); 52 state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>( 53 spirv::MemoryModel::GLSL450)); 54 state.addAttribute("vce_triple", 55 spirv::VerCapExtAttr::get( 56 spirv::Version::V_1_0, ArrayRef<spirv::Capability>(), 57 ArrayRef<spirv::Extension>(), &context)); 58 spirv::ModuleOp::build(builder, state); 59 module = cast<spirv::ModuleOp>(Operation::create(state)); 60 } 61 62 /// Gets the `struct { float }` type. 63 spirv::StructType getFloatStructType() { 64 OpBuilder builder(module->getRegion()); 65 llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()}; 66 llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0}; 67 return spirv::StructType::get(elementTypes, offsetInfo); 68 } 69 70 /// Inserts a global variable of the given `type` and `name`. 71 spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { 72 OpBuilder builder(module->getRegion()); 73 auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); 74 return builder.create<spirv::GlobalVariableOp>( 75 UnknownLoc::get(&context), TypeAttr::get(ptrType), 76 builder.getStringAttr(name), nullptr); 77 } 78 79 // Inserts an Integer or a Vector of Integers constant of value 'val'. 80 spirv::ConstantOp addConstInt(Type type, const APInt &val) { 81 OpBuilder builder(module->getRegion()); 82 auto loc = UnknownLoc::get(&context); 83 84 if (auto intType = dyn_cast<IntegerType>(type)) { 85 return builder.create<spirv::ConstantOp>( 86 loc, type, builder.getIntegerAttr(type, val)); 87 } 88 if (auto vectorType = dyn_cast<VectorType>(type)) { 89 Type elemType = vectorType.getElementType(); 90 if (auto intType = dyn_cast<IntegerType>(elemType)) { 91 return builder.create<spirv::ConstantOp>( 92 loc, type, 93 DenseElementsAttr::get(vectorType, 94 IntegerAttr::get(elemType, val).getValue())); 95 } 96 } 97 llvm_unreachable("unimplemented types for AddConstInt()"); 98 } 99 100 /// Handles a SPIR-V instruction with the given `opcode` and `operand`. 101 /// Returns true to interrupt. 102 using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode, 103 ArrayRef<uint32_t> operands)>; 104 105 /// Returns true if we can find a matching instruction in the SPIR-V blob. 106 bool scanInstruction(HandleFn handleFn) { 107 auto binarySize = binary.size(); 108 auto *begin = binary.begin(); 109 auto currOffset = spirv::kHeaderWordCount; 110 111 while (currOffset < binarySize) { 112 auto wordCount = binary[currOffset] >> 16; 113 if (!wordCount || (currOffset + wordCount > binarySize)) 114 return false; 115 116 spirv::Opcode opcode = 117 static_cast<spirv::Opcode>(binary[currOffset] & 0xffff); 118 llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1, 119 begin + currOffset + wordCount); 120 if (handleFn(opcode, operands)) 121 return true; 122 123 currOffset += wordCount; 124 } 125 return false; 126 } 127 128 protected: 129 MLIRContext context; 130 OwningOpRef<spirv::ModuleOp> module; 131 SmallVector<uint32_t, 0> binary; 132 }; 133 134 //===----------------------------------------------------------------------===// 135 // Block decoration 136 //===----------------------------------------------------------------------===// 137 138 TEST_F(SerializationTest, ContainsBlockDecoration) { 139 auto structType = getFloatStructType(); 140 addGlobalVar(structType, "var0"); 141 142 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 143 144 auto hasBlockDecoration = [](spirv::Opcode opcode, 145 ArrayRef<uint32_t> operands) { 146 return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && 147 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block); 148 }; 149 EXPECT_TRUE(scanInstruction(hasBlockDecoration)); 150 } 151 152 TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) { 153 auto structType = getFloatStructType(); 154 // Two global variables using the same type should not decorate the type with 155 // duplicated `Block` decorations. 156 addGlobalVar(structType, "var0"); 157 addGlobalVar(structType, "var1"); 158 159 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 160 161 unsigned count = 0; 162 auto countBlockDecoration = [&count](spirv::Opcode opcode, 163 ArrayRef<uint32_t> operands) { 164 if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 && 165 operands[1] == static_cast<uint32_t>(spirv::Decoration::Block)) 166 ++count; 167 return false; 168 }; 169 ASSERT_FALSE(scanInstruction(countBlockDecoration)); 170 EXPECT_EQ(count, 1u); 171 } 172 173 TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) { 174 175 auto signlessInt16Type = 176 IntegerType::get(&context, 16, IntegerType::Signless); 177 auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed); 178 // Check the bit extension of same value under different signedness semantics. 179 APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff, 180 signlessInt16Type.getSignedness()); 181 APInt signedIntConstVal(signedInt16Type.getWidth(), -1, 182 signedInt16Type.getSignedness()); 183 184 addConstInt(signlessInt16Type, signlessIntConstVal); 185 addConstInt(signedInt16Type, signedIntConstVal); 186 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); 187 188 auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { 189 return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && 190 operands[2] == 65535; 191 }; 192 EXPECT_TRUE(scanInstruction(hasSignlessVal)); 193 194 auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { 195 return opcode == spirv::Opcode::OpConstant && operands.size() == 3 && 196 operands[2] == 4294967295; 197 }; 198 EXPECT_TRUE(scanInstruction(hasSignedVal)); 199 } 200 201 TEST_F(SerializationTest, ContainsSymbolName) { 202 auto structType = getFloatStructType(); 203 addGlobalVar(structType, "var0"); 204 205 spirv::SerializationOptions options; 206 options.emitSymbolName = true; 207 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); 208 209 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { 210 unsigned index = 1; // Skip the result <id> 211 return opcode == spirv::Opcode::OpName && 212 spirv::decodeStringLiteral(operands, index) == "var0"; 213 }; 214 EXPECT_TRUE(scanInstruction(hasVarName)); 215 } 216 217 TEST_F(SerializationTest, DoesNotContainSymbolName) { 218 auto structType = getFloatStructType(); 219 addGlobalVar(structType, "var0"); 220 221 spirv::SerializationOptions options; 222 options.emitSymbolName = false; 223 ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); 224 225 auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) { 226 unsigned index = 1; // Skip the result <id> 227 return opcode == spirv::Opcode::OpName && 228 spirv::decodeStringLiteral(operands, index) == "var0"; 229 }; 230 EXPECT_FALSE(scanInstruction(hasVarName)); 231 } 232