xref: /llvm-project/mlir/lib/Target/SPIRV/Serialization/Serializer.h (revision 747d8fb01c2417546ebaa774874ff8c3005e058a)
1 //===- Serializer.h - MLIR SPIR-V Serializer ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H
14 #define MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Target/SPIRV/Serialization.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 namespace mlir {
24 namespace spirv {
25 
26 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
27                            ArrayRef<uint32_t> operands);
28 
29 /// A SPIR-V module serializer.
30 ///
31 /// A SPIR-V binary module is a single linear stream of instructions; each
32 /// instruction is composed of 32-bit words with the layout:
33 ///
34 ///   | <word-count>|<opcode> |  <operand>   |  <operand>   | ... |
35 ///   | <------ word -------> | <-- word --> | <-- word --> | ... |
36 ///
37 /// For the first word, the 16 high-order bits are the word count of the
38 /// instruction, the 16 low-order bits are the opcode enumerant. The
39 /// instructions then belong to different sections, which must be laid out in
40 /// the particular order as specified in "2.4 Logical Layout of a Module" of
41 /// the SPIR-V spec.
42 class Serializer {
43 public:
44   /// Creates a serializer for the given SPIR-V `module`.
45   explicit Serializer(spirv::ModuleOp module,
46                       const SerializationOptions &options);
47 
48   /// Serializes the remembered SPIR-V module.
49   LogicalResult serialize();
50 
51   /// Collects the final SPIR-V `binary`.
52   void collect(SmallVectorImpl<uint32_t> &binary);
53 
54 #ifndef NDEBUG
55   /// (For debugging) prints each value and its corresponding result <id>.
56   void printValueIDMap(raw_ostream &os);
57 #endif
58 
59 private:
60   // Note that there are two main categories of methods in this class:
61   // * process*() methods are meant to fully serialize a SPIR-V module entity
62   //   (header, type, op, etc.). They update internal vectors containing
63   //   different binary sections. They are not meant to be called except the
64   //   top-level serialization loop.
65   // * prepare*() methods are meant to be helpers that prepare for serializing
66   //   certain entity. They may or may not update internal vectors containing
67   //   different binary sections. They are meant to be called among themselves
68   //   or by other process*() methods for subtasks.
69 
70   //===--------------------------------------------------------------------===//
71   // <id>
72   //===--------------------------------------------------------------------===//
73 
74   // Note that it is illegal to use id <0> in SPIR-V binary module. Various
75   // methods in this class, if using SPIR-V word (uint32_t) as interface,
76   // check or return id <0> to indicate error in processing.
77 
78   /// Consumes the next unused <id>. This method will never return 0.
getNextID()79   uint32_t getNextID() { return nextID++; }
80 
81   //===--------------------------------------------------------------------===//
82   // Module structure
83   //===--------------------------------------------------------------------===//
84 
getSpecConstID(StringRef constName)85   uint32_t getSpecConstID(StringRef constName) const {
86     return specConstIDMap.lookup(constName);
87   }
88 
getVariableID(StringRef varName)89   uint32_t getVariableID(StringRef varName) const {
90     return globalVarIDMap.lookup(varName);
91   }
92 
getFunctionID(StringRef fnName)93   uint32_t getFunctionID(StringRef fnName) const {
94     return funcIDMap.lookup(fnName);
95   }
96 
97   /// Gets the <id> for the function with the given name. Assigns the next
98   /// available <id> if the function haven't been deserialized.
99   uint32_t getOrCreateFunctionID(StringRef fnName);
100 
101   void processCapability();
102 
103   void processDebugInfo();
104 
105   void processExtension();
106 
107   void processMemoryModel();
108 
109   LogicalResult processConstantOp(spirv::ConstantOp op);
110 
111   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
112 
113   LogicalResult
114   processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
115 
116   LogicalResult
117   processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
118 
119   /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA
120   /// value to use with other operations. The SPIR-V spec recommends that
121   /// OpUndef be generated at module level. The serialization generates an
122   /// OpUndef for each type needed at module level.
123   LogicalResult processUndefOp(spirv::UndefOp op);
124 
125   /// Emit OpName for the given `resultID`.
126   LogicalResult processName(uint32_t resultID, StringRef name);
127 
128   /// Processes a SPIR-V function op.
129   LogicalResult processFuncOp(spirv::FuncOp op);
130   LogicalResult processFuncParameter(spirv::FuncOp op);
131 
132   LogicalResult processVariableOp(spirv::VariableOp op);
133 
134   /// Process a SPIR-V GlobalVariableOp
135   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
136 
137   /// Process attributes that translate to decorations on the result <id>
138   LogicalResult processDecorationAttr(Location loc, uint32_t resultID,
139                                       Decoration decoration, Attribute attr);
140   LogicalResult processDecoration(Location loc, uint32_t resultID,
141                                   NamedAttribute attr);
142 
143   template <typename DType>
processTypeDecoration(Location loc,DType type,uint32_t resultId)144   LogicalResult processTypeDecoration(Location loc, DType type,
145                                       uint32_t resultId) {
146     return emitError(loc, "unhandled decoration for type:") << type;
147   }
148 
149   /// Process member decoration
150   LogicalResult processMemberDecoration(
151       uint32_t structID,
152       const spirv::StructType::MemberDecorationInfo &memberDecorationInfo);
153 
154   //===--------------------------------------------------------------------===//
155   // Types
156   //===--------------------------------------------------------------------===//
157 
getTypeID(Type type)158   uint32_t getTypeID(Type type) const { return typeIDMap.lookup(type); }
159 
getVoidType()160   Type getVoidType() { return mlirBuilder.getNoneType(); }
161 
isVoidType(Type type)162   bool isVoidType(Type type) const { return isa<NoneType>(type); }
163 
164   /// Returns true if the given type is a pointer type to a struct in some
165   /// interface storage class.
166   bool isInterfaceStructPtrType(Type type) const;
167 
168   /// Main dispatch method for serializing a type. The result <id> of the
169   /// serialized type will be returned as `typeID`.
170   LogicalResult processType(Location loc, Type type, uint32_t &typeID);
171   LogicalResult processTypeImpl(Location loc, Type type, uint32_t &typeID,
172                                 SetVector<StringRef> &serializationCtx);
173 
174   /// Method for preparing basic SPIR-V type serialization. Returns the type's
175   /// opcode and operands for the instruction via `typeEnum` and `operands`.
176   LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID,
177                                  spirv::Opcode &typeEnum,
178                                  SmallVectorImpl<uint32_t> &operands,
179                                  bool &deferSerialization,
180                                  SetVector<StringRef> &serializationCtx);
181 
182   LogicalResult prepareFunctionType(Location loc, FunctionType type,
183                                     spirv::Opcode &typeEnum,
184                                     SmallVectorImpl<uint32_t> &operands);
185 
186   //===--------------------------------------------------------------------===//
187   // Constant
188   //===--------------------------------------------------------------------===//
189 
getConstantID(Attribute value)190   uint32_t getConstantID(Attribute value) const {
191     return constIDMap.lookup(value);
192   }
193 
194   /// Main dispatch method for processing a constant with the given `constType`
195   /// and `valueAttr`. `constType` is needed here because we can interpret the
196   /// `valueAttr` as a different type than the type of `valueAttr` itself; for
197   /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
198   /// constants.
199   uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
200 
201   /// Prepares array attribute serialization. This method emits corresponding
202   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
203   /// failed.
204   uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
205 
206   /// Prepares bool/int/float DenseElementsAttr serialization. This method
207   /// iterates the DenseElementsAttr to construct the constant array, and
208   /// returns the result <id>  associated with it. Returns 0 if failed. Note
209   /// that the size of `index` must match the rank.
210   /// TODO: Consider to enhance splat elements cases. For splat cases,
211   /// we don't need to loop over all elements, especially when the splat value
212   /// is zero. We can use OpConstantNull when the value is zero.
213   uint32_t prepareDenseElementsConstant(Location loc, Type constType,
214                                         DenseElementsAttr valueAttr, int dim,
215                                         MutableArrayRef<uint64_t> index);
216 
217   /// Prepares scalar attribute serialization. This method emits corresponding
218   /// OpConstant* and returns the result <id> associated with it. Returns 0 if
219   /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
220   /// true, then the constant will be serialized as a specialization constant.
221   uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
222                                  bool isSpec = false);
223 
224   uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
225                                bool isSpec = false);
226 
227   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
228                               bool isSpec = false);
229 
230   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
231                              bool isSpec = false);
232 
233   //===--------------------------------------------------------------------===//
234   // Control flow
235   //===--------------------------------------------------------------------===//
236 
237   /// Returns the result <id> for the given block.
getBlockID(Block * block)238   uint32_t getBlockID(Block *block) const { return blockIDMap.lookup(block); }
239 
240   /// Returns the result <id> for the given block. If no <id> has been assigned,
241   /// assigns the next available <id>
242   uint32_t getOrCreateBlockID(Block *block);
243 
244 #ifndef NDEBUG
245   /// (For debugging) prints the block with its result <id>.
246   void printBlock(Block *block, raw_ostream &os);
247 #endif
248 
249   /// Processes the given `block` and emits SPIR-V instructions for all ops
250   /// inside. Does not emit OpLabel for this block if `omitLabel` is true.
251   /// `emitMerge` is a callback that will be invoked before handling the
252   /// terminator op to inject the Op*Merge instruction if this is a SPIR-V
253   /// selection/loop header block.
254   LogicalResult processBlock(Block *block, bool omitLabel = false,
255                              function_ref<LogicalResult()> emitMerge = nullptr);
256 
257   /// Emits OpPhi instructions for the given block if it has block arguments.
258   LogicalResult emitPhiForBlockArguments(Block *block);
259 
260   LogicalResult processSelectionOp(spirv::SelectionOp selectionOp);
261 
262   LogicalResult processLoopOp(spirv::LoopOp loopOp);
263 
264   LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
265 
266   LogicalResult processBranchOp(spirv::BranchOp branchOp);
267 
268   //===--------------------------------------------------------------------===//
269   // Operations
270   //===--------------------------------------------------------------------===//
271 
272   LogicalResult encodeExtensionInstruction(Operation *op,
273                                            StringRef extensionSetName,
274                                            uint32_t opcode,
275                                            ArrayRef<uint32_t> operands);
276 
getValueID(Value val)277   uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
278 
279   LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
280 
281   LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
282 
283   /// Main dispatch method for serializing an operation.
284   LogicalResult processOperation(Operation *op);
285 
286   /// Serializes an operation `op` as core instruction with `opcode` if
287   /// `extInstSet` is empty. Otherwise serializes it as an extended instruction
288   /// with `opcode` from `extInstSet`.
289   /// This method is a generic one for dispatching any SPIR-V ops that has no
290   /// variadic operands and attributes in TableGen definitions.
291   LogicalResult processOpWithoutGrammarAttr(Operation *op, StringRef extInstSet,
292                                             uint32_t opcode);
293 
294   /// Dispatches to the serialization function for an operation in SPIR-V
295   /// dialect that is a mirror of an instruction in the SPIR-V spec. This is
296   /// auto-generated from ODS. Dispatch is handled for all operations in SPIR-V
297   /// dialect that have hasOpcode == 1.
298   LogicalResult dispatchToAutogenSerialization(Operation *op);
299 
300   /// Serializes an operation in the SPIR-V dialect that is a mirror of an
301   /// instruction in the SPIR-V spec. This is auto generated if hasOpcode == 1
302   /// and autogenSerialization == 1 in ODS.
303   template <typename OpTy>
processOp(OpTy op)304   LogicalResult processOp(OpTy op) {
305     return op.emitError("unsupported op serialization");
306   }
307 
308   //===--------------------------------------------------------------------===//
309   // Utilities
310   //===--------------------------------------------------------------------===//
311 
312   /// Emits an OpDecorate instruction to decorate the given `target` with the
313   /// given `decoration`.
314   LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
315                                ArrayRef<uint32_t> params = {});
316 
317   /// Emits an OpLine instruction with the given `loc` location information into
318   /// the given `binary` vector.
319   LogicalResult emitDebugLine(SmallVectorImpl<uint32_t> &binary, Location loc);
320 
321 private:
322   /// The SPIR-V module to be serialized.
323   spirv::ModuleOp module;
324 
325   /// An MLIR builder for getting MLIR constructs.
326   mlir::Builder mlirBuilder;
327 
328   /// Serialization options.
329   SerializationOptions options;
330 
331   /// A flag which indicates if the last processed instruction was a merge
332   /// instruction.
333   /// According to SPIR-V spec: "If a branch merge instruction is used, the last
334   /// OpLine in the block must be before its merge instruction".
335   bool lastProcessedWasMergeInst = false;
336 
337   /// The <id> of the OpString instruction, which specifies a file name, for
338   /// use by other debug instructions.
339   uint32_t fileID = 0;
340 
341   /// The next available result <id>.
342   uint32_t nextID = 1;
343 
344   // The following are for different SPIR-V instruction sections. They follow
345   // the logical layout of a SPIR-V module.
346 
347   SmallVector<uint32_t, 4> capabilities;
348   SmallVector<uint32_t, 0> extensions;
349   SmallVector<uint32_t, 0> extendedSets;
350   SmallVector<uint32_t, 3> memoryModel;
351   SmallVector<uint32_t, 0> entryPoints;
352   SmallVector<uint32_t, 4> executionModes;
353   SmallVector<uint32_t, 0> debug;
354   SmallVector<uint32_t, 0> names;
355   SmallVector<uint32_t, 0> decorations;
356   SmallVector<uint32_t, 0> typesGlobalValues;
357   SmallVector<uint32_t, 0> functions;
358 
359   /// Recursive struct references are serialized as OpTypePointer instructions
360   /// to the recursive struct type. However, the OpTypePointer instruction
361   /// cannot be emitted before the recursive struct's OpTypeStruct.
362   /// RecursiveStructPointerInfo stores the data needed to emit such
363   /// OpTypePointer instructions after forward references to such types.
364   struct RecursiveStructPointerInfo {
365     uint32_t pointerTypeID;
366     spirv::StorageClass storageClass;
367   };
368 
369   // Maps spirv::StructType to its recursive reference member info.
370   DenseMap<Type, SmallVector<RecursiveStructPointerInfo, 0>>
371       recursiveStructInfos;
372 
373   /// `functionHeader` contains all the instructions that must be in the first
374   /// block in the function, and `functionBody` contains the rest. After
375   /// processing FuncOp, the encoded instructions of a function are appended to
376   /// `functions`. An example of instructions in `functionHeader` in order:
377   /// OpFunction ...
378   /// OpFunctionParameter ...
379   /// OpFunctionParameter ...
380   /// OpLabel ...
381   /// OpVariable ...
382   /// OpVariable ...
383   SmallVector<uint32_t, 0> functionHeader;
384   SmallVector<uint32_t, 0> functionBody;
385 
386   /// Map from type used in SPIR-V module to their <id>s.
387   DenseMap<Type, uint32_t> typeIDMap;
388 
389   /// Map from constant values to their <id>s.
390   DenseMap<Attribute, uint32_t> constIDMap;
391 
392   /// Map from specialization constant names to their <id>s.
393   llvm::StringMap<uint32_t> specConstIDMap;
394 
395   /// Map from GlobalVariableOps name to <id>s.
396   llvm::StringMap<uint32_t> globalVarIDMap;
397 
398   /// Map from FuncOps name to <id>s.
399   llvm::StringMap<uint32_t> funcIDMap;
400 
401   /// Map from blocks to their <id>s.
402   DenseMap<Block *, uint32_t> blockIDMap;
403 
404   /// Map from the Type to the <id> that represents undef value of that type.
405   DenseMap<Type, uint32_t> undefValIDMap;
406 
407   /// Map from results of normal operations to their <id>s.
408   DenseMap<Value, uint32_t> valueIDMap;
409 
410   /// Map from extended instruction set name to <id>s.
411   llvm::StringMap<uint32_t> extendedInstSetIDMap;
412 
413   /// Map from values used in OpPhi instructions to their offset in the
414   /// `functions` section.
415   ///
416   /// When processing a block with arguments, we need to emit OpPhi
417   /// instructions to record the predecessor block <id>s and the values they
418   /// send to the block in question. But it's not guaranteed all values are
419   /// visited and thus assigned result <id>s. So we need this list to capture
420   /// the offsets into `functions` where a value is used so that we can fix it
421   /// up later after processing all the blocks in a function.
422   ///
423   /// More concretely, say if we are visiting the following blocks:
424   ///
425   /// ```mlir
426   /// ^phi(%arg0: i32):
427   ///   ...
428   /// ^parent1:
429   ///   ...
430   ///   spirv.Branch ^phi(%val0: i32)
431   /// ^parent2:
432   ///   ...
433   ///   spirv.Branch ^phi(%val1: i32)
434   /// ```
435   ///
436   /// When we are serializing the `^phi` block, we need to emit at the beginning
437   /// of the block OpPhi instructions which has the following parameters:
438   ///
439   /// OpPhi id-for-i32 id-for-%arg0 id-for-%val0 id-for-^parent1
440   ///                               id-for-%val1 id-for-^parent2
441   ///
442   /// But we don't know the <id> for %val0 and %val1 yet. One way is to visit
443   /// all the blocks twice and use the first visit to assign an <id> to each
444   /// value. But it's paying the overheads just for OpPhi emission. Instead,
445   /// we still visit the blocks once for emission. When we emit the OpPhi
446   /// instructions, we use 0 as a placeholder for the <id>s for %val0 and %val1.
447   /// At the same time, we record their offsets in the emitted binary (which is
448   /// placed inside `functions`) here. And then after emitting all blocks, we
449   /// replace the dummy <id> 0 with the real result <id> by overwriting
450   /// `functions[offset]`.
451   DenseMap<Value, SmallVector<size_t, 1>> deferredPhiValues;
452 };
453 } // namespace spirv
454 } // namespace mlir
455 
456 #endif // MLIR_LIB_TARGET_SPIRV_SERIALIZATION_SERIALIZER_H
457