xref: /llvm-project/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (revision d01f559ce9eedc7ab036b5a9e3d03fc861bc6a5d)
1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpInterfacesGen generates definitions for operation interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DocGenUtilities.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Interfaces.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 #include "llvm/TableGen/TableGenBackend.h"
24 
25 using namespace mlir;
26 using mlir::tblgen::Interface;
27 using mlir::tblgen::InterfaceMethod;
28 using mlir::tblgen::OpInterface;
29 
30 /// Emit a string corresponding to a C++ type, followed by a space if necessary.
31 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
32   type = type.trim();
33   os << type;
34   if (type.back() != '&' && type.back() != '*')
35     os << " ";
36   return os;
37 }
38 
39 /// Emit the method name and argument list for the given method. If 'addThisArg'
40 /// is true, then an argument is added to the beginning of the argument list for
41 /// the concrete value.
42 static void emitMethodNameAndArgs(const InterfaceMethod &method,
43                                   raw_ostream &os, StringRef valueType,
44                                   bool addThisArg, bool addConst) {
45   os << method.getName() << '(';
46   if (addThisArg) {
47     if (addConst)
48       os << "const ";
49     os << "const Concept *impl, ";
50     emitCPPType(valueType, os)
51         << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
52   }
53   llvm::interleaveComma(method.getArguments(), os,
54                         [&](const InterfaceMethod::Argument &arg) {
55                           os << arg.type << " " << arg.name;
56                         });
57   os << ')';
58   if (addConst)
59     os << " const";
60 }
61 
62 /// Get an array of all OpInterface definitions but exclude those subclassing
63 /// "DeclareOpInterfaceMethods".
64 static std::vector<llvm::Record *>
65 getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
66                            StringRef name) {
67   std::vector<llvm::Record *> defs =
68       recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
69 
70   std::string declareName = ("Declare" + name + "InterfaceMethods").str();
71   llvm::erase_if(defs, [&](const llvm::Record *def) {
72     // Ignore any "declare methods" interfaces.
73     if (def->isSubClassOf(declareName))
74       return true;
75     // Ignore interfaces defined outside of the top-level file.
76     return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
77            llvm::SrcMgr.getMainFileID();
78   });
79   return defs;
80 }
81 
82 namespace {
83 /// This struct is the base generator used when processing tablegen interfaces.
84 class InterfaceGenerator {
85 public:
86   bool emitInterfaceDefs();
87   bool emitInterfaceDecls();
88   bool emitInterfaceDocs();
89 
90 protected:
91   InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
92       : defs(std::move(defs)), os(os) {}
93 
94   void emitConceptDecl(const Interface &interface);
95   void emitModelDecl(const Interface &interface);
96   void emitModelMethodsDef(const Interface &interface);
97   void emitTraitDecl(const Interface &interface, StringRef interfaceName,
98                      StringRef interfaceTraitsName);
99   void emitInterfaceDecl(const Interface &interface);
100 
101   /// The set of interface records to emit.
102   std::vector<llvm::Record *> defs;
103   // The stream to emit to.
104   raw_ostream &os;
105   /// The C++ value type of the interface, e.g. Operation*.
106   StringRef valueType;
107   /// The C++ base interface type.
108   StringRef interfaceBaseType;
109   /// The name of the typename for the value template.
110   StringRef valueTemplate;
111   /// The name of the substituion variable for the value.
112   StringRef substVar;
113   /// The format context to use for methods.
114   tblgen::FmtContext nonStaticMethodFmt;
115   tblgen::FmtContext traitMethodFmt;
116   tblgen::FmtContext extraDeclsFmt;
117 };
118 
119 /// A specialized generator for attribute interfaces.
120 struct AttrInterfaceGenerator : public InterfaceGenerator {
121   AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
122       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
123     valueType = "::mlir::Attribute";
124     interfaceBaseType = "AttributeInterface";
125     valueTemplate = "ConcreteAttr";
126     substVar = "_attr";
127     StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
128     nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
129     traitMethodFmt.addSubst(substVar,
130                             "(*static_cast<const ConcreteAttr *>(this))");
131     extraDeclsFmt.addSubst(substVar, "(*this)");
132   }
133 };
134 /// A specialized generator for operation interfaces.
135 struct OpInterfaceGenerator : public InterfaceGenerator {
136   OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
137       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
138     valueType = "::mlir::Operation *";
139     interfaceBaseType = "OpInterface";
140     valueTemplate = "ConcreteOp";
141     substVar = "_op";
142     StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
143     nonStaticMethodFmt.addSubst("_this", "impl")
144         .addSubst(substVar, castCode)
145         .withSelf(castCode);
146     traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
147     extraDeclsFmt.addSubst(substVar, "(*this)");
148   }
149 };
150 /// A specialized generator for type interfaces.
151 struct TypeInterfaceGenerator : public InterfaceGenerator {
152   TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
153       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
154     valueType = "::mlir::Type";
155     interfaceBaseType = "TypeInterface";
156     valueTemplate = "ConcreteType";
157     substVar = "_type";
158     StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
159     nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
160     traitMethodFmt.addSubst(substVar,
161                             "(*static_cast<const ConcreteType *>(this))");
162     extraDeclsFmt.addSubst(substVar, "(*this)");
163   }
164 };
165 } // namespace
166 
167 //===----------------------------------------------------------------------===//
168 // GEN: Interface definitions
169 //===----------------------------------------------------------------------===//
170 
171 static void emitInterfaceMethodDoc(const InterfaceMethod &method,
172                                    raw_ostream &os, StringRef prefix = "") {
173   if (std::optional<StringRef> description = method.getDescription())
174     tblgen::emitDescriptionComment(*description, os, prefix);
175 }
176 static void emitInterfaceDefMethods(StringRef interfaceQualName,
177                                     const Interface &interface,
178                                     StringRef valueType, const Twine &implValue,
179                                     raw_ostream &os, bool isOpInterface) {
180   for (auto &method : interface.getMethods()) {
181     emitInterfaceMethodDoc(method, os);
182     emitCPPType(method.getReturnType(), os);
183     os << interfaceQualName << "::";
184     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
185                           /*addConst=*/!isOpInterface);
186 
187     // Forward to the method on the concrete operation type.
188     os << " {\n      return " << implValue << "->" << method.getName() << '(';
189     if (!method.isStatic()) {
190       os << implValue << ", ";
191       os << (isOpInterface ? "getOperation()" : "*this");
192       os << (method.arg_empty() ? "" : ", ");
193     }
194     llvm::interleaveComma(
195         method.getArguments(), os,
196         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
197     os << ");\n  }\n";
198   }
199 }
200 
201 static void emitInterfaceDef(const Interface &interface, StringRef valueType,
202                              raw_ostream &os) {
203   std::string interfaceQualNameStr = interface.getFullyQualifiedName();
204   StringRef interfaceQualName = interfaceQualNameStr;
205   interfaceQualName.consume_front("::");
206 
207   // Insert the method definitions.
208   bool isOpInterface = isa<OpInterface>(interface);
209   emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
210                           os, isOpInterface);
211 
212   // Insert the method definitions for base classes.
213   for (auto &base : interface.getBaseInterfaces()) {
214     emitInterfaceDefMethods(interfaceQualName, base, valueType,
215                             "getImpl()->impl" + base.getName(), os,
216                             isOpInterface);
217   }
218 }
219 
220 bool InterfaceGenerator::emitInterfaceDefs() {
221   llvm::emitSourceFileHeader("Interface Definitions", os);
222 
223   for (const auto *def : defs)
224     emitInterfaceDef(Interface(def), valueType, os);
225   return false;
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // GEN: Interface declarations
230 //===----------------------------------------------------------------------===//
231 
232 void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
233   os << "  struct Concept {\n";
234 
235   // Insert each of the pure virtual concept methods.
236   os << "    /// The methods defined by the interface.\n";
237   for (auto &method : interface.getMethods()) {
238     os << "    ";
239     emitCPPType(method.getReturnType(), os);
240     os << "(*" << method.getName() << ")(";
241     if (!method.isStatic()) {
242       os << "const Concept *impl, ";
243       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
244     }
245     llvm::interleaveComma(
246         method.getArguments(), os,
247         [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
248     os << ");\n";
249   }
250 
251   // Insert a field containing a concept for each of the base interfaces.
252   auto baseInterfaces = interface.getBaseInterfaces();
253   if (!baseInterfaces.empty()) {
254     os << "    /// The base classes of this interface.\n";
255     for (const auto &base : interface.getBaseInterfaces()) {
256       os << "    const " << base.getFullyQualifiedName() << "::Concept *impl"
257          << base.getName() << " = nullptr;\n";
258     }
259 
260     // Define an "initialize" method that allows for the initialization of the
261     // base class concepts.
262     os << "\n    void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
263           "&interfaceMap) {\n";
264     std::string interfaceQualName = interface.getFullyQualifiedName();
265     for (const auto &base : interface.getBaseInterfaces()) {
266       StringRef baseName = base.getName();
267       std::string baseQualName = base.getFullyQualifiedName();
268       os << "      impl" << baseName << " = interfaceMap.lookup<"
269          << baseQualName << ">();\n"
270          << "      assert(impl" << baseName << " && \"`" << interfaceQualName
271          << "` expected its base interface `" << baseQualName
272          << "` to be registered\");\n";
273     }
274     os << "    }\n";
275   }
276 
277   os << "  };\n";
278 }
279 
280 void InterfaceGenerator::emitModelDecl(const Interface &interface) {
281   // Emit the basic model and the fallback model.
282   for (const char *modelClass : {"Model", "FallbackModel"}) {
283     os << "  template<typename " << valueTemplate << ">\n";
284     os << "  class " << modelClass << " : public Concept {\n  public:\n";
285     os << "    using Interface = " << interface.getFullyQualifiedName()
286        << ";\n";
287     os << "    " << modelClass << "() : Concept{";
288     llvm::interleaveComma(
289         interface.getMethods(), os,
290         [&](const InterfaceMethod &method) { os << method.getName(); });
291     os << "} {}\n\n";
292 
293     // Insert each of the virtual method overrides.
294     for (auto &method : interface.getMethods()) {
295       emitCPPType(method.getReturnType(), os << "    static inline ");
296       emitMethodNameAndArgs(method, os, valueType,
297                             /*addThisArg=*/!method.isStatic(),
298                             /*addConst=*/false);
299       os << ";\n";
300     }
301     os << "  };\n";
302   }
303 
304   // Emit the template for the external model.
305   os << "  template<typename ConcreteModel, typename " << valueTemplate
306      << ">\n";
307   os << "  class ExternalModel : public FallbackModel<ConcreteModel> {\n";
308   os << "  public:\n";
309   os << "    using ConcreteEntity = " << valueTemplate << ";\n";
310 
311   // Emit declarations for methods that have default implementations. Other
312   // methods are expected to be implemented by the concrete derived model.
313   for (auto &method : interface.getMethods()) {
314     if (!method.getDefaultImplementation())
315       continue;
316     os << "    ";
317     if (method.isStatic())
318       os << "static ";
319     emitCPPType(method.getReturnType(), os);
320     os << method.getName() << "(";
321     if (!method.isStatic()) {
322       emitCPPType(valueType, os);
323       os << "tablegen_opaque_val";
324       if (!method.arg_empty())
325         os << ", ";
326     }
327     llvm::interleaveComma(method.getArguments(), os,
328                           [&](const InterfaceMethod::Argument &arg) {
329                             emitCPPType(arg.type, os);
330                             os << arg.name;
331                           });
332     os << ")";
333     if (!method.isStatic())
334       os << " const";
335     os << ";\n";
336   }
337   os << "  };\n";
338 }
339 
340 void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
341   llvm::SmallVector<StringRef, 2> namespaces;
342   llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
343   for (StringRef ns : namespaces)
344     os << "namespace " << ns << " {\n";
345 
346   for (auto &method : interface.getMethods()) {
347     os << "template<typename " << valueTemplate << ">\n";
348     emitCPPType(method.getReturnType(), os);
349     os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
350        << valueTemplate << ">::";
351     emitMethodNameAndArgs(method, os, valueType,
352                           /*addThisArg=*/!method.isStatic(),
353                           /*addConst=*/false);
354     os << " {\n  ";
355 
356     // Check for a provided body to the function.
357     if (std::optional<StringRef> body = method.getBody()) {
358       if (method.isStatic())
359         os << body->trim();
360       else
361         os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
362       os << "\n}\n";
363       continue;
364     }
365 
366     // Forward to the method on the concrete operation type.
367     if (method.isStatic())
368       os << "return " << valueTemplate << "::";
369     else
370       os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
371 
372     // Add the arguments to the call.
373     os << method.getName() << '(';
374     llvm::interleaveComma(
375         method.getArguments(), os,
376         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
377     os << ");\n}\n";
378   }
379 
380   for (auto &method : interface.getMethods()) {
381     os << "template<typename " << valueTemplate << ">\n";
382     emitCPPType(method.getReturnType(), os);
383     os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
384        << valueTemplate << ">::";
385     emitMethodNameAndArgs(method, os, valueType,
386                           /*addThisArg=*/!method.isStatic(),
387                           /*addConst=*/false);
388     os << " {\n  ";
389 
390     // Forward to the method on the concrete Model implementation.
391     if (method.isStatic())
392       os << "return " << valueTemplate << "::";
393     else
394       os << "return static_cast<const " << valueTemplate << " *>(impl)->";
395 
396     // Add the arguments to the call.
397     os << method.getName() << '(';
398     if (!method.isStatic())
399       os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
400     llvm::interleaveComma(
401         method.getArguments(), os,
402         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
403     os << ");\n}\n";
404   }
405 
406   // Emit default implementations for the external model.
407   for (auto &method : interface.getMethods()) {
408     if (!method.getDefaultImplementation())
409       continue;
410     os << "template<typename ConcreteModel, typename " << valueTemplate
411        << ">\n";
412     emitCPPType(method.getReturnType(), os);
413     os << "detail::" << interface.getName()
414        << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
415        << ">::";
416 
417     os << method.getName() << "(";
418     if (!method.isStatic()) {
419       emitCPPType(valueType, os);
420       os << "tablegen_opaque_val";
421       if (!method.arg_empty())
422         os << ", ";
423     }
424     llvm::interleaveComma(method.getArguments(), os,
425                           [&](const InterfaceMethod::Argument &arg) {
426                             emitCPPType(arg.type, os);
427                             os << arg.name;
428                           });
429     os << ")";
430     if (!method.isStatic())
431       os << " const";
432 
433     os << " {\n";
434 
435     // Use the empty context for static methods.
436     tblgen::FmtContext ctx;
437     os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
438                         method.isStatic() ? &ctx : &nonStaticMethodFmt);
439     os << "\n}\n";
440   }
441 
442   for (StringRef ns : llvm::reverse(namespaces))
443     os << "} // namespace " << ns << "\n";
444 }
445 
446 void InterfaceGenerator::emitTraitDecl(const Interface &interface,
447                                        StringRef interfaceName,
448                                        StringRef interfaceTraitsName) {
449   os << llvm::formatv("  template <typename {3}>\n"
450                       "  struct {0}Trait : public ::mlir::{2}<{0},"
451                       " detail::{1}>::Trait<{3}> {{\n",
452                       interfaceName, interfaceTraitsName, interfaceBaseType,
453                       valueTemplate);
454 
455   // Insert the default implementation for any methods.
456   bool isOpInterface = isa<OpInterface>(interface);
457   for (auto &method : interface.getMethods()) {
458     // Flag interface methods named verifyTrait.
459     if (method.getName() == "verifyTrait")
460       PrintFatalError(
461           formatv("'verifyTrait' method cannot be specified as interface "
462                   "method for '{0}'; use the 'verify' field instead",
463                   interfaceName));
464     auto defaultImpl = method.getDefaultImplementation();
465     if (!defaultImpl)
466       continue;
467 
468     emitInterfaceMethodDoc(method, os, "    ");
469     os << "    " << (method.isStatic() ? "static " : "");
470     emitCPPType(method.getReturnType(), os);
471     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
472                           /*addConst=*/!isOpInterface && !method.isStatic());
473     os << " {\n      " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
474        << "\n    }\n";
475   }
476 
477   if (auto verify = interface.getVerify()) {
478     assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
479 
480     tblgen::FmtContext verifyCtx;
481     verifyCtx.addSubst("_op", "op");
482     os << llvm::formatv(
483               "    static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
484               (interface.verifyWithRegions() ? "verifyRegionTrait"
485                                              : "verifyTrait"))
486        << "{\n      " << tblgen::tgfmt(verify->trim(), &verifyCtx)
487        << "\n    }\n";
488   }
489   if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
490     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
491   if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
492     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
493 
494   os << "  };\n";
495 }
496 
497 static void emitInterfaceDeclMethods(const Interface &interface,
498                                      raw_ostream &os, StringRef valueType,
499                                      bool isOpInterface,
500                                      tblgen::FmtContext &extraDeclsFmt) {
501   for (auto &method : interface.getMethods()) {
502     emitInterfaceMethodDoc(method, os, "  ");
503     emitCPPType(method.getReturnType(), os << "  ");
504     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
505                           /*addConst=*/!isOpInterface);
506     os << ";\n";
507   }
508 
509   // Emit any extra declarations.
510   if (std::optional<StringRef> extraDecls =
511           interface.getExtraClassDeclaration())
512     os << extraDecls->rtrim() << "\n";
513   if (std::optional<StringRef> extraDecls =
514           interface.getExtraSharedClassDeclaration())
515     os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
516 }
517 
518 void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
519   llvm::SmallVector<StringRef, 2> namespaces;
520   llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
521   for (StringRef ns : namespaces)
522     os << "namespace " << ns << " {\n";
523 
524   StringRef interfaceName = interface.getName();
525   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
526 
527   // Emit a forward declaration of the interface class so that it becomes usable
528   // in the signature of its methods.
529   os << "class " << interfaceName << ";\n";
530 
531   // Emit the traits struct containing the concept and model declarations.
532   os << "namespace detail {\n"
533      << "struct " << interfaceTraitsName << " {\n";
534   emitConceptDecl(interface);
535   emitModelDecl(interface);
536   os << "};";
537 
538   // Emit the derived trait for the interface.
539   os << "template <typename " << valueTemplate << ">\n";
540   os << "struct " << interface.getName() << "Trait;\n";
541 
542   os << "\n} // namespace detail\n";
543 
544   // Emit the main interface class declaration.
545   os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
546                       "public:\n"
547                       "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
548                       interfaceName, interfaceName, interfaceTraitsName,
549                       interfaceBaseType);
550 
551   // Emit a utility wrapper trait class.
552   os << llvm::formatv("  template <typename {1}>\n"
553                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
554                       interfaceName, valueTemplate);
555 
556   // Insert the method declarations.
557   bool isOpInterface = isa<OpInterface>(interface);
558   emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
559                            extraDeclsFmt);
560 
561   // Insert the method declarations for base classes.
562   for (auto &base : interface.getBaseInterfaces()) {
563     std::string baseQualName = base.getFullyQualifiedName();
564     os << "  //"
565           "===---------------------------------------------------------------"
566           "-===//\n"
567        << "  // Inherited from " << baseQualName << "\n"
568        << "  //"
569           "===---------------------------------------------------------------"
570           "-===//\n\n";
571 
572     // Allow implicit conversion to the base interface.
573     os << "  operator " << baseQualName << " () const {\n"
574        << "    return " << baseQualName << "(*this, getImpl()->impl"
575        << base.getName() << ");\n"
576        << "  }\n\n";
577 
578     // Inherit the base interface's methods.
579     emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
580   }
581 
582   // Emit classof code if necessary.
583   if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
584     auto extraClassOfFmt = tblgen::FmtContext();
585     extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
586     os << "  static bool classof(" << valueType << " base) {\n"
587        << "    auto* interface = getInterfaceFor(base);\n"
588        << "    if (!interface)\n"
589           "      return false;\n"
590           "    " << interfaceName << " odsInterfaceInstance(base, interface);\n"
591        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
592        << "\n  }\n";
593   }
594 
595   os << "};\n";
596 
597   os << "namespace detail {\n";
598   emitTraitDecl(interface, interfaceName, interfaceTraitsName);
599   os << "}// namespace detail\n";
600 
601   for (StringRef ns : llvm::reverse(namespaces))
602     os << "} // namespace " << ns << "\n";
603 }
604 
605 bool InterfaceGenerator::emitInterfaceDecls() {
606   llvm::emitSourceFileHeader("Interface Declarations", os);
607   // Sort according to ID, so defs are emitted in the order in which they appear
608   // in the Tablegen file.
609   std::vector<llvm::Record *> sortedDefs(defs);
610   llvm::sort(sortedDefs, [](llvm::Record *lhs, llvm::Record *rhs) {
611     return lhs->getID() < rhs->getID();
612   });
613   for (const llvm::Record *def : sortedDefs)
614     emitInterfaceDecl(Interface(def));
615   for (const llvm::Record *def : sortedDefs)
616     emitModelMethodsDef(Interface(def));
617   return false;
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // GEN: Interface documentation
622 //===----------------------------------------------------------------------===//
623 
624 static void emitInterfaceDoc(const llvm::Record &interfaceDef,
625                              raw_ostream &os) {
626   Interface interface(&interfaceDef);
627 
628   // Emit the interface name followed by the description.
629   os << "## " << interface.getName() << " (`" << interfaceDef.getName()
630      << "`)\n\n";
631   if (auto description = interface.getDescription())
632     mlir::tblgen::emitDescription(*description, os);
633 
634   // Emit the methods required by the interface.
635   os << "\n### Methods:\n";
636   for (const auto &method : interface.getMethods()) {
637     // Emit the method name.
638     os << "#### `" << method.getName() << "`\n\n```c++\n";
639 
640     // Emit the method signature.
641     if (method.isStatic())
642       os << "static ";
643     emitCPPType(method.getReturnType(), os) << method.getName() << '(';
644     llvm::interleaveComma(method.getArguments(), os,
645                           [&](const InterfaceMethod::Argument &arg) {
646                             emitCPPType(arg.type, os) << arg.name;
647                           });
648     os << ");\n```\n";
649 
650     // Emit the description.
651     if (auto description = method.getDescription())
652       mlir::tblgen::emitDescription(*description, os);
653 
654     // If the body is not provided, this method must be provided by the user.
655     if (!method.getBody())
656       os << "\nNOTE: This method *must* be implemented by the user.";
657 
658     os << "\n\n";
659   }
660 }
661 
662 bool InterfaceGenerator::emitInterfaceDocs() {
663   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
664   os << "# " << interfaceBaseType << " definitions\n";
665 
666   for (const auto *def : defs)
667     emitInterfaceDoc(*def, os);
668   return false;
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // GEN: Interface registration hooks
673 //===----------------------------------------------------------------------===//
674 
675 namespace {
676 template <typename GeneratorT>
677 struct InterfaceGenRegistration {
678   InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
679       : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
680         genDefArg(("gen-" + genArg + "-interface-defs").str()),
681         genDocArg(("gen-" + genArg + "-interface-docs").str()),
682         genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
683         genDefDesc(("Generate " + genDesc + " interface definitions").str()),
684         genDocDesc(("Generate " + genDesc + " interface documentation").str()),
685         genDecls(genDeclArg, genDeclDesc,
686                  [](const llvm::RecordKeeper &records, raw_ostream &os) {
687                    return GeneratorT(records, os).emitInterfaceDecls();
688                  }),
689         genDefs(genDefArg, genDefDesc,
690                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
691                   return GeneratorT(records, os).emitInterfaceDefs();
692                 }),
693         genDocs(genDocArg, genDocDesc,
694                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
695                   return GeneratorT(records, os).emitInterfaceDocs();
696                 }) {}
697 
698   std::string genDeclArg, genDefArg, genDocArg;
699   std::string genDeclDesc, genDefDesc, genDocDesc;
700   mlir::GenRegistration genDecls, genDefs, genDocs;
701 };
702 } // namespace
703 
704 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
705                                                                 "attribute");
706 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
707 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
708