xref: /llvm-project/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp (revision e692af85966903614d470a7742ed89d124baf1a6)
1 //===- SerializationTest.cpp - SPIR-V Serialization Tests -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains corner case tests for the SPIR-V serializer that are not
10 // covered by normal serialization and deserialization roundtripping.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/SPIRV/Serialization.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "gmock/gmock.h"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Test Fixture
35 //===----------------------------------------------------------------------===//
36 
37 class SerializationTest : public ::testing::Test {
38 protected:
39   SerializationTest() {
40     context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
41     initModuleOp();
42   }
43 
44   /// Initializes an empty SPIR-V module op.
45   void initModuleOp() {
46     OpBuilder builder(&context);
47     OperationState state(UnknownLoc::get(&context),
48                          spirv::ModuleOp::getOperationName());
49     state.addAttribute("addressing_model",
50                        builder.getAttr<spirv::AddressingModelAttr>(
51                            spirv::AddressingModel::Logical));
52     state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
53                                            spirv::MemoryModel::GLSL450));
54     state.addAttribute("vce_triple",
55                        spirv::VerCapExtAttr::get(
56                            spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),
57                            ArrayRef<spirv::Extension>(), &context));
58     spirv::ModuleOp::build(builder, state);
59     module = cast<spirv::ModuleOp>(Operation::create(state));
60   }
61 
62   /// Gets the `struct { float }` type.
63   spirv::StructType getFloatStructType() {
64     OpBuilder builder(module->getRegion());
65     llvm::SmallVector<Type, 1> elementTypes{builder.getF32Type()};
66     llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
67     return spirv::StructType::get(elementTypes, offsetInfo);
68   }
69 
70   /// Inserts a global variable of the given `type` and `name`.
71   spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) {
72     OpBuilder builder(module->getRegion());
73     auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
74     return builder.create<spirv::GlobalVariableOp>(
75         UnknownLoc::get(&context), TypeAttr::get(ptrType),
76         builder.getStringAttr(name), nullptr);
77   }
78 
79   // Inserts an Integer or a Vector of Integers constant of value 'val'.
80   spirv::ConstantOp addConstInt(Type type, const APInt &val) {
81     OpBuilder builder(module->getRegion());
82     auto loc = UnknownLoc::get(&context);
83 
84     if (auto intType = dyn_cast<IntegerType>(type)) {
85       return builder.create<spirv::ConstantOp>(
86           loc, type, builder.getIntegerAttr(type, val));
87     }
88     if (auto vectorType = dyn_cast<VectorType>(type)) {
89       Type elemType = vectorType.getElementType();
90       if (auto intType = dyn_cast<IntegerType>(elemType)) {
91         return builder.create<spirv::ConstantOp>(
92             loc, type,
93             DenseElementsAttr::get(vectorType,
94                                    IntegerAttr::get(elemType, val).getValue()));
95       }
96     }
97     llvm_unreachable("unimplemented types for AddConstInt()");
98   }
99 
100   /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
101   /// Returns true to interrupt.
102   using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
103                                            ArrayRef<uint32_t> operands)>;
104 
105   /// Returns true if we can find a matching instruction in the SPIR-V blob.
106   bool scanInstruction(HandleFn handleFn) {
107     auto binarySize = binary.size();
108     auto *begin = binary.begin();
109     auto currOffset = spirv::kHeaderWordCount;
110 
111     while (currOffset < binarySize) {
112       auto wordCount = binary[currOffset] >> 16;
113       if (!wordCount || (currOffset + wordCount > binarySize))
114         return false;
115 
116       spirv::Opcode opcode =
117           static_cast<spirv::Opcode>(binary[currOffset] & 0xffff);
118       llvm::ArrayRef<uint32_t> operands(begin + currOffset + 1,
119                                         begin + currOffset + wordCount);
120       if (handleFn(opcode, operands))
121         return true;
122 
123       currOffset += wordCount;
124     }
125     return false;
126   }
127 
128 protected:
129   MLIRContext context;
130   OwningOpRef<spirv::ModuleOp> module;
131   SmallVector<uint32_t, 0> binary;
132 };
133 
134 //===----------------------------------------------------------------------===//
135 // Block decoration
136 //===----------------------------------------------------------------------===//
137 
138 TEST_F(SerializationTest, ContainsBlockDecoration) {
139   auto structType = getFloatStructType();
140   addGlobalVar(structType, "var0");
141 
142   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
143 
144   auto hasBlockDecoration = [](spirv::Opcode opcode,
145                                ArrayRef<uint32_t> operands) {
146     return opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
147            operands[1] == static_cast<uint32_t>(spirv::Decoration::Block);
148   };
149   EXPECT_TRUE(scanInstruction(hasBlockDecoration));
150 }
151 
152 TEST_F(SerializationTest, ContainsNoDuplicatedBlockDecoration) {
153   auto structType = getFloatStructType();
154   // Two global variables using the same type should not decorate the type with
155   // duplicated `Block` decorations.
156   addGlobalVar(structType, "var0");
157   addGlobalVar(structType, "var1");
158 
159   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
160 
161   unsigned count = 0;
162   auto countBlockDecoration = [&count](spirv::Opcode opcode,
163                                        ArrayRef<uint32_t> operands) {
164     if (opcode == spirv::Opcode::OpDecorate && operands.size() == 2 &&
165         operands[1] == static_cast<uint32_t>(spirv::Decoration::Block))
166       ++count;
167     return false;
168   };
169   ASSERT_FALSE(scanInstruction(countBlockDecoration));
170   EXPECT_EQ(count, 1u);
171 }
172 
173 TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
174 
175   auto signlessInt16Type =
176       IntegerType::get(&context, 16, IntegerType::Signless);
177   auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
178   // Check the bit extension of same value under different signedness semantics.
179   APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
180                             signlessInt16Type.getSignedness());
181   APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
182                           signedInt16Type.getSignedness());
183 
184   addConstInt(signlessInt16Type, signlessIntConstVal);
185   addConstInt(signedInt16Type, signedIntConstVal);
186   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
187 
188   auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
189     return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
190            operands[2] == 65535;
191   };
192   EXPECT_TRUE(scanInstruction(hasSignlessVal));
193 
194   auto hasSignedVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
195     return opcode == spirv::Opcode::OpConstant && operands.size() == 3 &&
196            operands[2] == 4294967295;
197   };
198   EXPECT_TRUE(scanInstruction(hasSignedVal));
199 }
200 
201 TEST_F(SerializationTest, ContainsSymbolName) {
202   auto structType = getFloatStructType();
203   addGlobalVar(structType, "var0");
204 
205   spirv::SerializationOptions options;
206   options.emitSymbolName = true;
207   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
208 
209   auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
210     unsigned index = 1; // Skip the result <id>
211     return opcode == spirv::Opcode::OpName &&
212            spirv::decodeStringLiteral(operands, index) == "var0";
213   };
214   EXPECT_TRUE(scanInstruction(hasVarName));
215 }
216 
217 TEST_F(SerializationTest, DoesNotContainSymbolName) {
218   auto structType = getFloatStructType();
219   addGlobalVar(structType, "var0");
220 
221   spirv::SerializationOptions options;
222   options.emitSymbolName = false;
223   ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
224 
225   auto hasVarName = [](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
226     unsigned index = 1; // Skip the result <id>
227     return opcode == spirv::Opcode::OpName &&
228            spirv::decodeStringLiteral(operands, index) == "var0";
229   };
230   EXPECT_FALSE(scanInstruction(hasVarName));
231 }
232