1 //===- EnumsGen.cpp - MLIR enum utility generator -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // EnumsGen generates common utility functions for enums. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "FormatGen.h" 14 #include "mlir/TableGen/Attribute.h" 15 #include "mlir/TableGen/Format.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "llvm/ADT/BitVector.h" 18 #include "llvm/ADT/SmallVector.h" 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/Support/FormatVariadic.h" 21 #include "llvm/Support/raw_ostream.h" 22 #include "llvm/TableGen/Error.h" 23 #include "llvm/TableGen/Record.h" 24 #include "llvm/TableGen/TableGenBackend.h" 25 26 using llvm::formatv; 27 using llvm::isDigit; 28 using llvm::PrintFatalError; 29 using llvm::Record; 30 using llvm::RecordKeeper; 31 using namespace mlir; 32 using mlir::tblgen::Attribute; 33 using mlir::tblgen::EnumAttr; 34 using mlir::tblgen::EnumAttrCase; 35 using mlir::tblgen::FmtContext; 36 using mlir::tblgen::tgfmt; 37 38 static std::string makeIdentifier(StringRef str) { 39 if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) { 40 std::string newStr = std::string("_") + str.str(); 41 return newStr; 42 } 43 return str.str(); 44 } 45 46 static void emitEnumClass(const Record &enumDef, StringRef enumName, 47 StringRef underlyingType, StringRef description, 48 const std::vector<EnumAttrCase> &enumerants, 49 raw_ostream &os) { 50 os << "// " << description << "\n"; 51 os << "enum class " << enumName; 52 53 if (!underlyingType.empty()) 54 os << " : " << underlyingType; 55 os << " {\n"; 56 57 for (const auto &enumerant : enumerants) { 58 auto symbol = makeIdentifier(enumerant.getSymbol()); 59 auto value = enumerant.getValue(); 60 if (value >= 0) { 61 os << formatv(" {0} = {1},\n", symbol, value); 62 } else { 63 os << formatv(" {0},\n", symbol); 64 } 65 } 66 os << "};\n\n"; 67 } 68 69 static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName, 70 StringRef cppNamespace, raw_ostream &os) { 71 if (enumAttr.getUnderlyingType().empty() || 72 enumAttr.getConstBuilderTemplate().empty()) 73 return; 74 auto cases = enumAttr.getAllCases(); 75 76 // Check which cases shouldn't be printed using a keyword. 77 llvm::BitVector nonKeywordCases(cases.size()); 78 for (auto [index, caseVal] : llvm::enumerate(cases)) 79 if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr())) 80 nonKeywordCases.set(index); 81 82 // Generate the parser and the start of the printer for the enum. 83 const char *parsedAndPrinterStart = R"( 84 namespace mlir { 85 template <typename T, typename> 86 struct FieldParser; 87 88 template<> 89 struct FieldParser<{0}, {0}> {{ 90 template <typename ParserT> 91 static FailureOr<{0}> parse(ParserT &parser) {{ 92 // Parse the keyword/string containing the enum. 93 std::string enumKeyword; 94 auto loc = parser.getCurrentLocation(); 95 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword))) 96 return parser.emitError(loc, "expected keyword for {2}"); 97 98 // Symbolize the keyword. 99 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword)) 100 return *attr; 101 return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword; 102 } 103 }; 104 105 /// Support for std::optional, useful in attribute/type definition where the enum is 106 /// used as: 107 /// 108 /// let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value); 109 template<> 110 struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{ 111 template <typename ParserT> 112 static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{ 113 // Parse the keyword/string containing the enum. 114 std::string enumKeyword; 115 auto loc = parser.getCurrentLocation(); 116 if (failed(parser.parseOptionalKeywordOrString(&enumKeyword))) 117 return std::optional<{0}>{{}; 118 119 // Symbolize the keyword. 120 if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword)) 121 return attr; 122 return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword; 123 } 124 }; 125 } // namespace mlir 126 127 namespace llvm { 128 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ 129 auto valueStr = stringifyEnum(value); 130 )"; 131 os << formatv(parsedAndPrinterStart, qualName, cppNamespace, 132 enumAttr.getSummary()); 133 134 // If all cases require a string, always wrap. 135 if (nonKeywordCases.all()) { 136 os << " return p << '\"' << valueStr << '\"';\n" 137 "}\n" 138 "} // namespace llvm\n"; 139 return; 140 } 141 142 // If there are any cases that can't be used with a keyword, switch on the 143 // case value to determine when to print in the string form. 144 if (nonKeywordCases.any()) { 145 os << " switch (value) {\n"; 146 for (auto it : llvm::enumerate(cases)) { 147 if (nonKeywordCases.test(it.index())) 148 continue; 149 StringRef symbol = it.value().getSymbol(); 150 os << llvm::formatv(" case {0}::{1}:\n", qualName, 151 makeIdentifier(symbol)); 152 } 153 os << " break;\n" 154 " default:\n" 155 " return p << '\"' << valueStr << '\"';\n" 156 " }\n"; 157 158 // If this is a bit enum, conservatively print the string form if the value 159 // is not a power of two (i.e. not a single bit case) and not a known case. 160 } else if (enumAttr.isBitEnum()) { 161 // Process the known multi-bit cases that use valid keywords. 162 SmallVector<EnumAttrCase *> validMultiBitCases; 163 for (auto [index, caseVal] : llvm::enumerate(cases)) { 164 uint64_t value = caseVal.getValue(); 165 if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index)) 166 validMultiBitCases.push_back(&caseVal); 167 } 168 if (!validMultiBitCases.empty()) { 169 os << " switch (value) {\n"; 170 for (EnumAttrCase *caseVal : validMultiBitCases) { 171 StringRef symbol = caseVal->getSymbol(); 172 os << llvm::formatv(" case {0}::{1}:\n", qualName, 173 llvm::isDigit(symbol.front()) ? ("_" + symbol) 174 : symbol); 175 } 176 os << " return p << valueStr;\n" 177 " default:\n" 178 " break;\n" 179 " }\n"; 180 } 181 182 // All other multi-bit cases should be printed as strings. 183 os << formatv(" auto underlyingValue = " 184 "static_cast<std::make_unsigned_t<{0}>>(value);\n", 185 qualName); 186 os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n" 187 " return p << '\"' << valueStr << '\"';\n"; 188 } 189 os << " return p << valueStr;\n" 190 "}\n" 191 "} // namespace llvm\n"; 192 } 193 194 static void emitDenseMapInfo(StringRef qualName, std::string underlyingType, 195 StringRef cppNamespace, raw_ostream &os) { 196 if (underlyingType.empty()) 197 underlyingType = 198 std::string(formatv("std::underlying_type_t<{0}>", qualName)); 199 200 const char *const mapInfo = R"( 201 namespace llvm { 202 template<> struct DenseMapInfo<{0}> {{ 203 using StorageInfo = ::llvm::DenseMapInfo<{1}>; 204 205 static inline {0} getEmptyKey() {{ 206 return static_cast<{0}>(StorageInfo::getEmptyKey()); 207 } 208 209 static inline {0} getTombstoneKey() {{ 210 return static_cast<{0}>(StorageInfo::getTombstoneKey()); 211 } 212 213 static unsigned getHashValue(const {0} &val) {{ 214 return StorageInfo::getHashValue(static_cast<{1}>(val)); 215 } 216 217 static bool isEqual(const {0} &lhs, const {0} &rhs) {{ 218 return lhs == rhs; 219 } 220 }; 221 })"; 222 os << formatv(mapInfo, qualName, underlyingType); 223 os << "\n\n"; 224 } 225 226 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) { 227 EnumAttr enumAttr(enumDef); 228 StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); 229 auto enumerants = enumAttr.getAllCases(); 230 231 unsigned maxEnumVal = 0; 232 for (const auto &enumerant : enumerants) { 233 int64_t value = enumerant.getValue(); 234 // Avoid generating the max value function if there is an enumerant without 235 // explicit value. 236 if (value < 0) 237 return; 238 239 maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value)); 240 } 241 242 // Emit the function to return the max enum value 243 os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); 244 os << formatv(" return {0};\n", maxEnumVal); 245 os << "}\n\n"; 246 } 247 248 // Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt 249 // otherwise. 250 static std::optional<EnumAttrCase> 251 getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) { 252 for (auto attrCase : cases) { 253 if (attrCase.getValue() == 0) 254 return attrCase; 255 } 256 return std::nullopt; 257 } 258 259 // Emits the following inline function for bit enums: 260 // 261 // inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b); 262 // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b); 263 // inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b); 264 // inline constexpr <enum-type> operator~(<enum-type> bits); 265 // inline constexpr bool bitEnumContainsAll(<enum-type> bits, <enum-type> bit); 266 // inline constexpr bool bitEnumContainsAny(<enum-type> bits, <enum-type> bit); 267 // inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit); 268 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit, 269 // bool value=true); 270 static void emitOperators(const Record &enumDef, raw_ostream &os) { 271 EnumAttr enumAttr(enumDef); 272 StringRef enumName = enumAttr.getEnumClassName(); 273 std::string underlyingType = std::string(enumAttr.getUnderlyingType()); 274 int64_t validBits = enumDef.getValueAsInt("validBits"); 275 const char *const operators = R"( 276 inline constexpr {0} operator|({0} a, {0} b) {{ 277 return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b)); 278 } 279 inline constexpr {0} operator&({0} a, {0} b) {{ 280 return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b)); 281 } 282 inline constexpr {0} operator^({0} a, {0} b) {{ 283 return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b)); 284 } 285 inline constexpr {0} operator~({0} bits) {{ 286 // Ensure only bits that can be present in the enum are set 287 return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u)); 288 } 289 inline constexpr bool bitEnumContainsAll({0} bits, {0} bit) {{ 290 return (bits & bit) == bit; 291 } 292 inline constexpr bool bitEnumContainsAny({0} bits, {0} bit) {{ 293 return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0; 294 } 295 inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{ 296 return bits & ~bit; 297 } 298 inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{ 299 return value ? (bits | bit) : bitEnumClear(bits, bit); 300 } 301 )"; 302 os << formatv(operators, enumName, underlyingType, validBits); 303 } 304 305 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) { 306 EnumAttr enumAttr(enumDef); 307 StringRef enumName = enumAttr.getEnumClassName(); 308 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); 309 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); 310 auto enumerants = enumAttr.getAllCases(); 311 312 os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName, 313 symToStrFnRetType); 314 os << " switch (val) {\n"; 315 for (const auto &enumerant : enumerants) { 316 auto symbol = enumerant.getSymbol(); 317 auto str = enumerant.getStr(); 318 os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, 319 makeIdentifier(symbol), str); 320 } 321 os << " }\n"; 322 os << " return \"\";\n"; 323 os << "}\n\n"; 324 } 325 326 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { 327 EnumAttr enumAttr(enumDef); 328 StringRef enumName = enumAttr.getEnumClassName(); 329 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); 330 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); 331 StringRef separator = enumDef.getValueAsString("separator"); 332 auto enumerants = enumAttr.getAllCases(); 333 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); 334 335 os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName, 336 symToStrFnRetType); 337 338 os << formatv(" auto val = static_cast<{0}>(symbol);\n", 339 enumAttr.getUnderlyingType()); 340 // If we have unknown bit set, return an empty string to signal errors. 341 int64_t validBits = enumDef.getValueAsInt("validBits"); 342 os << formatv(" assert({0}u == ({0}u | val) && \"invalid bits set in bit " 343 "enum\");\n", 344 validBits); 345 if (allBitsUnsetCase) { 346 os << " // Special case for all bits unset.\n"; 347 os << formatv(" if (val == 0) return \"{0}\";\n\n", 348 allBitsUnsetCase->getStr()); 349 } 350 os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n"; 351 352 // Add case string if the value has all case bits, and remove them to avoid 353 // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1. 354 const char *const formatCompareRemove = R"( 355 if ({0}u == ({0}u & val)) {{ 356 strs.push_back("{1}"); 357 val &= ~static_cast<{2}>({0}); 358 } 359 )"; 360 // Add case string if the value has all case bits. Used for individual bit 361 // cases, and for groups when printBitEnumPrimaryGroups is 0. 362 const char *const formatCompare = R"( 363 if ({0}u == ({0}u & val)) 364 strs.push_back("{1}"); 365 )"; 366 // Optionally elide bits that are members of groups that will also be printed 367 // for more concise output. 368 if (enumAttr.printBitEnumPrimaryGroups()) { 369 os << " // Print bit enum groups before individual bits\n"; 370 // Emit comparisons for group bit cases in reverse tablegen declaration 371 // order, removing bits for groups with all bits present. 372 for (const auto &enumerant : llvm::reverse(enumerants)) { 373 if ((enumerant.getValue() != 0) && 374 enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) { 375 os << formatv(formatCompareRemove, enumerant.getValue(), 376 enumerant.getStr(), enumAttr.getUnderlyingType()); 377 } 378 } 379 // Emit comparisons for individual bit cases in tablegen declaration order. 380 for (const auto &enumerant : enumerants) { 381 if ((enumerant.getValue() != 0) && 382 enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit")) 383 os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr()); 384 } 385 } else { 386 // Emit comparisons for ALL nonzero cases (individual bits and groups) in 387 // tablegen declaration order. 388 for (const auto &enumerant : enumerants) { 389 if (enumerant.getValue() != 0) 390 os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr()); 391 } 392 } 393 os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator); 394 395 os << "}\n\n"; 396 } 397 398 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) { 399 EnumAttr enumAttr(enumDef); 400 StringRef enumName = enumAttr.getEnumClassName(); 401 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); 402 auto enumerants = enumAttr.getAllCases(); 403 404 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n", 405 enumName, strToSymFnName); 406 os << formatv(" return ::llvm::StringSwitch<::std::optional<{0}>>(str)\n", 407 enumName); 408 for (const auto &enumerant : enumerants) { 409 auto symbol = enumerant.getSymbol(); 410 auto str = enumerant.getStr(); 411 os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str, 412 makeIdentifier(symbol)); 413 } 414 os << " .Default(::std::nullopt);\n"; 415 os << "}\n"; 416 } 417 418 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { 419 EnumAttr enumAttr(enumDef); 420 StringRef enumName = enumAttr.getEnumClassName(); 421 std::string underlyingType = std::string(enumAttr.getUnderlyingType()); 422 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); 423 StringRef separator = enumDef.getValueAsString("separator"); 424 StringRef separatorTrimmed = separator.trim(); 425 auto enumerants = enumAttr.getAllCases(); 426 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); 427 428 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n", 429 enumName, strToSymFnName); 430 431 if (allBitsUnsetCase) { 432 os << " // Special case for all bits unset.\n"; 433 StringRef caseSymbol = allBitsUnsetCase->getSymbol(); 434 os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName, 435 allBitsUnsetCase->getStr(), makeIdentifier(caseSymbol)); 436 } 437 438 // Split the string to get symbols for all the bits. 439 os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n"; 440 // Remove whitespace from the separator string when parsing. 441 os << formatv(" str.split(symbols, \"{0}\");\n\n", separatorTrimmed); 442 443 os << formatv(" {0} val = 0;\n", underlyingType); 444 os << " for (auto symbol : symbols) {\n"; 445 446 // Convert each symbol to the bit ordinal and set the corresponding bit. 447 os << formatv(" auto bit = " 448 "llvm::StringSwitch<::std::optional<{0}>>(symbol.trim())\n", 449 underlyingType); 450 for (const auto &enumerant : enumerants) { 451 // Skip the special enumerant for None. 452 if (auto val = enumerant.getValue()) 453 os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val); 454 } 455 os.indent(6) << ".Default(::std::nullopt);\n"; 456 457 os << " if (bit) { val |= *bit; } else { return ::std::nullopt; }\n"; 458 os << " }\n"; 459 460 os << formatv(" return static_cast<{0}>(val);\n", enumName); 461 os << "}\n\n"; 462 } 463 464 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef, 465 raw_ostream &os) { 466 EnumAttr enumAttr(enumDef); 467 StringRef enumName = enumAttr.getEnumClassName(); 468 std::string underlyingType = std::string(enumAttr.getUnderlyingType()); 469 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); 470 auto enumerants = enumAttr.getAllCases(); 471 472 // Avoid generating the underlying value to symbol conversion function if 473 // there is an enumerant without explicit value. 474 if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) { 475 return enumerant.getValue() < 0; 476 })) 477 return; 478 479 os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName, 480 underlyingToSymFnName, 481 underlyingType.empty() ? std::string("unsigned") 482 : underlyingType) 483 << " switch (value) {\n"; 484 for (const auto &enumerant : enumerants) { 485 auto symbol = enumerant.getSymbol(); 486 auto value = enumerant.getValue(); 487 os << formatv(" case {0}: return {1}::{2};\n", value, enumName, 488 makeIdentifier(symbol)); 489 } 490 os << " default: return ::std::nullopt;\n" 491 << " }\n" 492 << "}\n\n"; 493 } 494 495 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) { 496 EnumAttr enumAttr(enumDef); 497 StringRef enumName = enumAttr.getEnumClassName(); 498 StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); 499 const Record *baseAttrDef = enumAttr.getBaseAttrClass(); 500 Attribute baseAttr(baseAttrDef); 501 502 // Emit classof method 503 504 os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n", 505 attrClassName); 506 507 mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate(); 508 if (baseAttrPred.isNull()) 509 PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n"); 510 511 std::string condition = baseAttrPred.getCondition(); 512 FmtContext verifyCtx; 513 verifyCtx.withSelf("attr"); 514 os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx)); 515 516 os << "}\n"; 517 518 // Emit get method 519 520 os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n", 521 attrClassName, enumName); 522 523 StringRef underlyingType = enumAttr.getUnderlyingType(); 524 525 // Assuming that it is IntegerAttr constraint 526 int64_t bitwidth = 64; 527 if (baseAttrDef->getValue("valueType")) { 528 auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType"); 529 if (valueTypeDef->getValue("bitwidth")) 530 bitwidth = valueTypeDef->getValueAsInt("bitwidth"); 531 } 532 533 os << formatv(" ::mlir::IntegerType intType = " 534 "::mlir::IntegerType::get(context, {0});\n", 535 bitwidth); 536 os << formatv(" ::mlir::IntegerAttr baseAttr = " 537 "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n", 538 underlyingType); 539 os << formatv(" return ::llvm::cast<{0}>(baseAttr);\n", attrClassName); 540 541 os << "}\n"; 542 543 // Emit getValue method 544 545 os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName); 546 547 os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n", 548 enumName); 549 550 os << "}\n"; 551 } 552 553 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, 554 raw_ostream &os) { 555 EnumAttr enumAttr(enumDef); 556 StringRef enumName = enumAttr.getEnumClassName(); 557 std::string underlyingType = std::string(enumAttr.getUnderlyingType()); 558 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); 559 auto enumerants = enumAttr.getAllCases(); 560 auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants); 561 562 os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName, 563 underlyingToSymFnName, underlyingType); 564 if (allBitsUnsetCase) { 565 os << " // Special case for all bits unset.\n"; 566 os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName, 567 makeIdentifier(allBitsUnsetCase->getSymbol())); 568 } 569 int64_t validBits = enumDef.getValueAsInt("validBits"); 570 os << formatv(" if (value & ~static_cast<{0}>({1}u)) return std::nullopt;\n", 571 underlyingType, validBits); 572 os << formatv(" return static_cast<{0}>(value);\n", enumName); 573 os << "}\n"; 574 } 575 576 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { 577 EnumAttr enumAttr(enumDef); 578 StringRef enumName = enumAttr.getEnumClassName(); 579 StringRef cppNamespace = enumAttr.getCppNamespace(); 580 std::string underlyingType = std::string(enumAttr.getUnderlyingType()); 581 StringRef description = enumAttr.getSummary(); 582 StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); 583 StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); 584 StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); 585 StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); 586 auto enumerants = enumAttr.getAllCases(); 587 588 SmallVector<StringRef, 2> namespaces; 589 llvm::SplitString(cppNamespace, namespaces, "::"); 590 591 for (auto ns : namespaces) 592 os << "namespace " << ns << " {\n"; 593 594 // Emit the enum class definition 595 emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); 596 597 // Emit conversion function declarations 598 if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) { 599 return enumerant.getValue() >= 0; 600 })) { 601 os << formatv( 602 "::std::optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, 603 underlyingType.empty() ? std::string("unsigned") : underlyingType); 604 } 605 os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType); 606 os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName, 607 strToSymFnName); 608 609 if (enumAttr.isBitEnum()) { 610 emitOperators(enumDef, os); 611 } else { 612 emitMaxValueFn(enumDef, os); 613 } 614 615 // Generate a generic `stringifyEnum` function that forwards to the method 616 // specified by the user. 617 const char *const stringifyEnumStr = R"( 618 inline {0} stringifyEnum({1} enumValue) {{ 619 return {2}(enumValue); 620 } 621 )"; 622 os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName); 623 624 // Generate a generic `symbolizeEnum` function that forwards to the method 625 // specified by the user. 626 const char *const symbolizeEnumStr = R"( 627 template <typename EnumType> 628 ::std::optional<EnumType> symbolizeEnum(::llvm::StringRef); 629 630 template <> 631 inline ::std::optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) { 632 return {1}(str); 633 } 634 )"; 635 os << formatv(symbolizeEnumStr, enumName, strToSymFnName); 636 637 const char *const attrClassDecl = R"( 638 class {1} : public ::mlir::{2} { 639 public: 640 using ValueType = {0}; 641 using ::mlir::{2}::{2}; 642 static bool classof(::mlir::Attribute attr); 643 static {1} get(::mlir::MLIRContext *context, {0} val); 644 {0} getValue() const; 645 }; 646 )"; 647 if (enumAttr.genSpecializedAttr()) { 648 StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); 649 StringRef baseAttrClassName = "IntegerAttr"; 650 os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); 651 } 652 653 for (auto ns : llvm::reverse(namespaces)) 654 os << "} // namespace " << ns << "\n"; 655 656 // Generate a generic parser and printer for the enum. 657 std::string qualName = 658 std::string(formatv("{0}::{1}", cppNamespace, enumName)); 659 emitParserPrinter(enumAttr, qualName, cppNamespace, os); 660 661 // Emit DenseMapInfo for this enum class 662 emitDenseMapInfo(qualName, underlyingType, cppNamespace, os); 663 } 664 665 static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { 666 llvm::emitSourceFileHeader("Enum Utility Declarations", os, records); 667 668 for (const Record *def : 669 records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) 670 emitEnumDecl(*def, os); 671 672 return false; 673 } 674 675 static void emitEnumDef(const Record &enumDef, raw_ostream &os) { 676 EnumAttr enumAttr(enumDef); 677 StringRef cppNamespace = enumAttr.getCppNamespace(); 678 679 SmallVector<StringRef, 2> namespaces; 680 llvm::SplitString(cppNamespace, namespaces, "::"); 681 682 for (auto ns : namespaces) 683 os << "namespace " << ns << " {\n"; 684 685 if (enumAttr.isBitEnum()) { 686 emitSymToStrFnForBitEnum(enumDef, os); 687 emitStrToSymFnForBitEnum(enumDef, os); 688 emitUnderlyingToSymFnForBitEnum(enumDef, os); 689 } else { 690 emitSymToStrFnForIntEnum(enumDef, os); 691 emitStrToSymFnForIntEnum(enumDef, os); 692 emitUnderlyingToSymFnForIntEnum(enumDef, os); 693 } 694 695 if (enumAttr.genSpecializedAttr()) 696 emitSpecializedAttrDef(enumDef, os); 697 698 for (auto ns : llvm::reverse(namespaces)) 699 os << "} // namespace " << ns << "\n"; 700 os << "\n"; 701 } 702 703 static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { 704 llvm::emitSourceFileHeader("Enum Utility Definitions", os, records); 705 706 for (const Record *def : 707 records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) 708 emitEnumDef(*def, os); 709 710 return false; 711 } 712 713 // Registers the enum utility generator to mlir-tblgen. 714 static mlir::GenRegistration 715 genEnumDecls("gen-enum-decls", "Generate enum utility declarations", 716 [](const RecordKeeper &records, raw_ostream &os) { 717 return emitEnumDecls(records, os); 718 }); 719 720 // Registers the enum utility generator to mlir-tblgen. 721 static mlir::GenRegistration 722 genEnumDefs("gen-enum-defs", "Generate enum utility definitions", 723 [](const RecordKeeper &records, raw_ostream &os) { 724 return emitEnumDefs(records, os); 725 }); 726