xref: /llvm-project/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (revision 4f78f8519056953d26102c7426fbb028caf13bc9)
1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
27 #include <cstdint>
28 #include <optional>
29 
30 #define DEBUG_TYPE "spirv-serialization"
31 
32 using namespace mlir;
33 
34 /// Returns the merge block if the given `op` is a structured control flow op.
35 /// Otherwise returns nullptr.
36 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
37   if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38     return selectionOp.getMergeBlock();
39   if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40     return loopOp.getMergeBlock();
41   return nullptr;
42 }
43 
44 /// Given a predecessor `block` for a block with arguments, returns the block
45 /// that should be used as the parent block for SPIR-V OpPhi instructions
46 /// corresponding to the block arguments.
47 static Block *getPhiIncomingBlock(Block *block) {
48   // If the predecessor block in question is the entry block for a
49   // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50   if (block->isEntryBlock()) {
51     if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52       // Then the incoming parent block for OpPhi should be the merge block of
53       // the structured control flow op before this loop.
54       Operation *op = loopOp.getOperation();
55       while ((op = op->getPrevNode()) != nullptr)
56         if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57           return incomingBlock;
58       // Or the enclosing block itself if no structured control flow ops
59       // exists before this loop.
60       return loopOp->getBlock();
61     }
62   }
63 
64   // Otherwise, we jump from the given predecessor block. Try to see if there is
65   // a structured control flow op inside it.
66   for (Operation &op : llvm::reverse(block->getOperations())) {
67     if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68       return incomingBlock;
69   }
70   return block;
71 }
72 
73 namespace mlir {
74 namespace spirv {
75 
76 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
77 /// the given `binary` vector.
78 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
79                            ArrayRef<uint32_t> operands) {
80   uint32_t wordCount = 1 + operands.size();
81   binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
82   binary.append(operands.begin(), operands.end());
83 }
84 
85 Serializer::Serializer(spirv::ModuleOp module,
86                        const SerializationOptions &options)
87     : module(module), mlirBuilder(module.getContext()), options(options) {}
88 
89 LogicalResult Serializer::serialize() {
90   LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
91 
92   if (failed(module.verifyInvariants()))
93     return failure();
94 
95   // TODO: handle the other sections
96   processCapability();
97   processExtension();
98   processMemoryModel();
99   processDebugInfo();
100 
101   // Iterate over the module body to serialize it. Assumptions are that there is
102   // only one basic block in the moduleOp
103   for (auto &op : *module.getBody()) {
104     if (failed(processOperation(&op))) {
105       return failure();
106     }
107   }
108 
109   LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
110   return success();
111 }
112 
113 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
114   auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
115                     extensions.size() + extendedSets.size() +
116                     memoryModel.size() + entryPoints.size() +
117                     executionModes.size() + decorations.size() +
118                     typesGlobalValues.size() + functions.size();
119 
120   binary.clear();
121   binary.reserve(moduleSize);
122 
123   spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
124                             nextID);
125   binary.append(capabilities.begin(), capabilities.end());
126   binary.append(extensions.begin(), extensions.end());
127   binary.append(extendedSets.begin(), extendedSets.end());
128   binary.append(memoryModel.begin(), memoryModel.end());
129   binary.append(entryPoints.begin(), entryPoints.end());
130   binary.append(executionModes.begin(), executionModes.end());
131   binary.append(debug.begin(), debug.end());
132   binary.append(names.begin(), names.end());
133   binary.append(decorations.begin(), decorations.end());
134   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135   binary.append(functions.begin(), functions.end());
136 }
137 
138 #ifndef NDEBUG
139 void Serializer::printValueIDMap(raw_ostream &os) {
140   os << "\n= Value <id> Map =\n\n";
141   for (auto valueIDPair : valueIDMap) {
142     Value val = valueIDPair.first;
143     os << "  " << val << " "
144        << "id = " << valueIDPair.second << ' ';
145     if (auto *op = val.getDefiningOp()) {
146       os << "from op '" << op->getName() << "'";
147     } else if (auto arg = dyn_cast<BlockArgument>(val)) {
148       Block *block = arg.getOwner();
149       os << "from argument of block " << block << ' ';
150       os << " in op '" << block->getParentOp()->getName() << "'";
151     }
152     os << '\n';
153   }
154 }
155 #endif
156 
157 //===----------------------------------------------------------------------===//
158 // Module structure
159 //===----------------------------------------------------------------------===//
160 
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162   auto funcID = funcIDMap.lookup(fnName);
163   if (!funcID) {
164     funcID = getNextID();
165     funcIDMap[fnName] = funcID;
166   }
167   return funcID;
168 }
169 
170 void Serializer::processCapability() {
171   for (auto cap : module.getVceTriple()->getCapabilities())
172     encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
173                           {static_cast<uint32_t>(cap)});
174 }
175 
176 void Serializer::processDebugInfo() {
177   if (!options.emitDebugInfo)
178     return;
179   auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180   auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
181   fileID = getNextID();
182   SmallVector<uint32_t, 16> operands;
183   operands.push_back(fileID);
184   spirv::encodeStringLiteralInto(operands, fileName);
185   encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
186   // TODO: Encode more debug instructions.
187 }
188 
189 void Serializer::processExtension() {
190   llvm::SmallVector<uint32_t, 16> extName;
191   for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
192     extName.clear();
193     spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
194     encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
195   }
196 }
197 
198 void Serializer::processMemoryModel() {
199   StringAttr memoryModelName = module.getMemoryModelAttrName();
200   auto mm = static_cast<uint32_t>(
201       module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
202           .getValue());
203 
204   StringAttr addressingModelName = module.getAddressingModelAttrName();
205   auto am = static_cast<uint32_t>(
206       module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
207           .getValue());
208 
209   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
210 }
211 
212 static std::string getDecorationName(StringRef attrName) {
213   // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
214   // expected FPFastMathMode.
215   if (attrName == "fp_fast_math_mode")
216     return "FPFastMathMode";
217   // similar here
218   if (attrName == "fp_rounding_mode")
219     return "FPRoundingMode";
220   // convertToCamelFromSnakeCase will not capitalize "INTEL".
221   if (attrName == "cache_control_load_intel")
222     return "CacheControlLoadINTEL";
223   if (attrName == "cache_control_store_intel")
224     return "CacheControlStoreINTEL";
225 
226   return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
227 }
228 
229 template <typename AttrTy, typename EmitF>
230 LogicalResult processDecorationList(Location loc, Decoration decoration,
231                                     Attribute attrList, StringRef attrName,
232                                     EmitF emitter) {
233   auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
234   if (!arrayAttr) {
235     return emitError(loc, "expecting array attribute of ")
236            << attrName << " for " << stringifyDecoration(decoration);
237   }
238   if (arrayAttr.empty()) {
239     return emitError(loc, "expecting non-empty array attribute of ")
240            << attrName << " for " << stringifyDecoration(decoration);
241   }
242   for (Attribute attr : arrayAttr.getValue()) {
243     auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244     if (!cacheControlAttr) {
245       return emitError(loc, "expecting array attribute of ")
246              << attrName << " for " << stringifyDecoration(decoration);
247     }
248     // This named attribute encodes several decorations. Emit one per
249     // element in the array.
250     if (failed(emitter(cacheControlAttr)))
251       return failure();
252   }
253   return success();
254 }
255 
256 LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
257                                                 Decoration decoration,
258                                                 Attribute attr) {
259   SmallVector<uint32_t, 1> args;
260   switch (decoration) {
261   case spirv::Decoration::LinkageAttributes: {
262     // Get the value of the Linkage Attributes
263     // e.g., LinkageAttributes=["linkageName", linkageType].
264     auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
265     auto linkageName = linkageAttr.getLinkageName();
266     auto linkageType = linkageAttr.getLinkageType().getValue();
267     // Encode the Linkage Name (string literal to uint32_t).
268     spirv::encodeStringLiteralInto(args, linkageName);
269     // Encode LinkageType & Add the Linkagetype to the args.
270     args.push_back(static_cast<uint32_t>(linkageType));
271     break;
272   }
273   case spirv::Decoration::FPFastMathMode:
274     if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
275       args.push_back(static_cast<uint32_t>(intAttr.getValue()));
276       break;
277     }
278     return emitError(loc, "expected FPFastMathModeAttr attribute for ")
279            << stringifyDecoration(decoration);
280   case spirv::Decoration::FPRoundingMode:
281     if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
282       args.push_back(static_cast<uint32_t>(intAttr.getValue()));
283       break;
284     }
285     return emitError(loc, "expected FPRoundingModeAttr attribute for ")
286            << stringifyDecoration(decoration);
287   case spirv::Decoration::Binding:
288   case spirv::Decoration::DescriptorSet:
289   case spirv::Decoration::Location:
290     if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
291       args.push_back(intAttr.getValue().getZExtValue());
292       break;
293     }
294     return emitError(loc, "expected integer attribute for ")
295            << stringifyDecoration(decoration);
296   case spirv::Decoration::BuiltIn:
297     if (auto strAttr = dyn_cast<StringAttr>(attr)) {
298       auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
299       if (enumVal) {
300         args.push_back(static_cast<uint32_t>(*enumVal));
301         break;
302       }
303       return emitError(loc, "invalid ")
304              << stringifyDecoration(decoration) << " decoration attribute "
305              << strAttr.getValue();
306     }
307     return emitError(loc, "expected string attribute for ")
308            << stringifyDecoration(decoration);
309   case spirv::Decoration::Aliased:
310   case spirv::Decoration::AliasedPointer:
311   case spirv::Decoration::Flat:
312   case spirv::Decoration::NonReadable:
313   case spirv::Decoration::NonWritable:
314   case spirv::Decoration::NoPerspective:
315   case spirv::Decoration::NoSignedWrap:
316   case spirv::Decoration::NoUnsignedWrap:
317   case spirv::Decoration::RelaxedPrecision:
318   case spirv::Decoration::Restrict:
319   case spirv::Decoration::RestrictPointer:
320   case spirv::Decoration::NoContraction:
321   case spirv::Decoration::Constant:
322     // For unit attributes and decoration attributes, the args list
323     // has no values so we do nothing.
324     if (isa<UnitAttr, DecorationAttr>(attr))
325       break;
326     return emitError(loc,
327                      "expected unit attribute or decoration attribute for ")
328            << stringifyDecoration(decoration);
329   case spirv::Decoration::CacheControlLoadINTEL:
330     return processDecorationList<CacheControlLoadINTELAttr>(
331         loc, decoration, attr, "CacheControlLoadINTEL",
332         [&](CacheControlLoadINTELAttr attr) {
333           unsigned cacheLevel = attr.getCacheLevel();
334           LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335           return emitDecoration(
336               resultID, decoration,
337               {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
338         });
339   case spirv::Decoration::CacheControlStoreINTEL:
340     return processDecorationList<CacheControlStoreINTELAttr>(
341         loc, decoration, attr, "CacheControlStoreINTEL",
342         [&](CacheControlStoreINTELAttr attr) {
343           unsigned cacheLevel = attr.getCacheLevel();
344           StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345           return emitDecoration(
346               resultID, decoration,
347               {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
348         });
349   default:
350     return emitError(loc, "unhandled decoration ")
351            << stringifyDecoration(decoration);
352   }
353   return emitDecoration(resultID, decoration, args);
354 }
355 
356 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
357                                             NamedAttribute attr) {
358   StringRef attrName = attr.getName().strref();
359   std::string decorationName = getDecorationName(attrName);
360   std::optional<Decoration> decoration =
361       spirv::symbolizeDecoration(decorationName);
362   if (!decoration) {
363     return emitError(
364                loc, "non-argument attributes expected to have snake-case-ified "
365                     "decoration name, unhandled attribute with name : ")
366            << attrName;
367   }
368   return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
369 }
370 
371 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
372   assert(!name.empty() && "unexpected empty string for OpName");
373   if (!options.emitSymbolName)
374     return success();
375 
376   SmallVector<uint32_t, 4> nameOperands;
377   nameOperands.push_back(resultID);
378   spirv::encodeStringLiteralInto(nameOperands, name);
379   encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
380   return success();
381 }
382 
383 template <>
384 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
385     Location loc, spirv::ArrayType type, uint32_t resultID) {
386   if (unsigned stride = type.getArrayStride()) {
387     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
388     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
389   }
390   return success();
391 }
392 
393 template <>
394 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
395     Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
396   if (unsigned stride = type.getArrayStride()) {
397     // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
398     return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
399   }
400   return success();
401 }
402 
403 LogicalResult Serializer::processMemberDecoration(
404     uint32_t structID,
405     const spirv::StructType::MemberDecorationInfo &memberDecoration) {
406   SmallVector<uint32_t, 4> args(
407       {structID, memberDecoration.memberIndex,
408        static_cast<uint32_t>(memberDecoration.decoration)});
409   if (memberDecoration.hasValue) {
410     args.push_back(memberDecoration.decorationValue);
411   }
412   encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
413   return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // Type
418 //===----------------------------------------------------------------------===//
419 
420 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
421 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
422 // PushConstant Storage Classes must be explicitly laid out."
423 bool Serializer::isInterfaceStructPtrType(Type type) const {
424   if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
425     switch (ptrType.getStorageClass()) {
426     case spirv::StorageClass::PhysicalStorageBuffer:
427     case spirv::StorageClass::PushConstant:
428     case spirv::StorageClass::StorageBuffer:
429     case spirv::StorageClass::Uniform:
430       return isa<spirv::StructType>(ptrType.getPointeeType());
431     default:
432       break;
433     }
434   }
435   return false;
436 }
437 
438 LogicalResult Serializer::processType(Location loc, Type type,
439                                       uint32_t &typeID) {
440   // Maintains a set of names for nested identified struct types. This is used
441   // to properly serialize recursive references.
442   SetVector<StringRef> serializationCtx;
443   return processTypeImpl(loc, type, typeID, serializationCtx);
444 }
445 
446 LogicalResult
447 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
448                             SetVector<StringRef> &serializationCtx) {
449   typeID = getTypeID(type);
450   if (typeID)
451     return success();
452 
453   typeID = getNextID();
454   SmallVector<uint32_t, 4> operands;
455 
456   operands.push_back(typeID);
457   auto typeEnum = spirv::Opcode::OpTypeVoid;
458   bool deferSerialization = false;
459 
460   if ((isa<FunctionType>(type) &&
461        succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
462                                      operands))) ||
463       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
464                                  deferSerialization, serializationCtx))) {
465     if (deferSerialization)
466       return success();
467 
468     typeIDMap[type] = typeID;
469 
470     encodeInstructionInto(typesGlobalValues, typeEnum, operands);
471 
472     if (recursiveStructInfos.count(type) != 0) {
473       // This recursive struct type is emitted already, now the OpTypePointer
474       // instructions referring to recursive references are emitted as well.
475       for (auto &ptrInfo : recursiveStructInfos[type]) {
476         // TODO: This might not work if more than 1 recursive reference is
477         // present in the struct.
478         SmallVector<uint32_t, 4> ptrOperands;
479         ptrOperands.push_back(ptrInfo.pointerTypeID);
480         ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
481         ptrOperands.push_back(typeIDMap[type]);
482 
483         encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
484                               ptrOperands);
485       }
486 
487       recursiveStructInfos[type].clear();
488     }
489 
490     return success();
491   }
492 
493   return failure();
494 }
495 
496 LogicalResult Serializer::prepareBasicType(
497     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
498     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
499     SetVector<StringRef> &serializationCtx) {
500   deferSerialization = false;
501 
502   if (isVoidType(type)) {
503     typeEnum = spirv::Opcode::OpTypeVoid;
504     return success();
505   }
506 
507   if (auto intType = dyn_cast<IntegerType>(type)) {
508     if (intType.getWidth() == 1) {
509       typeEnum = spirv::Opcode::OpTypeBool;
510       return success();
511     }
512 
513     typeEnum = spirv::Opcode::OpTypeInt;
514     operands.push_back(intType.getWidth());
515     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
516     // to preserve or validate.
517     // 0 indicates unsigned, or no signedness semantics
518     // 1 indicates signed semantics."
519     operands.push_back(intType.isSigned() ? 1 : 0);
520     return success();
521   }
522 
523   if (auto floatType = dyn_cast<FloatType>(type)) {
524     typeEnum = spirv::Opcode::OpTypeFloat;
525     operands.push_back(floatType.getWidth());
526     return success();
527   }
528 
529   if (auto vectorType = dyn_cast<VectorType>(type)) {
530     uint32_t elementTypeID = 0;
531     if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
532                                serializationCtx))) {
533       return failure();
534     }
535     typeEnum = spirv::Opcode::OpTypeVector;
536     operands.push_back(elementTypeID);
537     operands.push_back(vectorType.getNumElements());
538     return success();
539   }
540 
541   if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
542     typeEnum = spirv::Opcode::OpTypeImage;
543     uint32_t sampledTypeID = 0;
544     if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
545       return failure();
546 
547     llvm::append_values(operands, sampledTypeID,
548                         static_cast<uint32_t>(imageType.getDim()),
549                         static_cast<uint32_t>(imageType.getDepthInfo()),
550                         static_cast<uint32_t>(imageType.getArrayedInfo()),
551                         static_cast<uint32_t>(imageType.getSamplingInfo()),
552                         static_cast<uint32_t>(imageType.getSamplerUseInfo()),
553                         static_cast<uint32_t>(imageType.getImageFormat()));
554     return success();
555   }
556 
557   if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
558     typeEnum = spirv::Opcode::OpTypeArray;
559     uint32_t elementTypeID = 0;
560     if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
561                                serializationCtx))) {
562       return failure();
563     }
564     operands.push_back(elementTypeID);
565     if (auto elementCountID = prepareConstantInt(
566             loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
567       operands.push_back(elementCountID);
568     }
569     return processTypeDecoration(loc, arrayType, resultID);
570   }
571 
572   if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
573     uint32_t pointeeTypeID = 0;
574     spirv::StructType pointeeStruct =
575         dyn_cast<spirv::StructType>(ptrType.getPointeeType());
576 
577     if (pointeeStruct && pointeeStruct.isIdentified() &&
578         serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
579       // A recursive reference to an enclosing struct is found.
580       //
581       // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
582       // class as operands.
583       SmallVector<uint32_t, 2> forwardPtrOperands;
584       forwardPtrOperands.push_back(resultID);
585       forwardPtrOperands.push_back(
586           static_cast<uint32_t>(ptrType.getStorageClass()));
587 
588       encodeInstructionInto(typesGlobalValues,
589                             spirv::Opcode::OpTypeForwardPointer,
590                             forwardPtrOperands);
591 
592       // 2. Find the pointee (enclosing) struct.
593       auto structType = spirv::StructType::getIdentified(
594           module.getContext(), pointeeStruct.getIdentifier());
595 
596       if (!structType)
597         return failure();
598 
599       // 3. Mark the OpTypePointer that is supposed to be emitted by this call
600       // as deferred.
601       deferSerialization = true;
602 
603       // 4. Record the info needed to emit the deferred OpTypePointer
604       // instruction when the enclosing struct is completely serialized.
605       recursiveStructInfos[structType].push_back(
606           {resultID, ptrType.getStorageClass()});
607     } else {
608       if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
609                                  serializationCtx)))
610         return failure();
611     }
612 
613     typeEnum = spirv::Opcode::OpTypePointer;
614     operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
615     operands.push_back(pointeeTypeID);
616 
617     if (isInterfaceStructPtrType(ptrType)) {
618       if (failed(emitDecoration(getTypeID(pointeeStruct),
619                                 spirv::Decoration::Block)))
620         return emitError(loc, "cannot decorate ")
621                << pointeeStruct << " with Block decoration";
622     }
623 
624     return success();
625   }
626 
627   if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
628     uint32_t elementTypeID = 0;
629     if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
630                                elementTypeID, serializationCtx))) {
631       return failure();
632     }
633     typeEnum = spirv::Opcode::OpTypeRuntimeArray;
634     operands.push_back(elementTypeID);
635     return processTypeDecoration(loc, runtimeArrayType, resultID);
636   }
637 
638   if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
639     typeEnum = spirv::Opcode::OpTypeSampledImage;
640     uint32_t imageTypeID = 0;
641     if (failed(
642             processType(loc, sampledImageType.getImageType(), imageTypeID))) {
643       return failure();
644     }
645     operands.push_back(imageTypeID);
646     return success();
647   }
648 
649   if (auto structType = dyn_cast<spirv::StructType>(type)) {
650     if (structType.isIdentified()) {
651       if (failed(processName(resultID, structType.getIdentifier())))
652         return failure();
653       serializationCtx.insert(structType.getIdentifier());
654     }
655 
656     bool hasOffset = structType.hasOffset();
657     for (auto elementIndex :
658          llvm::seq<uint32_t>(0, structType.getNumElements())) {
659       uint32_t elementTypeID = 0;
660       if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
661                                  elementTypeID, serializationCtx))) {
662         return failure();
663       }
664       operands.push_back(elementTypeID);
665       if (hasOffset) {
666         // Decorate each struct member with an offset
667         spirv::StructType::MemberDecorationInfo offsetDecoration{
668             elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
669             static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
670         if (failed(processMemberDecoration(resultID, offsetDecoration))) {
671           return emitError(loc, "cannot decorate ")
672                  << elementIndex << "-th member of " << structType
673                  << " with its offset";
674         }
675       }
676     }
677     SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
678     structType.getMemberDecorations(memberDecorations);
679 
680     for (auto &memberDecoration : memberDecorations) {
681       if (failed(processMemberDecoration(resultID, memberDecoration))) {
682         return emitError(loc, "cannot decorate ")
683                << static_cast<uint32_t>(memberDecoration.memberIndex)
684                << "-th member of " << structType << " with "
685                << stringifyDecoration(memberDecoration.decoration);
686       }
687     }
688 
689     typeEnum = spirv::Opcode::OpTypeStruct;
690 
691     if (structType.isIdentified())
692       serializationCtx.remove(structType.getIdentifier());
693 
694     return success();
695   }
696 
697   if (auto cooperativeMatrixType =
698           dyn_cast<spirv::CooperativeMatrixType>(type)) {
699     uint32_t elementTypeID = 0;
700     if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
701                                elementTypeID, serializationCtx))) {
702       return failure();
703     }
704     typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
705     auto getConstantOp = [&](uint32_t id) {
706       auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
707       return prepareConstantInt(loc, attr);
708     };
709     llvm::append_values(
710         operands, elementTypeID,
711         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
712         getConstantOp(cooperativeMatrixType.getRows()),
713         getConstantOp(cooperativeMatrixType.getColumns()),
714         getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
715     return success();
716   }
717 
718   if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
719     uint32_t elementTypeID = 0;
720     if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
721                                serializationCtx))) {
722       return failure();
723     }
724     typeEnum = spirv::Opcode::OpTypeMatrix;
725     llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
726     return success();
727   }
728 
729   // TODO: Handle other types.
730   return emitError(loc, "unhandled type in serialization: ") << type;
731 }
732 
733 LogicalResult
734 Serializer::prepareFunctionType(Location loc, FunctionType type,
735                                 spirv::Opcode &typeEnum,
736                                 SmallVectorImpl<uint32_t> &operands) {
737   typeEnum = spirv::Opcode::OpTypeFunction;
738   assert(type.getNumResults() <= 1 &&
739          "serialization supports only a single return value");
740   uint32_t resultID = 0;
741   if (failed(processType(
742           loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
743           resultID))) {
744     return failure();
745   }
746   operands.push_back(resultID);
747   for (auto &res : type.getInputs()) {
748     uint32_t argTypeID = 0;
749     if (failed(processType(loc, res, argTypeID))) {
750       return failure();
751     }
752     operands.push_back(argTypeID);
753   }
754   return success();
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // Constant
759 //===----------------------------------------------------------------------===//
760 
761 uint32_t Serializer::prepareConstant(Location loc, Type constType,
762                                      Attribute valueAttr) {
763   if (auto id = prepareConstantScalar(loc, valueAttr)) {
764     return id;
765   }
766 
767   // This is a composite literal. We need to handle each component separately
768   // and then emit an OpConstantComposite for the whole.
769 
770   if (auto id = getConstantID(valueAttr)) {
771     return id;
772   }
773 
774   uint32_t typeID = 0;
775   if (failed(processType(loc, constType, typeID))) {
776     return 0;
777   }
778 
779   uint32_t resultID = 0;
780   if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
781     int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
782     SmallVector<uint64_t, 4> index(rank);
783     resultID = prepareDenseElementsConstant(loc, constType, attr,
784                                             /*dim=*/0, index);
785   } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
786     resultID = prepareArrayConstant(loc, constType, arrayAttr);
787   }
788 
789   if (resultID == 0) {
790     emitError(loc, "cannot serialize attribute: ") << valueAttr;
791     return 0;
792   }
793 
794   constIDMap[valueAttr] = resultID;
795   return resultID;
796 }
797 
798 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
799                                           ArrayAttr attr) {
800   uint32_t typeID = 0;
801   if (failed(processType(loc, constType, typeID))) {
802     return 0;
803   }
804 
805   uint32_t resultID = getNextID();
806   SmallVector<uint32_t, 4> operands = {typeID, resultID};
807   operands.reserve(attr.size() + 2);
808   auto elementType = cast<spirv::ArrayType>(constType).getElementType();
809   for (Attribute elementAttr : attr) {
810     if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
811       operands.push_back(elementID);
812     } else {
813       return 0;
814     }
815   }
816   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
817   encodeInstructionInto(typesGlobalValues, opcode, operands);
818 
819   return resultID;
820 }
821 
822 // TODO: Turn the below function into iterative function, instead of
823 // recursive function.
824 uint32_t
825 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
826                                          DenseElementsAttr valueAttr, int dim,
827                                          MutableArrayRef<uint64_t> index) {
828   auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
829   assert(dim <= shapedType.getRank());
830   if (shapedType.getRank() == dim) {
831     if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
832       return attr.getType().getElementType().isInteger(1)
833                  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
834                  : prepareConstantInt(loc,
835                                       attr.getValues<IntegerAttr>()[index]);
836     }
837     if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
838       return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
839     }
840     return 0;
841   }
842 
843   uint32_t typeID = 0;
844   if (failed(processType(loc, constType, typeID))) {
845     return 0;
846   }
847 
848   uint32_t resultID = getNextID();
849   SmallVector<uint32_t, 4> operands = {typeID, resultID};
850   operands.reserve(shapedType.getDimSize(dim) + 2);
851   auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
852   for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
853     index[dim] = i;
854     if (auto elementID = prepareDenseElementsConstant(
855             loc, elementType, valueAttr, dim + 1, index)) {
856       operands.push_back(elementID);
857     } else {
858       return 0;
859     }
860   }
861   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
862   encodeInstructionInto(typesGlobalValues, opcode, operands);
863 
864   return resultID;
865 }
866 
867 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
868                                            bool isSpec) {
869   if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
870     return prepareConstantFp(loc, floatAttr, isSpec);
871   }
872   if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
873     return prepareConstantBool(loc, boolAttr, isSpec);
874   }
875   if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
876     return prepareConstantInt(loc, intAttr, isSpec);
877   }
878 
879   return 0;
880 }
881 
882 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
883                                          bool isSpec) {
884   if (!isSpec) {
885     // We can de-duplicate normal constants, but not specialization constants.
886     if (auto id = getConstantID(boolAttr)) {
887       return id;
888     }
889   }
890 
891   // Process the type for this bool literal
892   uint32_t typeID = 0;
893   if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
894     return 0;
895   }
896 
897   auto resultID = getNextID();
898   auto opcode = boolAttr.getValue()
899                     ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
900                               : spirv::Opcode::OpConstantTrue)
901                     : (isSpec ? spirv::Opcode::OpSpecConstantFalse
902                               : spirv::Opcode::OpConstantFalse);
903   encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
904 
905   if (!isSpec) {
906     constIDMap[boolAttr] = resultID;
907   }
908   return resultID;
909 }
910 
911 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
912                                         bool isSpec) {
913   if (!isSpec) {
914     // We can de-duplicate normal constants, but not specialization constants.
915     if (auto id = getConstantID(intAttr)) {
916       return id;
917     }
918   }
919 
920   // Process the type for this integer literal
921   uint32_t typeID = 0;
922   if (failed(processType(loc, intAttr.getType(), typeID))) {
923     return 0;
924   }
925 
926   auto resultID = getNextID();
927   APInt value = intAttr.getValue();
928   unsigned bitwidth = value.getBitWidth();
929   bool isSigned = intAttr.getType().isSignedInteger();
930   auto opcode =
931       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
932 
933   switch (bitwidth) {
934     // According to SPIR-V spec, "When the type's bit width is less than
935     // 32-bits, the literal's value appears in the low-order bits of the word,
936     // and the high-order bits must be 0 for a floating-point type, or 0 for an
937     // integer type with Signedness of 0, or sign extended when Signedness
938     // is 1."
939   case 32:
940   case 16:
941   case 8: {
942     uint32_t word = 0;
943     if (isSigned) {
944       word = static_cast<int32_t>(value.getSExtValue());
945     } else {
946       word = static_cast<uint32_t>(value.getZExtValue());
947     }
948     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
949   } break;
950     // According to SPIR-V spec: "When the type's bit width is larger than one
951     // word, the literal’s low-order words appear first."
952   case 64: {
953     struct DoubleWord {
954       uint32_t word1;
955       uint32_t word2;
956     } words;
957     if (isSigned) {
958       words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
959     } else {
960       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
961     }
962     encodeInstructionInto(typesGlobalValues, opcode,
963                           {typeID, resultID, words.word1, words.word2});
964   } break;
965   default: {
966     std::string valueStr;
967     llvm::raw_string_ostream rss(valueStr);
968     value.print(rss, /*isSigned=*/false);
969 
970     emitError(loc, "cannot serialize ")
971         << bitwidth << "-bit integer literal: " << valueStr;
972     return 0;
973   }
974   }
975 
976   if (!isSpec) {
977     constIDMap[intAttr] = resultID;
978   }
979   return resultID;
980 }
981 
982 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
983                                        bool isSpec) {
984   if (!isSpec) {
985     // We can de-duplicate normal constants, but not specialization constants.
986     if (auto id = getConstantID(floatAttr)) {
987       return id;
988     }
989   }
990 
991   // Process the type for this float literal
992   uint32_t typeID = 0;
993   if (failed(processType(loc, floatAttr.getType(), typeID))) {
994     return 0;
995   }
996 
997   auto resultID = getNextID();
998   APFloat value = floatAttr.getValue();
999   APInt intValue = value.bitcastToAPInt();
1000 
1001   auto opcode =
1002       isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1003 
1004   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1005     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1006     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1007   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1008     struct DoubleWord {
1009       uint32_t word1;
1010       uint32_t word2;
1011     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1012     encodeInstructionInto(typesGlobalValues, opcode,
1013                           {typeID, resultID, words.word1, words.word2});
1014   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1015     uint32_t word =
1016         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1017     encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1018   } else {
1019     std::string valueStr;
1020     llvm::raw_string_ostream rss(valueStr);
1021     value.print(rss);
1022 
1023     emitError(loc, "cannot serialize ")
1024         << floatAttr.getType() << "-typed float literal: " << valueStr;
1025     return 0;
1026   }
1027 
1028   if (!isSpec) {
1029     constIDMap[floatAttr] = resultID;
1030   }
1031   return resultID;
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // Control flow
1036 //===----------------------------------------------------------------------===//
1037 
1038 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1039   if (uint32_t id = getBlockID(block))
1040     return id;
1041   return blockIDMap[block] = getNextID();
1042 }
1043 
1044 #ifndef NDEBUG
1045 void Serializer::printBlock(Block *block, raw_ostream &os) {
1046   os << "block " << block << " (id = ";
1047   if (uint32_t id = getBlockID(block))
1048     os << id;
1049   else
1050     os << "unknown";
1051   os << ")\n";
1052 }
1053 #endif
1054 
1055 LogicalResult
1056 Serializer::processBlock(Block *block, bool omitLabel,
1057                          function_ref<LogicalResult()> emitMerge) {
1058   LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1059   LLVM_DEBUG(block->print(llvm::dbgs()));
1060   LLVM_DEBUG(llvm::dbgs() << '\n');
1061   if (!omitLabel) {
1062     uint32_t blockID = getOrCreateBlockID(block);
1063     LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1064 
1065     // Emit OpLabel for this block.
1066     encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1067   }
1068 
1069   // Emit OpPhi instructions for block arguments, if any.
1070   if (failed(emitPhiForBlockArguments(block)))
1071     return failure();
1072 
1073   // If we need to emit merge instructions, it must happen in this block. Check
1074   // whether we have other structured control flow ops, which will be expanded
1075   // into multiple basic blocks. If that's the case, we need to emit the merge
1076   // right now and then create new blocks for further serialization of the ops
1077   // in this block.
1078   if (emitMerge &&
1079       llvm::any_of(block->getOperations(),
1080                    llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1081     if (failed(emitMerge()))
1082       return failure();
1083     emitMerge = nullptr;
1084 
1085     // Start a new block for further serialization.
1086     uint32_t blockID = getNextID();
1087     encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1088     encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1089   }
1090 
1091   // Process each op in this block except the terminator.
1092   for (Operation &op : llvm::drop_end(*block)) {
1093     if (failed(processOperation(&op)))
1094       return failure();
1095   }
1096 
1097   // Process the terminator.
1098   if (emitMerge)
1099     if (failed(emitMerge()))
1100       return failure();
1101   if (failed(processOperation(&block->back())))
1102     return failure();
1103 
1104   return success();
1105 }
1106 
1107 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1108   // Nothing to do if this block has no arguments or it's the entry block, which
1109   // always has the same arguments as the function signature.
1110   if (block->args_empty() || block->isEntryBlock())
1111     return success();
1112 
1113   LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1114 
1115   // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1116   // A SPIR-V OpPhi instruction is of the syntax:
1117   //   OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1118   // So we need to collect all predecessor blocks and the arguments they send
1119   // to this block.
1120   SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1121   for (Block *mlirPredecessor : block->getPredecessors()) {
1122     auto *terminator = mlirPredecessor->getTerminator();
1123     LLVM_DEBUG(llvm::dbgs() << "  mlir predecessor ");
1124     LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1125     LLVM_DEBUG(llvm::dbgs() << "    terminator: " << *terminator << "\n");
1126     // The predecessor here is the immediate one according to MLIR's IR
1127     // structure. It does not directly map to the incoming parent block for the
1128     // OpPhi instructions at SPIR-V binary level. This is because structured
1129     // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1130     // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1131     // the branch op jumping to the OpPhi's block then resides in the previous
1132     // structured control flow op's merge block.
1133     Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1134     LLVM_DEBUG(llvm::dbgs() << "  spirv predecessor ");
1135     LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1136     if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1137       predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1138     } else if (auto branchCondOp =
1139                    dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1140       std::optional<OperandRange> blockOperands;
1141       if (branchCondOp.getTrueTarget() == block) {
1142         blockOperands = branchCondOp.getTrueTargetOperands();
1143       } else {
1144         assert(branchCondOp.getFalseTarget() == block);
1145         blockOperands = branchCondOp.getFalseTargetOperands();
1146       }
1147 
1148       assert(!blockOperands->empty() &&
1149              "expected non-empty block operand range");
1150       predecessors.emplace_back(spirvPredecessor, *blockOperands);
1151     } else {
1152       return terminator->emitError("unimplemented terminator for Phi creation");
1153     }
1154     LLVM_DEBUG({
1155       llvm::dbgs() << "    block arguments:\n";
1156       for (Value v : predecessors.back().second)
1157         llvm::dbgs() << "      " << v << "\n";
1158     });
1159   }
1160 
1161   // Then create OpPhi instruction for each of the block argument.
1162   for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1163     BlockArgument arg = block->getArgument(argIndex);
1164 
1165     // Get the type <id> and result <id> for this OpPhi instruction.
1166     uint32_t phiTypeID = 0;
1167     if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1168       return failure();
1169     uint32_t phiID = getNextID();
1170 
1171     LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1172                             << arg << " (id = " << phiID << ")\n");
1173 
1174     // Prepare the (value <id>, parent block <id>) pairs.
1175     SmallVector<uint32_t, 8> phiArgs;
1176     phiArgs.push_back(phiTypeID);
1177     phiArgs.push_back(phiID);
1178 
1179     for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1180       Value value = predecessors[predIndex].second[argIndex];
1181       uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1182       LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1183                               << ") value " << value << ' ');
1184       // Each pair is a value <id> ...
1185       uint32_t valueId = getValueID(value);
1186       if (valueId == 0) {
1187         // The op generating this value hasn't been visited yet so we don't have
1188         // an <id> assigned yet. Record this to fix up later.
1189         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1190         deferredPhiValues[value].push_back(functionBody.size() + 1 +
1191                                            phiArgs.size());
1192       } else {
1193         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1194       }
1195       phiArgs.push_back(valueId);
1196       // ... and a parent block <id>.
1197       phiArgs.push_back(predBlockId);
1198     }
1199 
1200     encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1201     valueIDMap[arg] = phiID;
1202   }
1203 
1204   return success();
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // Operation
1209 //===----------------------------------------------------------------------===//
1210 
1211 LogicalResult Serializer::encodeExtensionInstruction(
1212     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1213     ArrayRef<uint32_t> operands) {
1214   // Check if the extension has been imported.
1215   auto &setID = extendedInstSetIDMap[extensionSetName];
1216   if (!setID) {
1217     setID = getNextID();
1218     SmallVector<uint32_t, 16> importOperands;
1219     importOperands.push_back(setID);
1220     spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1221     encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1222                           importOperands);
1223   }
1224 
1225   // The first two operands are the result type <id> and result <id>. The set
1226   // <id> and the opcode need to be insert after this.
1227   if (operands.size() < 2) {
1228     return op->emitError("extended instructions must have a result encoding");
1229   }
1230   SmallVector<uint32_t, 8> extInstOperands;
1231   extInstOperands.reserve(operands.size() + 2);
1232   extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1233   extInstOperands.push_back(setID);
1234   extInstOperands.push_back(extensionOpcode);
1235   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1236   encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1237                         extInstOperands);
1238   return success();
1239 }
1240 
1241 LogicalResult Serializer::processOperation(Operation *opInst) {
1242   LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1243 
1244   // First dispatch the ops that do not directly mirror an instruction from
1245   // the SPIR-V spec.
1246   return TypeSwitch<Operation *, LogicalResult>(opInst)
1247       .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1248       .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1249       .Case([&](spirv::BranchConditionalOp op) {
1250         return processBranchConditionalOp(op);
1251       })
1252       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1253       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1254       .Case([&](spirv::GlobalVariableOp op) {
1255         return processGlobalVariableOp(op);
1256       })
1257       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1258       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1259       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1260       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1261       .Case([&](spirv::SpecConstantCompositeOp op) {
1262         return processSpecConstantCompositeOp(op);
1263       })
1264       .Case([&](spirv::SpecConstantOperationOp op) {
1265         return processSpecConstantOperationOp(op);
1266       })
1267       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1268       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1269 
1270       // Then handle all the ops that directly mirror SPIR-V instructions with
1271       // auto-generated methods.
1272       .Default(
1273           [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1274 }
1275 
1276 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1277                                                       StringRef extInstSet,
1278                                                       uint32_t opcode) {
1279   SmallVector<uint32_t, 4> operands;
1280   Location loc = op->getLoc();
1281 
1282   uint32_t resultID = 0;
1283   if (op->getNumResults() != 0) {
1284     uint32_t resultTypeID = 0;
1285     if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1286       return failure();
1287     operands.push_back(resultTypeID);
1288 
1289     resultID = getNextID();
1290     operands.push_back(resultID);
1291     valueIDMap[op->getResult(0)] = resultID;
1292   };
1293 
1294   for (Value operand : op->getOperands())
1295     operands.push_back(getValueID(operand));
1296 
1297   if (failed(emitDebugLine(functionBody, loc)))
1298     return failure();
1299 
1300   if (extInstSet.empty()) {
1301     encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1302                           operands);
1303   } else {
1304     if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1305       return failure();
1306   }
1307 
1308   if (op->getNumResults() != 0) {
1309     for (auto attr : op->getAttrs()) {
1310       if (failed(processDecoration(loc, resultID, attr)))
1311         return failure();
1312     }
1313   }
1314 
1315   return success();
1316 }
1317 
1318 LogicalResult Serializer::emitDecoration(uint32_t target,
1319                                          spirv::Decoration decoration,
1320                                          ArrayRef<uint32_t> params) {
1321   uint32_t wordCount = 3 + params.size();
1322   llvm::append_values(
1323       decorations,
1324       spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1325       static_cast<uint32_t>(decoration));
1326   llvm::append_range(decorations, params);
1327   return success();
1328 }
1329 
1330 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1331                                         Location loc) {
1332   if (!options.emitDebugInfo)
1333     return success();
1334 
1335   if (lastProcessedWasMergeInst) {
1336     lastProcessedWasMergeInst = false;
1337     return success();
1338   }
1339 
1340   auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1341   if (fileLoc)
1342     encodeInstructionInto(binary, spirv::Opcode::OpLine,
1343                           {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1344   return success();
1345 }
1346 } // namespace spirv
1347 } // namespace mlir
1348