1 //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "AttrOrTypeFormatGen.h" 10 #include "mlir/TableGen/AttrOrTypeDef.h" 11 #include "mlir/TableGen/Class.h" 12 #include "mlir/TableGen/CodeGenHelpers.h" 13 #include "mlir/TableGen/Format.h" 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Interfaces.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/TableGenBackend.h" 20 21 #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen" 22 23 using namespace mlir; 24 using namespace mlir::tblgen; 25 using llvm::Record; 26 using llvm::RecordKeeper; 27 28 //===----------------------------------------------------------------------===// 29 // Utility Functions 30 //===----------------------------------------------------------------------===// 31 32 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect 33 /// specified and can only find one dialect's defs, use that. 34 static void collectAllDefs(StringRef selectedDialect, 35 ArrayRef<const Record *> records, 36 SmallVectorImpl<AttrOrTypeDef> &resultDefs) { 37 // Nothing to do if no defs were found. 38 if (records.empty()) 39 return; 40 41 auto defs = llvm::map_range( 42 records, [&](const Record *rec) { return AttrOrTypeDef(rec); }); 43 if (selectedDialect.empty()) { 44 // If a dialect was not specified, ensure that all found defs belong to the 45 // same dialect. 46 if (!llvm::all_equal(llvm::map_range( 47 defs, [](const auto &def) { return def.getDialect(); }))) { 48 llvm::PrintFatalError("defs belonging to more than one dialect. Must " 49 "select one via '--(attr|type)defs-dialect'"); 50 } 51 resultDefs.assign(defs.begin(), defs.end()); 52 } else { 53 // Otherwise, generate the defs that belong to the selected dialect. 54 auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) { 55 return def.getDialect().getName() == selectedDialect; 56 }); 57 resultDefs.assign(dialectDefs.begin(), dialectDefs.end()); 58 } 59 } 60 61 //===----------------------------------------------------------------------===// 62 // DefGen 63 //===----------------------------------------------------------------------===// 64 65 namespace { 66 class DefGen { 67 public: 68 /// Create the attribute or type class. 69 DefGen(const AttrOrTypeDef &def); 70 71 void emitDecl(raw_ostream &os) const { 72 if (storageCls && def.genStorageClass()) { 73 NamespaceEmitter ns(os, def.getStorageNamespace()); 74 os << "struct " << def.getStorageClassName() << ";\n"; 75 } 76 defCls.writeDeclTo(os); 77 } 78 void emitDef(raw_ostream &os) const { 79 if (storageCls && def.genStorageClass()) { 80 NamespaceEmitter ns(os, def.getStorageNamespace()); 81 storageCls->writeDeclTo(os); // everything is inline 82 } 83 defCls.writeDefTo(os); 84 } 85 86 private: 87 /// Add traits from the TableGen definition to the class. 88 void createParentWithTraits(); 89 /// Emit top-level declarations: using declarations and any extra class 90 /// declarations. 91 void emitTopLevelDeclarations(); 92 /// Emit the function that returns the type or attribute name. 93 void emitName(); 94 /// Emit the dialect name as a static member variable. 95 void emitDialectName(); 96 /// Emit attribute or type builders. 97 void emitBuilders(); 98 /// Emit a verifier declaration for custom verification (impl. provided by 99 /// the users). 100 void emitVerifierDecl(); 101 /// Emit a verifier that checks type constraints. 102 void emitInvariantsVerifierImpl(); 103 /// Emit an entry poiunt for verification that calls the invariants and 104 /// custom verifier. 105 void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier); 106 /// Emit parsers and printers. 107 void emitParserPrinter(); 108 /// Emit parameter accessors, if required. 109 void emitAccessors(); 110 /// Emit interface methods. 111 void emitInterfaceMethods(); 112 113 //===--------------------------------------------------------------------===// 114 // Builder Emission 115 116 /// Emit the default builder `Attribute::get` 117 void emitDefaultBuilder(); 118 /// Emit the checked builder `Attribute::getChecked` 119 void emitCheckedBuilder(); 120 /// Emit a custom builder. 121 void emitCustomBuilder(const AttrOrTypeBuilder &builder); 122 /// Emit a checked custom builder. 123 void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder); 124 125 //===--------------------------------------------------------------------===// 126 // Interface Method Emission 127 128 /// Emit methods for a trait. 129 void emitTraitMethods(const InterfaceTrait &trait); 130 /// Emit a trait method. 131 void emitTraitMethod(const InterfaceMethod &method); 132 133 //===--------------------------------------------------------------------===// 134 // Storage Class Emission 135 void emitStorageClass(); 136 /// Generate the storage class constructor. 137 void emitStorageConstructor(); 138 /// Emit the key type `KeyTy`. 139 void emitKeyType(); 140 /// Emit the equality comparison operator. 141 void emitEquals(); 142 /// Emit the key hash function. 143 void emitHashKey(); 144 /// Emit the function to construct the storage class. 145 void emitConstruct(); 146 147 //===--------------------------------------------------------------------===// 148 // Utility Function Declarations 149 150 /// Get the method parameters for a def builder, where the first several 151 /// parameters may be different. 152 SmallVector<MethodParameter> 153 getBuilderParams(std::initializer_list<MethodParameter> prefix) const; 154 155 //===--------------------------------------------------------------------===// 156 // Class fields 157 158 /// The attribute or type definition. 159 const AttrOrTypeDef &def; 160 /// The list of attribute or type parameters. 161 ArrayRef<AttrOrTypeParameter> params; 162 /// The attribute or type class. 163 Class defCls; 164 /// An optional attribute or type storage class. The storage class will 165 /// exist if and only if the def has more than zero parameters. 166 std::optional<Class> storageCls; 167 168 /// The C++ base value of the def, either "Attribute" or "Type". 169 StringRef valueType; 170 /// The prefix/suffix of the TableGen def name, either "Attr" or "Type". 171 StringRef defType; 172 }; 173 } // namespace 174 175 DefGen::DefGen(const AttrOrTypeDef &def) 176 : def(def), params(def.getParameters()), defCls(def.getCppClassName()), 177 valueType(isa<AttrDef>(def) ? "Attribute" : "Type"), 178 defType(isa<AttrDef>(def) ? "Attr" : "Type") { 179 // Check that all parameters have names. 180 for (const AttrOrTypeParameter ¶m : def.getParameters()) 181 if (param.isAnonymous()) 182 llvm::PrintFatalError("all parameters must have a name"); 183 184 // If a storage class is needed, create one. 185 if (def.getNumParameters() > 0) 186 storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true); 187 188 // Create the parent class with any indicated traits. 189 createParentWithTraits(); 190 // Emit top-level declarations. 191 emitTopLevelDeclarations(); 192 // Emit builders for defs with parameters 193 if (storageCls) 194 emitBuilders(); 195 // Emit the type name. 196 emitName(); 197 // Emit the dialect name. 198 emitDialectName(); 199 // Emit verification of type constraints. 200 bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl(); 201 if (storageCls && genVerifyInvariantsImpl) 202 emitInvariantsVerifierImpl(); 203 // Emit the custom verifier (written by the user). 204 bool genVerifyDecl = def.genVerifyDecl(); 205 if (storageCls && genVerifyDecl) 206 emitVerifierDecl(); 207 // Emit the "verifyInvariants" function if there is any verification at all. 208 if (storageCls) 209 emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl); 210 // Emit the mnemonic, if there is one, and any associated parser and printer. 211 if (def.getMnemonic()) 212 emitParserPrinter(); 213 // Emit accessors 214 if (def.genAccessors()) 215 emitAccessors(); 216 // Emit trait interface methods 217 emitInterfaceMethods(); 218 defCls.finalize(); 219 // Emit a storage class if one is needed 220 if (storageCls && def.genStorageClass()) 221 emitStorageClass(); 222 } 223 224 void DefGen::createParentWithTraits() { 225 ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType)); 226 defParent.addTemplateParam(def.getCppClassName()); 227 defParent.addTemplateParam(def.getCppBaseClassName()); 228 defParent.addTemplateParam(storageCls 229 ? strfmt("{0}::{1}", def.getStorageNamespace(), 230 def.getStorageClassName()) 231 : strfmt("::mlir::{0}Storage", valueType)); 232 for (auto &trait : def.getTraits()) { 233 defParent.addTemplateParam( 234 isa<NativeTrait>(&trait) 235 ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName() 236 : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName()); 237 } 238 defCls.addParent(std::move(defParent)); 239 } 240 241 /// Include declarations specified on NativeTrait 242 static std::string formatExtraDeclarations(const AttrOrTypeDef &def) { 243 SmallVector<StringRef> extraDeclarations; 244 // Include extra class declarations from NativeTrait 245 for (const auto &trait : def.getTraits()) { 246 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 247 StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration(); 248 if (value.empty()) 249 continue; 250 extraDeclarations.push_back(value); 251 } 252 } 253 if (std::optional<StringRef> extraDecl = def.getExtraDecls()) { 254 extraDeclarations.push_back(*extraDecl); 255 } 256 return llvm::join(extraDeclarations, "\n"); 257 } 258 259 /// Extra class definitions have a `$cppClass` substitution that is to be 260 /// replaced by the C++ class name. 261 static std::string formatExtraDefinitions(const AttrOrTypeDef &def) { 262 SmallVector<StringRef> extraDefinitions; 263 // Include extra class definitions from NativeTrait 264 for (const auto &trait : def.getTraits()) { 265 if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 266 StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition(); 267 if (value.empty()) 268 continue; 269 extraDefinitions.push_back(value); 270 } 271 } 272 if (std::optional<StringRef> extraDef = def.getExtraDefs()) { 273 extraDefinitions.push_back(*extraDef); 274 } 275 FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); 276 return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str(); 277 } 278 279 void DefGen::emitTopLevelDeclarations() { 280 // Inherit constructors from the attribute or type class. 281 defCls.declare<VisibilityDeclaration>(Visibility::Public); 282 defCls.declare<UsingDeclaration>("Base::Base"); 283 284 // Emit the extra declarations first in case there's a definition in there. 285 std::string extraDecl = formatExtraDeclarations(def); 286 std::string extraDef = formatExtraDefinitions(def); 287 defCls.declare<ExtraClassDeclaration>(std::move(extraDecl), 288 std::move(extraDef)); 289 } 290 291 void DefGen::emitName() { 292 StringRef name; 293 if (auto *attrDef = dyn_cast<AttrDef>(&def)) { 294 name = attrDef->getAttrName(); 295 } else { 296 auto *typeDef = cast<TypeDef>(&def); 297 name = typeDef->getTypeName(); 298 } 299 std::string nameDecl = 300 strfmt("static constexpr ::llvm::StringLiteral name = \"{0}\";\n", name); 301 defCls.declare<ExtraClassDeclaration>(std::move(nameDecl)); 302 } 303 304 void DefGen::emitDialectName() { 305 std::string decl = 306 strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n", 307 def.getDialect().getName()); 308 defCls.declare<ExtraClassDeclaration>(std::move(decl)); 309 } 310 311 void DefGen::emitBuilders() { 312 if (!def.skipDefaultBuilders()) { 313 emitDefaultBuilder(); 314 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) 315 emitCheckedBuilder(); 316 } 317 for (auto &builder : def.getBuilders()) { 318 emitCustomBuilder(builder); 319 if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) 320 emitCheckedCustomBuilder(builder); 321 } 322 } 323 324 void DefGen::emitVerifierDecl() { 325 defCls.declareStaticMethod( 326 "::llvm::LogicalResult", "verify", 327 getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", 328 "emitError"}})); 329 } 330 331 static const char *const patternParameterVerificationCode = R"( 332 if (!({0})) { 333 emitError() << "failed to verify '{1}': {2}"; 334 return ::mlir::failure(); 335 } 336 )"; 337 338 void DefGen::emitInvariantsVerifierImpl() { 339 SmallVector<MethodParameter> builderParams = getBuilderParams( 340 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}); 341 Method *verifier = 342 defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl", 343 Method::Static, builderParams); 344 verifier->body().indent(); 345 346 // Generate verification for each parameter that is a type constraint. 347 for (auto it : llvm::enumerate(def.getParameters())) { 348 const AttrOrTypeParameter ¶m = it.value(); 349 std::optional<Constraint> constraint = param.getConstraint(); 350 // No verification needed for parameters that are not type constraints. 351 if (!constraint.has_value()) 352 continue; 353 FmtContext ctx; 354 // Note: Skip over the first method parameter (`emitError`). 355 ctx.withSelf(builderParams[it.index() + 1].getName()); 356 std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx); 357 verifier->body() << formatv(patternParameterVerificationCode, condition, 358 param.getName(), constraint->getSummary()) 359 << "\n"; 360 } 361 verifier->body() << "return ::mlir::success();"; 362 } 363 364 void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) { 365 if (!hasImpl && !hasCustomVerifier) 366 return; 367 defCls.declare<UsingDeclaration>("Base::getChecked"); 368 SmallVector<MethodParameter> builderParams = getBuilderParams( 369 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}); 370 Method *verifier = 371 defCls.addMethod("::llvm::LogicalResult", "verifyInvariants", 372 Method::Static, builderParams); 373 verifier->body().indent(); 374 375 auto emitVerifierCall = [&](StringRef name) { 376 verifier->body() << strfmt("if (::mlir::failed({0}(", name); 377 llvm::interleaveComma( 378 llvm::map_range(builderParams, 379 [](auto ¶m) { return param.getName(); }), 380 verifier->body()); 381 verifier->body() << ")))\n"; 382 verifier->body() << " return ::mlir::failure();\n"; 383 }; 384 385 if (hasImpl) { 386 // Call the verifier that checks the type constraints. 387 emitVerifierCall("verifyInvariantsImpl"); 388 } 389 if (hasCustomVerifier) { 390 // Call the custom verifier that is provided by the user. 391 emitVerifierCall("verify"); 392 } 393 verifier->body() << "return ::mlir::success();"; 394 } 395 396 void DefGen::emitParserPrinter() { 397 auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>( 398 "::llvm::StringLiteral", "getMnemonic"); 399 mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic()); 400 401 // Declare the parser and printer, if needed. 402 bool hasAssemblyFormat = def.getAssemblyFormat().has_value(); 403 if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat) 404 return; 405 406 // Declare the parser. 407 SmallVector<MethodParameter> parserParams; 408 parserParams.emplace_back("::mlir::AsmParser &", "odsParser"); 409 if (isa<AttrDef>(&def)) 410 parserParams.emplace_back("::mlir::Type", "odsType"); 411 auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse", 412 hasAssemblyFormat ? Method::Static 413 : Method::StaticDeclaration, 414 std::move(parserParams)); 415 // Declare the printer. 416 auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration; 417 Method *printer = 418 defCls.addMethod("void", "print", props, 419 MethodParameter("::mlir::AsmPrinter &", "odsPrinter")); 420 // Emit the bodies if we are using the declarative format. 421 if (hasAssemblyFormat) 422 return generateAttrOrTypeFormat(def, parser->body(), printer->body()); 423 } 424 425 void DefGen::emitAccessors() { 426 for (auto ¶m : params) { 427 Method *m = defCls.addMethod( 428 param.getCppAccessorType(), param.getAccessorName(), 429 def.genStorageClass() ? Method::Const : Method::ConstDeclaration); 430 // Generate accessor definitions only if we also generate the storage 431 // class. Otherwise, let the user define the exact accessor definition. 432 if (!def.genStorageClass()) 433 continue; 434 m->body().indent() << "return getImpl()->" << param.getName() << ";"; 435 } 436 } 437 438 void DefGen::emitInterfaceMethods() { 439 for (auto &traitDef : def.getTraits()) 440 if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef)) 441 if (trait->shouldDeclareMethods()) 442 emitTraitMethods(*trait); 443 } 444 445 //===----------------------------------------------------------------------===// 446 // Builder Emission 447 448 SmallVector<MethodParameter> 449 DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const { 450 SmallVector<MethodParameter> builderParams; 451 builderParams.append(prefix.begin(), prefix.end()); 452 for (auto ¶m : params) 453 builderParams.emplace_back(param.getCppType(), param.getName()); 454 return builderParams; 455 } 456 457 void DefGen::emitDefaultBuilder() { 458 Method *m = defCls.addStaticMethod( 459 def.getCppClassName(), "get", 460 getBuilderParams({{"::mlir::MLIRContext *", "context"}})); 461 MethodBody &body = m->body().indent(); 462 auto scope = body.scope("return Base::get(context", ");"); 463 for (const auto ¶m : params) 464 body << ", std::move(" << param.getName() << ")"; 465 } 466 467 void DefGen::emitCheckedBuilder() { 468 Method *m = defCls.addStaticMethod( 469 def.getCppClassName(), "getChecked", 470 getBuilderParams( 471 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}, 472 {"::mlir::MLIRContext *", "context"}})); 473 MethodBody &body = m->body().indent(); 474 auto scope = body.scope("return Base::getChecked(emitError, context", ");"); 475 for (const auto ¶m : params) 476 body << ", " << param.getName(); 477 } 478 479 static SmallVector<MethodParameter> 480 getCustomBuilderParams(std::initializer_list<MethodParameter> prefix, 481 const AttrOrTypeBuilder &builder) { 482 auto params = builder.getParameters(); 483 SmallVector<MethodParameter> builderParams; 484 builderParams.append(prefix.begin(), prefix.end()); 485 if (!builder.hasInferredContextParameter()) 486 builderParams.emplace_back("::mlir::MLIRContext *", "context"); 487 for (auto ¶m : params) { 488 builderParams.emplace_back(param.getCppType(), *param.getName(), 489 param.getDefaultValue()); 490 } 491 return builderParams; 492 } 493 494 void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) { 495 // Don't emit a body if there isn't one. 496 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; 497 StringRef returnType = def.getCppClassName(); 498 if (std::optional<StringRef> builderReturnType = builder.getReturnType()) 499 returnType = *builderReturnType; 500 Method *m = defCls.addMethod(returnType, "get", props, 501 getCustomBuilderParams({}, builder)); 502 if (!builder.getBody()) 503 return; 504 505 // Format the body and emit it. 506 FmtContext ctx; 507 ctx.addSubst("_get", "Base::get"); 508 if (!builder.hasInferredContextParameter()) 509 ctx.addSubst("_ctxt", "context"); 510 std::string bodyStr = tgfmt(*builder.getBody(), &ctx); 511 m->body().indent().getStream().printReindented(bodyStr); 512 } 513 514 /// Replace all instances of 'from' to 'to' in `str` and return the new string. 515 static std::string replaceInStr(std::string str, StringRef from, StringRef to) { 516 size_t pos = 0; 517 while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) 518 str.replace(pos, from.size(), to.data(), to.size()); 519 return str; 520 } 521 522 void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) { 523 // Don't emit a body if there isn't one. 524 auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; 525 StringRef returnType = def.getCppClassName(); 526 if (std::optional<StringRef> builderReturnType = builder.getReturnType()) 527 returnType = *builderReturnType; 528 Method *m = defCls.addMethod( 529 returnType, "getChecked", props, 530 getCustomBuilderParams( 531 {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}, 532 builder)); 533 if (!builder.getBody()) 534 return; 535 536 // Format the body and emit it. Replace $_get(...) with 537 // Base::getChecked(emitError, ...) 538 FmtContext ctx; 539 if (!builder.hasInferredContextParameter()) 540 ctx.addSubst("_ctxt", "context"); 541 std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(", 542 "Base::getChecked(emitError, "); 543 bodyStr = tgfmt(bodyStr, &ctx); 544 m->body().indent().getStream().printReindented(bodyStr); 545 } 546 547 //===----------------------------------------------------------------------===// 548 // Interface Method Emission 549 550 void DefGen::emitTraitMethods(const InterfaceTrait &trait) { 551 // Get the set of methods that should always be declared. 552 auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods(); 553 StringSet<> alwaysDeclared; 554 alwaysDeclared.insert(alwaysDeclaredMethods.begin(), 555 alwaysDeclaredMethods.end()); 556 557 Interface iface = trait.getInterface(); // causes strange bugs if elided 558 for (auto &method : iface.getMethods()) { 559 // Don't declare if the method has a body. Or if the method has a default 560 // implementation and the def didn't request that it always be declared. 561 if (method.getBody() || (method.getDefaultImplementation() && 562 !alwaysDeclared.count(method.getName()))) 563 continue; 564 emitTraitMethod(method); 565 } 566 } 567 568 void DefGen::emitTraitMethod(const InterfaceMethod &method) { 569 // All interface methods are declaration-only. 570 auto props = 571 method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration; 572 SmallVector<MethodParameter> params; 573 for (auto ¶m : method.getArguments()) 574 params.emplace_back(param.type, param.name); 575 defCls.addMethod(method.getReturnType(), method.getName(), props, 576 std::move(params)); 577 } 578 579 //===----------------------------------------------------------------------===// 580 // Storage Class Emission 581 582 void DefGen::emitStorageConstructor() { 583 Constructor *ctor = 584 storageCls->addConstructor<Method::Inline>(getBuilderParams({})); 585 for (auto ¶m : params) { 586 std::string movedValue = ("std::move(" + param.getName() + ")").str(); 587 ctor->addMemberInitializer(param.getName(), movedValue); 588 } 589 } 590 591 void DefGen::emitKeyType() { 592 std::string keyType("std::tuple<"); 593 llvm::raw_string_ostream os(keyType); 594 llvm::interleaveComma(params, os, 595 [&](auto ¶m) { os << param.getCppType(); }); 596 os << '>'; 597 storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str())); 598 599 // Add a method to construct the key type from the storage. 600 Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey"); 601 m->body().indent() << "return KeyTy("; 602 llvm::interleaveComma(params, m->body().indent(), 603 [&](auto ¶m) { m->body() << param.getName(); }); 604 m->body() << ");"; 605 } 606 607 void DefGen::emitEquals() { 608 Method *eq = storageCls->addConstMethod<Method::Inline>( 609 "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey")); 610 auto &body = eq->body().indent(); 611 auto scope = body.scope("return (", ");"); 612 const auto eachFn = [&](auto it) { 613 FmtContext ctx({{"_lhs", it.value().getName()}, 614 {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}}); 615 body << tgfmt(it.value().getComparator(), &ctx); 616 }; 617 llvm::interleave(llvm::enumerate(params), body, eachFn, ") && ("); 618 } 619 620 void DefGen::emitHashKey() { 621 Method *hash = storageCls->addStaticInlineMethod( 622 "::llvm::hash_code", "hashKey", 623 MethodParameter("const KeyTy &", "tblgenKey")); 624 auto &body = hash->body().indent(); 625 auto scope = body.scope("return ::llvm::hash_combine(", ");"); 626 llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) { 627 body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index()); 628 }); 629 } 630 631 void DefGen::emitConstruct() { 632 Method *construct = storageCls->addMethod<Method::Inline>( 633 strfmt("{0} *", def.getStorageClassName()), "construct", 634 def.hasStorageCustomConstructor() ? Method::StaticDeclaration 635 : Method::Static, 636 MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType), 637 "allocator"), 638 MethodParameter("KeyTy &&", "tblgenKey")); 639 if (!def.hasStorageCustomConstructor()) { 640 auto &body = construct->body().indent(); 641 for (const auto &it : llvm::enumerate(params)) { 642 body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n", 643 it.value().getName(), it.index()); 644 } 645 // Use the parameters' custom allocator code, if provided. 646 FmtContext ctx = FmtContext().addSubst("_allocator", "allocator"); 647 for (auto ¶m : params) { 648 if (std::optional<StringRef> allocCode = param.getAllocator()) { 649 ctx.withSelf(param.getName()).addSubst("_dst", param.getName()); 650 body << tgfmt(*allocCode, &ctx) << '\n'; 651 } 652 } 653 auto scope = 654 body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(", 655 def.getStorageClassName()), 656 ");"); 657 llvm::interleaveComma(params, body, [&](auto ¶m) { 658 body << "std::move(" << param.getName() << ")"; 659 }); 660 } 661 } 662 663 void DefGen::emitStorageClass() { 664 // Add the appropriate parent class. 665 storageCls->addParent(strfmt("::mlir::{0}Storage", valueType)); 666 // Add the constructor. 667 emitStorageConstructor(); 668 // Declare the key type. 669 emitKeyType(); 670 // Add the comparison method. 671 emitEquals(); 672 // Emit the key hash method. 673 emitHashKey(); 674 // Emit the storage constructor. Just declare it if the user wants to define 675 // it themself. 676 emitConstruct(); 677 // Emit the storage class members as public, at the very end of the struct. 678 storageCls->finalize(); 679 for (auto ¶m : params) 680 storageCls->declare<Field>(param.getCppType(), param.getName()); 681 } 682 683 //===----------------------------------------------------------------------===// 684 // DefGenerator 685 //===----------------------------------------------------------------------===// 686 687 namespace { 688 /// This struct is the base generator used when processing tablegen interfaces. 689 class DefGenerator { 690 public: 691 bool emitDecls(StringRef selectedDialect); 692 bool emitDefs(StringRef selectedDialect); 693 694 protected: 695 DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os, 696 StringRef defType, StringRef valueType, bool isAttrGenerator) 697 : defRecords(defs), os(os), defType(defType), valueType(valueType), 698 isAttrGenerator(isAttrGenerator) { 699 // Sort by occurrence in file. 700 llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) { 701 return lhs->getID() < rhs->getID(); 702 }); 703 } 704 705 /// Emit the list of def type names. 706 void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs); 707 /// Emit the code to dispatch between different defs during parsing/printing. 708 void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs); 709 710 /// The set of def records to emit. 711 std::vector<const Record *> defRecords; 712 /// The attribute or type class to emit. 713 /// The stream to emit to. 714 raw_ostream &os; 715 /// The prefix of the tablegen def name, e.g. Attr or Type. 716 StringRef defType; 717 /// The C++ base value type of the def, e.g. Attribute or Type. 718 StringRef valueType; 719 /// Flag indicating if this generator is for Attributes. False if the 720 /// generator is for types. 721 bool isAttrGenerator; 722 }; 723 724 /// A specialized generator for AttrDefs. 725 struct AttrDefGenerator : public DefGenerator { 726 AttrDefGenerator(const RecordKeeper &records, raw_ostream &os) 727 : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os, 728 "Attr", "Attribute", /*isAttrGenerator=*/true) {} 729 }; 730 /// A specialized generator for TypeDefs. 731 struct TypeDefGenerator : public DefGenerator { 732 TypeDefGenerator(const RecordKeeper &records, raw_ostream &os) 733 : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os, 734 "Type", "Type", /*isAttrGenerator=*/false) {} 735 }; 736 } // namespace 737 738 //===----------------------------------------------------------------------===// 739 // GEN: Declarations 740 //===----------------------------------------------------------------------===// 741 742 /// Print this above all the other declarations. Contains type declarations used 743 /// later on. 744 static const char *const typeDefDeclHeader = R"( 745 namespace mlir { 746 class AsmParser; 747 class AsmPrinter; 748 } // namespace mlir 749 )"; 750 751 bool DefGenerator::emitDecls(StringRef selectedDialect) { 752 emitSourceFileHeader((defType + "Def Declarations").str(), os); 753 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os); 754 755 // Output the common "header". 756 os << typeDefDeclHeader; 757 758 SmallVector<AttrOrTypeDef, 16> defs; 759 collectAllDefs(selectedDialect, defRecords, defs); 760 if (defs.empty()) 761 return false; 762 { 763 NamespaceEmitter nsEmitter(os, defs.front().getDialect()); 764 765 // Declare all the def classes first (in case they reference each other). 766 for (const AttrOrTypeDef &def : defs) 767 os << "class " << def.getCppClassName() << ";\n"; 768 769 // Emit the declarations. 770 for (const AttrOrTypeDef &def : defs) 771 DefGen(def).emitDecl(os); 772 } 773 // Emit the TypeID explicit specializations to have a single definition for 774 // each of these. 775 for (const AttrOrTypeDef &def : defs) 776 if (!def.getDialect().getCppNamespace().empty()) 777 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" 778 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName() 779 << ")\n"; 780 781 return false; 782 } 783 784 //===----------------------------------------------------------------------===// 785 // GEN: Def List 786 //===----------------------------------------------------------------------===// 787 788 void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) { 789 IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os); 790 auto interleaveFn = [&](const AttrOrTypeDef &def) { 791 os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); 792 }; 793 llvm::interleave(defs, os, interleaveFn, ",\n"); 794 os << "\n"; 795 } 796 797 //===----------------------------------------------------------------------===// 798 // GEN: Definitions 799 //===----------------------------------------------------------------------===// 800 801 /// The code block for default attribute parser/printer dispatch boilerplate. 802 /// {0}: the dialect fully qualified class name. 803 /// {1}: the optional code for the dynamic attribute parser dispatch. 804 /// {2}: the optional code for the dynamic attribute printer dispatch. 805 static const char *const dialectDefaultAttrPrinterParserDispatch = R"( 806 /// Parse an attribute registered to this dialect. 807 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, 808 ::mlir::Type type) const {{ 809 ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); 810 ::llvm::StringRef attrTag; 811 {{ 812 ::mlir::Attribute attr; 813 auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr); 814 if (parseResult.has_value()) 815 return attr; 816 } 817 {1} 818 parser.emitError(typeLoc) << "unknown attribute `" 819 << attrTag << "` in dialect `" << getNamespace() << "`"; 820 return {{}; 821 } 822 /// Print an attribute registered to this dialect. 823 void {0}::printAttribute(::mlir::Attribute attr, 824 ::mlir::DialectAsmPrinter &printer) const {{ 825 if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) 826 return; 827 {2} 828 } 829 )"; 830 831 /// The code block for dynamic attribute parser dispatch boilerplate. 832 static const char *const dialectDynamicAttrParserDispatch = R"( 833 { 834 ::mlir::Attribute genAttr; 835 auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr); 836 if (parseResult.has_value()) { 837 if (::mlir::succeeded(parseResult.value())) 838 return genAttr; 839 return Attribute(); 840 } 841 } 842 )"; 843 844 /// The code block for dynamic type printer dispatch boilerplate. 845 static const char *const dialectDynamicAttrPrinterDispatch = R"( 846 if (::mlir::succeeded(printIfDynamicAttr(attr, printer))) 847 return; 848 )"; 849 850 /// The code block for default type parser/printer dispatch boilerplate. 851 /// {0}: the dialect fully qualified class name. 852 /// {1}: the optional code for the dynamic type parser dispatch. 853 /// {2}: the optional code for the dynamic type printer dispatch. 854 static const char *const dialectDefaultTypePrinterParserDispatch = R"( 855 /// Parse a type registered to this dialect. 856 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ 857 ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); 858 ::llvm::StringRef mnemonic; 859 ::mlir::Type genType; 860 auto parseResult = generatedTypeParser(parser, &mnemonic, genType); 861 if (parseResult.has_value()) 862 return genType; 863 {1} 864 parser.emitError(typeLoc) << "unknown type `" 865 << mnemonic << "` in dialect `" << getNamespace() << "`"; 866 return {{}; 867 } 868 /// Print a type registered to this dialect. 869 void {0}::printType(::mlir::Type type, 870 ::mlir::DialectAsmPrinter &printer) const {{ 871 if (::mlir::succeeded(generatedTypePrinter(type, printer))) 872 return; 873 {2} 874 } 875 )"; 876 877 /// The code block for dynamic type parser dispatch boilerplate. 878 static const char *const dialectDynamicTypeParserDispatch = R"( 879 { 880 auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); 881 if (parseResult.has_value()) { 882 if (::mlir::succeeded(parseResult.value())) 883 return genType; 884 return ::mlir::Type(); 885 } 886 } 887 )"; 888 889 /// The code block for dynamic type printer dispatch boilerplate. 890 static const char *const dialectDynamicTypePrinterDispatch = R"( 891 if (::mlir::succeeded(printIfDynamicType(type, printer))) 892 return; 893 )"; 894 895 /// Emit the dialect printer/parser dispatcher. User's code should call these 896 /// functions from their dialect's print/parse methods. 897 void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) { 898 if (llvm::none_of(defs, [](const AttrOrTypeDef &def) { 899 return def.getMnemonic().has_value(); 900 })) { 901 return; 902 } 903 // Declare the parser. 904 SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"}, 905 {"::llvm::StringRef *", "mnemonic"}}; 906 if (isAttrGenerator) 907 params.emplace_back("::mlir::Type", "type"); 908 params.emplace_back(strfmt("::mlir::{0} &", valueType), "value"); 909 Method parse("::mlir::OptionalParseResult", 910 strfmt("generated{0}Parser", valueType), Method::StaticInline, 911 std::move(params)); 912 // Declare the printer. 913 Method printer("::llvm::LogicalResult", 914 strfmt("generated{0}Printer", valueType), Method::StaticInline, 915 {{strfmt("::mlir::{0}", valueType), "def"}, 916 {"::mlir::AsmPrinter &", "printer"}}); 917 918 // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and 919 // calling the def's parse function. 920 parse.body() << " return " 921 "::mlir::AsmParser::KeywordSwitch<::mlir::" 922 "OptionalParseResult>(parser)\n"; 923 const char *const getValueForMnemonic = 924 R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{ 925 value = {0}::{1}; 926 return ::mlir::success(!!value); 927 }) 928 )"; 929 930 // The printer dispatch uses llvm::TypeSwitch to find and call the correct 931 // printer. 932 printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType 933 << ", ::llvm::LogicalResult>(def)"; 934 const char *const printValue = R"( .Case<{0}>([&](auto t) {{ 935 printer << {0}::getMnemonic();{1} 936 return ::mlir::success(); 937 }) 938 )"; 939 for (auto &def : defs) { 940 if (!def.getMnemonic()) 941 continue; 942 bool hasParserPrinterDecl = 943 def.hasCustomAssemblyFormat() || def.getAssemblyFormat(); 944 std::string defClass = strfmt( 945 "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName()); 946 947 // If the def has no parameters or parser code, invoke a normal `get`. 948 std::string parseOrGet = 949 hasParserPrinterDecl 950 ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "") 951 : "get(parser.getContext())"; 952 parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet); 953 954 // If the def has no parameters and no printer, just print the mnemonic. 955 StringRef printDef = ""; 956 if (hasParserPrinterDecl) 957 printDef = "\nt.print(printer);"; 958 printer.body() << llvm::formatv(printValue, defClass, printDef); 959 } 960 parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n" 961 " *mnemonic = keyword;\n" 962 " return std::nullopt;\n" 963 " });"; 964 printer.body() << " .Default([](auto) { return ::mlir::failure(); });"; 965 966 raw_indented_ostream indentedOs(os); 967 parse.writeDeclTo(indentedOs); 968 printer.writeDeclTo(indentedOs); 969 } 970 971 bool DefGenerator::emitDefs(StringRef selectedDialect) { 972 emitSourceFileHeader((defType + "Def Definitions").str(), os); 973 974 SmallVector<AttrOrTypeDef, 16> defs; 975 collectAllDefs(selectedDialect, defRecords, defs); 976 if (defs.empty()) 977 return false; 978 emitTypeDefList(defs); 979 980 IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os); 981 emitParsePrintDispatch(defs); 982 for (const AttrOrTypeDef &def : defs) { 983 { 984 NamespaceEmitter ns(os, def.getDialect()); 985 DefGen gen(def); 986 gen.emitDef(os); 987 } 988 // Emit the TypeID explicit specializations to have a single symbol def. 989 if (!def.getDialect().getCppNamespace().empty()) 990 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" 991 << def.getDialect().getCppNamespace() << "::" << def.getCppClassName() 992 << ")\n"; 993 } 994 995 Dialect firstDialect = defs.front().getDialect(); 996 997 // Emit the default parser/printer for Attributes if the dialect asked for it. 998 if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) { 999 NamespaceEmitter nsEmitter(os, firstDialect); 1000 if (firstDialect.isExtensible()) { 1001 os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, 1002 firstDialect.getCppClassName(), 1003 dialectDynamicAttrParserDispatch, 1004 dialectDynamicAttrPrinterDispatch); 1005 } else { 1006 os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, 1007 firstDialect.getCppClassName(), "", ""); 1008 } 1009 } 1010 1011 // Emit the default parser/printer for Types if the dialect asked for it. 1012 if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) { 1013 NamespaceEmitter nsEmitter(os, firstDialect); 1014 if (firstDialect.isExtensible()) { 1015 os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, 1016 firstDialect.getCppClassName(), 1017 dialectDynamicTypeParserDispatch, 1018 dialectDynamicTypePrinterDispatch); 1019 } else { 1020 os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, 1021 firstDialect.getCppClassName(), "", ""); 1022 } 1023 } 1024 1025 return false; 1026 } 1027 1028 //===----------------------------------------------------------------------===// 1029 // Type Constraints 1030 //===----------------------------------------------------------------------===// 1031 1032 /// Find all type constraints for which a C++ function should be generated. 1033 static std::vector<Constraint> 1034 getAllTypeConstraints(const RecordKeeper &records) { 1035 std::vector<Constraint> result; 1036 for (const Record *def : 1037 records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) { 1038 // Ignore constraints defined outside of the top-level file. 1039 if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != 1040 llvm::SrcMgr.getMainFileID()) 1041 continue; 1042 Constraint constr(def); 1043 // Generate C++ function only if "cppFunctionName" is set. 1044 if (!constr.getCppFunctionName()) 1045 continue; 1046 result.push_back(constr); 1047 } 1048 return result; 1049 } 1050 1051 static void emitTypeConstraintDecls(const RecordKeeper &records, 1052 raw_ostream &os) { 1053 static const char *const typeConstraintDecl = R"( 1054 bool {0}(::mlir::Type type); 1055 )"; 1056 1057 for (Constraint constr : getAllTypeConstraints(records)) 1058 os << strfmt(typeConstraintDecl, *constr.getCppFunctionName()); 1059 } 1060 1061 static void emitTypeConstraintDefs(const RecordKeeper &records, 1062 raw_ostream &os) { 1063 static const char *const typeConstraintDef = R"( 1064 bool {0}(::mlir::Type type) { 1065 return ({1}); 1066 } 1067 )"; 1068 1069 for (Constraint constr : getAllTypeConstraints(records)) { 1070 FmtContext ctx; 1071 ctx.withSelf("type"); 1072 std::string condition = tgfmt(constr.getConditionTemplate(), &ctx); 1073 os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition); 1074 } 1075 } 1076 1077 //===----------------------------------------------------------------------===// 1078 // GEN: Registration hooks 1079 //===----------------------------------------------------------------------===// 1080 1081 //===----------------------------------------------------------------------===// 1082 // AttrDef 1083 1084 static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*"); 1085 static llvm::cl::opt<std::string> 1086 attrDialect("attrdefs-dialect", 1087 llvm::cl::desc("Generate attributes for this dialect"), 1088 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated); 1089 1090 static mlir::GenRegistration 1091 genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions", 1092 [](const RecordKeeper &records, raw_ostream &os) { 1093 AttrDefGenerator generator(records, os); 1094 return generator.emitDefs(attrDialect); 1095 }); 1096 static mlir::GenRegistration 1097 genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations", 1098 [](const RecordKeeper &records, raw_ostream &os) { 1099 AttrDefGenerator generator(records, os); 1100 return generator.emitDecls(attrDialect); 1101 }); 1102 1103 //===----------------------------------------------------------------------===// 1104 // TypeDef 1105 1106 static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); 1107 static llvm::cl::opt<std::string> 1108 typeDialect("typedefs-dialect", 1109 llvm::cl::desc("Generate types for this dialect"), 1110 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); 1111 1112 static mlir::GenRegistration 1113 genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions", 1114 [](const RecordKeeper &records, raw_ostream &os) { 1115 TypeDefGenerator generator(records, os); 1116 return generator.emitDefs(typeDialect); 1117 }); 1118 static mlir::GenRegistration 1119 genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations", 1120 [](const RecordKeeper &records, raw_ostream &os) { 1121 TypeDefGenerator generator(records, os); 1122 return generator.emitDecls(typeDialect); 1123 }); 1124 1125 static mlir::GenRegistration 1126 genTypeConstrDefs("gen-type-constraint-defs", 1127 "Generate type constraint definitions", 1128 [](const RecordKeeper &records, raw_ostream &os) { 1129 emitTypeConstraintDefs(records, os); 1130 return false; 1131 }); 1132 static mlir::GenRegistration 1133 genTypeConstrDecls("gen-type-constraint-decls", 1134 "Generate type constraint declarations", 1135 [](const RecordKeeper &records, raw_ostream &os) { 1136 emitTypeConstraintDecls(records, os); 1137 return false; 1138 }); 1139