1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpInterfacesGen generates definitions for operation interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "DocGenUtilities.h" 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/GenInfo.h" 16 #include "mlir/TableGen/Interfaces.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/FormatVariadic.h" 20 #include "llvm/Support/raw_ostream.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 #include "llvm/TableGen/TableGenBackend.h" 24 25 using namespace mlir; 26 using mlir::tblgen::Interface; 27 using mlir::tblgen::InterfaceMethod; 28 using mlir::tblgen::OpInterface; 29 30 /// Emit a string corresponding to a C++ type, followed by a space if necessary. 31 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { 32 type = type.trim(); 33 os << type; 34 if (type.back() != '&' && type.back() != '*') 35 os << " "; 36 return os; 37 } 38 39 /// Emit the method name and argument list for the given method. If 'addThisArg' 40 /// is true, then an argument is added to the beginning of the argument list for 41 /// the concrete value. 42 static void emitMethodNameAndArgs(const InterfaceMethod &method, 43 raw_ostream &os, StringRef valueType, 44 bool addThisArg, bool addConst) { 45 os << method.getName() << '('; 46 if (addThisArg) { 47 if (addConst) 48 os << "const "; 49 os << "const Concept *impl, "; 50 emitCPPType(valueType, os) 51 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); 52 } 53 llvm::interleaveComma(method.getArguments(), os, 54 [&](const InterfaceMethod::Argument &arg) { 55 os << arg.type << " " << arg.name; 56 }); 57 os << ')'; 58 if (addConst) 59 os << " const"; 60 } 61 62 /// Get an array of all OpInterface definitions but exclude those subclassing 63 /// "DeclareOpInterfaceMethods". 64 static std::vector<llvm::Record *> 65 getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, 66 StringRef name) { 67 std::vector<llvm::Record *> defs = 68 recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); 69 70 std::string declareName = ("Declare" + name + "InterfaceMethods").str(); 71 llvm::erase_if(defs, [&](const llvm::Record *def) { 72 // Ignore any "declare methods" interfaces. 73 if (def->isSubClassOf(declareName)) 74 return true; 75 // Ignore interfaces defined outside of the top-level file. 76 return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != 77 llvm::SrcMgr.getMainFileID(); 78 }); 79 return defs; 80 } 81 82 namespace { 83 /// This struct is the base generator used when processing tablegen interfaces. 84 class InterfaceGenerator { 85 public: 86 bool emitInterfaceDefs(); 87 bool emitInterfaceDecls(); 88 bool emitInterfaceDocs(); 89 90 protected: 91 InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os) 92 : defs(std::move(defs)), os(os) {} 93 94 void emitConceptDecl(const Interface &interface); 95 void emitModelDecl(const Interface &interface); 96 void emitModelMethodsDef(const Interface &interface); 97 void emitTraitDecl(const Interface &interface, StringRef interfaceName, 98 StringRef interfaceTraitsName); 99 void emitInterfaceDecl(const Interface &interface); 100 101 /// The set of interface records to emit. 102 std::vector<llvm::Record *> defs; 103 // The stream to emit to. 104 raw_ostream &os; 105 /// The C++ value type of the interface, e.g. Operation*. 106 StringRef valueType; 107 /// The C++ base interface type. 108 StringRef interfaceBaseType; 109 /// The name of the typename for the value template. 110 StringRef valueTemplate; 111 /// The name of the substituion variable for the value. 112 StringRef substVar; 113 /// The format context to use for methods. 114 tblgen::FmtContext nonStaticMethodFmt; 115 tblgen::FmtContext traitMethodFmt; 116 tblgen::FmtContext extraDeclsFmt; 117 }; 118 119 /// A specialized generator for attribute interfaces. 120 struct AttrInterfaceGenerator : public InterfaceGenerator { 121 AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) 122 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { 123 valueType = "::mlir::Attribute"; 124 interfaceBaseType = "AttributeInterface"; 125 valueTemplate = "ConcreteAttr"; 126 substVar = "_attr"; 127 StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))"; 128 nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode); 129 traitMethodFmt.addSubst(substVar, 130 "(*static_cast<const ConcreteAttr *>(this))"); 131 extraDeclsFmt.addSubst(substVar, "(*this)"); 132 } 133 }; 134 /// A specialized generator for operation interfaces. 135 struct OpInterfaceGenerator : public InterfaceGenerator { 136 OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) 137 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { 138 valueType = "::mlir::Operation *"; 139 interfaceBaseType = "OpInterface"; 140 valueTemplate = "ConcreteOp"; 141 substVar = "_op"; 142 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))"; 143 nonStaticMethodFmt.addSubst("_this", "impl") 144 .addSubst(substVar, castCode) 145 .withSelf(castCode); 146 traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))"); 147 extraDeclsFmt.addSubst(substVar, "(*this)"); 148 } 149 }; 150 /// A specialized generator for type interfaces. 151 struct TypeInterfaceGenerator : public InterfaceGenerator { 152 TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) 153 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { 154 valueType = "::mlir::Type"; 155 interfaceBaseType = "TypeInterface"; 156 valueTemplate = "ConcreteType"; 157 substVar = "_type"; 158 StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))"; 159 nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode); 160 traitMethodFmt.addSubst(substVar, 161 "(*static_cast<const ConcreteType *>(this))"); 162 extraDeclsFmt.addSubst(substVar, "(*this)"); 163 } 164 }; 165 } // namespace 166 167 //===----------------------------------------------------------------------===// 168 // GEN: Interface definitions 169 //===----------------------------------------------------------------------===// 170 171 static void emitInterfaceMethodDoc(const InterfaceMethod &method, 172 raw_ostream &os, StringRef prefix = "") { 173 if (std::optional<StringRef> description = method.getDescription()) 174 tblgen::emitDescriptionComment(*description, os, prefix); 175 } 176 static void emitInterfaceDefMethods(StringRef interfaceQualName, 177 const Interface &interface, 178 StringRef valueType, const Twine &implValue, 179 raw_ostream &os, bool isOpInterface) { 180 for (auto &method : interface.getMethods()) { 181 emitInterfaceMethodDoc(method, os); 182 emitCPPType(method.getReturnType(), os); 183 os << interfaceQualName << "::"; 184 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 185 /*addConst=*/!isOpInterface); 186 187 // Forward to the method on the concrete operation type. 188 os << " {\n return " << implValue << "->" << method.getName() << '('; 189 if (!method.isStatic()) { 190 os << implValue << ", "; 191 os << (isOpInterface ? "getOperation()" : "*this"); 192 os << (method.arg_empty() ? "" : ", "); 193 } 194 llvm::interleaveComma( 195 method.getArguments(), os, 196 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 197 os << ");\n }\n"; 198 } 199 } 200 201 static void emitInterfaceDef(const Interface &interface, StringRef valueType, 202 raw_ostream &os) { 203 std::string interfaceQualNameStr = interface.getFullyQualifiedName(); 204 StringRef interfaceQualName = interfaceQualNameStr; 205 interfaceQualName.consume_front("::"); 206 207 // Insert the method definitions. 208 bool isOpInterface = isa<OpInterface>(interface); 209 emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()", 210 os, isOpInterface); 211 212 // Insert the method definitions for base classes. 213 for (auto &base : interface.getBaseInterfaces()) { 214 emitInterfaceDefMethods(interfaceQualName, base, valueType, 215 "getImpl()->impl" + base.getName(), os, 216 isOpInterface); 217 } 218 } 219 220 bool InterfaceGenerator::emitInterfaceDefs() { 221 llvm::emitSourceFileHeader("Interface Definitions", os); 222 223 for (const auto *def : defs) 224 emitInterfaceDef(Interface(def), valueType, os); 225 return false; 226 } 227 228 //===----------------------------------------------------------------------===// 229 // GEN: Interface declarations 230 //===----------------------------------------------------------------------===// 231 232 void InterfaceGenerator::emitConceptDecl(const Interface &interface) { 233 os << " struct Concept {\n"; 234 235 // Insert each of the pure virtual concept methods. 236 os << " /// The methods defined by the interface.\n"; 237 for (auto &method : interface.getMethods()) { 238 os << " "; 239 emitCPPType(method.getReturnType(), os); 240 os << "(*" << method.getName() << ")("; 241 if (!method.isStatic()) { 242 os << "const Concept *impl, "; 243 emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); 244 } 245 llvm::interleaveComma( 246 method.getArguments(), os, 247 [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); 248 os << ");\n"; 249 } 250 251 // Insert a field containing a concept for each of the base interfaces. 252 auto baseInterfaces = interface.getBaseInterfaces(); 253 if (!baseInterfaces.empty()) { 254 os << " /// The base classes of this interface.\n"; 255 for (const auto &base : interface.getBaseInterfaces()) { 256 os << " const " << base.getFullyQualifiedName() << "::Concept *impl" 257 << base.getName() << " = nullptr;\n"; 258 } 259 260 // Define an "initialize" method that allows for the initialization of the 261 // base class concepts. 262 os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap " 263 "&interfaceMap) {\n"; 264 std::string interfaceQualName = interface.getFullyQualifiedName(); 265 for (const auto &base : interface.getBaseInterfaces()) { 266 StringRef baseName = base.getName(); 267 std::string baseQualName = base.getFullyQualifiedName(); 268 os << " impl" << baseName << " = interfaceMap.lookup<" 269 << baseQualName << ">();\n" 270 << " assert(impl" << baseName << " && \"`" << interfaceQualName 271 << "` expected its base interface `" << baseQualName 272 << "` to be registered\");\n"; 273 } 274 os << " }\n"; 275 } 276 277 os << " };\n"; 278 } 279 280 void InterfaceGenerator::emitModelDecl(const Interface &interface) { 281 // Emit the basic model and the fallback model. 282 for (const char *modelClass : {"Model", "FallbackModel"}) { 283 os << " template<typename " << valueTemplate << ">\n"; 284 os << " class " << modelClass << " : public Concept {\n public:\n"; 285 os << " using Interface = " << interface.getFullyQualifiedName() 286 << ";\n"; 287 os << " " << modelClass << "() : Concept{"; 288 llvm::interleaveComma( 289 interface.getMethods(), os, 290 [&](const InterfaceMethod &method) { os << method.getName(); }); 291 os << "} {}\n\n"; 292 293 // Insert each of the virtual method overrides. 294 for (auto &method : interface.getMethods()) { 295 emitCPPType(method.getReturnType(), os << " static inline "); 296 emitMethodNameAndArgs(method, os, valueType, 297 /*addThisArg=*/!method.isStatic(), 298 /*addConst=*/false); 299 os << ";\n"; 300 } 301 os << " };\n"; 302 } 303 304 // Emit the template for the external model. 305 os << " template<typename ConcreteModel, typename " << valueTemplate 306 << ">\n"; 307 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n"; 308 os << " public:\n"; 309 os << " using ConcreteEntity = " << valueTemplate << ";\n"; 310 311 // Emit declarations for methods that have default implementations. Other 312 // methods are expected to be implemented by the concrete derived model. 313 for (auto &method : interface.getMethods()) { 314 if (!method.getDefaultImplementation()) 315 continue; 316 os << " "; 317 if (method.isStatic()) 318 os << "static "; 319 emitCPPType(method.getReturnType(), os); 320 os << method.getName() << "("; 321 if (!method.isStatic()) { 322 emitCPPType(valueType, os); 323 os << "tablegen_opaque_val"; 324 if (!method.arg_empty()) 325 os << ", "; 326 } 327 llvm::interleaveComma(method.getArguments(), os, 328 [&](const InterfaceMethod::Argument &arg) { 329 emitCPPType(arg.type, os); 330 os << arg.name; 331 }); 332 os << ")"; 333 if (!method.isStatic()) 334 os << " const"; 335 os << ";\n"; 336 } 337 os << " };\n"; 338 } 339 340 void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { 341 llvm::SmallVector<StringRef, 2> namespaces; 342 llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); 343 for (StringRef ns : namespaces) 344 os << "namespace " << ns << " {\n"; 345 346 for (auto &method : interface.getMethods()) { 347 os << "template<typename " << valueTemplate << ">\n"; 348 emitCPPType(method.getReturnType(), os); 349 os << "detail::" << interface.getName() << "InterfaceTraits::Model<" 350 << valueTemplate << ">::"; 351 emitMethodNameAndArgs(method, os, valueType, 352 /*addThisArg=*/!method.isStatic(), 353 /*addConst=*/false); 354 os << " {\n "; 355 356 // Check for a provided body to the function. 357 if (std::optional<StringRef> body = method.getBody()) { 358 if (method.isStatic()) 359 os << body->trim(); 360 else 361 os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); 362 os << "\n}\n"; 363 continue; 364 } 365 366 // Forward to the method on the concrete operation type. 367 if (method.isStatic()) 368 os << "return " << valueTemplate << "::"; 369 else 370 os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt); 371 372 // Add the arguments to the call. 373 os << method.getName() << '('; 374 llvm::interleaveComma( 375 method.getArguments(), os, 376 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 377 os << ");\n}\n"; 378 } 379 380 for (auto &method : interface.getMethods()) { 381 os << "template<typename " << valueTemplate << ">\n"; 382 emitCPPType(method.getReturnType(), os); 383 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" 384 << valueTemplate << ">::"; 385 emitMethodNameAndArgs(method, os, valueType, 386 /*addThisArg=*/!method.isStatic(), 387 /*addConst=*/false); 388 os << " {\n "; 389 390 // Forward to the method on the concrete Model implementation. 391 if (method.isStatic()) 392 os << "return " << valueTemplate << "::"; 393 else 394 os << "return static_cast<const " << valueTemplate << " *>(impl)->"; 395 396 // Add the arguments to the call. 397 os << method.getName() << '('; 398 if (!method.isStatic()) 399 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); 400 llvm::interleaveComma( 401 method.getArguments(), os, 402 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 403 os << ");\n}\n"; 404 } 405 406 // Emit default implementations for the external model. 407 for (auto &method : interface.getMethods()) { 408 if (!method.getDefaultImplementation()) 409 continue; 410 os << "template<typename ConcreteModel, typename " << valueTemplate 411 << ">\n"; 412 emitCPPType(method.getReturnType(), os); 413 os << "detail::" << interface.getName() 414 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate 415 << ">::"; 416 417 os << method.getName() << "("; 418 if (!method.isStatic()) { 419 emitCPPType(valueType, os); 420 os << "tablegen_opaque_val"; 421 if (!method.arg_empty()) 422 os << ", "; 423 } 424 llvm::interleaveComma(method.getArguments(), os, 425 [&](const InterfaceMethod::Argument &arg) { 426 emitCPPType(arg.type, os); 427 os << arg.name; 428 }); 429 os << ")"; 430 if (!method.isStatic()) 431 os << " const"; 432 433 os << " {\n"; 434 435 // Use the empty context for static methods. 436 tblgen::FmtContext ctx; 437 os << tblgen::tgfmt(method.getDefaultImplementation()->trim(), 438 method.isStatic() ? &ctx : &nonStaticMethodFmt); 439 os << "\n}\n"; 440 } 441 442 for (StringRef ns : llvm::reverse(namespaces)) 443 os << "} // namespace " << ns << "\n"; 444 } 445 446 void InterfaceGenerator::emitTraitDecl(const Interface &interface, 447 StringRef interfaceName, 448 StringRef interfaceTraitsName) { 449 os << llvm::formatv(" template <typename {3}>\n" 450 " struct {0}Trait : public ::mlir::{2}<{0}," 451 " detail::{1}>::Trait<{3}> {{\n", 452 interfaceName, interfaceTraitsName, interfaceBaseType, 453 valueTemplate); 454 455 // Insert the default implementation for any methods. 456 bool isOpInterface = isa<OpInterface>(interface); 457 for (auto &method : interface.getMethods()) { 458 // Flag interface methods named verifyTrait. 459 if (method.getName() == "verifyTrait") 460 PrintFatalError( 461 formatv("'verifyTrait' method cannot be specified as interface " 462 "method for '{0}'; use the 'verify' field instead", 463 interfaceName)); 464 auto defaultImpl = method.getDefaultImplementation(); 465 if (!defaultImpl) 466 continue; 467 468 emitInterfaceMethodDoc(method, os, " "); 469 os << " " << (method.isStatic() ? "static " : ""); 470 emitCPPType(method.getReturnType(), os); 471 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 472 /*addConst=*/!isOpInterface && !method.isStatic()); 473 os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) 474 << "\n }\n"; 475 } 476 477 if (auto verify = interface.getVerify()) { 478 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'"); 479 480 tblgen::FmtContext verifyCtx; 481 verifyCtx.addSubst("_op", "op"); 482 os << llvm::formatv( 483 " static ::mlir::LogicalResult {0}(::mlir::Operation *op) ", 484 (interface.verifyWithRegions() ? "verifyRegionTrait" 485 : "verifyTrait")) 486 << "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx) 487 << "\n }\n"; 488 } 489 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) 490 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; 491 if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration()) 492 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; 493 494 os << " };\n"; 495 } 496 497 static void emitInterfaceDeclMethods(const Interface &interface, 498 raw_ostream &os, StringRef valueType, 499 bool isOpInterface, 500 tblgen::FmtContext &extraDeclsFmt) { 501 for (auto &method : interface.getMethods()) { 502 emitInterfaceMethodDoc(method, os, " "); 503 emitCPPType(method.getReturnType(), os << " "); 504 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 505 /*addConst=*/!isOpInterface); 506 os << ";\n"; 507 } 508 509 // Emit any extra declarations. 510 if (std::optional<StringRef> extraDecls = 511 interface.getExtraClassDeclaration()) 512 os << extraDecls->rtrim() << "\n"; 513 if (std::optional<StringRef> extraDecls = 514 interface.getExtraSharedClassDeclaration()) 515 os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n"; 516 } 517 518 void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { 519 llvm::SmallVector<StringRef, 2> namespaces; 520 llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); 521 for (StringRef ns : namespaces) 522 os << "namespace " << ns << " {\n"; 523 524 StringRef interfaceName = interface.getName(); 525 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); 526 527 // Emit a forward declaration of the interface class so that it becomes usable 528 // in the signature of its methods. 529 os << "class " << interfaceName << ";\n"; 530 531 // Emit the traits struct containing the concept and model declarations. 532 os << "namespace detail {\n" 533 << "struct " << interfaceTraitsName << " {\n"; 534 emitConceptDecl(interface); 535 emitModelDecl(interface); 536 os << "};"; 537 538 // Emit the derived trait for the interface. 539 os << "template <typename " << valueTemplate << ">\n"; 540 os << "struct " << interface.getName() << "Trait;\n"; 541 542 os << "\n} // namespace detail\n"; 543 544 // Emit the main interface class declaration. 545 os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" 546 "public:\n" 547 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", 548 interfaceName, interfaceName, interfaceTraitsName, 549 interfaceBaseType); 550 551 // Emit a utility wrapper trait class. 552 os << llvm::formatv(" template <typename {1}>\n" 553 " struct Trait : public detail::{0}Trait<{1}> {{};\n", 554 interfaceName, valueTemplate); 555 556 // Insert the method declarations. 557 bool isOpInterface = isa<OpInterface>(interface); 558 emitInterfaceDeclMethods(interface, os, valueType, isOpInterface, 559 extraDeclsFmt); 560 561 // Insert the method declarations for base classes. 562 for (auto &base : interface.getBaseInterfaces()) { 563 std::string baseQualName = base.getFullyQualifiedName(); 564 os << " //" 565 "===---------------------------------------------------------------" 566 "-===//\n" 567 << " // Inherited from " << baseQualName << "\n" 568 << " //" 569 "===---------------------------------------------------------------" 570 "-===//\n\n"; 571 572 // Allow implicit conversion to the base interface. 573 os << " operator " << baseQualName << " () const {\n" 574 << " return " << baseQualName << "(*this, getImpl()->impl" 575 << base.getName() << ");\n" 576 << " }\n\n"; 577 578 // Inherit the base interface's methods. 579 emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt); 580 } 581 582 // Emit classof code if necessary. 583 if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) { 584 auto extraClassOfFmt = tblgen::FmtContext(); 585 extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance"); 586 os << " static bool classof(" << valueType << " base) {\n" 587 << " auto* interface = getInterfaceFor(base);\n" 588 << " if (!interface)\n" 589 " return false;\n" 590 " " << interfaceName << " odsInterfaceInstance(base, interface);\n" 591 << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt) 592 << "\n }\n"; 593 } 594 595 os << "};\n"; 596 597 os << "namespace detail {\n"; 598 emitTraitDecl(interface, interfaceName, interfaceTraitsName); 599 os << "}// namespace detail\n"; 600 601 for (StringRef ns : llvm::reverse(namespaces)) 602 os << "} // namespace " << ns << "\n"; 603 } 604 605 bool InterfaceGenerator::emitInterfaceDecls() { 606 llvm::emitSourceFileHeader("Interface Declarations", os); 607 // Sort according to ID, so defs are emitted in the order in which they appear 608 // in the Tablegen file. 609 std::vector<llvm::Record *> sortedDefs(defs); 610 llvm::sort(sortedDefs, [](llvm::Record *lhs, llvm::Record *rhs) { 611 return lhs->getID() < rhs->getID(); 612 }); 613 for (const llvm::Record *def : sortedDefs) 614 emitInterfaceDecl(Interface(def)); 615 for (const llvm::Record *def : sortedDefs) 616 emitModelMethodsDef(Interface(def)); 617 return false; 618 } 619 620 //===----------------------------------------------------------------------===// 621 // GEN: Interface documentation 622 //===----------------------------------------------------------------------===// 623 624 static void emitInterfaceDoc(const llvm::Record &interfaceDef, 625 raw_ostream &os) { 626 Interface interface(&interfaceDef); 627 628 // Emit the interface name followed by the description. 629 os << "## " << interface.getName() << " (`" << interfaceDef.getName() 630 << "`)\n\n"; 631 if (auto description = interface.getDescription()) 632 mlir::tblgen::emitDescription(*description, os); 633 634 // Emit the methods required by the interface. 635 os << "\n### Methods:\n"; 636 for (const auto &method : interface.getMethods()) { 637 // Emit the method name. 638 os << "#### `" << method.getName() << "`\n\n```c++\n"; 639 640 // Emit the method signature. 641 if (method.isStatic()) 642 os << "static "; 643 emitCPPType(method.getReturnType(), os) << method.getName() << '('; 644 llvm::interleaveComma(method.getArguments(), os, 645 [&](const InterfaceMethod::Argument &arg) { 646 emitCPPType(arg.type, os) << arg.name; 647 }); 648 os << ");\n```\n"; 649 650 // Emit the description. 651 if (auto description = method.getDescription()) 652 mlir::tblgen::emitDescription(*description, os); 653 654 // If the body is not provided, this method must be provided by the user. 655 if (!method.getBody()) 656 os << "\nNOTE: This method *must* be implemented by the user."; 657 658 os << "\n\n"; 659 } 660 } 661 662 bool InterfaceGenerator::emitInterfaceDocs() { 663 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; 664 os << "# " << interfaceBaseType << " definitions\n"; 665 666 for (const auto *def : defs) 667 emitInterfaceDoc(*def, os); 668 return false; 669 } 670 671 //===----------------------------------------------------------------------===// 672 // GEN: Interface registration hooks 673 //===----------------------------------------------------------------------===// 674 675 namespace { 676 template <typename GeneratorT> 677 struct InterfaceGenRegistration { 678 InterfaceGenRegistration(StringRef genArg, StringRef genDesc) 679 : genDeclArg(("gen-" + genArg + "-interface-decls").str()), 680 genDefArg(("gen-" + genArg + "-interface-defs").str()), 681 genDocArg(("gen-" + genArg + "-interface-docs").str()), 682 genDeclDesc(("Generate " + genDesc + " interface declarations").str()), 683 genDefDesc(("Generate " + genDesc + " interface definitions").str()), 684 genDocDesc(("Generate " + genDesc + " interface documentation").str()), 685 genDecls(genDeclArg, genDeclDesc, 686 [](const llvm::RecordKeeper &records, raw_ostream &os) { 687 return GeneratorT(records, os).emitInterfaceDecls(); 688 }), 689 genDefs(genDefArg, genDefDesc, 690 [](const llvm::RecordKeeper &records, raw_ostream &os) { 691 return GeneratorT(records, os).emitInterfaceDefs(); 692 }), 693 genDocs(genDocArg, genDocDesc, 694 [](const llvm::RecordKeeper &records, raw_ostream &os) { 695 return GeneratorT(records, os).emitInterfaceDocs(); 696 }) {} 697 698 std::string genDeclArg, genDefArg, genDocArg; 699 std::string genDeclDesc, genDefDesc, genDocDesc; 700 mlir::GenRegistration genDecls, genDefs, genDocs; 701 }; 702 } // namespace 703 704 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr", 705 "attribute"); 706 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op"); 707 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type"); 708