xref: /llvm-project/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (revision e813750354bbc08551cf23ff559a54b4a9ea1f29)
1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpInterfacesGen generates definitions for operation interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DocGenUtilities.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Interfaces.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 #include "llvm/TableGen/TableGenBackend.h"
24 
25 using namespace mlir;
26 using llvm::Record;
27 using llvm::RecordKeeper;
28 using mlir::tblgen::Interface;
29 using mlir::tblgen::InterfaceMethod;
30 using mlir::tblgen::OpInterface;
31 
32 /// Emit a string corresponding to a C++ type, followed by a space if necessary.
33 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
34   type = type.trim();
35   os << type;
36   if (type.back() != '&' && type.back() != '*')
37     os << " ";
38   return os;
39 }
40 
41 /// Emit the method name and argument list for the given method. If 'addThisArg'
42 /// is true, then an argument is added to the beginning of the argument list for
43 /// the concrete value.
44 static void emitMethodNameAndArgs(const InterfaceMethod &method,
45                                   raw_ostream &os, StringRef valueType,
46                                   bool addThisArg, bool addConst) {
47   os << method.getName() << '(';
48   if (addThisArg) {
49     if (addConst)
50       os << "const ";
51     os << "const Concept *impl, ";
52     emitCPPType(valueType, os)
53         << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
54   }
55   llvm::interleaveComma(method.getArguments(), os,
56                         [&](const InterfaceMethod::Argument &arg) {
57                           os << arg.type << " " << arg.name;
58                         });
59   os << ')';
60   if (addConst)
61     os << " const";
62 }
63 
64 /// Get an array of all OpInterface definitions but exclude those subclassing
65 /// "DeclareOpInterfaceMethods".
66 static std::vector<const Record *>
67 getAllInterfaceDefinitions(const RecordKeeper &records, StringRef name) {
68   std::vector<const Record *> defs =
69       records.getAllDerivedDefinitions((name + "Interface").str());
70 
71   std::string declareName = ("Declare" + name + "InterfaceMethods").str();
72   llvm::erase_if(defs, [&](const Record *def) {
73     // Ignore any "declare methods" interfaces.
74     if (def->isSubClassOf(declareName))
75       return true;
76     // Ignore interfaces defined outside of the top-level file.
77     return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
78            llvm::SrcMgr.getMainFileID();
79   });
80   return defs;
81 }
82 
83 namespace {
84 /// This struct is the base generator used when processing tablegen interfaces.
85 class InterfaceGenerator {
86 public:
87   bool emitInterfaceDefs();
88   bool emitInterfaceDecls();
89   bool emitInterfaceDocs();
90 
91 protected:
92   InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
93       : defs(std::move(defs)), os(os) {}
94 
95   void emitConceptDecl(const Interface &interface);
96   void emitModelDecl(const Interface &interface);
97   void emitModelMethodsDef(const Interface &interface);
98   void emitTraitDecl(const Interface &interface, StringRef interfaceName,
99                      StringRef interfaceTraitsName);
100   void emitInterfaceDecl(const Interface &interface);
101 
102   /// The set of interface records to emit.
103   std::vector<const Record *> defs;
104   // The stream to emit to.
105   raw_ostream &os;
106   /// The C++ value type of the interface, e.g. Operation*.
107   StringRef valueType;
108   /// The C++ base interface type.
109   StringRef interfaceBaseType;
110   /// The name of the typename for the value template.
111   StringRef valueTemplate;
112   /// The name of the substituion variable for the value.
113   StringRef substVar;
114   /// The format context to use for methods.
115   tblgen::FmtContext nonStaticMethodFmt;
116   tblgen::FmtContext traitMethodFmt;
117   tblgen::FmtContext extraDeclsFmt;
118 };
119 
120 /// A specialized generator for attribute interfaces.
121 struct AttrInterfaceGenerator : public InterfaceGenerator {
122   AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
123       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
124     valueType = "::mlir::Attribute";
125     interfaceBaseType = "AttributeInterface";
126     valueTemplate = "ConcreteAttr";
127     substVar = "_attr";
128     StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
129     nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
130     traitMethodFmt.addSubst(substVar,
131                             "(*static_cast<const ConcreteAttr *>(this))");
132     extraDeclsFmt.addSubst(substVar, "(*this)");
133   }
134 };
135 /// A specialized generator for operation interfaces.
136 struct OpInterfaceGenerator : public InterfaceGenerator {
137   OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
138       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
139     valueType = "::mlir::Operation *";
140     interfaceBaseType = "OpInterface";
141     valueTemplate = "ConcreteOp";
142     substVar = "_op";
143     StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
144     nonStaticMethodFmt.addSubst("_this", "impl")
145         .addSubst(substVar, castCode)
146         .withSelf(castCode);
147     traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
148     extraDeclsFmt.addSubst(substVar, "(*this)");
149   }
150 };
151 /// A specialized generator for type interfaces.
152 struct TypeInterfaceGenerator : public InterfaceGenerator {
153   TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
154       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
155     valueType = "::mlir::Type";
156     interfaceBaseType = "TypeInterface";
157     valueTemplate = "ConcreteType";
158     substVar = "_type";
159     StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
160     nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
161     traitMethodFmt.addSubst(substVar,
162                             "(*static_cast<const ConcreteType *>(this))");
163     extraDeclsFmt.addSubst(substVar, "(*this)");
164   }
165 };
166 } // namespace
167 
168 //===----------------------------------------------------------------------===//
169 // GEN: Interface definitions
170 //===----------------------------------------------------------------------===//
171 
172 static void emitInterfaceMethodDoc(const InterfaceMethod &method,
173                                    raw_ostream &os, StringRef prefix = "") {
174   if (std::optional<StringRef> description = method.getDescription())
175     tblgen::emitDescriptionComment(*description, os, prefix);
176 }
177 static void emitInterfaceDefMethods(StringRef interfaceQualName,
178                                     const Interface &interface,
179                                     StringRef valueType, const Twine &implValue,
180                                     raw_ostream &os, bool isOpInterface) {
181   for (auto &method : interface.getMethods()) {
182     emitInterfaceMethodDoc(method, os);
183     emitCPPType(method.getReturnType(), os);
184     os << interfaceQualName << "::";
185     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
186                           /*addConst=*/!isOpInterface);
187 
188     // Forward to the method on the concrete operation type.
189     os << " {\n      return " << implValue << "->" << method.getName() << '(';
190     if (!method.isStatic()) {
191       os << implValue << ", ";
192       os << (isOpInterface ? "getOperation()" : "*this");
193       os << (method.arg_empty() ? "" : ", ");
194     }
195     llvm::interleaveComma(
196         method.getArguments(), os,
197         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
198     os << ");\n  }\n";
199   }
200 }
201 
202 static void emitInterfaceDef(const Interface &interface, StringRef valueType,
203                              raw_ostream &os) {
204   std::string interfaceQualNameStr = interface.getFullyQualifiedName();
205   StringRef interfaceQualName = interfaceQualNameStr;
206   interfaceQualName.consume_front("::");
207 
208   // Insert the method definitions.
209   bool isOpInterface = isa<OpInterface>(interface);
210   emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
211                           os, isOpInterface);
212 
213   // Insert the method definitions for base classes.
214   for (auto &base : interface.getBaseInterfaces()) {
215     emitInterfaceDefMethods(interfaceQualName, base, valueType,
216                             "getImpl()->impl" + base.getName(), os,
217                             isOpInterface);
218   }
219 }
220 
221 bool InterfaceGenerator::emitInterfaceDefs() {
222   llvm::emitSourceFileHeader("Interface Definitions", os);
223 
224   for (const auto *def : defs)
225     emitInterfaceDef(Interface(def), valueType, os);
226   return false;
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // GEN: Interface declarations
231 //===----------------------------------------------------------------------===//
232 
233 void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
234   os << "  struct Concept {\n";
235 
236   // Insert each of the pure virtual concept methods.
237   os << "    /// The methods defined by the interface.\n";
238   for (auto &method : interface.getMethods()) {
239     os << "    ";
240     emitCPPType(method.getReturnType(), os);
241     os << "(*" << method.getName() << ")(";
242     if (!method.isStatic()) {
243       os << "const Concept *impl, ";
244       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
245     }
246     llvm::interleaveComma(
247         method.getArguments(), os,
248         [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
249     os << ");\n";
250   }
251 
252   // Insert a field containing a concept for each of the base interfaces.
253   auto baseInterfaces = interface.getBaseInterfaces();
254   if (!baseInterfaces.empty()) {
255     os << "    /// The base classes of this interface.\n";
256     for (const auto &base : interface.getBaseInterfaces()) {
257       os << "    const " << base.getFullyQualifiedName() << "::Concept *impl"
258          << base.getName() << " = nullptr;\n";
259     }
260 
261     // Define an "initialize" method that allows for the initialization of the
262     // base class concepts.
263     os << "\n    void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
264           "&interfaceMap) {\n";
265     std::string interfaceQualName = interface.getFullyQualifiedName();
266     for (const auto &base : interface.getBaseInterfaces()) {
267       StringRef baseName = base.getName();
268       std::string baseQualName = base.getFullyQualifiedName();
269       os << "      impl" << baseName << " = interfaceMap.lookup<"
270          << baseQualName << ">();\n"
271          << "      assert(impl" << baseName << " && \"`" << interfaceQualName
272          << "` expected its base interface `" << baseQualName
273          << "` to be registered\");\n";
274     }
275     os << "    }\n";
276   }
277 
278   os << "  };\n";
279 }
280 
281 void InterfaceGenerator::emitModelDecl(const Interface &interface) {
282   // Emit the basic model and the fallback model.
283   for (const char *modelClass : {"Model", "FallbackModel"}) {
284     os << "  template<typename " << valueTemplate << ">\n";
285     os << "  class " << modelClass << " : public Concept {\n  public:\n";
286     os << "    using Interface = " << interface.getFullyQualifiedName()
287        << ";\n";
288     os << "    " << modelClass << "() : Concept{";
289     llvm::interleaveComma(
290         interface.getMethods(), os,
291         [&](const InterfaceMethod &method) { os << method.getName(); });
292     os << "} {}\n\n";
293 
294     // Insert each of the virtual method overrides.
295     for (auto &method : interface.getMethods()) {
296       emitCPPType(method.getReturnType(), os << "    static inline ");
297       emitMethodNameAndArgs(method, os, valueType,
298                             /*addThisArg=*/!method.isStatic(),
299                             /*addConst=*/false);
300       os << ";\n";
301     }
302     os << "  };\n";
303   }
304 
305   // Emit the template for the external model.
306   os << "  template<typename ConcreteModel, typename " << valueTemplate
307      << ">\n";
308   os << "  class ExternalModel : public FallbackModel<ConcreteModel> {\n";
309   os << "  public:\n";
310   os << "    using ConcreteEntity = " << valueTemplate << ";\n";
311 
312   // Emit declarations for methods that have default implementations. Other
313   // methods are expected to be implemented by the concrete derived model.
314   for (auto &method : interface.getMethods()) {
315     if (!method.getDefaultImplementation())
316       continue;
317     os << "    ";
318     if (method.isStatic())
319       os << "static ";
320     emitCPPType(method.getReturnType(), os);
321     os << method.getName() << "(";
322     if (!method.isStatic()) {
323       emitCPPType(valueType, os);
324       os << "tablegen_opaque_val";
325       if (!method.arg_empty())
326         os << ", ";
327     }
328     llvm::interleaveComma(method.getArguments(), os,
329                           [&](const InterfaceMethod::Argument &arg) {
330                             emitCPPType(arg.type, os);
331                             os << arg.name;
332                           });
333     os << ")";
334     if (!method.isStatic())
335       os << " const";
336     os << ";\n";
337   }
338   os << "  };\n";
339 }
340 
341 void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
342   llvm::SmallVector<StringRef, 2> namespaces;
343   llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
344   for (StringRef ns : namespaces)
345     os << "namespace " << ns << " {\n";
346 
347   for (auto &method : interface.getMethods()) {
348     os << "template<typename " << valueTemplate << ">\n";
349     emitCPPType(method.getReturnType(), os);
350     os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
351        << valueTemplate << ">::";
352     emitMethodNameAndArgs(method, os, valueType,
353                           /*addThisArg=*/!method.isStatic(),
354                           /*addConst=*/false);
355     os << " {\n  ";
356 
357     // Check for a provided body to the function.
358     if (std::optional<StringRef> body = method.getBody()) {
359       if (method.isStatic())
360         os << body->trim();
361       else
362         os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
363       os << "\n}\n";
364       continue;
365     }
366 
367     // Forward to the method on the concrete operation type.
368     if (method.isStatic())
369       os << "return " << valueTemplate << "::";
370     else
371       os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
372 
373     // Add the arguments to the call.
374     os << method.getName() << '(';
375     llvm::interleaveComma(
376         method.getArguments(), os,
377         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
378     os << ");\n}\n";
379   }
380 
381   for (auto &method : interface.getMethods()) {
382     os << "template<typename " << valueTemplate << ">\n";
383     emitCPPType(method.getReturnType(), os);
384     os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
385        << valueTemplate << ">::";
386     emitMethodNameAndArgs(method, os, valueType,
387                           /*addThisArg=*/!method.isStatic(),
388                           /*addConst=*/false);
389     os << " {\n  ";
390 
391     // Forward to the method on the concrete Model implementation.
392     if (method.isStatic())
393       os << "return " << valueTemplate << "::";
394     else
395       os << "return static_cast<const " << valueTemplate << " *>(impl)->";
396 
397     // Add the arguments to the call.
398     os << method.getName() << '(';
399     if (!method.isStatic())
400       os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
401     llvm::interleaveComma(
402         method.getArguments(), os,
403         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
404     os << ");\n}\n";
405   }
406 
407   // Emit default implementations for the external model.
408   for (auto &method : interface.getMethods()) {
409     if (!method.getDefaultImplementation())
410       continue;
411     os << "template<typename ConcreteModel, typename " << valueTemplate
412        << ">\n";
413     emitCPPType(method.getReturnType(), os);
414     os << "detail::" << interface.getName()
415        << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
416        << ">::";
417 
418     os << method.getName() << "(";
419     if (!method.isStatic()) {
420       emitCPPType(valueType, os);
421       os << "tablegen_opaque_val";
422       if (!method.arg_empty())
423         os << ", ";
424     }
425     llvm::interleaveComma(method.getArguments(), os,
426                           [&](const InterfaceMethod::Argument &arg) {
427                             emitCPPType(arg.type, os);
428                             os << arg.name;
429                           });
430     os << ")";
431     if (!method.isStatic())
432       os << " const";
433 
434     os << " {\n";
435 
436     // Use the empty context for static methods.
437     tblgen::FmtContext ctx;
438     os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
439                         method.isStatic() ? &ctx : &nonStaticMethodFmt);
440     os << "\n}\n";
441   }
442 
443   for (StringRef ns : llvm::reverse(namespaces))
444     os << "} // namespace " << ns << "\n";
445 }
446 
447 void InterfaceGenerator::emitTraitDecl(const Interface &interface,
448                                        StringRef interfaceName,
449                                        StringRef interfaceTraitsName) {
450   os << llvm::formatv("  template <typename {3}>\n"
451                       "  struct {0}Trait : public ::mlir::{2}<{0},"
452                       " detail::{1}>::Trait<{3}> {{\n",
453                       interfaceName, interfaceTraitsName, interfaceBaseType,
454                       valueTemplate);
455 
456   // Insert the default implementation for any methods.
457   bool isOpInterface = isa<OpInterface>(interface);
458   for (auto &method : interface.getMethods()) {
459     // Flag interface methods named verifyTrait.
460     if (method.getName() == "verifyTrait")
461       PrintFatalError(
462           formatv("'verifyTrait' method cannot be specified as interface "
463                   "method for '{0}'; use the 'verify' field instead",
464                   interfaceName));
465     auto defaultImpl = method.getDefaultImplementation();
466     if (!defaultImpl)
467       continue;
468 
469     emitInterfaceMethodDoc(method, os, "    ");
470     os << "    " << (method.isStatic() ? "static " : "");
471     emitCPPType(method.getReturnType(), os);
472     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
473                           /*addConst=*/!isOpInterface && !method.isStatic());
474     os << " {\n      " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
475        << "\n    }\n";
476   }
477 
478   if (auto verify = interface.getVerify()) {
479     assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
480 
481     tblgen::FmtContext verifyCtx;
482     verifyCtx.addSubst("_op", "op");
483     os << llvm::formatv(
484               "    static ::llvm::LogicalResult {0}(::mlir::Operation *op) ",
485               (interface.verifyWithRegions() ? "verifyRegionTrait"
486                                              : "verifyTrait"))
487        << "{\n      " << tblgen::tgfmt(verify->trim(), &verifyCtx)
488        << "\n    }\n";
489   }
490   if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
491     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
492   if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
493     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
494 
495   os << "  };\n";
496 }
497 
498 static void emitInterfaceDeclMethods(const Interface &interface,
499                                      raw_ostream &os, StringRef valueType,
500                                      bool isOpInterface,
501                                      tblgen::FmtContext &extraDeclsFmt) {
502   for (auto &method : interface.getMethods()) {
503     emitInterfaceMethodDoc(method, os, "  ");
504     emitCPPType(method.getReturnType(), os << "  ");
505     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
506                           /*addConst=*/!isOpInterface);
507     os << ";\n";
508   }
509 
510   // Emit any extra declarations.
511   if (std::optional<StringRef> extraDecls =
512           interface.getExtraClassDeclaration())
513     os << extraDecls->rtrim() << "\n";
514   if (std::optional<StringRef> extraDecls =
515           interface.getExtraSharedClassDeclaration())
516     os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
517 }
518 
519 void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
520   llvm::SmallVector<StringRef, 2> namespaces;
521   llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
522   for (StringRef ns : namespaces)
523     os << "namespace " << ns << " {\n";
524 
525   StringRef interfaceName = interface.getName();
526   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
527 
528   // Emit a forward declaration of the interface class so that it becomes usable
529   // in the signature of its methods.
530   os << "class " << interfaceName << ";\n";
531 
532   // Emit the traits struct containing the concept and model declarations.
533   os << "namespace detail {\n"
534      << "struct " << interfaceTraitsName << " {\n";
535   emitConceptDecl(interface);
536   emitModelDecl(interface);
537   os << "};\n";
538 
539   // Emit the derived trait for the interface.
540   os << "template <typename " << valueTemplate << ">\n";
541   os << "struct " << interface.getName() << "Trait;\n";
542 
543   os << "\n} // namespace detail\n";
544 
545   // Emit the main interface class declaration.
546   os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
547                       "public:\n"
548                       "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
549                       interfaceName, interfaceName, interfaceTraitsName,
550                       interfaceBaseType);
551 
552   // Emit a utility wrapper trait class.
553   os << llvm::formatv("  template <typename {1}>\n"
554                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
555                       interfaceName, valueTemplate);
556 
557   // Insert the method declarations.
558   bool isOpInterface = isa<OpInterface>(interface);
559   emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
560                            extraDeclsFmt);
561 
562   // Insert the method declarations for base classes.
563   for (auto &base : interface.getBaseInterfaces()) {
564     std::string baseQualName = base.getFullyQualifiedName();
565     os << "  //"
566           "===---------------------------------------------------------------"
567           "-===//\n"
568        << "  // Inherited from " << baseQualName << "\n"
569        << "  //"
570           "===---------------------------------------------------------------"
571           "-===//\n\n";
572 
573     // Allow implicit conversion to the base interface.
574     os << "  operator " << baseQualName << " () const {\n"
575        << "    if (!*this) return nullptr;\n"
576        << "    return " << baseQualName << "(*this, getImpl()->impl"
577        << base.getName() << ");\n"
578        << "  }\n\n";
579 
580     // Inherit the base interface's methods.
581     emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
582   }
583 
584   // Emit classof code if necessary.
585   if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
586     auto extraClassOfFmt = tblgen::FmtContext();
587     extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
588     os << "  static bool classof(" << valueType << " base) {\n"
589        << "    auto* interface = getInterfaceFor(base);\n"
590        << "    if (!interface)\n"
591           "      return false;\n"
592           "    " << interfaceName << " odsInterfaceInstance(base, interface);\n"
593        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
594        << "\n  }\n";
595   }
596 
597   os << "};\n";
598 
599   os << "namespace detail {\n";
600   emitTraitDecl(interface, interfaceName, interfaceTraitsName);
601   os << "}// namespace detail\n";
602 
603   for (StringRef ns : llvm::reverse(namespaces))
604     os << "} // namespace " << ns << "\n";
605 }
606 
607 bool InterfaceGenerator::emitInterfaceDecls() {
608   llvm::emitSourceFileHeader("Interface Declarations", os);
609   // Sort according to ID, so defs are emitted in the order in which they appear
610   // in the Tablegen file.
611   std::vector<const Record *> sortedDefs(defs);
612   llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
613     return lhs->getID() < rhs->getID();
614   });
615   for (const Record *def : sortedDefs)
616     emitInterfaceDecl(Interface(def));
617   for (const Record *def : sortedDefs)
618     emitModelMethodsDef(Interface(def));
619   return false;
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // GEN: Interface documentation
624 //===----------------------------------------------------------------------===//
625 
626 static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
627   Interface interface(&interfaceDef);
628 
629   // Emit the interface name followed by the description.
630   os << "## " << interface.getName() << " (`" << interfaceDef.getName()
631      << "`)\n\n";
632   if (auto description = interface.getDescription())
633     mlir::tblgen::emitDescription(*description, os);
634 
635   // Emit the methods required by the interface.
636   os << "\n### Methods:\n";
637   for (const auto &method : interface.getMethods()) {
638     // Emit the method name.
639     os << "#### `" << method.getName() << "`\n\n```c++\n";
640 
641     // Emit the method signature.
642     if (method.isStatic())
643       os << "static ";
644     emitCPPType(method.getReturnType(), os) << method.getName() << '(';
645     llvm::interleaveComma(method.getArguments(), os,
646                           [&](const InterfaceMethod::Argument &arg) {
647                             emitCPPType(arg.type, os) << arg.name;
648                           });
649     os << ");\n```\n";
650 
651     // Emit the description.
652     if (auto description = method.getDescription())
653       mlir::tblgen::emitDescription(*description, os);
654 
655     // If the body is not provided, this method must be provided by the user.
656     if (!method.getBody())
657       os << "\nNOTE: This method *must* be implemented by the user.";
658 
659     os << "\n\n";
660   }
661 }
662 
663 bool InterfaceGenerator::emitInterfaceDocs() {
664   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
665   os << "# " << interfaceBaseType << " definitions\n";
666 
667   for (const auto *def : defs)
668     emitInterfaceDoc(*def, os);
669   return false;
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // GEN: Interface registration hooks
674 //===----------------------------------------------------------------------===//
675 
676 namespace {
677 template <typename GeneratorT>
678 struct InterfaceGenRegistration {
679   InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
680       : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
681         genDefArg(("gen-" + genArg + "-interface-defs").str()),
682         genDocArg(("gen-" + genArg + "-interface-docs").str()),
683         genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
684         genDefDesc(("Generate " + genDesc + " interface definitions").str()),
685         genDocDesc(("Generate " + genDesc + " interface documentation").str()),
686         genDecls(genDeclArg, genDeclDesc,
687                  [](const RecordKeeper &records, raw_ostream &os) {
688                    return GeneratorT(records, os).emitInterfaceDecls();
689                  }),
690         genDefs(genDefArg, genDefDesc,
691                 [](const RecordKeeper &records, raw_ostream &os) {
692                   return GeneratorT(records, os).emitInterfaceDefs();
693                 }),
694         genDocs(genDocArg, genDocDesc,
695                 [](const RecordKeeper &records, raw_ostream &os) {
696                   return GeneratorT(records, os).emitInterfaceDocs();
697                 }) {}
698 
699   std::string genDeclArg, genDefArg, genDocArg;
700   std::string genDeclDesc, genDefDesc, genDocDesc;
701   mlir::GenRegistration genDecls, genDefs, genDocs;
702 };
703 } // namespace
704 
705 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
706                                                                 "attribute");
707 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
708 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
709