xref: /llvm-project/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp (revision a60e83fe7cebbfb9f0a6a3ad7504c034f5d0e40f)
1 //===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpInterfacesGen generates definitions for operation interfaces.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DocGenUtilities.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "mlir/TableGen/Interfaces.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
23 #include "llvm/TableGen/TableGenBackend.h"
24 
25 using namespace mlir;
26 using mlir::tblgen::Interface;
27 using mlir::tblgen::InterfaceMethod;
28 using mlir::tblgen::OpInterface;
29 
30 /// Emit a string corresponding to a C++ type, followed by a space if necessary.
31 static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
32   type = type.trim();
33   os << type;
34   if (type.back() != '&' && type.back() != '*')
35     os << " ";
36   return os;
37 }
38 
39 /// Emit the method name and argument list for the given method. If 'addThisArg'
40 /// is true, then an argument is added to the beginning of the argument list for
41 /// the concrete value.
42 static void emitMethodNameAndArgs(const InterfaceMethod &method,
43                                   raw_ostream &os, StringRef valueType,
44                                   bool addThisArg, bool addConst) {
45   os << method.getName() << '(';
46   if (addThisArg) {
47     if (addConst)
48       os << "const ";
49     os << "const Concept *impl, ";
50     emitCPPType(valueType, os)
51         << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
52   }
53   llvm::interleaveComma(method.getArguments(), os,
54                         [&](const InterfaceMethod::Argument &arg) {
55                           os << arg.type << " " << arg.name;
56                         });
57   os << ')';
58   if (addConst)
59     os << " const";
60 }
61 
62 /// Get an array of all OpInterface definitions but exclude those subclassing
63 /// "DeclareOpInterfaceMethods".
64 static std::vector<llvm::Record *>
65 getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) {
66   std::vector<llvm::Record *> defs =
67       recordKeeper.getAllDerivedDefinitions("OpInterface");
68 
69   llvm::erase_if(defs, [](const llvm::Record *def) {
70     return def->isSubClassOf("DeclareOpInterfaceMethods");
71   });
72   return defs;
73 }
74 
75 namespace {
76 /// This struct is the base generator used when processing tablegen interfaces.
77 class InterfaceGenerator {
78 public:
79   bool emitInterfaceDefs();
80   bool emitInterfaceDecls();
81   bool emitInterfaceDocs();
82 
83 protected:
84   InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
85       : defs(std::move(defs)), os(os) {}
86 
87   void emitConceptDecl(Interface &interface);
88   void emitModelDecl(Interface &interface);
89   void emitModelMethodsDef(Interface &interface);
90   void emitTraitDecl(Interface &interface, StringRef interfaceName,
91                      StringRef interfaceTraitsName);
92   void emitInterfaceDecl(Interface interface);
93 
94   /// The set of interface records to emit.
95   std::vector<llvm::Record *> defs;
96   // The stream to emit to.
97   raw_ostream &os;
98   /// The C++ value type of the interface, e.g. Operation*.
99   StringRef valueType;
100   /// The C++ base interface type.
101   StringRef interfaceBaseType;
102   /// The name of the typename for the value template.
103   StringRef valueTemplate;
104   /// The format context to use for methods.
105   tblgen::FmtContext nonStaticMethodFmt;
106   tblgen::FmtContext traitMethodFmt;
107   tblgen::FmtContext extraDeclsFmt;
108 };
109 
110 /// A specialized generator for attribute interfaces.
111 struct AttrInterfaceGenerator : public InterfaceGenerator {
112   AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
113       : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"),
114                            os) {
115     valueType = "::mlir::Attribute";
116     interfaceBaseType = "AttributeInterface";
117     valueTemplate = "ConcreteAttr";
118     StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
119     nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
120     traitMethodFmt.addSubst("_attr",
121                             "(*static_cast<const ConcreteAttr *>(this))");
122     extraDeclsFmt.addSubst("_attr", "(*this)");
123   }
124 };
125 /// A specialized generator for operation interfaces.
126 struct OpInterfaceGenerator : public InterfaceGenerator {
127   OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
128       : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) {
129     valueType = "::mlir::Operation *";
130     interfaceBaseType = "OpInterface";
131     valueTemplate = "ConcreteOp";
132     StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
133     nonStaticMethodFmt.addSubst("_this", "impl")
134         .withOp(castCode)
135         .withSelf(castCode);
136     traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
137     extraDeclsFmt.withOp("(*this)");
138   }
139 };
140 /// A specialized generator for type interfaces.
141 struct TypeInterfaceGenerator : public InterfaceGenerator {
142   TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
143       : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"),
144                            os) {
145     valueType = "::mlir::Type";
146     interfaceBaseType = "TypeInterface";
147     valueTemplate = "ConcreteType";
148     StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
149     nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
150     traitMethodFmt.addSubst("_type",
151                             "(*static_cast<const ConcreteType *>(this))");
152     extraDeclsFmt.addSubst("_type", "(*this)");
153   }
154 };
155 } // namespace
156 
157 //===----------------------------------------------------------------------===//
158 // GEN: Interface definitions
159 //===----------------------------------------------------------------------===//
160 
161 static void emitInterfaceDef(const Interface &interface, StringRef valueType,
162                              raw_ostream &os) {
163   StringRef interfaceName = interface.getName();
164   StringRef cppNamespace = interface.getCppNamespace();
165   cppNamespace.consume_front("::");
166 
167   // Insert the method definitions.
168   bool isOpInterface = isa<OpInterface>(interface);
169   for (auto &method : interface.getMethods()) {
170     emitCPPType(method.getReturnType(), os);
171     if (!cppNamespace.empty())
172       os << cppNamespace << "::";
173     os << interfaceName << "::";
174     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
175                           /*addConst=*/!isOpInterface);
176 
177     // Forward to the method on the concrete operation type.
178     os << " {\n      return getImpl()->" << method.getName() << '(';
179     if (!method.isStatic()) {
180       os << "getImpl(), ";
181       os << (isOpInterface ? "getOperation()" : "*this");
182       os << (method.arg_empty() ? "" : ", ");
183     }
184     llvm::interleaveComma(
185         method.getArguments(), os,
186         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
187     os << ");\n  }\n";
188   }
189 }
190 
191 bool InterfaceGenerator::emitInterfaceDefs() {
192   llvm::emitSourceFileHeader("Interface Definitions", os);
193 
194   for (const auto *def : defs)
195     emitInterfaceDef(Interface(def), valueType, os);
196   return false;
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // GEN: Interface declarations
201 //===----------------------------------------------------------------------===//
202 
203 void InterfaceGenerator::emitConceptDecl(Interface &interface) {
204   os << "  struct Concept {\n";
205 
206   // Insert each of the pure virtual concept methods.
207   for (auto &method : interface.getMethods()) {
208     os << "    ";
209     emitCPPType(method.getReturnType(), os);
210     os << "(*" << method.getName() << ")(";
211     if (!method.isStatic()) {
212       os << "const Concept *impl, ";
213       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
214     }
215     llvm::interleaveComma(
216         method.getArguments(), os,
217         [&](const InterfaceMethod::Argument &arg) { os << arg.type; });
218     os << ");\n";
219   }
220   os << "  };\n";
221 }
222 
223 void InterfaceGenerator::emitModelDecl(Interface &interface) {
224   // Emit the basic model and the fallback model.
225   for (const char *modelClass : {"Model", "FallbackModel"}) {
226     os << "  template<typename " << valueTemplate << ">\n";
227     os << "  class " << modelClass << " : public Concept {\n  public:\n";
228     os << "    using Interface = " << interface.getCppNamespace()
229        << (interface.getCppNamespace().empty() ? "" : "::")
230        << interface.getName() << ";\n";
231     os << "    " << modelClass << "() : Concept{";
232     llvm::interleaveComma(
233         interface.getMethods(), os,
234         [&](const InterfaceMethod &method) { os << method.getName(); });
235     os << "} {}\n\n";
236 
237     // Insert each of the virtual method overrides.
238     for (auto &method : interface.getMethods()) {
239       emitCPPType(method.getReturnType(), os << "    static inline ");
240       emitMethodNameAndArgs(method, os, valueType,
241                             /*addThisArg=*/!method.isStatic(),
242                             /*addConst=*/false);
243       os << ";\n";
244     }
245     os << "  };\n";
246   }
247 
248   // Emit the template for the external model.
249   os << "  template<typename ConcreteModel, typename " << valueTemplate
250      << ">\n";
251   os << "  class ExternalModel : public FallbackModel<ConcreteModel> {\n";
252   os << "  public:\n";
253 
254   // Emit declarations for methods that have default implementations. Other
255   // methods are expected to be implemented by the concrete derived model.
256   for (auto &method : interface.getMethods()) {
257     if (!method.getDefaultImplementation())
258       continue;
259     os << "    ";
260     if (method.isStatic())
261       os << "static ";
262     emitCPPType(method.getReturnType(), os);
263     os << method.getName() << "(";
264     if (!method.isStatic()) {
265       emitCPPType(valueType, os);
266       os << "tablegen_opaque_val";
267       if (!method.arg_empty())
268         os << ", ";
269     }
270     llvm::interleaveComma(method.getArguments(), os,
271                           [&](const InterfaceMethod::Argument &arg) {
272                             emitCPPType(arg.type, os);
273                             os << arg.name;
274                           });
275     os << ")";
276     if (!method.isStatic())
277       os << " const";
278     os << ";\n";
279   }
280   os << "  };\n";
281 }
282 
283 void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
284   for (auto &method : interface.getMethods()) {
285     os << "template<typename " << valueTemplate << ">\n";
286     emitCPPType(method.getReturnType(), os);
287     os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
288        << valueTemplate << ">::";
289     emitMethodNameAndArgs(method, os, valueType,
290                           /*addThisArg=*/!method.isStatic(),
291                           /*addConst=*/false);
292     os << " {\n  ";
293 
294     // Check for a provided body to the function.
295     if (Optional<StringRef> body = method.getBody()) {
296       if (method.isStatic())
297         os << body->trim();
298       else
299         os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
300       os << "\n}\n";
301       continue;
302     }
303 
304     // Forward to the method on the concrete operation type.
305     if (method.isStatic())
306       os << "return " << valueTemplate << "::";
307     else
308       os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
309 
310     // Add the arguments to the call.
311     os << method.getName() << '(';
312     llvm::interleaveComma(
313         method.getArguments(), os,
314         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
315     os << ");\n}\n";
316   }
317 
318   for (auto &method : interface.getMethods()) {
319     os << "template<typename " << valueTemplate << ">\n";
320     emitCPPType(method.getReturnType(), os);
321     os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
322        << valueTemplate << ">::";
323     emitMethodNameAndArgs(method, os, valueType,
324                           /*addThisArg=*/!method.isStatic(),
325                           /*addConst=*/false);
326     os << " {\n  ";
327 
328     // Forward to the method on the concrete Model implementation.
329     if (method.isStatic())
330       os << "return " << valueTemplate << "::";
331     else
332       os << "return static_cast<const " << valueTemplate << " *>(impl)->";
333 
334     // Add the arguments to the call.
335     os << method.getName() << '(';
336     if (!method.isStatic())
337       os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
338     llvm::interleaveComma(
339         method.getArguments(), os,
340         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
341     os << ");\n}\n";
342   }
343 
344   // Emit default implementations for the external model.
345   for (auto &method : interface.getMethods()) {
346     if (!method.getDefaultImplementation())
347       continue;
348     os << "template<typename ConcreteModel, typename " << valueTemplate
349        << ">\n";
350     emitCPPType(method.getReturnType(), os);
351     os << "detail::" << interface.getName()
352        << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
353        << ">::";
354 
355     os << method.getName() << "(";
356     if (!method.isStatic()) {
357       emitCPPType(valueType, os);
358       os << "tablegen_opaque_val";
359       if (!method.arg_empty())
360         os << ", ";
361     }
362     llvm::interleaveComma(method.getArguments(), os,
363                           [&](const InterfaceMethod::Argument &arg) {
364                             emitCPPType(arg.type, os);
365                             os << arg.name;
366                           });
367     os << ")";
368     if (!method.isStatic())
369       os << " const";
370 
371     os << " {\n";
372 
373     // Use the empty context for static methods.
374     tblgen::FmtContext ctx;
375     os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
376                         method.isStatic() ? &ctx : &nonStaticMethodFmt);
377     os << "\n}\n";
378   }
379 }
380 
381 void InterfaceGenerator::emitTraitDecl(Interface &interface,
382                                        StringRef interfaceName,
383                                        StringRef interfaceTraitsName) {
384   os << llvm::formatv("  template <typename {3}>\n"
385                       "  struct {0}Trait : public ::mlir::{2}<{0},"
386                       " detail::{1}>::Trait<{3}> {{\n",
387                       interfaceName, interfaceTraitsName, interfaceBaseType,
388                       valueTemplate);
389 
390   // Insert the default implementation for any methods.
391   bool isOpInterface = isa<OpInterface>(interface);
392   for (auto &method : interface.getMethods()) {
393     // Flag interface methods named verifyTrait.
394     if (method.getName() == "verifyTrait")
395       PrintFatalError(
396           formatv("'verifyTrait' method cannot be specified as interface "
397                   "method for '{0}'; use the 'verify' field instead",
398                   interfaceName));
399     auto defaultImpl = method.getDefaultImplementation();
400     if (!defaultImpl)
401       continue;
402 
403     os << "    " << (method.isStatic() ? "static " : "");
404     emitCPPType(method.getReturnType(), os);
405     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
406                           /*addConst=*/!isOpInterface && !method.isStatic());
407     os << " {\n      " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
408        << "\n    }\n";
409   }
410 
411   if (auto verify = interface.getVerify()) {
412     assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
413 
414     tblgen::FmtContext verifyCtx;
415     verifyCtx.withOp("op");
416     os << "    static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) "
417           "{\n      "
418        << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n    }\n";
419   }
420   if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
421     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
422   if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
423     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
424 
425   os << "  };\n";
426 }
427 
428 void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
429   llvm::SmallVector<StringRef, 2> namespaces;
430   llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
431   for (StringRef ns : namespaces)
432     os << "namespace " << ns << " {\n";
433 
434   StringRef interfaceName = interface.getName();
435   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
436 
437   // Emit a forward declaration of the interface class so that it becomes usable
438   // in the signature of its methods.
439   os << "class " << interfaceName << ";\n";
440 
441   // Emit the traits struct containing the concept and model declarations.
442   os << "namespace detail {\n"
443      << "struct " << interfaceTraitsName << " {\n";
444   emitConceptDecl(interface);
445   emitModelDecl(interface);
446   os << "};";
447 
448   // Emit the derived trait for the interface.
449   os << "template <typename " << valueTemplate << ">\n";
450   os << "struct " << interface.getName() << "Trait;\n";
451 
452   os << "\n} // namespace detail\n";
453 
454   // Emit the main interface class declaration.
455   os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
456                       "public:\n"
457                       "  using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
458                       interfaceName, interfaceName, interfaceTraitsName,
459                       interfaceBaseType);
460 
461   // Emit a utility wrapper trait class.
462   os << llvm::formatv("  template <typename {1}>\n"
463                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
464                       interfaceName, valueTemplate);
465 
466   // Insert the method declarations.
467   bool isOpInterface = isa<OpInterface>(interface);
468   for (auto &method : interface.getMethods()) {
469     emitCPPType(method.getReturnType(), os << "  ");
470     emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
471                           /*addConst=*/!isOpInterface);
472     os << ";\n";
473   }
474 
475   // Emit any extra declarations.
476   if (Optional<StringRef> extraDecls = interface.getExtraClassDeclaration())
477     os << *extraDecls << "\n";
478   if (Optional<StringRef> extraDecls =
479           interface.getExtraSharedClassDeclaration())
480     os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
481 
482   os << "};\n";
483 
484   os << "namespace detail {\n";
485   emitTraitDecl(interface, interfaceName, interfaceTraitsName);
486   os << "}// namespace detail\n";
487 
488   emitModelMethodsDef(interface);
489 
490   for (StringRef ns : llvm::reverse(namespaces))
491     os << "} // namespace " << ns << "\n";
492 }
493 
494 bool InterfaceGenerator::emitInterfaceDecls() {
495   llvm::emitSourceFileHeader("Interface Declarations", os);
496 
497   for (const auto *def : defs)
498     emitInterfaceDecl(Interface(def));
499   return false;
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // GEN: Interface documentation
504 //===----------------------------------------------------------------------===//
505 
506 static void emitInterfaceDoc(const llvm::Record &interfaceDef,
507                              raw_ostream &os) {
508   Interface interface(&interfaceDef);
509 
510   // Emit the interface name followed by the description.
511   os << "## " << interface.getName() << " (`" << interfaceDef.getName()
512      << "`)\n\n";
513   if (auto description = interface.getDescription())
514     mlir::tblgen::emitDescription(*description, os);
515 
516   // Emit the methods required by the interface.
517   os << "\n### Methods:\n";
518   for (const auto &method : interface.getMethods()) {
519     // Emit the method name.
520     os << "#### `" << method.getName() << "`\n\n```c++\n";
521 
522     // Emit the method signature.
523     if (method.isStatic())
524       os << "static ";
525     emitCPPType(method.getReturnType(), os) << method.getName() << '(';
526     llvm::interleaveComma(method.getArguments(), os,
527                           [&](const InterfaceMethod::Argument &arg) {
528                             emitCPPType(arg.type, os) << arg.name;
529                           });
530     os << ");\n```\n";
531 
532     // Emit the description.
533     if (auto description = method.getDescription())
534       mlir::tblgen::emitDescription(*description, os);
535 
536     // If the body is not provided, this method must be provided by the user.
537     if (!method.getBody())
538       os << "\nNOTE: This method *must* be implemented by the user.\n\n";
539   }
540 }
541 
542 bool InterfaceGenerator::emitInterfaceDocs() {
543   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
544   os << "# " << interfaceBaseType << " definitions\n";
545 
546   for (const auto *def : defs)
547     emitInterfaceDoc(*def, os);
548   return false;
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // GEN: Interface registration hooks
553 //===----------------------------------------------------------------------===//
554 
555 namespace {
556 template <typename GeneratorT>
557 struct InterfaceGenRegistration {
558   InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
559       : genDeclArg(("gen-" + genArg + "-interface-decls").str()),
560         genDefArg(("gen-" + genArg + "-interface-defs").str()),
561         genDocArg(("gen-" + genArg + "-interface-docs").str()),
562         genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
563         genDefDesc(("Generate " + genDesc + " interface definitions").str()),
564         genDocDesc(("Generate " + genDesc + " interface documentation").str()),
565         genDecls(genDeclArg, genDeclDesc,
566                  [](const llvm::RecordKeeper &records, raw_ostream &os) {
567                    return GeneratorT(records, os).emitInterfaceDecls();
568                  }),
569         genDefs(genDefArg, genDefDesc,
570                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
571                   return GeneratorT(records, os).emitInterfaceDefs();
572                 }),
573         genDocs(genDocArg, genDocDesc,
574                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
575                   return GeneratorT(records, os).emitInterfaceDocs();
576                 }) {}
577 
578   std::string genDeclArg, genDefArg, genDocArg;
579   std::string genDeclDesc, genDefDesc, genDocDesc;
580   mlir::GenRegistration genDecls, genDefs, genDocs;
581 };
582 } // namespace
583 
584 static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
585                                                                 "attribute");
586 static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
587 static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");
588