1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpInterfacesGen generates definitions for operation interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "DocGenUtilities.h" 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/GenInfo.h" 16 #include "mlir/TableGen/Interfaces.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/FormatVariadic.h" 20 #include "llvm/Support/raw_ostream.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 #include "llvm/TableGen/TableGenBackend.h" 24 25 using namespace mlir; 26 using llvm::Record; 27 using llvm::RecordKeeper; 28 using mlir::tblgen::Interface; 29 using mlir::tblgen::InterfaceMethod; 30 using mlir::tblgen::OpInterface; 31 32 /// Emit a string corresponding to a C++ type, followed by a space if necessary. 33 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { 34 type = type.trim(); 35 os << type; 36 if (type.back() != '&' && type.back() != '*') 37 os << " "; 38 return os; 39 } 40 41 /// Emit the method name and argument list for the given method. If 'addThisArg' 42 /// is true, then an argument is added to the beginning of the argument list for 43 /// the concrete value. 44 static void emitMethodNameAndArgs(const InterfaceMethod &method, 45 raw_ostream &os, StringRef valueType, 46 bool addThisArg, bool addConst) { 47 os << method.getName() << '('; 48 if (addThisArg) { 49 if (addConst) 50 os << "const "; 51 os << "const Concept *impl, "; 52 emitCPPType(valueType, os) 53 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); 54 } 55 llvm::interleaveComma(method.getArguments(), os, 56 [&](const InterfaceMethod::Argument &arg) { 57 os << arg.type << " " << arg.name; 58 }); 59 os << ')'; 60 if (addConst) 61 os << " const"; 62 } 63 64 /// Get an array of all OpInterface definitions but exclude those subclassing 65 /// "DeclareOpInterfaceMethods". 66 static std::vector<const Record *> 67 getAllInterfaceDefinitions(const RecordKeeper &records, StringRef name) { 68 std::vector<const Record *> defs = 69 records.getAllDerivedDefinitions((name + "Interface").str()); 70 71 std::string declareName = ("Declare" + name + "InterfaceMethods").str(); 72 llvm::erase_if(defs, [&](const Record *def) { 73 // Ignore any "declare methods" interfaces. 74 if (def->isSubClassOf(declareName)) 75 return true; 76 // Ignore interfaces defined outside of the top-level file. 77 return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != 78 llvm::SrcMgr.getMainFileID(); 79 }); 80 return defs; 81 } 82 83 namespace { 84 /// This struct is the base generator used when processing tablegen interfaces. 85 class InterfaceGenerator { 86 public: 87 bool emitInterfaceDefs(); 88 bool emitInterfaceDecls(); 89 bool emitInterfaceDocs(); 90 91 protected: 92 InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os) 93 : defs(std::move(defs)), os(os) {} 94 95 void emitConceptDecl(const Interface &interface); 96 void emitModelDecl(const Interface &interface); 97 void emitModelMethodsDef(const Interface &interface); 98 void emitTraitDecl(const Interface &interface, StringRef interfaceName, 99 StringRef interfaceTraitsName); 100 void emitInterfaceDecl(const Interface &interface); 101 102 /// The set of interface records to emit. 103 std::vector<const Record *> defs; 104 // The stream to emit to. 105 raw_ostream &os; 106 /// The C++ value type of the interface, e.g. Operation*. 107 StringRef valueType; 108 /// The C++ base interface type. 109 StringRef interfaceBaseType; 110 /// The name of the typename for the value template. 111 StringRef valueTemplate; 112 /// The name of the substituion variable for the value. 113 StringRef substVar; 114 /// The format context to use for methods. 115 tblgen::FmtContext nonStaticMethodFmt; 116 tblgen::FmtContext traitMethodFmt; 117 tblgen::FmtContext extraDeclsFmt; 118 }; 119 120 /// A specialized generator for attribute interfaces. 121 struct AttrInterfaceGenerator : public InterfaceGenerator { 122 AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) 123 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { 124 valueType = "::mlir::Attribute"; 125 interfaceBaseType = "AttributeInterface"; 126 valueTemplate = "ConcreteAttr"; 127 substVar = "_attr"; 128 StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))"; 129 nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode); 130 traitMethodFmt.addSubst(substVar, 131 "(*static_cast<const ConcreteAttr *>(this))"); 132 extraDeclsFmt.addSubst(substVar, "(*this)"); 133 } 134 }; 135 /// A specialized generator for operation interfaces. 136 struct OpInterfaceGenerator : public InterfaceGenerator { 137 OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) 138 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { 139 valueType = "::mlir::Operation *"; 140 interfaceBaseType = "OpInterface"; 141 valueTemplate = "ConcreteOp"; 142 substVar = "_op"; 143 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))"; 144 nonStaticMethodFmt.addSubst("_this", "impl") 145 .addSubst(substVar, castCode) 146 .withSelf(castCode); 147 traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))"); 148 extraDeclsFmt.addSubst(substVar, "(*this)"); 149 } 150 }; 151 /// A specialized generator for type interfaces. 152 struct TypeInterfaceGenerator : public InterfaceGenerator { 153 TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os) 154 : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { 155 valueType = "::mlir::Type"; 156 interfaceBaseType = "TypeInterface"; 157 valueTemplate = "ConcreteType"; 158 substVar = "_type"; 159 StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))"; 160 nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode); 161 traitMethodFmt.addSubst(substVar, 162 "(*static_cast<const ConcreteType *>(this))"); 163 extraDeclsFmt.addSubst(substVar, "(*this)"); 164 } 165 }; 166 } // namespace 167 168 //===----------------------------------------------------------------------===// 169 // GEN: Interface definitions 170 //===----------------------------------------------------------------------===// 171 172 static void emitInterfaceMethodDoc(const InterfaceMethod &method, 173 raw_ostream &os, StringRef prefix = "") { 174 if (std::optional<StringRef> description = method.getDescription()) 175 tblgen::emitDescriptionComment(*description, os, prefix); 176 } 177 static void emitInterfaceDefMethods(StringRef interfaceQualName, 178 const Interface &interface, 179 StringRef valueType, const Twine &implValue, 180 raw_ostream &os, bool isOpInterface) { 181 for (auto &method : interface.getMethods()) { 182 emitInterfaceMethodDoc(method, os); 183 emitCPPType(method.getReturnType(), os); 184 os << interfaceQualName << "::"; 185 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 186 /*addConst=*/!isOpInterface); 187 188 // Forward to the method on the concrete operation type. 189 os << " {\n return " << implValue << "->" << method.getName() << '('; 190 if (!method.isStatic()) { 191 os << implValue << ", "; 192 os << (isOpInterface ? "getOperation()" : "*this"); 193 os << (method.arg_empty() ? "" : ", "); 194 } 195 llvm::interleaveComma( 196 method.getArguments(), os, 197 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 198 os << ");\n }\n"; 199 } 200 } 201 202 static void emitInterfaceDef(const Interface &interface, StringRef valueType, 203 raw_ostream &os) { 204 std::string interfaceQualNameStr = interface.getFullyQualifiedName(); 205 StringRef interfaceQualName = interfaceQualNameStr; 206 interfaceQualName.consume_front("::"); 207 208 // Insert the method definitions. 209 bool isOpInterface = isa<OpInterface>(interface); 210 emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()", 211 os, isOpInterface); 212 213 // Insert the method definitions for base classes. 214 for (auto &base : interface.getBaseInterfaces()) { 215 emitInterfaceDefMethods(interfaceQualName, base, valueType, 216 "getImpl()->impl" + base.getName(), os, 217 isOpInterface); 218 } 219 } 220 221 bool InterfaceGenerator::emitInterfaceDefs() { 222 llvm::emitSourceFileHeader("Interface Definitions", os); 223 224 for (const auto *def : defs) 225 emitInterfaceDef(Interface(def), valueType, os); 226 return false; 227 } 228 229 //===----------------------------------------------------------------------===// 230 // GEN: Interface declarations 231 //===----------------------------------------------------------------------===// 232 233 void InterfaceGenerator::emitConceptDecl(const Interface &interface) { 234 os << " struct Concept {\n"; 235 236 // Insert each of the pure virtual concept methods. 237 os << " /// The methods defined by the interface.\n"; 238 for (auto &method : interface.getMethods()) { 239 os << " "; 240 emitCPPType(method.getReturnType(), os); 241 os << "(*" << method.getName() << ")("; 242 if (!method.isStatic()) { 243 os << "const Concept *impl, "; 244 emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); 245 } 246 llvm::interleaveComma( 247 method.getArguments(), os, 248 [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); 249 os << ");\n"; 250 } 251 252 // Insert a field containing a concept for each of the base interfaces. 253 auto baseInterfaces = interface.getBaseInterfaces(); 254 if (!baseInterfaces.empty()) { 255 os << " /// The base classes of this interface.\n"; 256 for (const auto &base : interface.getBaseInterfaces()) { 257 os << " const " << base.getFullyQualifiedName() << "::Concept *impl" 258 << base.getName() << " = nullptr;\n"; 259 } 260 261 // Define an "initialize" method that allows for the initialization of the 262 // base class concepts. 263 os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap " 264 "&interfaceMap) {\n"; 265 std::string interfaceQualName = interface.getFullyQualifiedName(); 266 for (const auto &base : interface.getBaseInterfaces()) { 267 StringRef baseName = base.getName(); 268 std::string baseQualName = base.getFullyQualifiedName(); 269 os << " impl" << baseName << " = interfaceMap.lookup<" 270 << baseQualName << ">();\n" 271 << " assert(impl" << baseName << " && \"`" << interfaceQualName 272 << "` expected its base interface `" << baseQualName 273 << "` to be registered\");\n"; 274 } 275 os << " }\n"; 276 } 277 278 os << " };\n"; 279 } 280 281 void InterfaceGenerator::emitModelDecl(const Interface &interface) { 282 // Emit the basic model and the fallback model. 283 for (const char *modelClass : {"Model", "FallbackModel"}) { 284 os << " template<typename " << valueTemplate << ">\n"; 285 os << " class " << modelClass << " : public Concept {\n public:\n"; 286 os << " using Interface = " << interface.getFullyQualifiedName() 287 << ";\n"; 288 os << " " << modelClass << "() : Concept{"; 289 llvm::interleaveComma( 290 interface.getMethods(), os, 291 [&](const InterfaceMethod &method) { os << method.getName(); }); 292 os << "} {}\n\n"; 293 294 // Insert each of the virtual method overrides. 295 for (auto &method : interface.getMethods()) { 296 emitCPPType(method.getReturnType(), os << " static inline "); 297 emitMethodNameAndArgs(method, os, valueType, 298 /*addThisArg=*/!method.isStatic(), 299 /*addConst=*/false); 300 os << ";\n"; 301 } 302 os << " };\n"; 303 } 304 305 // Emit the template for the external model. 306 os << " template<typename ConcreteModel, typename " << valueTemplate 307 << ">\n"; 308 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n"; 309 os << " public:\n"; 310 os << " using ConcreteEntity = " << valueTemplate << ";\n"; 311 312 // Emit declarations for methods that have default implementations. Other 313 // methods are expected to be implemented by the concrete derived model. 314 for (auto &method : interface.getMethods()) { 315 if (!method.getDefaultImplementation()) 316 continue; 317 os << " "; 318 if (method.isStatic()) 319 os << "static "; 320 emitCPPType(method.getReturnType(), os); 321 os << method.getName() << "("; 322 if (!method.isStatic()) { 323 emitCPPType(valueType, os); 324 os << "tablegen_opaque_val"; 325 if (!method.arg_empty()) 326 os << ", "; 327 } 328 llvm::interleaveComma(method.getArguments(), os, 329 [&](const InterfaceMethod::Argument &arg) { 330 emitCPPType(arg.type, os); 331 os << arg.name; 332 }); 333 os << ")"; 334 if (!method.isStatic()) 335 os << " const"; 336 os << ";\n"; 337 } 338 os << " };\n"; 339 } 340 341 void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { 342 llvm::SmallVector<StringRef, 2> namespaces; 343 llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); 344 for (StringRef ns : namespaces) 345 os << "namespace " << ns << " {\n"; 346 347 for (auto &method : interface.getMethods()) { 348 os << "template<typename " << valueTemplate << ">\n"; 349 emitCPPType(method.getReturnType(), os); 350 os << "detail::" << interface.getName() << "InterfaceTraits::Model<" 351 << valueTemplate << ">::"; 352 emitMethodNameAndArgs(method, os, valueType, 353 /*addThisArg=*/!method.isStatic(), 354 /*addConst=*/false); 355 os << " {\n "; 356 357 // Check for a provided body to the function. 358 if (std::optional<StringRef> body = method.getBody()) { 359 if (method.isStatic()) 360 os << body->trim(); 361 else 362 os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); 363 os << "\n}\n"; 364 continue; 365 } 366 367 // Forward to the method on the concrete operation type. 368 if (method.isStatic()) 369 os << "return " << valueTemplate << "::"; 370 else 371 os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt); 372 373 // Add the arguments to the call. 374 os << method.getName() << '('; 375 llvm::interleaveComma( 376 method.getArguments(), os, 377 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 378 os << ");\n}\n"; 379 } 380 381 for (auto &method : interface.getMethods()) { 382 os << "template<typename " << valueTemplate << ">\n"; 383 emitCPPType(method.getReturnType(), os); 384 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" 385 << valueTemplate << ">::"; 386 emitMethodNameAndArgs(method, os, valueType, 387 /*addThisArg=*/!method.isStatic(), 388 /*addConst=*/false); 389 os << " {\n "; 390 391 // Forward to the method on the concrete Model implementation. 392 if (method.isStatic()) 393 os << "return " << valueTemplate << "::"; 394 else 395 os << "return static_cast<const " << valueTemplate << " *>(impl)->"; 396 397 // Add the arguments to the call. 398 os << method.getName() << '('; 399 if (!method.isStatic()) 400 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); 401 llvm::interleaveComma( 402 method.getArguments(), os, 403 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 404 os << ");\n}\n"; 405 } 406 407 // Emit default implementations for the external model. 408 for (auto &method : interface.getMethods()) { 409 if (!method.getDefaultImplementation()) 410 continue; 411 os << "template<typename ConcreteModel, typename " << valueTemplate 412 << ">\n"; 413 emitCPPType(method.getReturnType(), os); 414 os << "detail::" << interface.getName() 415 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate 416 << ">::"; 417 418 os << method.getName() << "("; 419 if (!method.isStatic()) { 420 emitCPPType(valueType, os); 421 os << "tablegen_opaque_val"; 422 if (!method.arg_empty()) 423 os << ", "; 424 } 425 llvm::interleaveComma(method.getArguments(), os, 426 [&](const InterfaceMethod::Argument &arg) { 427 emitCPPType(arg.type, os); 428 os << arg.name; 429 }); 430 os << ")"; 431 if (!method.isStatic()) 432 os << " const"; 433 434 os << " {\n"; 435 436 // Use the empty context for static methods. 437 tblgen::FmtContext ctx; 438 os << tblgen::tgfmt(method.getDefaultImplementation()->trim(), 439 method.isStatic() ? &ctx : &nonStaticMethodFmt); 440 os << "\n}\n"; 441 } 442 443 for (StringRef ns : llvm::reverse(namespaces)) 444 os << "} // namespace " << ns << "\n"; 445 } 446 447 void InterfaceGenerator::emitTraitDecl(const Interface &interface, 448 StringRef interfaceName, 449 StringRef interfaceTraitsName) { 450 os << llvm::formatv(" template <typename {3}>\n" 451 " struct {0}Trait : public ::mlir::{2}<{0}," 452 " detail::{1}>::Trait<{3}> {{\n", 453 interfaceName, interfaceTraitsName, interfaceBaseType, 454 valueTemplate); 455 456 // Insert the default implementation for any methods. 457 bool isOpInterface = isa<OpInterface>(interface); 458 for (auto &method : interface.getMethods()) { 459 // Flag interface methods named verifyTrait. 460 if (method.getName() == "verifyTrait") 461 PrintFatalError( 462 formatv("'verifyTrait' method cannot be specified as interface " 463 "method for '{0}'; use the 'verify' field instead", 464 interfaceName)); 465 auto defaultImpl = method.getDefaultImplementation(); 466 if (!defaultImpl) 467 continue; 468 469 emitInterfaceMethodDoc(method, os, " "); 470 os << " " << (method.isStatic() ? "static " : ""); 471 emitCPPType(method.getReturnType(), os); 472 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 473 /*addConst=*/!isOpInterface && !method.isStatic()); 474 os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) 475 << "\n }\n"; 476 } 477 478 if (auto verify = interface.getVerify()) { 479 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'"); 480 481 tblgen::FmtContext verifyCtx; 482 verifyCtx.addSubst("_op", "op"); 483 os << llvm::formatv( 484 " static ::llvm::LogicalResult {0}(::mlir::Operation *op) ", 485 (interface.verifyWithRegions() ? "verifyRegionTrait" 486 : "verifyTrait")) 487 << "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx) 488 << "\n }\n"; 489 } 490 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) 491 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; 492 if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration()) 493 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; 494 495 os << " };\n"; 496 } 497 498 static void emitInterfaceDeclMethods(const Interface &interface, 499 raw_ostream &os, StringRef valueType, 500 bool isOpInterface, 501 tblgen::FmtContext &extraDeclsFmt) { 502 for (auto &method : interface.getMethods()) { 503 emitInterfaceMethodDoc(method, os, " "); 504 emitCPPType(method.getReturnType(), os << " "); 505 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 506 /*addConst=*/!isOpInterface); 507 os << ";\n"; 508 } 509 510 // Emit any extra declarations. 511 if (std::optional<StringRef> extraDecls = 512 interface.getExtraClassDeclaration()) 513 os << extraDecls->rtrim() << "\n"; 514 if (std::optional<StringRef> extraDecls = 515 interface.getExtraSharedClassDeclaration()) 516 os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n"; 517 } 518 519 void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { 520 llvm::SmallVector<StringRef, 2> namespaces; 521 llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); 522 for (StringRef ns : namespaces) 523 os << "namespace " << ns << " {\n"; 524 525 StringRef interfaceName = interface.getName(); 526 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); 527 528 // Emit a forward declaration of the interface class so that it becomes usable 529 // in the signature of its methods. 530 os << "class " << interfaceName << ";\n"; 531 532 // Emit the traits struct containing the concept and model declarations. 533 os << "namespace detail {\n" 534 << "struct " << interfaceTraitsName << " {\n"; 535 emitConceptDecl(interface); 536 emitModelDecl(interface); 537 os << "};\n"; 538 539 // Emit the derived trait for the interface. 540 os << "template <typename " << valueTemplate << ">\n"; 541 os << "struct " << interface.getName() << "Trait;\n"; 542 543 os << "\n} // namespace detail\n"; 544 545 // Emit the main interface class declaration. 546 os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" 547 "public:\n" 548 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", 549 interfaceName, interfaceName, interfaceTraitsName, 550 interfaceBaseType); 551 552 // Emit a utility wrapper trait class. 553 os << llvm::formatv(" template <typename {1}>\n" 554 " struct Trait : public detail::{0}Trait<{1}> {{};\n", 555 interfaceName, valueTemplate); 556 557 // Insert the method declarations. 558 bool isOpInterface = isa<OpInterface>(interface); 559 emitInterfaceDeclMethods(interface, os, valueType, isOpInterface, 560 extraDeclsFmt); 561 562 // Insert the method declarations for base classes. 563 for (auto &base : interface.getBaseInterfaces()) { 564 std::string baseQualName = base.getFullyQualifiedName(); 565 os << " //" 566 "===---------------------------------------------------------------" 567 "-===//\n" 568 << " // Inherited from " << baseQualName << "\n" 569 << " //" 570 "===---------------------------------------------------------------" 571 "-===//\n\n"; 572 573 // Allow implicit conversion to the base interface. 574 os << " operator " << baseQualName << " () const {\n" 575 << " if (!*this) return nullptr;\n" 576 << " return " << baseQualName << "(*this, getImpl()->impl" 577 << base.getName() << ");\n" 578 << " }\n\n"; 579 580 // Inherit the base interface's methods. 581 emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt); 582 } 583 584 // Emit classof code if necessary. 585 if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) { 586 auto extraClassOfFmt = tblgen::FmtContext(); 587 extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance"); 588 os << " static bool classof(" << valueType << " base) {\n" 589 << " auto* interface = getInterfaceFor(base);\n" 590 << " if (!interface)\n" 591 " return false;\n" 592 " " << interfaceName << " odsInterfaceInstance(base, interface);\n" 593 << " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt) 594 << "\n }\n"; 595 } 596 597 os << "};\n"; 598 599 os << "namespace detail {\n"; 600 emitTraitDecl(interface, interfaceName, interfaceTraitsName); 601 os << "}// namespace detail\n"; 602 603 for (StringRef ns : llvm::reverse(namespaces)) 604 os << "} // namespace " << ns << "\n"; 605 } 606 607 bool InterfaceGenerator::emitInterfaceDecls() { 608 llvm::emitSourceFileHeader("Interface Declarations", os); 609 // Sort according to ID, so defs are emitted in the order in which they appear 610 // in the Tablegen file. 611 std::vector<const Record *> sortedDefs(defs); 612 llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { 613 return lhs->getID() < rhs->getID(); 614 }); 615 for (const Record *def : sortedDefs) 616 emitInterfaceDecl(Interface(def)); 617 for (const Record *def : sortedDefs) 618 emitModelMethodsDef(Interface(def)); 619 return false; 620 } 621 622 //===----------------------------------------------------------------------===// 623 // GEN: Interface documentation 624 //===----------------------------------------------------------------------===// 625 626 static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) { 627 Interface interface(&interfaceDef); 628 629 // Emit the interface name followed by the description. 630 os << "## " << interface.getName() << " (`" << interfaceDef.getName() 631 << "`)\n\n"; 632 if (auto description = interface.getDescription()) 633 mlir::tblgen::emitDescription(*description, os); 634 635 // Emit the methods required by the interface. 636 os << "\n### Methods:\n"; 637 for (const auto &method : interface.getMethods()) { 638 // Emit the method name. 639 os << "#### `" << method.getName() << "`\n\n```c++\n"; 640 641 // Emit the method signature. 642 if (method.isStatic()) 643 os << "static "; 644 emitCPPType(method.getReturnType(), os) << method.getName() << '('; 645 llvm::interleaveComma(method.getArguments(), os, 646 [&](const InterfaceMethod::Argument &arg) { 647 emitCPPType(arg.type, os) << arg.name; 648 }); 649 os << ");\n```\n"; 650 651 // Emit the description. 652 if (auto description = method.getDescription()) 653 mlir::tblgen::emitDescription(*description, os); 654 655 // If the body is not provided, this method must be provided by the user. 656 if (!method.getBody()) 657 os << "\nNOTE: This method *must* be implemented by the user."; 658 659 os << "\n\n"; 660 } 661 } 662 663 bool InterfaceGenerator::emitInterfaceDocs() { 664 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; 665 os << "# " << interfaceBaseType << " definitions\n"; 666 667 for (const auto *def : defs) 668 emitInterfaceDoc(*def, os); 669 return false; 670 } 671 672 //===----------------------------------------------------------------------===// 673 // GEN: Interface registration hooks 674 //===----------------------------------------------------------------------===// 675 676 namespace { 677 template <typename GeneratorT> 678 struct InterfaceGenRegistration { 679 InterfaceGenRegistration(StringRef genArg, StringRef genDesc) 680 : genDeclArg(("gen-" + genArg + "-interface-decls").str()), 681 genDefArg(("gen-" + genArg + "-interface-defs").str()), 682 genDocArg(("gen-" + genArg + "-interface-docs").str()), 683 genDeclDesc(("Generate " + genDesc + " interface declarations").str()), 684 genDefDesc(("Generate " + genDesc + " interface definitions").str()), 685 genDocDesc(("Generate " + genDesc + " interface documentation").str()), 686 genDecls(genDeclArg, genDeclDesc, 687 [](const RecordKeeper &records, raw_ostream &os) { 688 return GeneratorT(records, os).emitInterfaceDecls(); 689 }), 690 genDefs(genDefArg, genDefDesc, 691 [](const RecordKeeper &records, raw_ostream &os) { 692 return GeneratorT(records, os).emitInterfaceDefs(); 693 }), 694 genDocs(genDocArg, genDocDesc, 695 [](const RecordKeeper &records, raw_ostream &os) { 696 return GeneratorT(records, os).emitInterfaceDocs(); 697 }) {} 698 699 std::string genDeclArg, genDefArg, genDocArg; 700 std::string genDeclDesc, genDefDesc, genDocDesc; 701 mlir::GenRegistration genDecls, genDefs, genDocs; 702 }; 703 } // namespace 704 705 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr", 706 "attribute"); 707 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op"); 708 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type"); 709