1 //===- Serializer.h - MLIR SPIR-V Serializer ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H 14 #define MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/Target/SPIRV/Serialization.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 namespace mlir { 24 namespace spirv { 25 26 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op, 27 ArrayRef<uint32_t> operands); 28 29 /// A SPIR-V module serializer. 30 /// 31 /// A SPIR-V binary module is a single linear stream of instructions; each 32 /// instruction is composed of 32-bit words with the layout: 33 /// 34 /// | <word-count>|<opcode> | <operand> | <operand> | ... | 35 /// | <------ word -------> | <-- word --> | <-- word --> | ... | 36 /// 37 /// For the first word, the 16 high-order bits are the word count of the 38 /// instruction, the 16 low-order bits are the opcode enumerant. The 39 /// instructions then belong to different sections, which must be laid out in 40 /// the particular order as specified in "2.4 Logical Layout of a Module" of 41 /// the SPIR-V spec. 42 class Serializer { 43 public: 44 /// Creates a serializer for the given SPIR-V `module`. 45 explicit Serializer(spirv::ModuleOp module, 46 const SerializationOptions &options); 47 48 /// Serializes the remembered SPIR-V module. 49 LogicalResult serialize(); 50 51 /// Collects the final SPIR-V `binary`. 52 void collect(SmallVectorImpl<uint32_t> &binary); 53 54 #ifndef NDEBUG 55 /// (For debugging) prints each value and its corresponding result <id>. 56 void printValueIDMap(raw_ostream &os); 57 #endif 58 59 private: 60 // Note that there are two main categories of methods in this class: 61 // * process*() methods are meant to fully serialize a SPIR-V module entity 62 // (header, type, op, etc.). They update internal vectors containing 63 // different binary sections. They are not meant to be called except the 64 // top-level serialization loop. 65 // * prepare*() methods are meant to be helpers that prepare for serializing 66 // certain entity. They may or may not update internal vectors containing 67 // different binary sections. They are meant to be called among themselves 68 // or by other process*() methods for subtasks. 69 70 //===--------------------------------------------------------------------===// 71 // <id> 72 //===--------------------------------------------------------------------===// 73 74 // Note that it is illegal to use id <0> in SPIR-V binary module. Various 75 // methods in this class, if using SPIR-V word (uint32_t) as interface, 76 // check or return id <0> to indicate error in processing. 77 78 /// Consumes the next unused <id>. This method will never return 0. getNextID()79 uint32_t getNextID() { return nextID++; } 80 81 //===--------------------------------------------------------------------===// 82 // Module structure 83 //===--------------------------------------------------------------------===// 84 getSpecConstID(StringRef constName)85 uint32_t getSpecConstID(StringRef constName) const { 86 return specConstIDMap.lookup(constName); 87 } 88 getVariableID(StringRef varName)89 uint32_t getVariableID(StringRef varName) const { 90 return globalVarIDMap.lookup(varName); 91 } 92 getFunctionID(StringRef fnName)93 uint32_t getFunctionID(StringRef fnName) const { 94 return funcIDMap.lookup(fnName); 95 } 96 97 /// Gets the <id> for the function with the given name. Assigns the next 98 /// available <id> if the function haven't been deserialized. 99 uint32_t getOrCreateFunctionID(StringRef fnName); 100 101 void processCapability(); 102 103 void processDebugInfo(); 104 105 void processExtension(); 106 107 void processMemoryModel(); 108 109 LogicalResult processConstantOp(spirv::ConstantOp op); 110 111 LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); 112 113 LogicalResult 114 processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); 115 116 LogicalResult 117 processSpecConstantOperationOp(spirv::SpecConstantOperationOp op); 118 119 /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA 120 /// value to use with other operations. The SPIR-V spec recommends that 121 /// OpUndef be generated at module level. The serialization generates an 122 /// OpUndef for each type needed at module level. 123 LogicalResult processUndefOp(spirv::UndefOp op); 124 125 /// Emit OpName for the given `resultID`. 126 LogicalResult processName(uint32_t resultID, StringRef name); 127 128 /// Processes a SPIR-V function op. 129 LogicalResult processFuncOp(spirv::FuncOp op); 130 LogicalResult processFuncParameter(spirv::FuncOp op); 131 132 LogicalResult processVariableOp(spirv::VariableOp op); 133 134 /// Process a SPIR-V GlobalVariableOp 135 LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp); 136 137 /// Process attributes that translate to decorations on the result <id> 138 LogicalResult processDecorationAttr(Location loc, uint32_t resultID, 139 Decoration decoration, Attribute attr); 140 LogicalResult processDecoration(Location loc, uint32_t resultID, 141 NamedAttribute attr); 142 143 template <typename DType> processTypeDecoration(Location loc,DType type,uint32_t resultId)144 LogicalResult processTypeDecoration(Location loc, DType type, 145 uint32_t resultId) { 146 return emitError(loc, "unhandled decoration for type:") << type; 147 } 148 149 /// Process member decoration 150 LogicalResult processMemberDecoration( 151 uint32_t structID, 152 const spirv::StructType::MemberDecorationInfo &memberDecorationInfo); 153 154 //===--------------------------------------------------------------------===// 155 // Types 156 //===--------------------------------------------------------------------===// 157 getTypeID(Type type)158 uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); } 159 getVoidType()160 Type getVoidType() { return mlirBuilder.getNoneType(); } 161 isVoidType(Type type)162 bool isVoidType(Type type) const { return isa<NoneType>(type); } 163 164 /// Returns true if the given type is a pointer type to a struct in some 165 /// interface storage class. 166 bool isInterfaceStructPtrType(Type type) const; 167 168 /// Main dispatch method for serializing a type. The result <id> of the 169 /// serialized type will be returned as `typeID`. 170 LogicalResult processType(Location loc, Type type, uint32_t &typeID); 171 LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID, 172 SetVector<StringRef> &serializationCtx); 173 174 /// Method for preparing basic SPIR-V type serialization. Returns the type's 175 /// opcode and operands for the instruction via `typeEnum` and `operands`. 176 LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, 177 spirv::Opcode &typeEnum, 178 SmallVectorImpl<uint32_t> &operands, 179 bool &deferSerialization, 180 SetVector<StringRef> &serializationCtx); 181 182 LogicalResult prepareFunctionType(Location loc, FunctionType type, 183 spirv::Opcode &typeEnum, 184 SmallVectorImpl<uint32_t> &operands); 185 186 //===--------------------------------------------------------------------===// 187 // Constant 188 //===--------------------------------------------------------------------===// 189 getConstantID(Attribute value)190 uint32_t getConstantID(Attribute value) const { 191 return constIDMap.lookup(value); 192 } 193 194 /// Main dispatch method for processing a constant with the given `constType` 195 /// and `valueAttr`. `constType` is needed here because we can interpret the 196 /// `valueAttr` as a different type than the type of `valueAttr` itself; for 197 /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType 198 /// constants. 199 uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); 200 201 /// Prepares array attribute serialization. This method emits corresponding 202 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 203 /// failed. 204 uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); 205 206 /// Prepares bool/int/float DenseElementsAttr serialization. This method 207 /// iterates the DenseElementsAttr to construct the constant array, and 208 /// returns the result <id> associated with it. Returns 0 if failed. Note 209 /// that the size of `index` must match the rank. 210 /// TODO: Consider to enhance splat elements cases. For splat cases, 211 /// we don't need to loop over all elements, especially when the splat value 212 /// is zero. We can use OpConstantNull when the value is zero. 213 uint32_t prepareDenseElementsConstant(Location loc, Type constType, 214 DenseElementsAttr valueAttr, int dim, 215 MutableArrayRef<uint64_t> index); 216 217 /// Prepares scalar attribute serialization. This method emits corresponding 218 /// OpConstant* and returns the result <id> associated with it. Returns 0 if 219 /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is 220 /// true, then the constant will be serialized as a specialization constant. 221 uint32_t prepareConstantScalar(Location loc, Attribute valueAttr, 222 bool isSpec = false); 223 224 uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, 225 bool isSpec = false); 226 227 uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, 228 bool isSpec = false); 229 230 uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, 231 bool isSpec = false); 232 233 //===--------------------------------------------------------------------===// 234 // Control flow 235 //===--------------------------------------------------------------------===// 236 237 /// Returns the result <id> for the given block. getBlockID(Block * block)238 uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); } 239 240 /// Returns the result <id> for the given block. If no <id> has been assigned, 241 /// assigns the next available <id> 242 uint32_t getOrCreateBlockID(Block *block); 243 244 #ifndef NDEBUG 245 /// (For debugging) prints the block with its result <id>. 246 void printBlock(Block *block, raw_ostream &os); 247 #endif 248 249 /// Processes the given `block` and emits SPIR-V instructions for all ops 250 /// inside. Does not emit OpLabel for this block if `omitLabel` is true. 251 /// `emitMerge` is a callback that will be invoked before handling the 252 /// terminator op to inject the Op*Merge instruction if this is a SPIR-V 253 /// selection/loop header block. 254 LogicalResult processBlock(Block *block, bool omitLabel = false, 255 function_ref<LogicalResult()> emitMerge = nullptr); 256 257 /// Emits OpPhi instructions for the given block if it has block arguments. 258 LogicalResult emitPhiForBlockArguments(Block *block); 259 260 LogicalResult processSelectionOp(spirv::SelectionOp selectionOp); 261 262 LogicalResult processLoopOp(spirv::LoopOp loopOp); 263 264 LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); 265 266 LogicalResult processBranchOp(spirv::BranchOp branchOp); 267 268 //===--------------------------------------------------------------------===// 269 // Operations 270 //===--------------------------------------------------------------------===// 271 272 LogicalResult encodeExtensionInstruction(Operation *op, 273 StringRef extensionSetName, 274 uint32_t opcode, 275 ArrayRef<uint32_t> operands); 276 getValueID(Value val)277 uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); } 278 279 LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp); 280 281 LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp); 282 283 /// Main dispatch method for serializing an operation. 284 LogicalResult processOperation(Operation *op); 285 286 /// Serializes an operation `op` as core instruction with `opcode` if 287 /// `extInstSet` is empty. Otherwise serializes it as an extended instruction 288 /// with `opcode` from `extInstSet`. 289 /// This method is a generic one for dispatching any SPIR-V ops that has no 290 /// variadic operands and attributes in TableGen definitions. 291 LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet, 292 uint32_t opcode); 293 294 /// Dispatches to the serialization function for an operation in SPIR-V 295 /// dialect that is a mirror of an instruction in the SPIR-V spec. This is 296 /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V 297 /// dialect that have hasOpcode == 1. 298 LogicalResult dispatchToAutogenSerialization(Operation *op); 299 300 /// Serializes an operation in the SPIR-V dialect that is a mirror of an 301 /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1 302 /// and autogenSerialization == 1 in ODS. 303 template <typename OpTy> processOp(OpTy op)304 LogicalResult processOp(OpTy op) { 305 return op.emitError("unsupported op serialization"); 306 } 307 308 //===--------------------------------------------------------------------===// 309 // Utilities 310 //===--------------------------------------------------------------------===// 311 312 /// Emits an OpDecorate instruction to decorate the given `target` with the 313 /// given `decoration`. 314 LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, 315 ArrayRef<uint32_t> params = {}); 316 317 /// Emits an OpLine instruction with the given `loc` location information into 318 /// the given `binary` vector. 319 LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc); 320 321 private: 322 /// The SPIR-V module to be serialized. 323 spirv::ModuleOp module; 324 325 /// An MLIR builder for getting MLIR constructs. 326 mlir::Builder mlirBuilder; 327 328 /// Serialization options. 329 SerializationOptions options; 330 331 /// A flag which indicates if the last processed instruction was a merge 332 /// instruction. 333 /// According to SPIR-V spec: "If a branch merge instruction is used, the last 334 /// OpLine in the block must be before its merge instruction". 335 bool lastProcessedWasMergeInst = false; 336 337 /// The <id> of the OpString instruction, which specifies a file name, for 338 /// use by other debug instructions. 339 uint32_t fileID = 0; 340 341 /// The next available result <id>. 342 uint32_t nextID = 1; 343 344 // The following are for different SPIR-V instruction sections. They follow 345 // the logical layout of a SPIR-V module. 346 347 SmallVector<uint32_t, 4> capabilities; 348 SmallVector<uint32_t, 0> extensions; 349 SmallVector<uint32_t, 0> extendedSets; 350 SmallVector<uint32_t, 3> memoryModel; 351 SmallVector<uint32_t, 0> entryPoints; 352 SmallVector<uint32_t, 4> executionModes; 353 SmallVector<uint32_t, 0> debug; 354 SmallVector<uint32_t, 0> names; 355 SmallVector<uint32_t, 0> decorations; 356 SmallVector<uint32_t, 0> typesGlobalValues; 357 SmallVector<uint32_t, 0> functions; 358 359 /// Recursive struct references are serialized as OpTypePointer instructions 360 /// to the recursive struct type. However, the OpTypePointer instruction 361 /// cannot be emitted before the recursive struct's OpTypeStruct. 362 /// RecursiveStructPointerInfo stores the data needed to emit such 363 /// OpTypePointer instructions after forward references to such types. 364 struct RecursiveStructPointerInfo { 365 uint32_t pointerTypeID; 366 spirv::StorageClass storageClass; 367 }; 368 369 // Maps spirv::StructType to its recursive reference member info. 370 DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>> 371 recursiveStructInfos; 372 373 /// `functionHeader` contains all the instructions that must be in the first 374 /// block in the function, and `functionBody` contains the rest. After 375 /// processing FuncOp, the encoded instructions of a function are appended to 376 /// `functions`. An example of instructions in `functionHeader` in order: 377 /// OpFunction ... 378 /// OpFunctionParameter ... 379 /// OpFunctionParameter ... 380 /// OpLabel ... 381 /// OpVariable ... 382 /// OpVariable ... 383 SmallVector<uint32_t, 0> functionHeader; 384 SmallVector<uint32_t, 0> functionBody; 385 386 /// Map from type used in SPIR-V module to their <id>s. 387 DenseMap<Type, uint32_t> typeIDMap; 388 389 /// Map from constant values to their <id>s. 390 DenseMap<Attribute, uint32_t> constIDMap; 391 392 /// Map from specialization constant names to their <id>s. 393 llvm::StringMap<uint32_t> specConstIDMap; 394 395 /// Map from GlobalVariableOps name to <id>s. 396 llvm::StringMap<uint32_t> globalVarIDMap; 397 398 /// Map from FuncOps name to <id>s. 399 llvm::StringMap<uint32_t> funcIDMap; 400 401 /// Map from blocks to their <id>s. 402 DenseMap<Block *, uint32_t> blockIDMap; 403 404 /// Map from the Type to the <id> that represents undef value of that type. 405 DenseMap<Type, uint32_t> undefValIDMap; 406 407 /// Map from results of normal operations to their <id>s. 408 DenseMap<Value, uint32_t> valueIDMap; 409 410 /// Map from extended instruction set name to <id>s. 411 llvm::StringMap<uint32_t> extendedInstSetIDMap; 412 413 /// Map from values used in OpPhi instructions to their offset in the 414 /// `functions` section. 415 /// 416 /// When processing a block with arguments, we need to emit OpPhi 417 /// instructions to record the predecessor block <id>s and the values they 418 /// send to the block in question. But it's not guaranteed all values are 419 /// visited and thus assigned result <id>s. So we need this list to capture 420 /// the offsets into `functions` where a value is used so that we can fix it 421 /// up later after processing all the blocks in a function. 422 /// 423 /// More concretely, say if we are visiting the following blocks: 424 /// 425 /// ```mlir 426 /// ^phi(%arg0: i32): 427 /// ... 428 /// ^parent1: 429 /// ... 430 /// spirv.Branch ^phi(%val0: i32) 431 /// ^parent2: 432 /// ... 433 /// spirv.Branch ^phi(%val1: i32) 434 /// ``` 435 /// 436 /// When we are serializing the `^phi` block, we need to emit at the beginning 437 /// of the block OpPhi instructions which has the following parameters: 438 /// 439 /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1 440 /// id-for-%val1 id-for-^parent2 441 /// 442 /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit 443 /// all the blocks twice and use the first visit to assign an <id> to each 444 /// value. But it's paying the overheads just for OpPhi emission. Instead, 445 /// we still visit the blocks once for emission. When we emit the OpPhi 446 /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1. 447 /// At the same time, we record their offsets in the emitted binary (which is 448 /// placed inside `functions`) here. And then after emitting all blocks, we 449 /// replace the dummy <id> 0 with the real result <id> by overwriting 450 /// `functions[offset]`. 451 DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues; 452 }; 453 } // namespace spirv 454 } // namespace mlir 455 456 #endif // MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H 457