xref: /llvm-project/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp (revision b0746c68629c26567cd123d8f9b28e796ef26f47)
1 //===- SPIRVSerializationGen.cpp - SPIR-V serialization utility generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // SPIRVSerializationGen generates common utility functions for SPIR-V
10 // serialization.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringMap.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31 
32 #include <list>
33 #include <optional>
34 
35 using llvm::ArrayRef;
36 using llvm::cast;
37 using llvm::formatv;
38 using llvm::isa;
39 using llvm::raw_ostream;
40 using llvm::raw_string_ostream;
41 using llvm::Record;
42 using llvm::RecordKeeper;
43 using llvm::SmallVector;
44 using llvm::SMLoc;
45 using llvm::StringMap;
46 using llvm::StringRef;
47 using mlir::tblgen::Attribute;
48 using mlir::tblgen::EnumAttr;
49 using mlir::tblgen::EnumAttrCase;
50 using mlir::tblgen::NamedAttribute;
51 using mlir::tblgen::NamedTypeConstraint;
52 using mlir::tblgen::NamespaceEmitter;
53 using mlir::tblgen::Operator;
54 
55 //===----------------------------------------------------------------------===//
56 // Availability Wrapper Class
57 //===----------------------------------------------------------------------===//
58 
59 namespace {
60 // Wrapper class with helper methods for accessing availability defined in
61 // TableGen.
62 class Availability {
63 public:
64   explicit Availability(const Record *def);
65 
66   // Returns the name of the direct TableGen class for this availability
67   // instance.
68   StringRef getClass() const;
69 
70   // Returns the generated C++ interface's class namespace.
71   StringRef getInterfaceClassNamespace() const;
72 
73   // Returns the generated C++ interface's class name.
74   StringRef getInterfaceClassName() const;
75 
76   // Returns the generated C++ interface's description.
77   StringRef getInterfaceDescription() const;
78 
79   // Returns the name of the query function insided the generated C++ interface.
80   StringRef getQueryFnName() const;
81 
82   // Returns the return type of the query function insided the generated C++
83   // interface.
84   StringRef getQueryFnRetType() const;
85 
86   // Returns the code for merging availability requirements.
87   StringRef getMergeActionCode() const;
88 
89   // Returns the initializer expression for initializing the final availability
90   // requirements.
91   StringRef getMergeInitializer() const;
92 
93   // Returns the C++ type for an availability instance.
94   StringRef getMergeInstanceType() const;
95 
96   // Returns the C++ statements for preparing availability instance.
97   StringRef getMergeInstancePreparation() const;
98 
99   // Returns the concrete availability instance carried in this case.
100   StringRef getMergeInstance() const;
101 
102   // Returns the underlying LLVM TableGen Record.
103   const Record *getDef() const { return def; }
104 
105 private:
106   // The TableGen definition of this availability.
107   const Record *def;
108 };
109 } // namespace
110 
111 Availability::Availability(const Record *def) : def(def) {
112   assert(def->isSubClassOf("Availability") &&
113          "must be subclass of TableGen 'Availability' class");
114 }
115 
116 StringRef Availability::getClass() const {
117   SmallVector<const Record *, 1> parentClass;
118   def->getDirectSuperClasses(parentClass);
119   if (parentClass.size() != 1) {
120     PrintFatalError(def->getLoc(),
121                     "expected to only have one direct superclass");
122   }
123   return parentClass.front()->getName();
124 }
125 
126 StringRef Availability::getInterfaceClassNamespace() const {
127   return def->getValueAsString("cppNamespace");
128 }
129 
130 StringRef Availability::getInterfaceClassName() const {
131   return def->getValueAsString("interfaceName");
132 }
133 
134 StringRef Availability::getInterfaceDescription() const {
135   return def->getValueAsString("interfaceDescription");
136 }
137 
138 StringRef Availability::getQueryFnRetType() const {
139   return def->getValueAsString("queryFnRetType");
140 }
141 
142 StringRef Availability::getQueryFnName() const {
143   return def->getValueAsString("queryFnName");
144 }
145 
146 StringRef Availability::getMergeActionCode() const {
147   return def->getValueAsString("mergeAction");
148 }
149 
150 StringRef Availability::getMergeInitializer() const {
151   return def->getValueAsString("initializer");
152 }
153 
154 StringRef Availability::getMergeInstanceType() const {
155   return def->getValueAsString("instanceType");
156 }
157 
158 StringRef Availability::getMergeInstancePreparation() const {
159   return def->getValueAsString("instancePreparation");
160 }
161 
162 StringRef Availability::getMergeInstance() const {
163   return def->getValueAsString("instance");
164 }
165 
166 // Returns the availability spec of the given `def`.
167 std::vector<Availability> getAvailabilities(const Record &def) {
168   std::vector<Availability> availabilities;
169 
170   if (def.getValue("availability")) {
171     std::vector<const Record *> availDefs =
172         def.getValueAsListOfDefs("availability");
173     availabilities.reserve(availDefs.size());
174     for (const Record *avail : availDefs)
175       availabilities.emplace_back(avail);
176   }
177 
178   return availabilities;
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // Availability Interface Definitions AutoGen
183 //===----------------------------------------------------------------------===//
184 
185 static void emitInterfaceDef(const Availability &availability,
186                              raw_ostream &os) {
187 
188   os << availability.getQueryFnRetType() << " ";
189 
190   StringRef cppNamespace = availability.getInterfaceClassNamespace();
191   cppNamespace.consume_front("::");
192   if (!cppNamespace.empty())
193     os << cppNamespace << "::";
194 
195   StringRef methodName = availability.getQueryFnName();
196   os << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
197      << "  return getImpl()->" << methodName << "(getImpl(), getOperation());\n"
198      << "}\n";
199 }
200 
201 static bool emitInterfaceDefs(const RecordKeeper &records, raw_ostream &os) {
202   llvm::emitSourceFileHeader("Availability Interface Definitions", os, records);
203 
204   auto defs = records.getAllDerivedDefinitions("Availability");
205   SmallVector<const Record *, 1> handledClasses;
206   for (const Record *def : defs) {
207     SmallVector<const Record *, 1> parent;
208     def->getDirectSuperClasses(parent);
209     if (parent.size() != 1) {
210       PrintFatalError(def->getLoc(),
211                       "expected to only have one direct superclass");
212     }
213     if (llvm::is_contained(handledClasses, parent.front()))
214       continue;
215 
216     Availability availability(def);
217     emitInterfaceDef(availability, os);
218     handledClasses.push_back(parent.front());
219   }
220   return false;
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // Availability Interface Declarations AutoGen
225 //===----------------------------------------------------------------------===//
226 
227 static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
228   os << "  class Concept {\n"
229      << "  public:\n"
230      << "    virtual ~Concept() = default;\n"
231      << "    virtual " << availability.getQueryFnRetType() << " "
232      << availability.getQueryFnName()
233      << "(const Concept *impl, Operation *tblgen_opaque_op) const = 0;\n"
234      << "  };\n";
235 }
236 
237 static void emitModelDecl(const Availability &availability, raw_ostream &os) {
238   for (const char *modelClass : {"Model", "FallbackModel"}) {
239     os << "  template<typename ConcreteOp>\n";
240     os << "  class " << modelClass << " : public Concept {\n"
241        << "  public:\n"
242        << "    using Interface = " << availability.getInterfaceClassName()
243        << ";\n"
244        << "    " << availability.getQueryFnRetType() << " "
245        << availability.getQueryFnName()
246        << "(const Concept *impl, Operation *tblgen_opaque_op) const final {\n"
247        << "      auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
248        << "      (void)op;\n"
249        // Forward to the method on the concrete operation type.
250        << "      return op." << availability.getQueryFnName() << "();\n"
251        << "    }\n"
252        << "  };\n";
253   }
254   os << "  template<typename ConcreteModel, typename ConcreteOp>\n";
255   os << "  class ExternalModel : public FallbackModel<ConcreteOp> {};\n";
256 }
257 
258 static void emitInterfaceDecl(const Availability &availability,
259                               raw_ostream &os) {
260   StringRef interfaceName = availability.getInterfaceClassName();
261   std::string interfaceTraitsName =
262       std::string(formatv("{0}Traits", interfaceName));
263 
264   StringRef cppNamespace = availability.getInterfaceClassNamespace();
265   NamespaceEmitter nsEmitter(os, cppNamespace);
266   os << "class " << interfaceName << ";\n\n";
267 
268   // Emit the traits struct containing the concept and model declarations.
269   os << "namespace detail {\n"
270      << "struct " << interfaceTraitsName << " {\n";
271   emitConceptDecl(availability, os);
272   os << '\n';
273   emitModelDecl(availability, os);
274   os << "};\n} // namespace detail\n\n";
275 
276   // Emit the main interface class declaration.
277   os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
278   os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
279                       "public:\n"
280                       "  using OpInterface<{1}, detail::{2}>::OpInterface;\n",
281                       interfaceName, interfaceName, interfaceTraitsName);
282 
283   // Emit query function declaration.
284   os << "  " << availability.getQueryFnRetType() << " "
285      << availability.getQueryFnName() << "();\n";
286   os << "};\n\n";
287 }
288 
289 static bool emitInterfaceDecls(const RecordKeeper &records, raw_ostream &os) {
290   llvm::emitSourceFileHeader("Availability Interface Declarations", os,
291                              records);
292 
293   auto defs = records.getAllDerivedDefinitions("Availability");
294   SmallVector<const Record *, 4> handledClasses;
295   for (const Record *def : defs) {
296     SmallVector<const Record *, 1> parent;
297     def->getDirectSuperClasses(parent);
298     if (parent.size() != 1) {
299       PrintFatalError(def->getLoc(),
300                       "expected to only have one direct superclass");
301     }
302     if (llvm::is_contained(handledClasses, parent.front()))
303       continue;
304 
305     Availability avail(def);
306     emitInterfaceDecl(avail, os);
307     handledClasses.push_back(parent.front());
308   }
309   return false;
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // Availability Interface Hook Registration
314 //===----------------------------------------------------------------------===//
315 
316 // Registers the operation interface generator to mlir-tblgen.
317 static mlir::GenRegistration
318     genInterfaceDecls("gen-avail-interface-decls",
319                       "Generate availability interface declarations",
320                       [](const RecordKeeper &records, raw_ostream &os) {
321                         return emitInterfaceDecls(records, os);
322                       });
323 
324 // Registers the operation interface generator to mlir-tblgen.
325 static mlir::GenRegistration
326     genInterfaceDefs("gen-avail-interface-defs",
327                      "Generate op interface definitions",
328                      [](const RecordKeeper &records, raw_ostream &os) {
329                        return emitInterfaceDefs(records, os);
330                      });
331 
332 //===----------------------------------------------------------------------===//
333 // Enum Availability Query AutoGen
334 //===----------------------------------------------------------------------===//
335 
336 static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
337                                             raw_ostream &os) {
338   EnumAttr enumAttr(enumDef);
339   StringRef enumName = enumAttr.getEnumClassName();
340   std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
341 
342   // Mapping from availability class name to (enumerant, availability
343   // specification) pairs.
344   llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
345       classCaseMap;
346 
347   // Place all availability specifications to their corresponding
348   // availability classes.
349   for (const EnumAttrCase &enumerant : enumerants)
350     for (const Availability &avail : getAvailabilities(enumerant.getDef()))
351       classCaseMap[avail.getClass()].push_back({enumerant, avail});
352 
353   for (const auto &classCasePair : classCaseMap) {
354     Availability avail = classCasePair.getValue().front().second;
355 
356     os << formatv("std::optional<{0}> {1}({2} value) {{\n",
357                   avail.getMergeInstanceType(), avail.getQueryFnName(),
358                   enumName);
359 
360     os << "  switch (value) {\n";
361     for (const auto &caseSpecPair : classCasePair.getValue()) {
362       EnumAttrCase enumerant = caseSpecPair.first;
363       Availability avail = caseSpecPair.second;
364       os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
365                     enumerant.getSymbol(), avail.getMergeInstancePreparation(),
366                     avail.getMergeInstanceType(), avail.getMergeInstance());
367     }
368     // Only emit default if uncovered cases.
369     if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
370       os << "  default: break;\n";
371     os << "  }\n"
372        << "  return std::nullopt;\n"
373        << "}\n";
374   }
375 }
376 
377 static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
378                                             raw_ostream &os) {
379   EnumAttr enumAttr(enumDef);
380   StringRef enumName = enumAttr.getEnumClassName();
381   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
382   std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
383 
384   // Mapping from availability class name to (enumerant, availability
385   // specification) pairs.
386   llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
387       classCaseMap;
388 
389   // Place all availability specifications to their corresponding
390   // availability classes.
391   for (const EnumAttrCase &enumerant : enumerants)
392     for (const Availability &avail : getAvailabilities(enumerant.getDef()))
393       classCaseMap[avail.getClass()].push_back({enumerant, avail});
394 
395   for (const auto &classCasePair : classCaseMap) {
396     Availability avail = classCasePair.getValue().front().second;
397 
398     os << formatv("std::optional<{0}> {1}({2} value) {{\n",
399                   avail.getMergeInstanceType(), avail.getQueryFnName(),
400                   enumName);
401 
402     os << formatv(
403         "  assert(::llvm::popcount(static_cast<{0}>(value)) <= 1"
404         " && \"cannot have more than one bit set\");\n",
405         underlyingType);
406 
407     os << "  switch (value) {\n";
408     for (const auto &caseSpecPair : classCasePair.getValue()) {
409       EnumAttrCase enumerant = caseSpecPair.first;
410       Availability avail = caseSpecPair.second;
411       os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
412                     enumerant.getSymbol(), avail.getMergeInstancePreparation(),
413                     avail.getMergeInstanceType(), avail.getMergeInstance());
414     }
415     os << "  default: break;\n";
416     os << "  }\n"
417        << "  return std::nullopt;\n"
418        << "}\n";
419   }
420 }
421 
422 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
423   EnumAttr enumAttr(enumDef);
424   StringRef enumName = enumAttr.getEnumClassName();
425   StringRef cppNamespace = enumAttr.getCppNamespace();
426   auto enumerants = enumAttr.getAllCases();
427 
428   llvm::SmallVector<StringRef, 2> namespaces;
429   llvm::SplitString(cppNamespace, namespaces, "::");
430 
431   for (auto ns : namespaces)
432     os << "namespace " << ns << " {\n";
433 
434   llvm::StringSet<> handledClasses;
435 
436   // Place all availability specifications to their corresponding
437   // availability classes.
438   for (const EnumAttrCase &enumerant : enumerants)
439     for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
440       StringRef className = avail.getClass();
441       if (handledClasses.count(className))
442         continue;
443       os << formatv("std::optional<{0}> {1}({2} value);\n",
444                     avail.getMergeInstanceType(), avail.getQueryFnName(),
445                     enumName);
446       handledClasses.insert(className);
447     }
448 
449   for (auto ns : llvm::reverse(namespaces))
450     os << "} // namespace " << ns << "\n";
451 }
452 
453 static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
454   llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os,
455                              records);
456 
457   auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
458   for (const auto *def : defs)
459     emitEnumDecl(*def, os);
460 
461   return false;
462 }
463 
464 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
465   EnumAttr enumAttr(enumDef);
466   StringRef cppNamespace = enumAttr.getCppNamespace();
467 
468   llvm::SmallVector<StringRef, 2> namespaces;
469   llvm::SplitString(cppNamespace, namespaces, "::");
470 
471   for (auto ns : namespaces)
472     os << "namespace " << ns << " {\n";
473 
474   if (enumAttr.isBitEnum()) {
475     emitAvailabilityQueryForBitEnum(enumDef, os);
476   } else {
477     emitAvailabilityQueryForIntEnum(enumDef, os);
478   }
479 
480   for (auto ns : llvm::reverse(namespaces))
481     os << "} // namespace " << ns << "\n";
482   os << "\n";
483 }
484 
485 static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
486   llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os,
487                              records);
488 
489   auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
490   for (const auto *def : defs)
491     emitEnumDef(*def, os);
492 
493   return false;
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // Enum Availability Query Hook Registration
498 //===----------------------------------------------------------------------===//
499 
500 // Registers the enum utility generator to mlir-tblgen.
501 static mlir::GenRegistration
502     genEnumDecls("gen-spirv-enum-avail-decls",
503                  "Generate SPIR-V enum availability declarations",
504                  [](const RecordKeeper &records, raw_ostream &os) {
505                    return emitEnumDecls(records, os);
506                  });
507 
508 // Registers the enum utility generator to mlir-tblgen.
509 static mlir::GenRegistration
510     genEnumDefs("gen-spirv-enum-avail-defs",
511                 "Generate SPIR-V enum availability definitions",
512                 [](const RecordKeeper &records, raw_ostream &os) {
513                   return emitEnumDefs(records, os);
514                 });
515 
516 //===----------------------------------------------------------------------===//
517 // Serialization AutoGen
518 //===----------------------------------------------------------------------===//
519 
520 // These enums are encoded as <id> to constant values in SPIR-V blob, but we
521 // directly use the constant value as attribute in SPIR-V dialect. So need
522 // to handle them separately from normal enum attributes.
523 constexpr llvm::StringLiteral constantIdEnumAttrs[] = {
524     "SPIRV_ScopeAttr", "SPIRV_KHR_CooperativeMatrixUseAttr",
525     "SPIRV_KHR_CooperativeMatrixLayoutAttr", "SPIRV_MemorySemanticsAttr",
526     "SPIRV_MatrixLayoutAttr"};
527 
528 /// Generates code to serialize attributes of a SPIRV_Op `op` into `os`. The
529 /// generates code extracts the attribute with name `attrName` from
530 /// `operandList` of `op`.
531 static void emitAttributeSerialization(const Attribute &attr,
532                                        ArrayRef<SMLoc> loc, StringRef tabs,
533                                        StringRef opVar, StringRef operandList,
534                                        StringRef attrName, raw_ostream &os) {
535   os << tabs
536      << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
537   if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
538     EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
539     os << tabs
540        << formatv("  {0}.push_back(prepareConstantInt({1}.getLoc(), "
541                   "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
542                   "::llvm::cast<{2}::{3}Attr>(attr).getValue()))));\n",
543                   operandList, opVar, baseEnum.getCppNamespace(),
544                   baseEnum.getEnumClassName());
545   } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
546              attr.isSubClassOf("SPIRV_I32EnumAttr")) {
547     EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
548     os << tabs
549        << formatv("  {0}.push_back(static_cast<uint32_t>("
550                   "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
551                   operandList, baseEnum.getCppNamespace(),
552                   baseEnum.getEnumClassName());
553   } else if (attr.getAttrDefName() == "I32ArrayAttr") {
554     // Serialize all the elements of the array
555     os << tabs << "  for (auto attrElem : llvm::cast<ArrayAttr>(attr)) {\n";
556     os << tabs
557        << formatv("    {0}.push_back(static_cast<uint32_t>("
558                   "llvm::cast<IntegerAttr>(attrElem).getValue().getZExtValue())"
559                   ");\n",
560                   operandList);
561     os << tabs << "  }\n";
562   } else if (attr.getAttrDefName() == "I32Attr") {
563     os << tabs
564        << formatv(
565               "  {0}.push_back(static_cast<uint32_t>("
566               "llvm::cast<IntegerAttr>(attr).getValue().getZExtValue()));\n",
567               operandList);
568   } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
569     // It may be the first time this type appears in the IR, so we need to
570     // process it.
571     StringRef attrTypeID = "attrTypeID";
572     os << tabs << formatv("  uint32_t {0} = 0;\n", attrTypeID);
573     os << tabs
574        << formatv("  if (failed(processType({0}.getLoc(), "
575                   "llvm::cast<TypeAttr>(attr).getValue(), {1}))) {{\n",
576                   opVar, attrTypeID);
577     os << tabs << "    return failure();\n";
578     os << tabs << "  }\n";
579     os << tabs << formatv("  {0}.push_back(attrTypeID);\n", operandList);
580   } else {
581     PrintFatalError(
582         loc,
583         llvm::Twine(
584             "unhandled attribute type in SPIR-V serialization generation : '") +
585             attr.getAttrDefName() + llvm::Twine("'"));
586   }
587   os << tabs << "}\n";
588 }
589 
590 /// Generates code to serialize the operands of a SPIRV_Op `op` into `os`. The
591 /// generated queries the SSA-ID if operand is a SSA-Value, or serializes the
592 /// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
593 /// updated as well to include the serialized attributes.
594 static void emitArgumentSerialization(const Operator &op, ArrayRef<SMLoc> loc,
595                                       StringRef tabs, StringRef opVar,
596                                       StringRef operands, StringRef elidedAttrs,
597                                       raw_ostream &os) {
598   using mlir::tblgen::Argument;
599 
600   // SPIR-V ops can mix operands and attributes in the definition. These
601   // operands and attributes are serialized in the exact order of the definition
602   // to match SPIR-V binary format requirements. It can cause excessive
603   // generated code bloat because we are emitting code to handle each
604   // operand/attribute separately. So here we probe first to check whether all
605   // the operands are ahead of attributes. Then we can serialize all operands
606   // together.
607 
608   // Whether all operands are ahead of all attributes in the op's spec.
609   bool areOperandsAheadOfAttrs = true;
610   // Find the first attribute.
611   const Argument *it = llvm::find_if(op.getArgs(), [](const Argument &arg) {
612     return isa<NamedAttribute *>(arg);
613   });
614   // Check whether all following arguments are attributes.
615   for (const Argument *ie = op.arg_end(); it != ie; ++it) {
616     if (!isa<NamedAttribute *>(*it)) {
617       areOperandsAheadOfAttrs = false;
618       break;
619     }
620   }
621 
622   // Serialize all operands together.
623   if (areOperandsAheadOfAttrs) {
624     if (op.getNumOperands() != 0) {
625       os << tabs
626          << formatv("for (Value operand : {0}->getOperands()) {{\n", opVar);
627       os << tabs << "  auto id = getValueID(operand);\n";
628       os << tabs << "  assert(id && \"use before def!\");\n";
629       os << tabs << formatv("  {0}.push_back(id);\n", operands);
630       os << tabs << "}\n";
631     }
632     for (const NamedAttribute &attr : op.getAttributes()) {
633       emitAttributeSerialization(
634           (attr.attr.isOptional() ? attr.attr.getBaseAttr() : attr.attr), loc,
635           tabs, opVar, operands, attr.name, os);
636       os << tabs
637          << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr.name);
638     }
639     return;
640   }
641 
642   // Serialize operands separately.
643   auto operandNum = 0;
644   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
645     auto argument = op.getArg(i);
646     os << tabs << "{\n";
647     if (isa<NamedTypeConstraint *>(argument)) {
648       os << tabs
649          << formatv("  for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
650                     operandNum);
651       os << tabs << "    auto argID = getValueID(arg);\n";
652       os << tabs << "    if (!argID) {\n";
653       os << tabs
654          << formatv("      return emitError({0}.getLoc(), "
655                     "\"operand #{1} has a use before def\");\n",
656                     opVar, operandNum);
657       os << tabs << "    }\n";
658       os << tabs << formatv("    {0}.push_back(argID);\n", operands);
659       os << "    }\n";
660       operandNum++;
661     } else {
662       NamedAttribute *attr = cast<NamedAttribute *>(argument);
663       auto newtabs = tabs.str() + "  ";
664       emitAttributeSerialization(
665           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
666           loc, newtabs, opVar, operands, attr->name, os);
667       os << newtabs
668          << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
669     }
670     os << tabs << "}\n";
671   }
672 }
673 
674 /// Generates code to serializes the result of SPIRV_Op `op` into `os`. The
675 /// generated gets the ID for the type of the result (if any), the SSA-ID of
676 /// the result and updates `resultID` with the SSA-ID.
677 static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
678                                     StringRef tabs, StringRef opVar,
679                                     StringRef operands, StringRef resultID,
680                                     raw_ostream &os) {
681   if (op.getNumResults() == 1) {
682     StringRef resultTypeID("resultTypeID");
683     os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
684     os << tabs
685        << formatv(
686               "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
687               opVar, resultTypeID);
688     os << tabs << "  return failure();\n";
689     os << tabs << "}\n";
690     os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
691     // Create an SSA result <id> for the op
692     os << tabs << formatv("{0} = getNextID();\n", resultID);
693     os << tabs
694        << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
695     os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
696   } else if (op.getNumResults() != 0) {
697     PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
698   }
699 }
700 
701 /// Generates code to serialize attributes of SPIRV_Op `op` that become
702 /// decorations on the `resultID` of the serialized operation `opVar` in the
703 /// SPIR-V binary.
704 static void emitDecorationSerialization(const Operator &op, StringRef tabs,
705                                         StringRef opVar, StringRef elidedAttrs,
706                                         StringRef resultID, raw_ostream &os) {
707   if (op.getNumResults() == 1) {
708     // All non-argument attributes translated into OpDecorate instruction
709     os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar);
710     os << tabs
711        << formatv("  if (llvm::is_contained({0}, attr.getName())) {{",
712                   elidedAttrs);
713     os << tabs << "    continue;\n";
714     os << tabs << "  }\n";
715     os << tabs
716        << formatv(
717               "  if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
718               opVar, resultID);
719     os << tabs << "    return failure();\n";
720     os << tabs << "  }\n";
721     os << tabs << "}\n";
722   }
723 }
724 
725 /// Generates code to serialize an SPIRV_Op `op` into `os`.
726 static void emitSerializationFunction(const Record *attrClass,
727                                       const Record *record, const Operator &op,
728                                       raw_ostream &os) {
729   // If the record has 'autogenSerialization' set to 0, nothing to do
730   if (!record->getValueAsBit("autogenSerialization"))
731     return;
732 
733   StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
734       resultID("resultID");
735 
736   os << formatv(
737       "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
738       op.getQualCppClassName(), opVar);
739 
740   // Special case for ops without attributes in TableGen definitions
741   if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
742     std::string extInstSet;
743     std::string opcode;
744     if (record->isSubClassOf("SPIRV_ExtInstOp")) {
745       extInstSet =
746           formatv("\"{0}\"", record->getValueAsString("extendedInstSetName"));
747       opcode = std::to_string(record->getValueAsInt("extendedInstOpcode"));
748     } else {
749       extInstSet = "\"\"";
750       opcode = formatv("static_cast<uint32_t>(spirv::Opcode::{0})",
751                        record->getValueAsString("spirvOpName"));
752     }
753 
754     os << formatv("  return processOpWithoutGrammarAttr({0}, {1}, {2});\n}\n\n",
755                   opVar, extInstSet, opcode);
756     return;
757   }
758 
759   os << formatv("  SmallVector<uint32_t, 4> {0};\n", operands);
760   os << formatv("  SmallVector<StringRef, 2> {0};\n", elidedAttrs);
761 
762   // Serialize result information.
763   if (op.getNumResults() == 1) {
764     os << formatv("  uint32_t {0} = 0;\n", resultID);
765     emitResultSerialization(op, record->getLoc(), "  ", opVar, operands,
766                             resultID, os);
767   }
768 
769   // Process arguments.
770   emitArgumentSerialization(op, record->getLoc(), "  ", opVar, operands,
771                             elidedAttrs, os);
772 
773   if (record->isSubClassOf("SPIRV_ExtInstOp")) {
774     os << formatv(
775         "  (void)encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n", opVar,
776         record->getValueAsString("extendedInstSetName"),
777         record->getValueAsInt("extendedInstOpcode"), operands);
778   } else {
779     // Emit debug info.
780     os << formatv("  (void)emitDebugLine(functionBody, {0}.getLoc());\n",
781                   opVar);
782     os << formatv("  (void)encodeInstructionInto("
783                   "functionBody, spirv::Opcode::{0}, {1});\n",
784                   record->getValueAsString("spirvOpName"), operands);
785   }
786 
787   // Process decorations.
788   emitDecorationSerialization(op, "  ", opVar, elidedAttrs, resultID, os);
789 
790   os << "  return success();\n";
791   os << "}\n\n";
792 }
793 
794 /// Generates the prologue for the function that dispatches the serialization of
795 /// the operation `opVar` based on its opcode.
796 static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
797   os << formatv(
798       "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
799       "*{0}) {{\n",
800       opVar);
801 }
802 
803 /// Generates the body of the dispatch function. This function generates the
804 /// check that if satisfied, will call the serialization function generated for
805 /// the `op`.
806 static void emitSerializationDispatch(const Operator &op, StringRef tabs,
807                                       StringRef opVar, raw_ostream &os) {
808   os << tabs
809      << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
810   os << tabs
811      << formatv("  return processOp(cast<{0}>({1}));\n",
812                 op.getQualCppClassName(), opVar);
813   os << tabs << "}\n";
814 }
815 
816 /// Generates the epilogue for the function that dispatches the serialization of
817 /// the operation.
818 static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
819   os << formatv(
820       "  return {0}->emitError(\"unhandled operation serialization\");\n",
821       opVar);
822   os << "}\n\n";
823 }
824 
825 /// Generates code to deserialize the attribute of a SPIRV_Op into `os`. The
826 /// generated code reads the `words` of the serialized instruction at
827 /// position `wordIndex` and adds the deserialized attribute into `attrList`.
828 static void emitAttributeDeserialization(const Attribute &attr,
829                                          ArrayRef<SMLoc> loc, StringRef tabs,
830                                          StringRef attrList, StringRef attrName,
831                                          StringRef words, StringRef wordIndex,
832                                          raw_ostream &os) {
833   if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
834     EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
835     os << tabs
836        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
837                   "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
838                   "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
839                   attrList, attrName, baseEnum.getCppNamespace(),
840                   baseEnum.getEnumClassName(), words, wordIndex);
841   } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
842              attr.isSubClassOf("SPIRV_I32EnumAttr")) {
843     EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
844     os << tabs
845        << formatv("  {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
846                   "opBuilder.getAttr<{2}::{3}Attr>("
847                   "static_cast<{2}::{3}>({4}[{5}++]))));\n",
848                   attrList, attrName, baseEnum.getCppNamespace(),
849                   baseEnum.getEnumClassName(), words, wordIndex);
850   } else if (attr.getAttrDefName() == "I32ArrayAttr") {
851     os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
852     os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
853     os << tabs
854        << formatv(
855               "  "
856               "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
857               ";\n",
858               words, wordIndex);
859     os << tabs << "}\n";
860     os << tabs
861        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
862                   "opBuilder.getArrayAttr(attrListElems)));\n",
863                   attrList, attrName);
864   } else if (attr.getAttrDefName() == "I32Attr") {
865     os << tabs
866        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
867                   "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
868                   attrList, attrName, words, wordIndex);
869   } else if (attr.isEnumAttr() || attr.isTypeAttr()) {
870     os << tabs
871        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
872                   "TypeAttr::get(getType({2}[{3}++]))));\n",
873                   attrList, attrName, words, wordIndex);
874   } else {
875     PrintFatalError(
876         loc, llvm::Twine(
877                  "unhandled attribute type in deserialization generation : '") +
878                  attrName + llvm::Twine("'"));
879   }
880 }
881 
882 /// Generates the code to deserialize the result of an SPIRV_Op `op` into
883 /// `os`. The generated code gets the type of the result specified at
884 /// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
885 /// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
886 /// respectively.
887 static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
888                                       StringRef tabs, StringRef words,
889                                       StringRef wordIndex,
890                                       StringRef resultTypes, StringRef valueID,
891                                       raw_ostream &os) {
892   // Deserialize result information if it exists
893   if (op.getNumResults() == 1) {
894     os << tabs << "{\n";
895     os << tabs << formatv("  if ({0} >= {1}.size()) {{\n", wordIndex, words);
896     os << tabs
897        << formatv(
898               "    return emitError(unknownLoc, \"expected result type <id> "
899               "while deserializing {0}\");\n",
900               op.getQualCppClassName());
901     os << tabs << "  }\n";
902     os << tabs << formatv("  auto ty = getType({0}[{1}]);\n", words, wordIndex);
903     os << tabs << "  if (!ty) {\n";
904     os << tabs
905        << formatv(
906               "    return emitError(unknownLoc, \"unknown type result <id> : "
907               "\") << {0}[{1}];\n",
908               words, wordIndex);
909     os << tabs << "  }\n";
910     os << tabs << formatv("  {0}.push_back(ty);\n", resultTypes);
911     os << tabs << formatv("  {0}++;\n", wordIndex);
912     os << tabs << formatv("  if ({0} >= {1}.size()) {{\n", wordIndex, words);
913     os << tabs
914        << formatv(
915               "    return emitError(unknownLoc, \"expected result <id> while "
916               "deserializing {0}\");\n",
917               op.getQualCppClassName());
918     os << tabs << "  }\n";
919     os << tabs << "}\n";
920     os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
921   } else if (op.getNumResults() != 0) {
922     PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
923   }
924 }
925 
926 /// Generates the code to deserialize the operands of an SPIRV_Op `op` into
927 /// `os`. The generated code reads the `words` of the binary instruction, from
928 /// position `wordIndex` to the end, and either gets the Value corresponding to
929 /// the ID encoded, or deserializes the attributes encoded. The parsed
930 /// operand(attribute) is added to the `operands` list or `attributes` list.
931 static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
932                                        StringRef tabs, StringRef words,
933                                        StringRef wordIndex, StringRef operands,
934                                        StringRef attributes, raw_ostream &os) {
935   // Process operands/attributes
936   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
937     auto argument = op.getArg(i);
938     if (auto *valueArg = llvm::dyn_cast_if_present<NamedTypeConstraint *>(argument)) {
939       if (valueArg->isVariableLength()) {
940         if (i != e - 1) {
941           PrintFatalError(
942               loc, "SPIR-V ops can have Variadic<..> or "
943                    "Optional<...> arguments only if it's the last argument");
944         }
945         os << tabs
946            << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
947       } else {
948         os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
949       }
950       os << " {\n";
951       os << tabs
952          << formatv("  auto arg = getValue({0}[{1}]);\n", words, wordIndex);
953       os << tabs << "  if (!arg) {\n";
954       os << tabs
955          << formatv(
956                 "    return emitError(unknownLoc, \"unknown result <id> : \") "
957                 "<< {0}[{1}];\n",
958                 words, wordIndex);
959       os << tabs << "  }\n";
960       os << tabs << formatv("  {0}.push_back(arg);\n", operands);
961       if (!valueArg->isVariableLength()) {
962         os << tabs << formatv("  {0}++;\n", wordIndex);
963       }
964       os << tabs << "}\n";
965     } else {
966       os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
967       auto *attr = cast<NamedAttribute *>(argument);
968       auto newtabs = tabs.str() + "  ";
969       emitAttributeDeserialization(
970           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
971           loc, newtabs, attributes, attr->name, words, wordIndex, os);
972       os << "  }\n";
973     }
974   }
975 
976   os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
977   os << tabs
978      << formatv(
979             "  return emitError(unknownLoc, \"found more operands than "
980             "expected when deserializing {0}, only \") << {1} << \" of \" << "
981             "{2}.size() << \" processed\";\n",
982             op.getQualCppClassName(), wordIndex, words);
983   os << tabs << "}\n\n";
984 }
985 
986 /// Generates code to update the `attributes` vector with the attributes
987 /// obtained from parsing the decorations in the SPIR-V binary associated with
988 /// an <id> `valueID`
989 static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
990                                           StringRef valueID,
991                                           StringRef attributes,
992                                           raw_ostream &os) {
993   // Import decorations parsed
994   if (op.getNumResults() == 1) {
995     os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
996     os << tabs
997        << formatv("  auto attrs = decorations[{0}].getAttrs();\n", valueID);
998     os << tabs
999        << formatv("  {0}.append(attrs.begin(), attrs.end());\n", attributes);
1000     os << tabs << "}\n";
1001   }
1002 }
1003 
1004 /// Generates code to deserialize an SPIRV_Op `op` into `os`.
1005 static void emitDeserializationFunction(const Record *attrClass,
1006                                         const Record *record,
1007                                         const Operator &op, raw_ostream &os) {
1008   // If the record has 'autogenSerialization' set to 0, nothing to do
1009   if (!record->getValueAsBit("autogenSerialization"))
1010     return;
1011 
1012   StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
1013       wordIndex("wordIndex"), opVar("op"), operands("operands"),
1014       attributes("attributes");
1015 
1016   // Method declaration
1017   os << formatv("template <> "
1018                 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
1019                 "uint32_t> {1}) {{\n",
1020                 op.getQualCppClassName(), words);
1021 
1022   // Special case for ops without attributes in TableGen definitions
1023   if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
1024     os << formatv("  return processOpWithoutGrammarAttr("
1025                   "{0}, \"{1}\", {2}, {3});\n}\n\n",
1026                   words, op.getOperationName(),
1027                   op.getNumResults() ? "true" : "false", op.getNumOperands());
1028     return;
1029   }
1030 
1031   os << formatv("  SmallVector<Type, 1> {0};\n", resultTypes);
1032   os << formatv("  size_t {0} = 0; (void){0};\n", wordIndex);
1033   os << formatv("  uint32_t {0} = 0; (void){0};\n", valueID);
1034 
1035   // Deserialize result information
1036   emitResultDeserialization(op, record->getLoc(), "  ", words, wordIndex,
1037                             resultTypes, valueID, os);
1038 
1039   os << formatv("  SmallVector<Value, 4> {0};\n", operands);
1040   os << formatv("  SmallVector<NamedAttribute, 4> {0};\n", attributes);
1041   // Operand deserialization
1042   emitOperandDeserialization(op, record->getLoc(), "  ", words, wordIndex,
1043                              operands, attributes, os);
1044 
1045   // Decorations
1046   emitDecorationDeserialization(op, "  ", valueID, attributes, os);
1047 
1048   os << formatv("  Location loc = createFileLineColLoc(opBuilder);\n");
1049   os << formatv("  auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
1050                 "(void){1};\n",
1051                 op.getQualCppClassName(), opVar, resultTypes, operands,
1052                 attributes);
1053   if (op.getNumResults() == 1) {
1054     os << formatv("  valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
1055   }
1056 
1057   // According to SPIR-V spec:
1058   // This location information applies to the instructions physically following
1059   // this instruction, up to the first occurrence of any of the following: the
1060   // next end of block.
1061   os << formatv("  if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar);
1062   os << formatv("    (void)clearDebugLine();\n");
1063   os << "  return success();\n";
1064   os << "}\n\n";
1065 }
1066 
1067 /// Generates the prologue for the function that dispatches the deserialization
1068 /// based on the `opcode`.
1069 static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
1070                                           raw_ostream &os) {
1071   os << formatv("LogicalResult spirv::Deserializer::"
1072                 "dispatchToAutogenDeserialization(spirv::Opcode {0},"
1073                 " ArrayRef<uint32_t> {1}) {{\n",
1074                 opcode, words);
1075   os << formatv("  switch ({0}) {{\n", opcode);
1076 }
1077 
1078 /// Generates the body of the dispatch function, by generating the case label
1079 /// for an opcode and the call to the method to perform the deserialization.
1080 static void emitDeserializationDispatch(const Operator &op, const Record *def,
1081                                         StringRef tabs, StringRef words,
1082                                         raw_ostream &os) {
1083   os << tabs
1084      << formatv("case spirv::Opcode::{0}:\n",
1085                 def->getValueAsString("spirvOpName"));
1086   os << tabs
1087      << formatv("  return processOp<{0}>({1});\n", op.getQualCppClassName(),
1088                 words);
1089 }
1090 
1091 /// Generates the epilogue for the function that dispatches the deserialization
1092 /// of the operation.
1093 static void finalizeDispatchDeserializationFn(StringRef opcode,
1094                                               raw_ostream &os) {
1095   os << "  default:\n";
1096   os << "    ;\n";
1097   os << "  }\n";
1098   StringRef opcodeVar("opcodeString");
1099   os << formatv("  auto {0} = spirv::stringifyOpcode({1});\n", opcodeVar,
1100                 opcode);
1101   os << formatv("  if (!{0}.empty()) {{\n", opcodeVar);
1102   os << formatv("    return emitError(unknownLoc, \"unhandled deserialization "
1103                 "of \") << {0};\n",
1104                 opcodeVar);
1105   os << "  } else {\n";
1106   os << formatv("   return emitError(unknownLoc, \"unhandled opcode \") << "
1107                 "static_cast<uint32_t>({0});\n",
1108                 opcode);
1109   os << "  }\n";
1110   os << "}\n";
1111 }
1112 
1113 static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
1114                                                    StringRef instructionID,
1115                                                    StringRef words,
1116                                                    raw_ostream &os) {
1117   os << formatv("LogicalResult spirv::Deserializer::"
1118                 "dispatchToExtensionSetAutogenDeserialization("
1119                 "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
1120                 extensionSetName, instructionID, words);
1121 }
1122 
1123 static void emitExtendedSetDeserializationDispatch(const RecordKeeper &records,
1124                                                    raw_ostream &os) {
1125   StringRef extensionSetName("extensionSetName"),
1126       instructionID("instructionID"), words("words");
1127 
1128   // First iterate over all ops derived from SPIRV_ExtensionSetOps to get all
1129   // extensionSets.
1130 
1131   // For each of the extensions a separate raw_string_ostream is used to
1132   // generate code into. These are then concatenated at the end. Since
1133   // raw_string_ostream needs a string&, use a vector to store all the string
1134   // that are captured by reference within raw_string_ostream.
1135   StringMap<raw_string_ostream> extensionSets;
1136   std::list<std::string> extensionSetNames;
1137 
1138   initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
1139                                          os);
1140   auto defs = records.getAllDerivedDefinitions("SPIRV_ExtInstOp");
1141   for (const auto *def : defs) {
1142     if (!def->getValueAsBit("autogenSerialization")) {
1143       continue;
1144     }
1145     Operator op(def);
1146     auto setName = def->getValueAsString("extendedInstSetName");
1147     if (!extensionSets.count(setName)) {
1148       extensionSetNames.emplace_back("");
1149       extensionSets.try_emplace(setName, extensionSetNames.back());
1150       auto &setos = extensionSets.find(setName)->second;
1151       setos << formatv("  if ({0} == \"{1}\") {{\n", extensionSetName, setName);
1152       setos << formatv("    switch ({0}) {{\n", instructionID);
1153     }
1154     auto &setos = extensionSets.find(setName)->second;
1155     setos << formatv("    case {0}:\n",
1156                      def->getValueAsInt("extendedInstOpcode"));
1157     setos << formatv("      return processOp<{0}>({1});\n",
1158                      op.getQualCppClassName(), words);
1159   }
1160 
1161   // Append the dispatch code for all the extended sets.
1162   for (auto &extensionSet : extensionSets) {
1163     os << extensionSet.second.str();
1164     os << "    default:\n";
1165     os << formatv(
1166         "      return emitError(unknownLoc, \"unhandled deserializations of "
1167         "\") << {0} << \" from extension set \" << {1};\n",
1168         instructionID, extensionSetName);
1169     os << "    }\n";
1170     os << "  }\n";
1171   }
1172 
1173   os << formatv("  return emitError(unknownLoc, \"unhandled deserialization of "
1174                 "extended instruction set {0}\");\n",
1175                 extensionSetName);
1176   os << "}\n";
1177 }
1178 
1179 /// Emits all the autogenerated serialization/deserializations functions for the
1180 /// SPIRV_Ops.
1181 static bool emitSerializationFns(const RecordKeeper &records, raw_ostream &os) {
1182   llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os,
1183                              records);
1184 
1185   std::string dSerFnString, dDesFnString, serFnString, deserFnString,
1186       utilsString;
1187   raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
1188       serFn(serFnString), deserFn(deserFnString);
1189   const Record *attrClass = records.getClass("Attr");
1190 
1191   // Emit the serialization and deserialization functions simultaneously.
1192   StringRef opVar("op");
1193   StringRef opcode("opcode"), words("words");
1194 
1195   // Handle the SPIR-V ops.
1196   initDispatchSerializationFn(opVar, dSerFn);
1197   initDispatchDeserializationFn(opcode, words, dDesFn);
1198   auto defs = records.getAllDerivedDefinitions("SPIRV_Op");
1199   for (const auto *def : defs) {
1200     Operator op(def);
1201     emitSerializationFunction(attrClass, def, op, serFn);
1202     emitDeserializationFunction(attrClass, def, op, deserFn);
1203     if (def->getValueAsBit("hasOpcode") ||
1204         def->isSubClassOf("SPIRV_ExtInstOp")) {
1205       emitSerializationDispatch(op, "  ", opVar, dSerFn);
1206     }
1207     if (def->getValueAsBit("hasOpcode")) {
1208       emitDeserializationDispatch(op, def, "  ", words, dDesFn);
1209     }
1210   }
1211   finalizeDispatchSerializationFn(opVar, dSerFn);
1212   finalizeDispatchDeserializationFn(opcode, dDesFn);
1213 
1214   emitExtendedSetDeserializationDispatch(records, dDesFn);
1215 
1216   os << "#ifdef GET_SERIALIZATION_FNS\n\n";
1217   os << serFn.str();
1218   os << dSerFn.str();
1219   os << "#endif // GET_SERIALIZATION_FNS\n\n";
1220 
1221   os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
1222   os << deserFn.str();
1223   os << dDesFn.str();
1224   os << "#endif // GET_DESERIALIZATION_FNS\n\n";
1225 
1226   return false;
1227 }
1228 
1229 //===----------------------------------------------------------------------===//
1230 // Serialization Hook Registration
1231 //===----------------------------------------------------------------------===//
1232 
1233 static mlir::GenRegistration genSerialization(
1234     "gen-spirv-serialization",
1235     "Generate SPIR-V (de)serialization utilities and functions",
1236     [](const RecordKeeper &records, raw_ostream &os) {
1237       return emitSerializationFns(records, os);
1238     });
1239 
1240 //===----------------------------------------------------------------------===//
1241 // Op Utils AutoGen
1242 //===----------------------------------------------------------------------===//
1243 
1244 static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
1245   os << formatv("template <typename EnumClass> inline constexpr StringRef "
1246                 "attributeName();\n");
1247 }
1248 
1249 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
1250                                       raw_ostream &os) {
1251   auto enumName = enumAttr.getEnumClassName();
1252   os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
1253                 enumName);
1254   os << "  "
1255      << formatv("static constexpr const char attrName[] = \"{0}\";\n",
1256                 llvm::convertToSnakeFromCamelCase(enumName));
1257   os << "  return attrName;\n";
1258   os << "}\n";
1259 }
1260 
1261 static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) {
1262   llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os, records);
1263 
1264   auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
1265   os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1266   os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
1267   emitEnumGetAttrNameFnDecl(os);
1268   for (const auto *def : defs) {
1269     EnumAttr enumAttr(*def);
1270     emitEnumGetAttrNameFnDefn(enumAttr, os);
1271   }
1272   os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
1273   return false;
1274 }
1275 
1276 //===----------------------------------------------------------------------===//
1277 // Op Utils Hook Registration
1278 //===----------------------------------------------------------------------===//
1279 
1280 static mlir::GenRegistration
1281     genOpUtils("gen-spirv-attr-utils",
1282                "Generate SPIR-V attribute utility definitions",
1283                [](const RecordKeeper &records, raw_ostream &os) {
1284                  return emitAttrUtils(records, os);
1285                });
1286 
1287 //===----------------------------------------------------------------------===//
1288 // SPIR-V Availability Impl AutoGen
1289 //===----------------------------------------------------------------------===//
1290 
1291 static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
1292   mlir::tblgen::FmtContext fctx;
1293   fctx.addSubst("overall", "tblgen_overall");
1294 
1295   std::vector<Availability> opAvailabilities =
1296       getAvailabilities(srcOp.getDef());
1297 
1298   // First collect all availability classes this op should implement.
1299   // All availability instances keep information for the generated interface and
1300   // the instance's specific requirement. Here we remember a random instance so
1301   // we can get the information regarding the generated interface.
1302   llvm::StringMap<Availability> availClasses;
1303   for (const Availability &avail : opAvailabilities)
1304     availClasses.try_emplace(avail.getClass(), avail);
1305   for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1306     if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
1307         !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
1308       continue;
1309     EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
1310 
1311     for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1312       for (const Availability &caseAvail :
1313            getAvailabilities(enumerant.getDef()))
1314         availClasses.try_emplace(caseAvail.getClass(), caseAvail);
1315   }
1316 
1317   // Then generate implementation for each availability class.
1318   for (const auto &availClass : availClasses) {
1319     StringRef availClassName = availClass.getKey();
1320     Availability avail = availClass.getValue();
1321 
1322     // Generate the implementation method signature.
1323     os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
1324                   srcOp.getCppClassName(), avail.getQueryFnName());
1325 
1326     // Create the variable for the final requirement and initialize it.
1327     os << formatv("  {0} tblgen_overall = {1};\n", avail.getQueryFnRetType(),
1328                   avail.getMergeInitializer());
1329 
1330     // Update with the op's specific availability spec.
1331     for (const Availability &avail : opAvailabilities)
1332       if (avail.getClass() == availClassName &&
1333           (!avail.getMergeInstancePreparation().empty() ||
1334            !avail.getMergeActionCode().empty())) {
1335         os << "  {\n    "
1336            // Prepare this instance.
1337            << avail.getMergeInstancePreparation()
1338            << "\n    "
1339            // Merge this instance.
1340            << std::string(
1341                   tgfmt(avail.getMergeActionCode(),
1342                         &fctx.addSubst("instance", avail.getMergeInstance())))
1343            << ";\n  }\n";
1344       }
1345 
1346     // Update with enum attributes' specific availability spec.
1347     for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
1348       if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
1349           !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
1350         continue;
1351       EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
1352 
1353       // (enumerant, availability specification) pairs for this availability
1354       // class.
1355       SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
1356 
1357       // Collect all cases' availability specs.
1358       for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
1359         for (const Availability &caseAvail :
1360              getAvailabilities(enumerant.getDef()))
1361           if (availClassName == caseAvail.getClass())
1362             caseSpecs.push_back({enumerant, caseAvail});
1363 
1364       // If this attribute kind does not have any availability spec from any of
1365       // its cases, no more work to do.
1366       if (caseSpecs.empty())
1367         continue;
1368 
1369       if (enumAttr.isBitEnum()) {
1370         // For BitEnumAttr, we need to iterate over each bit to query its
1371         // availability spec.
1372         os << formatv("  for (unsigned i = 0; "
1373                       "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
1374                       enumAttr.getUnderlyingType());
1375         os << formatv("    {0}::{1} tblgen_attrVal = this->{2}() & "
1376                       "static_cast<{0}::{1}>(1 << i);\n",
1377                       enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
1378                       srcOp.getGetterName(namedAttr.name));
1379         os << formatv(
1380             "    if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
1381             enumAttr.getUnderlyingType());
1382       } else {
1383         // For IntEnumAttr, we just need to query the value as a whole.
1384         os << "  {\n";
1385         os << formatv("    auto tblgen_attrVal = this->{0}();\n",
1386                       srcOp.getGetterName(namedAttr.name));
1387       }
1388       os << formatv("    auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
1389                     enumAttr.getCppNamespace(), avail.getQueryFnName());
1390       os << "    if (tblgen_instance) "
1391          // TODO` here once ODS supports
1392          // dialect-specific contents so that we can use not implementing the
1393          // availability interface as indication of no requirements.
1394          << std::string(tgfmt(caseSpecs.front().second.getMergeActionCode(),
1395                               &fctx.addSubst("instance", "*tblgen_instance")))
1396          << ";\n";
1397       os << "  }\n";
1398     }
1399 
1400     os << "  return tblgen_overall;\n";
1401     os << "}\n";
1402   }
1403 }
1404 
1405 static bool emitAvailabilityImpl(const RecordKeeper &records, raw_ostream &os) {
1406   llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os,
1407                              records);
1408 
1409   auto defs = records.getAllDerivedDefinitions("SPIRV_Op");
1410   for (const auto *def : defs) {
1411     Operator op(def);
1412     if (def->getValueAsBit("autogenAvailability"))
1413       emitAvailabilityImpl(op, os);
1414   }
1415   return false;
1416 }
1417 
1418 //===----------------------------------------------------------------------===//
1419 // Op Availability Implementation Hook Registration
1420 //===----------------------------------------------------------------------===//
1421 
1422 static mlir::GenRegistration
1423     genOpAvailabilityImpl("gen-spirv-avail-impls",
1424                           "Generate SPIR-V operation utility definitions",
1425                           [](const RecordKeeper &records, raw_ostream &os) {
1426                             return emitAvailabilityImpl(records, os);
1427                           });
1428 
1429 //===----------------------------------------------------------------------===//
1430 // SPIR-V Capability Implication AutoGen
1431 //===----------------------------------------------------------------------===//
1432 
1433 static bool emitCapabilityImplication(const RecordKeeper &records,
1434                                       raw_ostream &os) {
1435   llvm::emitSourceFileHeader("SPIR-V Capability Implication", os, records);
1436 
1437   EnumAttr enumAttr(
1438       records.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
1439 
1440   os << "ArrayRef<spirv::Capability> "
1441         "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
1442      << "  switch (cap) {\n"
1443      << "  default: return {};\n";
1444   for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
1445     const Record &def = enumerant.getDef();
1446     if (!def.getValue("implies"))
1447       continue;
1448 
1449     std::vector<const Record *> impliedCapsDefs =
1450         def.getValueAsListOfDefs("implies");
1451     os << "  case spirv::Capability::" << enumerant.getSymbol()
1452        << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
1453        << "] = {";
1454     llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
1455       os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
1456     });
1457     os << "}; return ArrayRef<spirv::Capability>(implies, "
1458        << impliedCapsDefs.size() << "); }\n";
1459   }
1460   os << "  }\n";
1461   os << "}\n";
1462 
1463   return false;
1464 }
1465 
1466 //===----------------------------------------------------------------------===//
1467 // SPIR-V Capability Implication Hook Registration
1468 //===----------------------------------------------------------------------===//
1469 
1470 static mlir::GenRegistration
1471     genCapabilityImplication("gen-spirv-capability-implication",
1472                              "Generate utility function to return implied "
1473                              "capabilities for a given capability",
1474                              [](const RecordKeeper &records, raw_ostream &os) {
1475                                return emitCapabilityImplication(records, os);
1476                              });
1477