1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpInterfacesGen generates definitions for operation interfaces. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "DocGenUtilities.h" 14 #include "mlir/TableGen/Format.h" 15 #include "mlir/TableGen/GenInfo.h" 16 #include "mlir/TableGen/Interfaces.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/ADT/StringExtras.h" 19 #include "llvm/Support/FormatVariadic.h" 20 #include "llvm/Support/raw_ostream.h" 21 #include "llvm/TableGen/Error.h" 22 #include "llvm/TableGen/Record.h" 23 #include "llvm/TableGen/TableGenBackend.h" 24 25 using namespace mlir; 26 using mlir::tblgen::Interface; 27 using mlir::tblgen::InterfaceMethod; 28 using mlir::tblgen::OpInterface; 29 30 /// Emit a string corresponding to a C++ type, followed by a space if necessary. 31 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { 32 type = type.trim(); 33 os << type; 34 if (type.back() != '&' && type.back() != '*') 35 os << " "; 36 return os; 37 } 38 39 /// Emit the method name and argument list for the given method. If 'addThisArg' 40 /// is true, then an argument is added to the beginning of the argument list for 41 /// the concrete value. 42 static void emitMethodNameAndArgs(const InterfaceMethod &method, 43 raw_ostream &os, StringRef valueType, 44 bool addThisArg, bool addConst) { 45 os << method.getName() << '('; 46 if (addThisArg) { 47 if (addConst) 48 os << "const "; 49 os << "const Concept *impl, "; 50 emitCPPType(valueType, os) 51 << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); 52 } 53 llvm::interleaveComma(method.getArguments(), os, 54 [&](const InterfaceMethod::Argument &arg) { 55 os << arg.type << " " << arg.name; 56 }); 57 os << ')'; 58 if (addConst) 59 os << " const"; 60 } 61 62 /// Get an array of all OpInterface definitions but exclude those subclassing 63 /// "DeclareOpInterfaceMethods". 64 static std::vector<llvm::Record *> 65 getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) { 66 std::vector<llvm::Record *> defs = 67 recordKeeper.getAllDerivedDefinitions("OpInterface"); 68 69 llvm::erase_if(defs, [](const llvm::Record *def) { 70 return def->isSubClassOf("DeclareOpInterfaceMethods"); 71 }); 72 return defs; 73 } 74 75 namespace { 76 /// This struct is the base generator used when processing tablegen interfaces. 77 class InterfaceGenerator { 78 public: 79 bool emitInterfaceDefs(); 80 bool emitInterfaceDecls(); 81 bool emitInterfaceDocs(); 82 83 protected: 84 InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os) 85 : defs(std::move(defs)), os(os) {} 86 87 void emitConceptDecl(Interface &interface); 88 void emitModelDecl(Interface &interface); 89 void emitModelMethodsDef(Interface &interface); 90 void emitTraitDecl(Interface &interface, StringRef interfaceName, 91 StringRef interfaceTraitsName); 92 void emitInterfaceDecl(Interface interface); 93 94 /// The set of interface records to emit. 95 std::vector<llvm::Record *> defs; 96 // The stream to emit to. 97 raw_ostream &os; 98 /// The C++ value type of the interface, e.g. Operation*. 99 StringRef valueType; 100 /// The C++ base interface type. 101 StringRef interfaceBaseType; 102 /// The name of the typename for the value template. 103 StringRef valueTemplate; 104 /// The format context to use for methods. 105 tblgen::FmtContext nonStaticMethodFmt; 106 tblgen::FmtContext traitMethodFmt; 107 tblgen::FmtContext extraDeclsFmt; 108 }; 109 110 /// A specialized generator for attribute interfaces. 111 struct AttrInterfaceGenerator : public InterfaceGenerator { 112 AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) 113 : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), 114 os) { 115 valueType = "::mlir::Attribute"; 116 interfaceBaseType = "AttributeInterface"; 117 valueTemplate = "ConcreteAttr"; 118 StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())"; 119 nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode); 120 traitMethodFmt.addSubst("_attr", 121 "(*static_cast<const ConcreteAttr *>(this))"); 122 extraDeclsFmt.addSubst("_attr", "(*this)"); 123 } 124 }; 125 /// A specialized generator for operation interfaces. 126 struct OpInterfaceGenerator : public InterfaceGenerator { 127 OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) 128 : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { 129 valueType = "::mlir::Operation *"; 130 interfaceBaseType = "OpInterface"; 131 valueTemplate = "ConcreteOp"; 132 StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))"; 133 nonStaticMethodFmt.addSubst("_this", "impl") 134 .withOp(castCode) 135 .withSelf(castCode); 136 traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))"); 137 extraDeclsFmt.withOp("(*this)"); 138 } 139 }; 140 /// A specialized generator for type interfaces. 141 struct TypeInterfaceGenerator : public InterfaceGenerator { 142 TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) 143 : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), 144 os) { 145 valueType = "::mlir::Type"; 146 interfaceBaseType = "TypeInterface"; 147 valueTemplate = "ConcreteType"; 148 StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())"; 149 nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode); 150 traitMethodFmt.addSubst("_type", 151 "(*static_cast<const ConcreteType *>(this))"); 152 extraDeclsFmt.addSubst("_type", "(*this)"); 153 } 154 }; 155 } // namespace 156 157 //===----------------------------------------------------------------------===// 158 // GEN: Interface definitions 159 //===----------------------------------------------------------------------===// 160 161 static void emitInterfaceDef(const Interface &interface, StringRef valueType, 162 raw_ostream &os) { 163 StringRef interfaceName = interface.getName(); 164 StringRef cppNamespace = interface.getCppNamespace(); 165 cppNamespace.consume_front("::"); 166 167 // Insert the method definitions. 168 bool isOpInterface = isa<OpInterface>(interface); 169 for (auto &method : interface.getMethods()) { 170 emitCPPType(method.getReturnType(), os); 171 if (!cppNamespace.empty()) 172 os << cppNamespace << "::"; 173 os << interfaceName << "::"; 174 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 175 /*addConst=*/!isOpInterface); 176 177 // Forward to the method on the concrete operation type. 178 os << " {\n return getImpl()->" << method.getName() << '('; 179 if (!method.isStatic()) { 180 os << "getImpl(), "; 181 os << (isOpInterface ? "getOperation()" : "*this"); 182 os << (method.arg_empty() ? "" : ", "); 183 } 184 llvm::interleaveComma( 185 method.getArguments(), os, 186 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 187 os << ");\n }\n"; 188 } 189 } 190 191 bool InterfaceGenerator::emitInterfaceDefs() { 192 llvm::emitSourceFileHeader("Interface Definitions", os); 193 194 for (const auto *def : defs) 195 emitInterfaceDef(Interface(def), valueType, os); 196 return false; 197 } 198 199 //===----------------------------------------------------------------------===// 200 // GEN: Interface declarations 201 //===----------------------------------------------------------------------===// 202 203 void InterfaceGenerator::emitConceptDecl(Interface &interface) { 204 os << " struct Concept {\n"; 205 206 // Insert each of the pure virtual concept methods. 207 for (auto &method : interface.getMethods()) { 208 os << " "; 209 emitCPPType(method.getReturnType(), os); 210 os << "(*" << method.getName() << ")("; 211 if (!method.isStatic()) { 212 os << "const Concept *impl, "; 213 emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); 214 } 215 llvm::interleaveComma( 216 method.getArguments(), os, 217 [&](const InterfaceMethod::Argument &arg) { os << arg.type; }); 218 os << ");\n"; 219 } 220 os << " };\n"; 221 } 222 223 void InterfaceGenerator::emitModelDecl(Interface &interface) { 224 // Emit the basic model and the fallback model. 225 for (const char *modelClass : {"Model", "FallbackModel"}) { 226 os << " template<typename " << valueTemplate << ">\n"; 227 os << " class " << modelClass << " : public Concept {\n public:\n"; 228 os << " using Interface = " << interface.getCppNamespace() 229 << (interface.getCppNamespace().empty() ? "" : "::") 230 << interface.getName() << ";\n"; 231 os << " " << modelClass << "() : Concept{"; 232 llvm::interleaveComma( 233 interface.getMethods(), os, 234 [&](const InterfaceMethod &method) { os << method.getName(); }); 235 os << "} {}\n\n"; 236 237 // Insert each of the virtual method overrides. 238 for (auto &method : interface.getMethods()) { 239 emitCPPType(method.getReturnType(), os << " static inline "); 240 emitMethodNameAndArgs(method, os, valueType, 241 /*addThisArg=*/!method.isStatic(), 242 /*addConst=*/false); 243 os << ";\n"; 244 } 245 os << " };\n"; 246 } 247 248 // Emit the template for the external model. 249 os << " template<typename ConcreteModel, typename " << valueTemplate 250 << ">\n"; 251 os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n"; 252 os << " public:\n"; 253 254 // Emit declarations for methods that have default implementations. Other 255 // methods are expected to be implemented by the concrete derived model. 256 for (auto &method : interface.getMethods()) { 257 if (!method.getDefaultImplementation()) 258 continue; 259 os << " "; 260 if (method.isStatic()) 261 os << "static "; 262 emitCPPType(method.getReturnType(), os); 263 os << method.getName() << "("; 264 if (!method.isStatic()) { 265 emitCPPType(valueType, os); 266 os << "tablegen_opaque_val"; 267 if (!method.arg_empty()) 268 os << ", "; 269 } 270 llvm::interleaveComma(method.getArguments(), os, 271 [&](const InterfaceMethod::Argument &arg) { 272 emitCPPType(arg.type, os); 273 os << arg.name; 274 }); 275 os << ")"; 276 if (!method.isStatic()) 277 os << " const"; 278 os << ";\n"; 279 } 280 os << " };\n"; 281 } 282 283 void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { 284 for (auto &method : interface.getMethods()) { 285 os << "template<typename " << valueTemplate << ">\n"; 286 emitCPPType(method.getReturnType(), os); 287 os << "detail::" << interface.getName() << "InterfaceTraits::Model<" 288 << valueTemplate << ">::"; 289 emitMethodNameAndArgs(method, os, valueType, 290 /*addThisArg=*/!method.isStatic(), 291 /*addConst=*/false); 292 os << " {\n "; 293 294 // Check for a provided body to the function. 295 if (Optional<StringRef> body = method.getBody()) { 296 if (method.isStatic()) 297 os << body->trim(); 298 else 299 os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); 300 os << "\n}\n"; 301 continue; 302 } 303 304 // Forward to the method on the concrete operation type. 305 if (method.isStatic()) 306 os << "return " << valueTemplate << "::"; 307 else 308 os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt); 309 310 // Add the arguments to the call. 311 os << method.getName() << '('; 312 llvm::interleaveComma( 313 method.getArguments(), os, 314 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 315 os << ");\n}\n"; 316 } 317 318 for (auto &method : interface.getMethods()) { 319 os << "template<typename " << valueTemplate << ">\n"; 320 emitCPPType(method.getReturnType(), os); 321 os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" 322 << valueTemplate << ">::"; 323 emitMethodNameAndArgs(method, os, valueType, 324 /*addThisArg=*/!method.isStatic(), 325 /*addConst=*/false); 326 os << " {\n "; 327 328 // Forward to the method on the concrete Model implementation. 329 if (method.isStatic()) 330 os << "return " << valueTemplate << "::"; 331 else 332 os << "return static_cast<const " << valueTemplate << " *>(impl)->"; 333 334 // Add the arguments to the call. 335 os << method.getName() << '('; 336 if (!method.isStatic()) 337 os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); 338 llvm::interleaveComma( 339 method.getArguments(), os, 340 [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); 341 os << ");\n}\n"; 342 } 343 344 // Emit default implementations for the external model. 345 for (auto &method : interface.getMethods()) { 346 if (!method.getDefaultImplementation()) 347 continue; 348 os << "template<typename ConcreteModel, typename " << valueTemplate 349 << ">\n"; 350 emitCPPType(method.getReturnType(), os); 351 os << "detail::" << interface.getName() 352 << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate 353 << ">::"; 354 355 os << method.getName() << "("; 356 if (!method.isStatic()) { 357 emitCPPType(valueType, os); 358 os << "tablegen_opaque_val"; 359 if (!method.arg_empty()) 360 os << ", "; 361 } 362 llvm::interleaveComma(method.getArguments(), os, 363 [&](const InterfaceMethod::Argument &arg) { 364 emitCPPType(arg.type, os); 365 os << arg.name; 366 }); 367 os << ")"; 368 if (!method.isStatic()) 369 os << " const"; 370 371 os << " {\n"; 372 373 // Use the empty context for static methods. 374 tblgen::FmtContext ctx; 375 os << tblgen::tgfmt(method.getDefaultImplementation()->trim(), 376 method.isStatic() ? &ctx : &nonStaticMethodFmt); 377 os << "\n}\n"; 378 } 379 } 380 381 void InterfaceGenerator::emitTraitDecl(Interface &interface, 382 StringRef interfaceName, 383 StringRef interfaceTraitsName) { 384 os << llvm::formatv(" template <typename {3}>\n" 385 " struct {0}Trait : public ::mlir::{2}<{0}," 386 " detail::{1}>::Trait<{3}> {{\n", 387 interfaceName, interfaceTraitsName, interfaceBaseType, 388 valueTemplate); 389 390 // Insert the default implementation for any methods. 391 bool isOpInterface = isa<OpInterface>(interface); 392 for (auto &method : interface.getMethods()) { 393 // Flag interface methods named verifyTrait. 394 if (method.getName() == "verifyTrait") 395 PrintFatalError( 396 formatv("'verifyTrait' method cannot be specified as interface " 397 "method for '{0}'; use the 'verify' field instead", 398 interfaceName)); 399 auto defaultImpl = method.getDefaultImplementation(); 400 if (!defaultImpl) 401 continue; 402 403 os << " " << (method.isStatic() ? "static " : ""); 404 emitCPPType(method.getReturnType(), os); 405 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 406 /*addConst=*/!isOpInterface && !method.isStatic()); 407 os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) 408 << "\n }\n"; 409 } 410 411 if (auto verify = interface.getVerify()) { 412 assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'"); 413 414 tblgen::FmtContext verifyCtx; 415 verifyCtx.withOp("op"); 416 os << " static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) " 417 "{\n " 418 << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n"; 419 } 420 if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration()) 421 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; 422 if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration()) 423 os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; 424 425 os << " };\n"; 426 } 427 428 void InterfaceGenerator::emitInterfaceDecl(Interface interface) { 429 llvm::SmallVector<StringRef, 2> namespaces; 430 llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); 431 for (StringRef ns : namespaces) 432 os << "namespace " << ns << " {\n"; 433 434 StringRef interfaceName = interface.getName(); 435 auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); 436 437 // Emit a forward declaration of the interface class so that it becomes usable 438 // in the signature of its methods. 439 os << "class " << interfaceName << ";\n"; 440 441 // Emit the traits struct containing the concept and model declarations. 442 os << "namespace detail {\n" 443 << "struct " << interfaceTraitsName << " {\n"; 444 emitConceptDecl(interface); 445 emitModelDecl(interface); 446 os << "};"; 447 448 // Emit the derived trait for the interface. 449 os << "template <typename " << valueTemplate << ">\n"; 450 os << "struct " << interface.getName() << "Trait;\n"; 451 452 os << "\n} // namespace detail\n"; 453 454 // Emit the main interface class declaration. 455 os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" 456 "public:\n" 457 " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", 458 interfaceName, interfaceName, interfaceTraitsName, 459 interfaceBaseType); 460 461 // Emit a utility wrapper trait class. 462 os << llvm::formatv(" template <typename {1}>\n" 463 " struct Trait : public detail::{0}Trait<{1}> {{};\n", 464 interfaceName, valueTemplate); 465 466 // Insert the method declarations. 467 bool isOpInterface = isa<OpInterface>(interface); 468 for (auto &method : interface.getMethods()) { 469 emitCPPType(method.getReturnType(), os << " "); 470 emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, 471 /*addConst=*/!isOpInterface); 472 os << ";\n"; 473 } 474 475 // Emit any extra declarations. 476 if (Optional<StringRef> extraDecls = interface.getExtraClassDeclaration()) 477 os << *extraDecls << "\n"; 478 if (Optional<StringRef> extraDecls = 479 interface.getExtraSharedClassDeclaration()) 480 os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt); 481 482 os << "};\n"; 483 484 os << "namespace detail {\n"; 485 emitTraitDecl(interface, interfaceName, interfaceTraitsName); 486 os << "}// namespace detail\n"; 487 488 emitModelMethodsDef(interface); 489 490 for (StringRef ns : llvm::reverse(namespaces)) 491 os << "} // namespace " << ns << "\n"; 492 } 493 494 bool InterfaceGenerator::emitInterfaceDecls() { 495 llvm::emitSourceFileHeader("Interface Declarations", os); 496 497 for (const auto *def : defs) 498 emitInterfaceDecl(Interface(def)); 499 return false; 500 } 501 502 //===----------------------------------------------------------------------===// 503 // GEN: Interface documentation 504 //===----------------------------------------------------------------------===// 505 506 static void emitInterfaceDoc(const llvm::Record &interfaceDef, 507 raw_ostream &os) { 508 Interface interface(&interfaceDef); 509 510 // Emit the interface name followed by the description. 511 os << "## " << interface.getName() << " (`" << interfaceDef.getName() 512 << "`)\n\n"; 513 if (auto description = interface.getDescription()) 514 mlir::tblgen::emitDescription(*description, os); 515 516 // Emit the methods required by the interface. 517 os << "\n### Methods:\n"; 518 for (const auto &method : interface.getMethods()) { 519 // Emit the method name. 520 os << "#### `" << method.getName() << "`\n\n```c++\n"; 521 522 // Emit the method signature. 523 if (method.isStatic()) 524 os << "static "; 525 emitCPPType(method.getReturnType(), os) << method.getName() << '('; 526 llvm::interleaveComma(method.getArguments(), os, 527 [&](const InterfaceMethod::Argument &arg) { 528 emitCPPType(arg.type, os) << arg.name; 529 }); 530 os << ");\n```\n"; 531 532 // Emit the description. 533 if (auto description = method.getDescription()) 534 mlir::tblgen::emitDescription(*description, os); 535 536 // If the body is not provided, this method must be provided by the user. 537 if (!method.getBody()) 538 os << "\nNOTE: This method *must* be implemented by the user.\n\n"; 539 } 540 } 541 542 bool InterfaceGenerator::emitInterfaceDocs() { 543 os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; 544 os << "# " << interfaceBaseType << " definitions\n"; 545 546 for (const auto *def : defs) 547 emitInterfaceDoc(*def, os); 548 return false; 549 } 550 551 //===----------------------------------------------------------------------===// 552 // GEN: Interface registration hooks 553 //===----------------------------------------------------------------------===// 554 555 namespace { 556 template <typename GeneratorT> 557 struct InterfaceGenRegistration { 558 InterfaceGenRegistration(StringRef genArg, StringRef genDesc) 559 : genDeclArg(("gen-" + genArg + "-interface-decls").str()), 560 genDefArg(("gen-" + genArg + "-interface-defs").str()), 561 genDocArg(("gen-" + genArg + "-interface-docs").str()), 562 genDeclDesc(("Generate " + genDesc + " interface declarations").str()), 563 genDefDesc(("Generate " + genDesc + " interface definitions").str()), 564 genDocDesc(("Generate " + genDesc + " interface documentation").str()), 565 genDecls(genDeclArg, genDeclDesc, 566 [](const llvm::RecordKeeper &records, raw_ostream &os) { 567 return GeneratorT(records, os).emitInterfaceDecls(); 568 }), 569 genDefs(genDefArg, genDefDesc, 570 [](const llvm::RecordKeeper &records, raw_ostream &os) { 571 return GeneratorT(records, os).emitInterfaceDefs(); 572 }), 573 genDocs(genDocArg, genDocDesc, 574 [](const llvm::RecordKeeper &records, raw_ostream &os) { 575 return GeneratorT(records, os).emitInterfaceDocs(); 576 }) {} 577 578 std::string genDeclArg, genDefArg, genDocArg; 579 std::string genDeclDesc, genDefDesc, genDocDesc; 580 mlir::GenRegistration genDecls, genDefs, genDocs; 581 }; 582 } // namespace 583 584 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr", 585 "attribute"); 586 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op"); 587 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type"); 588