xref: /llvm-project/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (revision bccd37f69fdc7b5cd00d9231cabbe74bfe38f598)
1 //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AttrOrTypeFormatGen.h"
10 #include "mlir/TableGen/AttrOrTypeDef.h"
11 #include "mlir/TableGen/Class.h"
12 #include "mlir/TableGen/CodeGenHelpers.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Interfaces.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/TableGenBackend.h"
20 
21 #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
22 
23 using namespace mlir;
24 using namespace mlir::tblgen;
25 using llvm::Record;
26 using llvm::RecordKeeper;
27 
28 //===----------------------------------------------------------------------===//
29 // Utility Functions
30 //===----------------------------------------------------------------------===//
31 
32 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
33 /// specified and can only find one dialect's defs, use that.
34 static void collectAllDefs(StringRef selectedDialect,
35                            ArrayRef<const Record *> records,
36                            SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
37   // Nothing to do if no defs were found.
38   if (records.empty())
39     return;
40 
41   auto defs = llvm::map_range(
42       records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
43   if (selectedDialect.empty()) {
44     // If a dialect was not specified, ensure that all found defs belong to the
45     // same dialect.
46     if (!llvm::all_equal(llvm::map_range(
47             defs, [](const auto &def) { return def.getDialect(); }))) {
48       llvm::PrintFatalError("defs belonging to more than one dialect. Must "
49                             "select one via '--(attr|type)defs-dialect'");
50     }
51     resultDefs.assign(defs.begin(), defs.end());
52   } else {
53     // Otherwise, generate the defs that belong to the selected dialect.
54     auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
55       return def.getDialect().getName() == selectedDialect;
56     });
57     resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
58   }
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // DefGen
63 //===----------------------------------------------------------------------===//
64 
65 namespace {
66 class DefGen {
67 public:
68   /// Create the attribute or type class.
69   DefGen(const AttrOrTypeDef &def);
70 
71   void emitDecl(raw_ostream &os) const {
72     if (storageCls && def.genStorageClass()) {
73       NamespaceEmitter ns(os, def.getStorageNamespace());
74       os << "struct " << def.getStorageClassName() << ";\n";
75     }
76     defCls.writeDeclTo(os);
77   }
78   void emitDef(raw_ostream &os) const {
79     if (storageCls && def.genStorageClass()) {
80       NamespaceEmitter ns(os, def.getStorageNamespace());
81       storageCls->writeDeclTo(os); // everything is inline
82     }
83     defCls.writeDefTo(os);
84   }
85 
86 private:
87   /// Add traits from the TableGen definition to the class.
88   void createParentWithTraits();
89   /// Emit top-level declarations: using declarations and any extra class
90   /// declarations.
91   void emitTopLevelDeclarations();
92   /// Emit the function that returns the type or attribute name.
93   void emitName();
94   /// Emit the dialect name as a static member variable.
95   void emitDialectName();
96   /// Emit attribute or type builders.
97   void emitBuilders();
98   /// Emit a verifier declaration for custom verification (impl. provided by
99   /// the users).
100   void emitVerifierDecl();
101   /// Emit a verifier that checks type constraints.
102   void emitInvariantsVerifierImpl();
103   /// Emit an entry poiunt for verification that calls the invariants and
104   /// custom verifier.
105   void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
106   /// Emit parsers and printers.
107   void emitParserPrinter();
108   /// Emit parameter accessors, if required.
109   void emitAccessors();
110   /// Emit interface methods.
111   void emitInterfaceMethods();
112 
113   //===--------------------------------------------------------------------===//
114   // Builder Emission
115 
116   /// Emit the default builder `Attribute::get`
117   void emitDefaultBuilder();
118   /// Emit the checked builder `Attribute::getChecked`
119   void emitCheckedBuilder();
120   /// Emit a custom builder.
121   void emitCustomBuilder(const AttrOrTypeBuilder &builder);
122   /// Emit a checked custom builder.
123   void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
124 
125   //===--------------------------------------------------------------------===//
126   // Interface Method Emission
127 
128   /// Emit methods for a trait.
129   void emitTraitMethods(const InterfaceTrait &trait);
130   /// Emit a trait method.
131   void emitTraitMethod(const InterfaceMethod &method);
132 
133   //===--------------------------------------------------------------------===//
134   // Storage Class Emission
135   void emitStorageClass();
136   /// Generate the storage class constructor.
137   void emitStorageConstructor();
138   /// Emit the key type `KeyTy`.
139   void emitKeyType();
140   /// Emit the equality comparison operator.
141   void emitEquals();
142   /// Emit the key hash function.
143   void emitHashKey();
144   /// Emit the function to construct the storage class.
145   void emitConstruct();
146 
147   //===--------------------------------------------------------------------===//
148   // Utility Function Declarations
149 
150   /// Get the method parameters for a def builder, where the first several
151   /// parameters may be different.
152   SmallVector<MethodParameter>
153   getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
154 
155   //===--------------------------------------------------------------------===//
156   // Class fields
157 
158   /// The attribute or type definition.
159   const AttrOrTypeDef &def;
160   /// The list of attribute or type parameters.
161   ArrayRef<AttrOrTypeParameter> params;
162   /// The attribute or type class.
163   Class defCls;
164   /// An optional attribute or type storage class. The storage class will
165   /// exist if and only if the def has more than zero parameters.
166   std::optional<Class> storageCls;
167 
168   /// The C++ base value of the def, either "Attribute" or "Type".
169   StringRef valueType;
170   /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
171   StringRef defType;
172 };
173 } // namespace
174 
175 DefGen::DefGen(const AttrOrTypeDef &def)
176     : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
177       valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
178       defType(isa<AttrDef>(def) ? "Attr" : "Type") {
179   // Check that all parameters have names.
180   for (const AttrOrTypeParameter &param : def.getParameters())
181     if (param.isAnonymous())
182       llvm::PrintFatalError("all parameters must have a name");
183 
184   // If a storage class is needed, create one.
185   if (def.getNumParameters() > 0)
186     storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
187 
188   // Create the parent class with any indicated traits.
189   createParentWithTraits();
190   // Emit top-level declarations.
191   emitTopLevelDeclarations();
192   // Emit builders for defs with parameters
193   if (storageCls)
194     emitBuilders();
195   // Emit the type name.
196   emitName();
197   // Emit the dialect name.
198   emitDialectName();
199   // Emit verification of type constraints.
200   bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
201   if (storageCls && genVerifyInvariantsImpl)
202     emitInvariantsVerifierImpl();
203   // Emit the custom verifier (written by the user).
204   bool genVerifyDecl = def.genVerifyDecl();
205   if (storageCls && genVerifyDecl)
206     emitVerifierDecl();
207   // Emit the "verifyInvariants" function if there is any verification at all.
208   if (storageCls)
209     emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
210   // Emit the mnemonic, if there is one, and any associated parser and printer.
211   if (def.getMnemonic())
212     emitParserPrinter();
213   // Emit accessors
214   if (def.genAccessors())
215     emitAccessors();
216   // Emit trait interface methods
217   emitInterfaceMethods();
218   defCls.finalize();
219   // Emit a storage class if one is needed
220   if (storageCls && def.genStorageClass())
221     emitStorageClass();
222 }
223 
224 void DefGen::createParentWithTraits() {
225   ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
226   defParent.addTemplateParam(def.getCppClassName());
227   defParent.addTemplateParam(def.getCppBaseClassName());
228   defParent.addTemplateParam(storageCls
229                                  ? strfmt("{0}::{1}", def.getStorageNamespace(),
230                                           def.getStorageClassName())
231                                  : strfmt("::mlir::{0}Storage", valueType));
232   for (auto &trait : def.getTraits()) {
233     defParent.addTemplateParam(
234         isa<NativeTrait>(&trait)
235             ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
236             : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
237   }
238   defCls.addParent(std::move(defParent));
239 }
240 
241 /// Include declarations specified on NativeTrait
242 static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
243   SmallVector<StringRef> extraDeclarations;
244   // Include extra class declarations from NativeTrait
245   for (const auto &trait : def.getTraits()) {
246     if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
247       StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
248       if (value.empty())
249         continue;
250       extraDeclarations.push_back(value);
251     }
252   }
253   if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
254     extraDeclarations.push_back(*extraDecl);
255   }
256   return llvm::join(extraDeclarations, "\n");
257 }
258 
259 /// Extra class definitions have a `$cppClass` substitution that is to be
260 /// replaced by the C++ class name.
261 static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
262   SmallVector<StringRef> extraDefinitions;
263   // Include extra class definitions from NativeTrait
264   for (const auto &trait : def.getTraits()) {
265     if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
266       StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
267       if (value.empty())
268         continue;
269       extraDefinitions.push_back(value);
270     }
271   }
272   if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
273     extraDefinitions.push_back(*extraDef);
274   }
275   FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
276   return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
277 }
278 
279 void DefGen::emitTopLevelDeclarations() {
280   // Inherit constructors from the attribute or type class.
281   defCls.declare<VisibilityDeclaration>(Visibility::Public);
282   defCls.declare<UsingDeclaration>("Base::Base");
283 
284   // Emit the extra declarations first in case there's a definition in there.
285   std::string extraDecl = formatExtraDeclarations(def);
286   std::string extraDef = formatExtraDefinitions(def);
287   defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
288                                         std::move(extraDef));
289 }
290 
291 void DefGen::emitName() {
292   StringRef name;
293   if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
294     name = attrDef->getAttrName();
295   } else {
296     auto *typeDef = cast<TypeDef>(&def);
297     name = typeDef->getTypeName();
298   }
299   std::string nameDecl =
300       strfmt("static constexpr ::llvm::StringLiteral name = \"{0}\";\n", name);
301   defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
302 }
303 
304 void DefGen::emitDialectName() {
305   std::string decl =
306       strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
307              def.getDialect().getName());
308   defCls.declare<ExtraClassDeclaration>(std::move(decl));
309 }
310 
311 void DefGen::emitBuilders() {
312   if (!def.skipDefaultBuilders()) {
313     emitDefaultBuilder();
314     if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
315       emitCheckedBuilder();
316   }
317   for (auto &builder : def.getBuilders()) {
318     emitCustomBuilder(builder);
319     if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
320       emitCheckedCustomBuilder(builder);
321   }
322 }
323 
324 void DefGen::emitVerifierDecl() {
325   defCls.declareStaticMethod(
326       "::llvm::LogicalResult", "verify",
327       getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
328                          "emitError"}}));
329 }
330 
331 static const char *const patternParameterVerificationCode = R"(
332 if (!({0})) {
333   emitError() << "failed to verify '{1}': {2}";
334   return ::mlir::failure();
335 }
336 )";
337 
338 void DefGen::emitInvariantsVerifierImpl() {
339   SmallVector<MethodParameter> builderParams = getBuilderParams(
340       {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
341   Method *verifier =
342       defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
343                        Method::Static, builderParams);
344   verifier->body().indent();
345 
346   // Generate verification for each parameter that is a type constraint.
347   for (auto it : llvm::enumerate(def.getParameters())) {
348     const AttrOrTypeParameter &param = it.value();
349     std::optional<Constraint> constraint = param.getConstraint();
350     // No verification needed for parameters that are not type constraints.
351     if (!constraint.has_value())
352       continue;
353     FmtContext ctx;
354     // Note: Skip over the first method parameter (`emitError`).
355     ctx.withSelf(builderParams[it.index() + 1].getName());
356     std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
357     verifier->body() << formatv(patternParameterVerificationCode, condition,
358                                 param.getName(), constraint->getSummary())
359                      << "\n";
360   }
361   verifier->body() << "return ::mlir::success();";
362 }
363 
364 void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
365   if (!hasImpl && !hasCustomVerifier)
366     return;
367   defCls.declare<UsingDeclaration>("Base::getChecked");
368   SmallVector<MethodParameter> builderParams = getBuilderParams(
369       {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
370   Method *verifier =
371       defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
372                        Method::Static, builderParams);
373   verifier->body().indent();
374 
375   auto emitVerifierCall = [&](StringRef name) {
376     verifier->body() << strfmt("if (::mlir::failed({0}(", name);
377     llvm::interleaveComma(
378         llvm::map_range(builderParams,
379                         [](auto &param) { return param.getName(); }),
380         verifier->body());
381     verifier->body() << ")))\n";
382     verifier->body() << "  return ::mlir::failure();\n";
383   };
384 
385   if (hasImpl) {
386     // Call the verifier that checks the type constraints.
387     emitVerifierCall("verifyInvariantsImpl");
388   }
389   if (hasCustomVerifier) {
390     // Call the custom verifier that is provided by the user.
391     emitVerifierCall("verify");
392   }
393   verifier->body() << "return ::mlir::success();";
394 }
395 
396 void DefGen::emitParserPrinter() {
397   auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
398       "::llvm::StringLiteral", "getMnemonic");
399   mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
400 
401   // Declare the parser and printer, if needed.
402   bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
403   if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
404     return;
405 
406   // Declare the parser.
407   SmallVector<MethodParameter> parserParams;
408   parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
409   if (isa<AttrDef>(&def))
410     parserParams.emplace_back("::mlir::Type", "odsType");
411   auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
412                                   hasAssemblyFormat ? Method::Static
413                                                     : Method::StaticDeclaration,
414                                   std::move(parserParams));
415   // Declare the printer.
416   auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
417   Method *printer =
418       defCls.addMethod("void", "print", props,
419                        MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
420   // Emit the bodies if we are using the declarative format.
421   if (hasAssemblyFormat)
422     return generateAttrOrTypeFormat(def, parser->body(), printer->body());
423 }
424 
425 void DefGen::emitAccessors() {
426   for (auto &param : params) {
427     Method *m = defCls.addMethod(
428         param.getCppAccessorType(), param.getAccessorName(),
429         def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
430     // Generate accessor definitions only if we also generate the storage
431     // class. Otherwise, let the user define the exact accessor definition.
432     if (!def.genStorageClass())
433       continue;
434     m->body().indent() << "return getImpl()->" << param.getName() << ";";
435   }
436 }
437 
438 void DefGen::emitInterfaceMethods() {
439   for (auto &traitDef : def.getTraits())
440     if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
441       if (trait->shouldDeclareMethods())
442         emitTraitMethods(*trait);
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // Builder Emission
447 
448 SmallVector<MethodParameter>
449 DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
450   SmallVector<MethodParameter> builderParams;
451   builderParams.append(prefix.begin(), prefix.end());
452   for (auto &param : params)
453     builderParams.emplace_back(param.getCppType(), param.getName());
454   return builderParams;
455 }
456 
457 void DefGen::emitDefaultBuilder() {
458   Method *m = defCls.addStaticMethod(
459       def.getCppClassName(), "get",
460       getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
461   MethodBody &body = m->body().indent();
462   auto scope = body.scope("return Base::get(context", ");");
463   for (const auto &param : params)
464     body << ", std::move(" << param.getName() << ")";
465 }
466 
467 void DefGen::emitCheckedBuilder() {
468   Method *m = defCls.addStaticMethod(
469       def.getCppClassName(), "getChecked",
470       getBuilderParams(
471           {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
472            {"::mlir::MLIRContext *", "context"}}));
473   MethodBody &body = m->body().indent();
474   auto scope = body.scope("return Base::getChecked(emitError, context", ");");
475   for (const auto &param : params)
476     body << ", " << param.getName();
477 }
478 
479 static SmallVector<MethodParameter>
480 getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
481                        const AttrOrTypeBuilder &builder) {
482   auto params = builder.getParameters();
483   SmallVector<MethodParameter> builderParams;
484   builderParams.append(prefix.begin(), prefix.end());
485   if (!builder.hasInferredContextParameter())
486     builderParams.emplace_back("::mlir::MLIRContext *", "context");
487   for (auto &param : params) {
488     builderParams.emplace_back(param.getCppType(), *param.getName(),
489                                param.getDefaultValue());
490   }
491   return builderParams;
492 }
493 
494 void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
495   // Don't emit a body if there isn't one.
496   auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
497   StringRef returnType = def.getCppClassName();
498   if (std::optional<StringRef> builderReturnType = builder.getReturnType())
499     returnType = *builderReturnType;
500   Method *m = defCls.addMethod(returnType, "get", props,
501                                getCustomBuilderParams({}, builder));
502   if (!builder.getBody())
503     return;
504 
505   // Format the body and emit it.
506   FmtContext ctx;
507   ctx.addSubst("_get", "Base::get");
508   if (!builder.hasInferredContextParameter())
509     ctx.addSubst("_ctxt", "context");
510   std::string bodyStr = tgfmt(*builder.getBody(), &ctx);
511   m->body().indent().getStream().printReindented(bodyStr);
512 }
513 
514 /// Replace all instances of 'from' to 'to' in `str` and return the new string.
515 static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
516   size_t pos = 0;
517   while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
518     str.replace(pos, from.size(), to.data(), to.size());
519   return str;
520 }
521 
522 void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
523   // Don't emit a body if there isn't one.
524   auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
525   StringRef returnType = def.getCppClassName();
526   if (std::optional<StringRef> builderReturnType = builder.getReturnType())
527     returnType = *builderReturnType;
528   Method *m = defCls.addMethod(
529       returnType, "getChecked", props,
530       getCustomBuilderParams(
531           {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
532           builder));
533   if (!builder.getBody())
534     return;
535 
536   // Format the body and emit it. Replace $_get(...) with
537   // Base::getChecked(emitError, ...)
538   FmtContext ctx;
539   if (!builder.hasInferredContextParameter())
540     ctx.addSubst("_ctxt", "context");
541   std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(",
542                                      "Base::getChecked(emitError, ");
543   bodyStr = tgfmt(bodyStr, &ctx);
544   m->body().indent().getStream().printReindented(bodyStr);
545 }
546 
547 //===----------------------------------------------------------------------===//
548 // Interface Method Emission
549 
550 void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
551   // Get the set of methods that should always be declared.
552   auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
553   StringSet<> alwaysDeclared;
554   alwaysDeclared.insert(alwaysDeclaredMethods.begin(),
555                         alwaysDeclaredMethods.end());
556 
557   Interface iface = trait.getInterface(); // causes strange bugs if elided
558   for (auto &method : iface.getMethods()) {
559     // Don't declare if the method has a body. Or if the method has a default
560     // implementation and the def didn't request that it always be declared.
561     if (method.getBody() || (method.getDefaultImplementation() &&
562                              !alwaysDeclared.count(method.getName())))
563       continue;
564     emitTraitMethod(method);
565   }
566 }
567 
568 void DefGen::emitTraitMethod(const InterfaceMethod &method) {
569   // All interface methods are declaration-only.
570   auto props =
571       method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
572   SmallVector<MethodParameter> params;
573   for (auto &param : method.getArguments())
574     params.emplace_back(param.type, param.name);
575   defCls.addMethod(method.getReturnType(), method.getName(), props,
576                    std::move(params));
577 }
578 
579 //===----------------------------------------------------------------------===//
580 // Storage Class Emission
581 
582 void DefGen::emitStorageConstructor() {
583   Constructor *ctor =
584       storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
585   for (auto &param : params) {
586     std::string movedValue = ("std::move(" + param.getName() + ")").str();
587     ctor->addMemberInitializer(param.getName(), movedValue);
588   }
589 }
590 
591 void DefGen::emitKeyType() {
592   std::string keyType("std::tuple<");
593   llvm::raw_string_ostream os(keyType);
594   llvm::interleaveComma(params, os,
595                         [&](auto &param) { os << param.getCppType(); });
596   os << '>';
597   storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
598 
599   // Add a method to construct the key type from the storage.
600   Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey");
601   m->body().indent() << "return KeyTy(";
602   llvm::interleaveComma(params, m->body().indent(),
603                         [&](auto &param) { m->body() << param.getName(); });
604   m->body() << ");";
605 }
606 
607 void DefGen::emitEquals() {
608   Method *eq = storageCls->addConstMethod<Method::Inline>(
609       "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
610   auto &body = eq->body().indent();
611   auto scope = body.scope("return (", ");");
612   const auto eachFn = [&](auto it) {
613     FmtContext ctx({{"_lhs", it.value().getName()},
614                     {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
615     body << tgfmt(it.value().getComparator(), &ctx);
616   };
617   llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
618 }
619 
620 void DefGen::emitHashKey() {
621   Method *hash = storageCls->addStaticInlineMethod(
622       "::llvm::hash_code", "hashKey",
623       MethodParameter("const KeyTy &", "tblgenKey"));
624   auto &body = hash->body().indent();
625   auto scope = body.scope("return ::llvm::hash_combine(", ");");
626   llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) {
627     body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
628   });
629 }
630 
631 void DefGen::emitConstruct() {
632   Method *construct = storageCls->addMethod<Method::Inline>(
633       strfmt("{0} *", def.getStorageClassName()), "construct",
634       def.hasStorageCustomConstructor() ? Method::StaticDeclaration
635                                         : Method::Static,
636       MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
637                       "allocator"),
638       MethodParameter("KeyTy &&", "tblgenKey"));
639   if (!def.hasStorageCustomConstructor()) {
640     auto &body = construct->body().indent();
641     for (const auto &it : llvm::enumerate(params)) {
642       body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
643                       it.value().getName(), it.index());
644     }
645     // Use the parameters' custom allocator code, if provided.
646     FmtContext ctx = FmtContext().addSubst("_allocator", "allocator");
647     for (auto &param : params) {
648       if (std::optional<StringRef> allocCode = param.getAllocator()) {
649         ctx.withSelf(param.getName()).addSubst("_dst", param.getName());
650         body << tgfmt(*allocCode, &ctx) << '\n';
651       }
652     }
653     auto scope =
654         body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
655                           def.getStorageClassName()),
656                    ");");
657     llvm::interleaveComma(params, body, [&](auto &param) {
658       body << "std::move(" << param.getName() << ")";
659     });
660   }
661 }
662 
663 void DefGen::emitStorageClass() {
664   // Add the appropriate parent class.
665   storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
666   // Add the constructor.
667   emitStorageConstructor();
668   // Declare the key type.
669   emitKeyType();
670   // Add the comparison method.
671   emitEquals();
672   // Emit the key hash method.
673   emitHashKey();
674   // Emit the storage constructor. Just declare it if the user wants to define
675   // it themself.
676   emitConstruct();
677   // Emit the storage class members as public, at the very end of the struct.
678   storageCls->finalize();
679   for (auto &param : params)
680     storageCls->declare<Field>(param.getCppType(), param.getName());
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // DefGenerator
685 //===----------------------------------------------------------------------===//
686 
687 namespace {
688 /// This struct is the base generator used when processing tablegen interfaces.
689 class DefGenerator {
690 public:
691   bool emitDecls(StringRef selectedDialect);
692   bool emitDefs(StringRef selectedDialect);
693 
694 protected:
695   DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
696                StringRef defType, StringRef valueType, bool isAttrGenerator)
697       : defRecords(defs), os(os), defType(defType), valueType(valueType),
698         isAttrGenerator(isAttrGenerator) {
699     // Sort by occurrence in file.
700     llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
701       return lhs->getID() < rhs->getID();
702     });
703   }
704 
705   /// Emit the list of def type names.
706   void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
707   /// Emit the code to dispatch between different defs during parsing/printing.
708   void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
709 
710   /// The set of def records to emit.
711   std::vector<const Record *> defRecords;
712   /// The attribute or type class to emit.
713   /// The stream to emit to.
714   raw_ostream &os;
715   /// The prefix of the tablegen def name, e.g. Attr or Type.
716   StringRef defType;
717   /// The C++ base value type of the def, e.g. Attribute or Type.
718   StringRef valueType;
719   /// Flag indicating if this generator is for Attributes. False if the
720   /// generator is for types.
721   bool isAttrGenerator;
722 };
723 
724 /// A specialized generator for AttrDefs.
725 struct AttrDefGenerator : public DefGenerator {
726   AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
727       : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
728                      "Attr", "Attribute", /*isAttrGenerator=*/true) {}
729 };
730 /// A specialized generator for TypeDefs.
731 struct TypeDefGenerator : public DefGenerator {
732   TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
733       : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
734                      "Type", "Type", /*isAttrGenerator=*/false) {}
735 };
736 } // namespace
737 
738 //===----------------------------------------------------------------------===//
739 // GEN: Declarations
740 //===----------------------------------------------------------------------===//
741 
742 /// Print this above all the other declarations. Contains type declarations used
743 /// later on.
744 static const char *const typeDefDeclHeader = R"(
745 namespace mlir {
746 class AsmParser;
747 class AsmPrinter;
748 } // namespace mlir
749 )";
750 
751 bool DefGenerator::emitDecls(StringRef selectedDialect) {
752   emitSourceFileHeader((defType + "Def Declarations").str(), os);
753   IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
754 
755   // Output the common "header".
756   os << typeDefDeclHeader;
757 
758   SmallVector<AttrOrTypeDef, 16> defs;
759   collectAllDefs(selectedDialect, defRecords, defs);
760   if (defs.empty())
761     return false;
762   {
763     NamespaceEmitter nsEmitter(os, defs.front().getDialect());
764 
765     // Declare all the def classes first (in case they reference each other).
766     for (const AttrOrTypeDef &def : defs)
767       os << "class " << def.getCppClassName() << ";\n";
768 
769     // Emit the declarations.
770     for (const AttrOrTypeDef &def : defs)
771       DefGen(def).emitDecl(os);
772   }
773   // Emit the TypeID explicit specializations to have a single definition for
774   // each of these.
775   for (const AttrOrTypeDef &def : defs)
776     if (!def.getDialect().getCppNamespace().empty())
777       os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
778          << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
779          << ")\n";
780 
781   return false;
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // GEN: Def List
786 //===----------------------------------------------------------------------===//
787 
788 void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
789   IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
790   auto interleaveFn = [&](const AttrOrTypeDef &def) {
791     os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
792   };
793   llvm::interleave(defs, os, interleaveFn, ",\n");
794   os << "\n";
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // GEN: Definitions
799 //===----------------------------------------------------------------------===//
800 
801 /// The code block for default attribute parser/printer dispatch boilerplate.
802 /// {0}: the dialect fully qualified class name.
803 /// {1}: the optional code for the dynamic attribute parser dispatch.
804 /// {2}: the optional code for the dynamic attribute printer dispatch.
805 static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
806 /// Parse an attribute registered to this dialect.
807 ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
808                                       ::mlir::Type type) const {{
809   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
810   ::llvm::StringRef attrTag;
811   {{
812     ::mlir::Attribute attr;
813     auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
814     if (parseResult.has_value())
815       return attr;
816   }
817   {1}
818   parser.emitError(typeLoc) << "unknown attribute `"
819       << attrTag << "` in dialect `" << getNamespace() << "`";
820   return {{};
821 }
822 /// Print an attribute registered to this dialect.
823 void {0}::printAttribute(::mlir::Attribute attr,
824                          ::mlir::DialectAsmPrinter &printer) const {{
825   if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
826     return;
827   {2}
828 }
829 )";
830 
831 /// The code block for dynamic attribute parser dispatch boilerplate.
832 static const char *const dialectDynamicAttrParserDispatch = R"(
833   {
834     ::mlir::Attribute genAttr;
835     auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
836     if (parseResult.has_value()) {
837       if (::mlir::succeeded(parseResult.value()))
838         return genAttr;
839       return Attribute();
840     }
841   }
842 )";
843 
844 /// The code block for dynamic type printer dispatch boilerplate.
845 static const char *const dialectDynamicAttrPrinterDispatch = R"(
846   if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
847     return;
848 )";
849 
850 /// The code block for default type parser/printer dispatch boilerplate.
851 /// {0}: the dialect fully qualified class name.
852 /// {1}: the optional code for the dynamic type parser dispatch.
853 /// {2}: the optional code for the dynamic type printer dispatch.
854 static const char *const dialectDefaultTypePrinterParserDispatch = R"(
855 /// Parse a type registered to this dialect.
856 ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
857   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
858   ::llvm::StringRef mnemonic;
859   ::mlir::Type genType;
860   auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
861   if (parseResult.has_value())
862     return genType;
863   {1}
864   parser.emitError(typeLoc) << "unknown  type `"
865       << mnemonic << "` in dialect `" << getNamespace() << "`";
866   return {{};
867 }
868 /// Print a type registered to this dialect.
869 void {0}::printType(::mlir::Type type,
870                     ::mlir::DialectAsmPrinter &printer) const {{
871   if (::mlir::succeeded(generatedTypePrinter(type, printer)))
872     return;
873   {2}
874 }
875 )";
876 
877 /// The code block for dynamic type parser dispatch boilerplate.
878 static const char *const dialectDynamicTypeParserDispatch = R"(
879   {
880     auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
881     if (parseResult.has_value()) {
882       if (::mlir::succeeded(parseResult.value()))
883         return genType;
884       return ::mlir::Type();
885     }
886   }
887 )";
888 
889 /// The code block for dynamic type printer dispatch boilerplate.
890 static const char *const dialectDynamicTypePrinterDispatch = R"(
891   if (::mlir::succeeded(printIfDynamicType(type, printer)))
892     return;
893 )";
894 
895 /// Emit the dialect printer/parser dispatcher. User's code should call these
896 /// functions from their dialect's print/parse methods.
897 void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
898   if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
899         return def.getMnemonic().has_value();
900       })) {
901     return;
902   }
903   // Declare the parser.
904   SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
905                                          {"::llvm::StringRef *", "mnemonic"}};
906   if (isAttrGenerator)
907     params.emplace_back("::mlir::Type", "type");
908   params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
909   Method parse("::mlir::OptionalParseResult",
910                strfmt("generated{0}Parser", valueType), Method::StaticInline,
911                std::move(params));
912   // Declare the printer.
913   Method printer("::llvm::LogicalResult",
914                  strfmt("generated{0}Printer", valueType), Method::StaticInline,
915                  {{strfmt("::mlir::{0}", valueType), "def"},
916                   {"::mlir::AsmPrinter &", "printer"}});
917 
918   // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
919   // calling the def's parse function.
920   parse.body() << "  return "
921                   "::mlir::AsmParser::KeywordSwitch<::mlir::"
922                   "OptionalParseResult>(parser)\n";
923   const char *const getValueForMnemonic =
924       R"(    .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
925       value = {0}::{1};
926       return ::mlir::success(!!value);
927     })
928 )";
929 
930   // The printer dispatch uses llvm::TypeSwitch to find and call the correct
931   // printer.
932   printer.body() << "  return ::llvm::TypeSwitch<::mlir::" << valueType
933                  << ", ::llvm::LogicalResult>(def)";
934   const char *const printValue = R"(    .Case<{0}>([&](auto t) {{
935       printer << {0}::getMnemonic();{1}
936       return ::mlir::success();
937     })
938 )";
939   for (auto &def : defs) {
940     if (!def.getMnemonic())
941       continue;
942     bool hasParserPrinterDecl =
943         def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
944     std::string defClass = strfmt(
945         "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
946 
947     // If the def has no parameters or parser code, invoke a normal `get`.
948     std::string parseOrGet =
949         hasParserPrinterDecl
950             ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
951             : "get(parser.getContext())";
952     parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
953 
954     // If the def has no parameters and no printer, just print the mnemonic.
955     StringRef printDef = "";
956     if (hasParserPrinterDecl)
957       printDef = "\nt.print(printer);";
958     printer.body() << llvm::formatv(printValue, defClass, printDef);
959   }
960   parse.body() << "    .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
961                   "      *mnemonic = keyword;\n"
962                   "      return std::nullopt;\n"
963                   "    });";
964   printer.body() << "    .Default([](auto) { return ::mlir::failure(); });";
965 
966   raw_indented_ostream indentedOs(os);
967   parse.writeDeclTo(indentedOs);
968   printer.writeDeclTo(indentedOs);
969 }
970 
971 bool DefGenerator::emitDefs(StringRef selectedDialect) {
972   emitSourceFileHeader((defType + "Def Definitions").str(), os);
973 
974   SmallVector<AttrOrTypeDef, 16> defs;
975   collectAllDefs(selectedDialect, defRecords, defs);
976   if (defs.empty())
977     return false;
978   emitTypeDefList(defs);
979 
980   IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
981   emitParsePrintDispatch(defs);
982   for (const AttrOrTypeDef &def : defs) {
983     {
984       NamespaceEmitter ns(os, def.getDialect());
985       DefGen gen(def);
986       gen.emitDef(os);
987     }
988     // Emit the TypeID explicit specializations to have a single symbol def.
989     if (!def.getDialect().getCppNamespace().empty())
990       os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
991          << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
992          << ")\n";
993   }
994 
995   Dialect firstDialect = defs.front().getDialect();
996 
997   // Emit the default parser/printer for Attributes if the dialect asked for it.
998   if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
999     NamespaceEmitter nsEmitter(os, firstDialect);
1000     if (firstDialect.isExtensible()) {
1001       os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
1002                           firstDialect.getCppClassName(),
1003                           dialectDynamicAttrParserDispatch,
1004                           dialectDynamicAttrPrinterDispatch);
1005     } else {
1006       os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
1007                           firstDialect.getCppClassName(), "", "");
1008     }
1009   }
1010 
1011   // Emit the default parser/printer for Types if the dialect asked for it.
1012   if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
1013     NamespaceEmitter nsEmitter(os, firstDialect);
1014     if (firstDialect.isExtensible()) {
1015       os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
1016                           firstDialect.getCppClassName(),
1017                           dialectDynamicTypeParserDispatch,
1018                           dialectDynamicTypePrinterDispatch);
1019     } else {
1020       os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
1021                           firstDialect.getCppClassName(), "", "");
1022     }
1023   }
1024 
1025   return false;
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // Type Constraints
1030 //===----------------------------------------------------------------------===//
1031 
1032 /// Find all type constraints for which a C++ function should be generated.
1033 static std::vector<Constraint>
1034 getAllTypeConstraints(const RecordKeeper &records) {
1035   std::vector<Constraint> result;
1036   for (const Record *def :
1037        records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
1038     // Ignore constraints defined outside of the top-level file.
1039     if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
1040         llvm::SrcMgr.getMainFileID())
1041       continue;
1042     Constraint constr(def);
1043     // Generate C++ function only if "cppFunctionName" is set.
1044     if (!constr.getCppFunctionName())
1045       continue;
1046     result.push_back(constr);
1047   }
1048   return result;
1049 }
1050 
1051 static void emitTypeConstraintDecls(const RecordKeeper &records,
1052                                     raw_ostream &os) {
1053   static const char *const typeConstraintDecl = R"(
1054 bool {0}(::mlir::Type type);
1055 )";
1056 
1057   for (Constraint constr : getAllTypeConstraints(records))
1058     os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
1059 }
1060 
1061 static void emitTypeConstraintDefs(const RecordKeeper &records,
1062                                    raw_ostream &os) {
1063   static const char *const typeConstraintDef = R"(
1064 bool {0}(::mlir::Type type) {
1065   return ({1});
1066 }
1067 )";
1068 
1069   for (Constraint constr : getAllTypeConstraints(records)) {
1070     FmtContext ctx;
1071     ctx.withSelf("type");
1072     std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
1073     os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
1074   }
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // GEN: Registration hooks
1079 //===----------------------------------------------------------------------===//
1080 
1081 //===----------------------------------------------------------------------===//
1082 // AttrDef
1083 
1084 static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
1085 static llvm::cl::opt<std::string>
1086     attrDialect("attrdefs-dialect",
1087                 llvm::cl::desc("Generate attributes for this dialect"),
1088                 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
1089 
1090 static mlir::GenRegistration
1091     genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
1092                 [](const RecordKeeper &records, raw_ostream &os) {
1093                   AttrDefGenerator generator(records, os);
1094                   return generator.emitDefs(attrDialect);
1095                 });
1096 static mlir::GenRegistration
1097     genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
1098                  [](const RecordKeeper &records, raw_ostream &os) {
1099                    AttrDefGenerator generator(records, os);
1100                    return generator.emitDecls(attrDialect);
1101                  });
1102 
1103 //===----------------------------------------------------------------------===//
1104 // TypeDef
1105 
1106 static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
1107 static llvm::cl::opt<std::string>
1108     typeDialect("typedefs-dialect",
1109                 llvm::cl::desc("Generate types for this dialect"),
1110                 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
1111 
1112 static mlir::GenRegistration
1113     genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
1114                 [](const RecordKeeper &records, raw_ostream &os) {
1115                   TypeDefGenerator generator(records, os);
1116                   return generator.emitDefs(typeDialect);
1117                 });
1118 static mlir::GenRegistration
1119     genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
1120                  [](const RecordKeeper &records, raw_ostream &os) {
1121                    TypeDefGenerator generator(records, os);
1122                    return generator.emitDecls(typeDialect);
1123                  });
1124 
1125 static mlir::GenRegistration
1126     genTypeConstrDefs("gen-type-constraint-defs",
1127                       "Generate type constraint definitions",
1128                       [](const RecordKeeper &records, raw_ostream &os) {
1129                         emitTypeConstraintDefs(records, os);
1130                         return false;
1131                       });
1132 static mlir::GenRegistration
1133     genTypeConstrDecls("gen-type-constraint-decls",
1134                        "Generate type constraint declarations",
1135                        [](const RecordKeeper &records, raw_ostream &os) {
1136                          emitTypeConstraintDecls(records, os);
1137                          return false;
1138                        });
1139