xref: /llvm-project/mlir/tools/mlir-tblgen/EnumsGen.cpp (revision db273c6c242f51792ed4298a24bd2c344214ce38)
1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // EnumsGen generates common utility functions for enums.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "FormatGen.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "llvm/ADT/BitVector.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/Record.h"
24 #include "llvm/TableGen/TableGenBackend.h"
25 
26 using llvm::formatv;
27 using llvm::isDigit;
28 using llvm::PrintFatalError;
29 using llvm::Record;
30 using llvm::RecordKeeper;
31 using namespace mlir;
32 using mlir::tblgen::Attribute;
33 using mlir::tblgen::EnumAttr;
34 using mlir::tblgen::EnumAttrCase;
35 using mlir::tblgen::FmtContext;
36 using mlir::tblgen::tgfmt;
37 
38 static std::string makeIdentifier(StringRef str) {
39   if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
40     std::string newStr = std::string("_") + str.str();
41     return newStr;
42   }
43   return str.str();
44 }
45 
46 static void emitEnumClass(const Record &enumDef, StringRef enumName,
47                           StringRef underlyingType, StringRef description,
48                           const std::vector<EnumAttrCase> &enumerants,
49                           raw_ostream &os) {
50   os << "// " << description << "\n";
51   os << "enum class " << enumName;
52 
53   if (!underlyingType.empty())
54     os << " : " << underlyingType;
55   os << " {\n";
56 
57   for (const auto &enumerant : enumerants) {
58     auto symbol = makeIdentifier(enumerant.getSymbol());
59     auto value = enumerant.getValue();
60     if (value >= 0) {
61       os << formatv("  {0} = {1},\n", symbol, value);
62     } else {
63       os << formatv("  {0},\n", symbol);
64     }
65   }
66   os << "};\n\n";
67 }
68 
69 static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
70                               StringRef cppNamespace, raw_ostream &os) {
71   if (enumAttr.getUnderlyingType().empty() ||
72       enumAttr.getConstBuilderTemplate().empty())
73     return;
74   auto cases = enumAttr.getAllCases();
75 
76   // Check which cases shouldn't be printed using a keyword.
77   llvm::BitVector nonKeywordCases(cases.size());
78   for (auto [index, caseVal] : llvm::enumerate(cases))
79     if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
80       nonKeywordCases.set(index);
81 
82   // Generate the parser and the start of the printer for the enum.
83   const char *parsedAndPrinterStart = R"(
84 namespace mlir {
85 template <typename T, typename>
86 struct FieldParser;
87 
88 template<>
89 struct FieldParser<{0}, {0}> {{
90   template <typename ParserT>
91   static FailureOr<{0}> parse(ParserT &parser) {{
92     // Parse the keyword/string containing the enum.
93     std::string enumKeyword;
94     auto loc = parser.getCurrentLocation();
95     if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
96       return parser.emitError(loc, "expected keyword for {2}");
97 
98     // Symbolize the keyword.
99     if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
100       return *attr;
101     return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
102   }
103 };
104 
105 /// Support for std::optional, useful in attribute/type definition where the enum is
106 /// used as:
107 ///
108 ///    let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value);
109 template<>
110 struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
111   template <typename ParserT>
112   static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{
113     // Parse the keyword/string containing the enum.
114     std::string enumKeyword;
115     auto loc = parser.getCurrentLocation();
116     if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
117       return std::optional<{0}>{{};
118 
119     // Symbolize the keyword.
120     if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
121       return attr;
122     return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
123   }
124 };
125 } // namespace mlir
126 
127 namespace llvm {
128 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
129   auto valueStr = stringifyEnum(value);
130 )";
131   os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
132                 enumAttr.getSummary());
133 
134   // If all cases require a string, always wrap.
135   if (nonKeywordCases.all()) {
136     os << "  return p << '\"' << valueStr << '\"';\n"
137           "}\n"
138           "} // namespace llvm\n";
139     return;
140   }
141 
142   // If there are any cases that can't be used with a keyword, switch on the
143   // case value to determine when to print in the string form.
144   if (nonKeywordCases.any()) {
145     os << "  switch (value) {\n";
146     for (auto it : llvm::enumerate(cases)) {
147       if (nonKeywordCases.test(it.index()))
148         continue;
149       StringRef symbol = it.value().getSymbol();
150       os << llvm::formatv("  case {0}::{1}:\n", qualName,
151                           makeIdentifier(symbol));
152     }
153     os << "    break;\n"
154           "  default:\n"
155           "    return p << '\"' << valueStr << '\"';\n"
156           "  }\n";
157 
158     // If this is a bit enum, conservatively print the string form if the value
159     // is not a power of two (i.e. not a single bit case) and not a known case.
160   } else if (enumAttr.isBitEnum()) {
161     // Process the known multi-bit cases that use valid keywords.
162     SmallVector<EnumAttrCase *> validMultiBitCases;
163     for (auto [index, caseVal] : llvm::enumerate(cases)) {
164       uint64_t value = caseVal.getValue();
165       if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
166         validMultiBitCases.push_back(&caseVal);
167     }
168     if (!validMultiBitCases.empty()) {
169       os << "  switch (value) {\n";
170       for (EnumAttrCase *caseVal : validMultiBitCases) {
171         StringRef symbol = caseVal->getSymbol();
172         os << llvm::formatv("  case {0}::{1}:\n", qualName,
173                             llvm::isDigit(symbol.front()) ? ("_" + symbol)
174                                                           : symbol);
175       }
176       os << "    return p << valueStr;\n"
177             "  default:\n"
178             "    break;\n"
179             "  }\n";
180     }
181 
182     // All other multi-bit cases should be printed as strings.
183     os << formatv("  auto underlyingValue = "
184                   "static_cast<std::make_unsigned_t<{0}>>(value);\n",
185                   qualName);
186     os << "  if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
187           "    return p << '\"' << valueStr << '\"';\n";
188   }
189   os << "  return p << valueStr;\n"
190         "}\n"
191         "} // namespace llvm\n";
192 }
193 
194 static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
195                              StringRef cppNamespace, raw_ostream &os) {
196   if (underlyingType.empty())
197     underlyingType =
198         std::string(formatv("std::underlying_type_t<{0}>", qualName));
199 
200   const char *const mapInfo = R"(
201 namespace llvm {
202 template<> struct DenseMapInfo<{0}> {{
203   using StorageInfo = ::llvm::DenseMapInfo<{1}>;
204 
205   static inline {0} getEmptyKey() {{
206     return static_cast<{0}>(StorageInfo::getEmptyKey());
207   }
208 
209   static inline {0} getTombstoneKey() {{
210     return static_cast<{0}>(StorageInfo::getTombstoneKey());
211   }
212 
213   static unsigned getHashValue(const {0} &val) {{
214     return StorageInfo::getHashValue(static_cast<{1}>(val));
215   }
216 
217   static bool isEqual(const {0} &lhs, const {0} &rhs) {{
218     return lhs == rhs;
219   }
220 };
221 })";
222   os << formatv(mapInfo, qualName, underlyingType);
223   os << "\n\n";
224 }
225 
226 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
227   EnumAttr enumAttr(enumDef);
228   StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
229   auto enumerants = enumAttr.getAllCases();
230 
231   unsigned maxEnumVal = 0;
232   for (const auto &enumerant : enumerants) {
233     int64_t value = enumerant.getValue();
234     // Avoid generating the max value function if there is an enumerant without
235     // explicit value.
236     if (value < 0)
237       return;
238 
239     maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
240   }
241 
242   // Emit the function to return the max enum value
243   os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
244   os << formatv("  return {0};\n", maxEnumVal);
245   os << "}\n\n";
246 }
247 
248 // Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt
249 // otherwise.
250 static std::optional<EnumAttrCase>
251 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
252   for (auto attrCase : cases) {
253     if (attrCase.getValue() == 0)
254       return attrCase;
255   }
256   return std::nullopt;
257 }
258 
259 // Emits the following inline function for bit enums:
260 //
261 // inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
262 // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
263 // inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
264 // inline constexpr <enum-type> operator~(<enum-type> bits);
265 // inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit);
266 // inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit);
267 // inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
268 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
269 // bool value=true);
270 static void emitOperators(const Record &enumDef, raw_ostream &os) {
271   EnumAttr enumAttr(enumDef);
272   StringRef enumName = enumAttr.getEnumClassName();
273   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
274   int64_t validBits = enumDef.getValueAsInt("validBits");
275   const char *const operators = R"(
276 inline constexpr {0} operator|({0} a, {0} b) {{
277   return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
278 }
279 inline constexpr {0} operator&({0} a, {0} b) {{
280   return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
281 }
282 inline constexpr {0} operator^({0} a, {0} b) {{
283   return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
284 }
285 inline constexpr {0} operator~({0} bits) {{
286   // Ensure only bits that can be present in the enum are set
287   return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
288 }
289 inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{
290   return (bits & bit) == bit;
291 }
292 inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{
293   return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;
294 }
295 inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
296   return bits & ~bit;
297 }
298 inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
299   return value ? (bits | bit) : bitEnumClear(bits, bit);
300 }
301   )";
302   os << formatv(operators, enumName, underlyingType, validBits);
303 }
304 
305 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
306   EnumAttr enumAttr(enumDef);
307   StringRef enumName = enumAttr.getEnumClassName();
308   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
309   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
310   auto enumerants = enumAttr.getAllCases();
311 
312   os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
313                 symToStrFnRetType);
314   os << "  switch (val) {\n";
315   for (const auto &enumerant : enumerants) {
316     auto symbol = enumerant.getSymbol();
317     auto str = enumerant.getStr();
318     os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
319                   makeIdentifier(symbol), str);
320   }
321   os << "  }\n";
322   os << "  return \"\";\n";
323   os << "}\n\n";
324 }
325 
326 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
327   EnumAttr enumAttr(enumDef);
328   StringRef enumName = enumAttr.getEnumClassName();
329   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
330   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
331   StringRef separator = enumDef.getValueAsString("separator");
332   auto enumerants = enumAttr.getAllCases();
333   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
334 
335   os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
336                 symToStrFnRetType);
337 
338   os << formatv("  auto val = static_cast<{0}>(symbol);\n",
339                 enumAttr.getUnderlyingType());
340   // If we have unknown bit set, return an empty string to signal errors.
341   int64_t validBits = enumDef.getValueAsInt("validBits");
342   os << formatv("  assert({0}u == ({0}u | val) && \"invalid bits set in bit "
343                 "enum\");\n",
344                 validBits);
345   if (allBitsUnsetCase) {
346     os << "  // Special case for all bits unset.\n";
347     os << formatv("  if (val == 0) return \"{0}\";\n\n",
348                   allBitsUnsetCase->getStr());
349   }
350   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
351 
352   // Add case string if the value has all case bits, and remove them to avoid
353   // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
354   const char *const formatCompareRemove = R"(
355   if ({0}u == ({0}u & val)) {{
356     strs.push_back("{1}");
357     val &= ~static_cast<{2}>({0});
358   }
359 )";
360   // Add case string if the value has all case bits. Used for individual bit
361   // cases, and for groups when printBitEnumPrimaryGroups is 0.
362   const char *const formatCompare = R"(
363   if ({0}u == ({0}u & val))
364     strs.push_back("{1}");
365 )";
366   // Optionally elide bits that are members of groups that will also be printed
367   // for more concise output.
368   if (enumAttr.printBitEnumPrimaryGroups()) {
369     os << "  // Print bit enum groups before individual bits\n";
370     // Emit comparisons for group bit cases in reverse tablegen declaration
371     // order, removing bits for groups with all bits present.
372     for (const auto &enumerant : llvm::reverse(enumerants)) {
373       if ((enumerant.getValue() != 0) &&
374           enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
375         os << formatv(formatCompareRemove, enumerant.getValue(),
376                       enumerant.getStr(), enumAttr.getUnderlyingType());
377       }
378     }
379     // Emit comparisons for individual bit cases in tablegen declaration order.
380     for (const auto &enumerant : enumerants) {
381       if ((enumerant.getValue() != 0) &&
382           enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
383         os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
384     }
385   } else {
386     // Emit comparisons for ALL nonzero cases (individual bits and groups) in
387     // tablegen declaration order.
388     for (const auto &enumerant : enumerants) {
389       if (enumerant.getValue() != 0)
390         os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
391     }
392   }
393   os << formatv("  return ::llvm::join(strs, \"{0}\");\n", separator);
394 
395   os << "}\n\n";
396 }
397 
398 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
399   EnumAttr enumAttr(enumDef);
400   StringRef enumName = enumAttr.getEnumClassName();
401   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
402   auto enumerants = enumAttr.getAllCases();
403 
404   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
405                 enumName, strToSymFnName);
406   os << formatv("  return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n",
407                 enumName);
408   for (const auto &enumerant : enumerants) {
409     auto symbol = enumerant.getSymbol();
410     auto str = enumerant.getStr();
411     os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, str,
412                   makeIdentifier(symbol));
413   }
414   os << "      .Default(::std::nullopt);\n";
415   os << "}\n";
416 }
417 
418 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
419   EnumAttr enumAttr(enumDef);
420   StringRef enumName = enumAttr.getEnumClassName();
421   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
422   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
423   StringRef separator = enumDef.getValueAsString("separator");
424   StringRef separatorTrimmed = separator.trim();
425   auto enumerants = enumAttr.getAllCases();
426   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
427 
428   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
429                 enumName, strToSymFnName);
430 
431   if (allBitsUnsetCase) {
432     os << "  // Special case for all bits unset.\n";
433     StringRef caseSymbol = allBitsUnsetCase->getSymbol();
434     os << formatv("  if (str == \"{1}\") return {0}::{2};\n\n", enumName,
435                   allBitsUnsetCase->getStr(), makeIdentifier(caseSymbol));
436   }
437 
438   // Split the string to get symbols for all the bits.
439   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
440   // Remove whitespace from the separator string when parsing.
441   os << formatv("  str.split(symbols, \"{0}\");\n\n", separatorTrimmed);
442 
443   os << formatv("  {0} val = 0;\n", underlyingType);
444   os << "  for (auto symbol : symbols) {\n";
445 
446   // Convert each symbol to the bit ordinal and set the corresponding bit.
447   os << formatv("    auto bit = "
448                 "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n",
449                 underlyingType);
450   for (const auto &enumerant : enumerants) {
451     // Skip the special enumerant for None.
452     if (auto val = enumerant.getValue())
453       os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val);
454   }
455   os.indent(6) << ".Default(::std::nullopt);\n";
456 
457   os << "    if (bit) { val |= *bit; } else { return ::std::nullopt; }\n";
458   os << "  }\n";
459 
460   os << formatv("  return static_cast<{0}>(val);\n", enumName);
461   os << "}\n\n";
462 }
463 
464 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
465                                             raw_ostream &os) {
466   EnumAttr enumAttr(enumDef);
467   StringRef enumName = enumAttr.getEnumClassName();
468   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
469   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
470   auto enumerants = enumAttr.getAllCases();
471 
472   // Avoid generating the underlying value to symbol conversion function if
473   // there is an enumerant without explicit value.
474   if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
475         return enumerant.getValue() < 0;
476       }))
477     return;
478 
479   os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
480                 underlyingToSymFnName,
481                 underlyingType.empty() ? std::string("unsigned")
482                                        : underlyingType)
483      << "  switch (value) {\n";
484   for (const auto &enumerant : enumerants) {
485     auto symbol = enumerant.getSymbol();
486     auto value = enumerant.getValue();
487     os << formatv("  case {0}: return {1}::{2};\n", value, enumName,
488                   makeIdentifier(symbol));
489   }
490   os << "  default: return ::std::nullopt;\n"
491      << "  }\n"
492      << "}\n\n";
493 }
494 
495 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
496   EnumAttr enumAttr(enumDef);
497   StringRef enumName = enumAttr.getEnumClassName();
498   StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
499   const Record *baseAttrDef = enumAttr.getBaseAttrClass();
500   Attribute baseAttr(baseAttrDef);
501 
502   // Emit classof method
503 
504   os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
505                 attrClassName);
506 
507   mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
508   if (baseAttrPred.isNull())
509     PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n");
510 
511   std::string condition = baseAttrPred.getCondition();
512   FmtContext verifyCtx;
513   verifyCtx.withSelf("attr");
514   os << tgfmt("  return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx));
515 
516   os << "}\n";
517 
518   // Emit get method
519 
520   os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
521                 attrClassName, enumName);
522 
523   StringRef underlyingType = enumAttr.getUnderlyingType();
524 
525   // Assuming that it is IntegerAttr constraint
526   int64_t bitwidth = 64;
527   if (baseAttrDef->getValue("valueType")) {
528     auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
529     if (valueTypeDef->getValue("bitwidth"))
530       bitwidth = valueTypeDef->getValueAsInt("bitwidth");
531   }
532 
533   os << formatv("  ::mlir::IntegerType intType = "
534                 "::mlir::IntegerType::get(context, {0});\n",
535                 bitwidth);
536   os << formatv("  ::mlir::IntegerAttr baseAttr = "
537                 "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
538                 underlyingType);
539   os << formatv("  return ::llvm::cast<{0}>(baseAttr);\n", attrClassName);
540 
541   os << "}\n";
542 
543   // Emit getValue method
544 
545   os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
546 
547   os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
548                 enumName);
549 
550   os << "}\n";
551 }
552 
553 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
554                                             raw_ostream &os) {
555   EnumAttr enumAttr(enumDef);
556   StringRef enumName = enumAttr.getEnumClassName();
557   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
558   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
559   auto enumerants = enumAttr.getAllCases();
560   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
561 
562   os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
563                 underlyingToSymFnName, underlyingType);
564   if (allBitsUnsetCase) {
565     os << "  // Special case for all bits unset.\n";
566     os << formatv("  if (value == 0) return {0}::{1};\n\n", enumName,
567                   makeIdentifier(allBitsUnsetCase->getSymbol()));
568   }
569   int64_t validBits = enumDef.getValueAsInt("validBits");
570   os << formatv("  if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n",
571                 underlyingType, validBits);
572   os << formatv("  return static_cast<{0}>(value);\n", enumName);
573   os << "}\n";
574 }
575 
576 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
577   EnumAttr enumAttr(enumDef);
578   StringRef enumName = enumAttr.getEnumClassName();
579   StringRef cppNamespace = enumAttr.getCppNamespace();
580   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
581   StringRef description = enumAttr.getSummary();
582   StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
583   StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
584   StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
585   StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
586   auto enumerants = enumAttr.getAllCases();
587 
588   SmallVector<StringRef, 2> namespaces;
589   llvm::SplitString(cppNamespace, namespaces, "::");
590 
591   for (auto ns : namespaces)
592     os << "namespace " << ns << " {\n";
593 
594   // Emit the enum class definition
595   emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
596 
597   // Emit conversion function declarations
598   if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
599         return enumerant.getValue() >= 0;
600       })) {
601     os << formatv(
602         "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
603         underlyingType.empty() ? std::string("unsigned") : underlyingType);
604   }
605   os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
606   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
607                 strToSymFnName);
608 
609   if (enumAttr.isBitEnum()) {
610     emitOperators(enumDef, os);
611   } else {
612     emitMaxValueFn(enumDef, os);
613   }
614 
615   // Generate a generic `stringifyEnum` function that forwards to the method
616   // specified by the user.
617   const char *const stringifyEnumStr = R"(
618 inline {0} stringifyEnum({1} enumValue) {{
619   return {2}(enumValue);
620 }
621 )";
622   os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
623 
624   // Generate a generic `symbolizeEnum` function that forwards to the method
625   // specified by the user.
626   const char *const symbolizeEnumStr = R"(
627 template <typename EnumType>
628 ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef);
629 
630 template <>
631 inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
632   return {1}(str);
633 }
634 )";
635   os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
636 
637   const char *const attrClassDecl = R"(
638 class {1} : public ::mlir::{2} {
639 public:
640   using ValueType = {0};
641   using ::mlir::{2}::{2};
642   static bool classof(::mlir::Attribute attr);
643   static {1} get(::mlir::MLIRContext *context, {0} val);
644   {0} getValue() const;
645 };
646 )";
647   if (enumAttr.genSpecializedAttr()) {
648     StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
649     StringRef baseAttrClassName = "IntegerAttr";
650     os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
651   }
652 
653   for (auto ns : llvm::reverse(namespaces))
654     os << "} // namespace " << ns << "\n";
655 
656   // Generate a generic parser and printer for the enum.
657   std::string qualName =
658       std::string(formatv("{0}::{1}", cppNamespace, enumName));
659   emitParserPrinter(enumAttr, qualName, cppNamespace, os);
660 
661   // Emit DenseMapInfo for this enum class
662   emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
663 }
664 
665 static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
666   llvm::emitSourceFileHeader("Enum Utility Declarations", os, records);
667 
668   for (const Record *def :
669        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
670     emitEnumDecl(*def, os);
671 
672   return false;
673 }
674 
675 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
676   EnumAttr enumAttr(enumDef);
677   StringRef cppNamespace = enumAttr.getCppNamespace();
678 
679   SmallVector<StringRef, 2> namespaces;
680   llvm::SplitString(cppNamespace, namespaces, "::");
681 
682   for (auto ns : namespaces)
683     os << "namespace " << ns << " {\n";
684 
685   if (enumAttr.isBitEnum()) {
686     emitSymToStrFnForBitEnum(enumDef, os);
687     emitStrToSymFnForBitEnum(enumDef, os);
688     emitUnderlyingToSymFnForBitEnum(enumDef, os);
689   } else {
690     emitSymToStrFnForIntEnum(enumDef, os);
691     emitStrToSymFnForIntEnum(enumDef, os);
692     emitUnderlyingToSymFnForIntEnum(enumDef, os);
693   }
694 
695   if (enumAttr.genSpecializedAttr())
696     emitSpecializedAttrDef(enumDef, os);
697 
698   for (auto ns : llvm::reverse(namespaces))
699     os << "} // namespace " << ns << "\n";
700   os << "\n";
701 }
702 
703 static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
704   llvm::emitSourceFileHeader("Enum Utility Definitions", os, records);
705 
706   for (const Record *def :
707        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
708     emitEnumDef(*def, os);
709 
710   return false;
711 }
712 
713 // Registers the enum utility generator to mlir-tblgen.
714 static mlir::GenRegistration
715     genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
716                  [](const RecordKeeper &records, raw_ostream &os) {
717                    return emitEnumDecls(records, os);
718                  });
719 
720 // Registers the enum utility generator to mlir-tblgen.
721 static mlir::GenRegistration
722     genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
723                 [](const RecordKeeper &records, raw_ostream &os) {
724                   return emitEnumDefs(records, os);
725                 });
726