xref: /llvm-project/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (revision bccd37f69fdc7b5cd00d9231cabbe74bfe38f598)
183ef862fSRiver Riddle //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===//
283ef862fSRiver Riddle //
383ef862fSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
483ef862fSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
583ef862fSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
683ef862fSRiver Riddle //
783ef862fSRiver Riddle //===----------------------------------------------------------------------===//
883ef862fSRiver Riddle 
99a2fdc36SJeff Niu #include "AttrOrTypeFormatGen.h"
1083ef862fSRiver Riddle #include "mlir/TableGen/AttrOrTypeDef.h"
11ca6bd9cdSMogball #include "mlir/TableGen/Class.h"
1283ef862fSRiver Riddle #include "mlir/TableGen/CodeGenHelpers.h"
1383ef862fSRiver Riddle #include "mlir/TableGen/Format.h"
1483ef862fSRiver Riddle #include "mlir/TableGen/GenInfo.h"
1594662ee0SRiver Riddle #include "mlir/TableGen/Interfaces.h"
1694662ee0SRiver Riddle #include "llvm/ADT/StringSet.h"
1783ef862fSRiver Riddle #include "llvm/Support/CommandLine.h"
1883ef862fSRiver Riddle #include "llvm/TableGen/Error.h"
1983ef862fSRiver Riddle #include "llvm/TableGen/TableGenBackend.h"
2083ef862fSRiver Riddle 
2183ef862fSRiver Riddle #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen"
2283ef862fSRiver Riddle 
2383ef862fSRiver Riddle using namespace mlir;
2483ef862fSRiver Riddle using namespace mlir::tblgen;
25*bccd37f6SRahul Joshi using llvm::Record;
26*bccd37f6SRahul Joshi using llvm::RecordKeeper;
2783ef862fSRiver Riddle 
289a2fdc36SJeff Niu //===----------------------------------------------------------------------===//
299a2fdc36SJeff Niu // Utility Functions
309a2fdc36SJeff Niu //===----------------------------------------------------------------------===//
319a2fdc36SJeff Niu 
3283ef862fSRiver Riddle /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
3383ef862fSRiver Riddle /// specified and can only find one dialect's defs, use that.
3483ef862fSRiver Riddle static void collectAllDefs(StringRef selectedDialect,
35*bccd37f6SRahul Joshi                            ArrayRef<const Record *> records,
3683ef862fSRiver Riddle                            SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
3775dfeef9SMogball   // Nothing to do if no defs were found.
3875dfeef9SMogball   if (records.empty())
3975dfeef9SMogball     return;
4075dfeef9SMogball 
4183ef862fSRiver Riddle   auto defs = llvm::map_range(
42*bccd37f6SRahul Joshi       records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
4383ef862fSRiver Riddle   if (selectedDialect.empty()) {
4475dfeef9SMogball     // If a dialect was not specified, ensure that all found defs belong to the
4575dfeef9SMogball     // same dialect.
466fa87ec1SJakub Kuderski     if (!llvm::all_equal(llvm::map_range(
4778fdbdbfSMehdi Amini             defs, [](const auto &def) { return def.getDialect(); }))) {
4883ef862fSRiver Riddle       llvm::PrintFatalError("defs belonging to more than one dialect. Must "
4983ef862fSRiver Riddle                             "select one via '--(attr|type)defs-dialect'");
5083ef862fSRiver Riddle     }
51ca6bd9cdSMogball     resultDefs.assign(defs.begin(), defs.end());
5283ef862fSRiver Riddle   } else {
5375dfeef9SMogball     // Otherwise, generate the defs that belong to the selected dialect.
5478fdbdbfSMehdi Amini     auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
55dec8055aSKazu Hirata       return def.getDialect().getName() == selectedDialect;
56ca6bd9cdSMogball     });
57ca6bd9cdSMogball     resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
5883ef862fSRiver Riddle   }
5983ef862fSRiver Riddle }
6083ef862fSRiver Riddle 
6183ef862fSRiver Riddle //===----------------------------------------------------------------------===//
62ca6bd9cdSMogball // DefGen
6383ef862fSRiver Riddle //===----------------------------------------------------------------------===//
6483ef862fSRiver Riddle 
6583ef862fSRiver Riddle namespace {
66ca6bd9cdSMogball class DefGen {
6783ef862fSRiver Riddle public:
68ca6bd9cdSMogball   /// Create the attribute or type class.
69ca6bd9cdSMogball   DefGen(const AttrOrTypeDef &def);
7083ef862fSRiver Riddle 
71ca6bd9cdSMogball   void emitDecl(raw_ostream &os) const {
72cf40fde4SHideto Ueno     if (storageCls && def.genStorageClass()) {
73ca6bd9cdSMogball       NamespaceEmitter ns(os, def.getStorageNamespace());
74ca6bd9cdSMogball       os << "struct " << def.getStorageClassName() << ";\n";
7583ef862fSRiver Riddle     }
76ca6bd9cdSMogball     defCls.writeDeclTo(os);
77ca6bd9cdSMogball   }
78ca6bd9cdSMogball   void emitDef(raw_ostream &os) const {
79ca6bd9cdSMogball     if (storageCls && def.genStorageClass()) {
80ca6bd9cdSMogball       NamespaceEmitter ns(os, def.getStorageNamespace());
81ca6bd9cdSMogball       storageCls->writeDeclTo(os); // everything is inline
82ca6bd9cdSMogball     }
83ca6bd9cdSMogball     defCls.writeDefTo(os);
8483ef862fSRiver Riddle   }
8583ef862fSRiver Riddle 
8683ef862fSRiver Riddle private:
87ca6bd9cdSMogball   /// Add traits from the TableGen definition to the class.
88ca6bd9cdSMogball   void createParentWithTraits();
89ca6bd9cdSMogball   /// Emit top-level declarations: using declarations and any extra class
90ca6bd9cdSMogball   /// declarations.
91ca6bd9cdSMogball   void emitTopLevelDeclarations();
923dbac2c0SFehr Mathieu   /// Emit the function that returns the type or attribute name.
933dbac2c0SFehr Mathieu   void emitName();
9407c157a4SJeremy Kun   /// Emit the dialect name as a static member variable.
9507c157a4SJeremy Kun   void emitDialectName();
96ca6bd9cdSMogball   /// Emit attribute or type builders.
97ca6bd9cdSMogball   void emitBuilders();
987359a6b7SMatthias Springer   /// Emit a verifier declaration for custom verification (impl. provided by
997359a6b7SMatthias Springer   /// the users).
1007359a6b7SMatthias Springer   void emitVerifierDecl();
1017359a6b7SMatthias Springer   /// Emit a verifier that checks type constraints.
1027359a6b7SMatthias Springer   void emitInvariantsVerifierImpl();
1037359a6b7SMatthias Springer   /// Emit an entry poiunt for verification that calls the invariants and
1047359a6b7SMatthias Springer   /// custom verifier.
1057359a6b7SMatthias Springer   void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
106ca6bd9cdSMogball   /// Emit parsers and printers.
107ca6bd9cdSMogball   void emitParserPrinter();
108ca6bd9cdSMogball   /// Emit parameter accessors, if required.
109ca6bd9cdSMogball   void emitAccessors();
110ca6bd9cdSMogball   /// Emit interface methods.
111ca6bd9cdSMogball   void emitInterfaceMethods();
11283ef862fSRiver Riddle 
113ca6bd9cdSMogball   //===--------------------------------------------------------------------===//
114ca6bd9cdSMogball   // Builder Emission
115ca6bd9cdSMogball 
116ca6bd9cdSMogball   /// Emit the default builder `Attribute::get`
117ca6bd9cdSMogball   void emitDefaultBuilder();
118ca6bd9cdSMogball   /// Emit the checked builder `Attribute::getChecked`
119ca6bd9cdSMogball   void emitCheckedBuilder();
120ca6bd9cdSMogball   /// Emit a custom builder.
121ca6bd9cdSMogball   void emitCustomBuilder(const AttrOrTypeBuilder &builder);
122ca6bd9cdSMogball   /// Emit a checked custom builder.
123ca6bd9cdSMogball   void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
124ca6bd9cdSMogball 
125ca6bd9cdSMogball   //===--------------------------------------------------------------------===//
126ca6bd9cdSMogball   // Interface Method Emission
127ca6bd9cdSMogball 
128ca6bd9cdSMogball   /// Emit methods for a trait.
129ca6bd9cdSMogball   void emitTraitMethods(const InterfaceTrait &trait);
130ca6bd9cdSMogball   /// Emit a trait method.
131ca6bd9cdSMogball   void emitTraitMethod(const InterfaceMethod &method);
132ca6bd9cdSMogball 
133ca6bd9cdSMogball   //===--------------------------------------------------------------------===//
134ca6bd9cdSMogball   // Storage Class Emission
135ca6bd9cdSMogball   void emitStorageClass();
136ca6bd9cdSMogball   /// Generate the storage class constructor.
137ca6bd9cdSMogball   void emitStorageConstructor();
138ca6bd9cdSMogball   /// Emit the key type `KeyTy`.
139ca6bd9cdSMogball   void emitKeyType();
140ca6bd9cdSMogball   /// Emit the equality comparison operator.
141ca6bd9cdSMogball   void emitEquals();
142ca6bd9cdSMogball   /// Emit the key hash function.
143ca6bd9cdSMogball   void emitHashKey();
144ca6bd9cdSMogball   /// Emit the function to construct the storage class.
145ca6bd9cdSMogball   void emitConstruct();
146ca6bd9cdSMogball 
147ca6bd9cdSMogball   //===--------------------------------------------------------------------===//
148ca6bd9cdSMogball   // Utility Function Declarations
149ca6bd9cdSMogball 
150ca6bd9cdSMogball   /// Get the method parameters for a def builder, where the first several
151ca6bd9cdSMogball   /// parameters may be different.
152ca6bd9cdSMogball   SmallVector<MethodParameter>
153ca6bd9cdSMogball   getBuilderParams(std::initializer_list<MethodParameter> prefix) const;
154ca6bd9cdSMogball 
155ca6bd9cdSMogball   //===--------------------------------------------------------------------===//
156ca6bd9cdSMogball   // Class fields
157ca6bd9cdSMogball 
158ca6bd9cdSMogball   /// The attribute or type definition.
159ca6bd9cdSMogball   const AttrOrTypeDef &def;
160ca6bd9cdSMogball   /// The list of attribute or type parameters.
16183ef862fSRiver Riddle   ArrayRef<AttrOrTypeParameter> params;
162ca6bd9cdSMogball   /// The attribute or type class.
163ca6bd9cdSMogball   Class defCls;
164ca6bd9cdSMogball   /// An optional attribute or type storage class. The storage class will
165ca6bd9cdSMogball   /// exist if and only if the def has more than zero parameters.
1663cfe412eSFangrui Song   std::optional<Class> storageCls;
16783ef862fSRiver Riddle 
168ca6bd9cdSMogball   /// The C++ base value of the def, either "Attribute" or "Type".
169ca6bd9cdSMogball   StringRef valueType;
170ca6bd9cdSMogball   /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
171ca6bd9cdSMogball   StringRef defType;
172ca6bd9cdSMogball };
173be0a7e9fSMehdi Amini } // namespace
17483ef862fSRiver Riddle 
175ca6bd9cdSMogball DefGen::DefGen(const AttrOrTypeDef &def)
176ca6bd9cdSMogball     : def(def), params(def.getParameters()), defCls(def.getCppClassName()),
177ca6bd9cdSMogball       valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
178ca6bd9cdSMogball       defType(isa<AttrDef>(def) ? "Attr" : "Type") {
179761bc83aSMogball   // Check that all parameters have names.
180761bc83aSMogball   for (const AttrOrTypeParameter &param : def.getParameters())
181761bc83aSMogball     if (param.isAnonymous())
182761bc83aSMogball       llvm::PrintFatalError("all parameters must have a name");
183761bc83aSMogball 
184ca6bd9cdSMogball   // If a storage class is needed, create one.
185ca6bd9cdSMogball   if (def.getNumParameters() > 0)
186ca6bd9cdSMogball     storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
187ca6bd9cdSMogball 
188ca6bd9cdSMogball   // Create the parent class with any indicated traits.
189ca6bd9cdSMogball   createParentWithTraits();
190ca6bd9cdSMogball   // Emit top-level declarations.
191ca6bd9cdSMogball   emitTopLevelDeclarations();
192ca6bd9cdSMogball   // Emit builders for defs with parameters
193ca6bd9cdSMogball   if (storageCls)
194ca6bd9cdSMogball     emitBuilders();
1953dbac2c0SFehr Mathieu   // Emit the type name.
1963dbac2c0SFehr Mathieu   emitName();
19707c157a4SJeremy Kun   // Emit the dialect name.
19807c157a4SJeremy Kun   emitDialectName();
1997359a6b7SMatthias Springer   // Emit verification of type constraints.
2007359a6b7SMatthias Springer   bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
2017359a6b7SMatthias Springer   if (storageCls && genVerifyInvariantsImpl)
2027359a6b7SMatthias Springer     emitInvariantsVerifierImpl();
2037359a6b7SMatthias Springer   // Emit the custom verifier (written by the user).
2047359a6b7SMatthias Springer   bool genVerifyDecl = def.genVerifyDecl();
2057359a6b7SMatthias Springer   if (storageCls && genVerifyDecl)
2067359a6b7SMatthias Springer     emitVerifierDecl();
2077359a6b7SMatthias Springer   // Emit the "verifyInvariants" function if there is any verification at all.
2087359a6b7SMatthias Springer   if (storageCls)
2097359a6b7SMatthias Springer     emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
210ca6bd9cdSMogball   // Emit the mnemonic, if there is one, and any associated parser and printer.
211ca6bd9cdSMogball   if (def.getMnemonic())
212ca6bd9cdSMogball     emitParserPrinter();
213ca6bd9cdSMogball   // Emit accessors
214ca6bd9cdSMogball   if (def.genAccessors())
215ca6bd9cdSMogball     emitAccessors();
216ca6bd9cdSMogball   // Emit trait interface methods
217ca6bd9cdSMogball   emitInterfaceMethods();
218ca6bd9cdSMogball   defCls.finalize();
219ca6bd9cdSMogball   // Emit a storage class if one is needed
220ca6bd9cdSMogball   if (storageCls && def.genStorageClass())
221ca6bd9cdSMogball     emitStorageClass();
222ca6bd9cdSMogball }
223ca6bd9cdSMogball 
224ca6bd9cdSMogball void DefGen::createParentWithTraits() {
225ca6bd9cdSMogball   ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType));
226ca6bd9cdSMogball   defParent.addTemplateParam(def.getCppClassName());
227ca6bd9cdSMogball   defParent.addTemplateParam(def.getCppBaseClassName());
228ca6bd9cdSMogball   defParent.addTemplateParam(storageCls
229ca6bd9cdSMogball                                  ? strfmt("{0}::{1}", def.getStorageNamespace(),
230ca6bd9cdSMogball                                           def.getStorageClassName())
231ca6bd9cdSMogball                                  : strfmt("::mlir::{0}Storage", valueType));
232ca6bd9cdSMogball   for (auto &trait : def.getTraits()) {
233ca6bd9cdSMogball     defParent.addTemplateParam(
234ca6bd9cdSMogball         isa<NativeTrait>(&trait)
235ca6bd9cdSMogball             ? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
236ca6bd9cdSMogball             : cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
237ca6bd9cdSMogball   }
238ca6bd9cdSMogball   defCls.addParent(std::move(defParent));
239ca6bd9cdSMogball }
240ca6bd9cdSMogball 
24147b0a9b9SAmanda Tang /// Include declarations specified on NativeTrait
24247b0a9b9SAmanda Tang static std::string formatExtraDeclarations(const AttrOrTypeDef &def) {
24347b0a9b9SAmanda Tang   SmallVector<StringRef> extraDeclarations;
24447b0a9b9SAmanda Tang   // Include extra class declarations from NativeTrait
24547b0a9b9SAmanda Tang   for (const auto &trait : def.getTraits()) {
24647b0a9b9SAmanda Tang     if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
24747b0a9b9SAmanda Tang       StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration();
24847b0a9b9SAmanda Tang       if (value.empty())
24947b0a9b9SAmanda Tang         continue;
25047b0a9b9SAmanda Tang       extraDeclarations.push_back(value);
25147b0a9b9SAmanda Tang     }
25247b0a9b9SAmanda Tang   }
25347b0a9b9SAmanda Tang   if (std::optional<StringRef> extraDecl = def.getExtraDecls()) {
25447b0a9b9SAmanda Tang     extraDeclarations.push_back(*extraDecl);
25547b0a9b9SAmanda Tang   }
25647b0a9b9SAmanda Tang   return llvm::join(extraDeclarations, "\n");
25747b0a9b9SAmanda Tang }
25847b0a9b9SAmanda Tang 
2597f76471eSbhatuzdaname /// Extra class definitions have a `$cppClass` substitution that is to be
2607f76471eSbhatuzdaname /// replaced by the C++ class name.
2617f76471eSbhatuzdaname static std::string formatExtraDefinitions(const AttrOrTypeDef &def) {
26247b0a9b9SAmanda Tang   SmallVector<StringRef> extraDefinitions;
26347b0a9b9SAmanda Tang   // Include extra class definitions from NativeTrait
26447b0a9b9SAmanda Tang   for (const auto &trait : def.getTraits()) {
26547b0a9b9SAmanda Tang     if (auto *attrOrTypeTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
26647b0a9b9SAmanda Tang       StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition();
26747b0a9b9SAmanda Tang       if (value.empty())
26847b0a9b9SAmanda Tang         continue;
26947b0a9b9SAmanda Tang       extraDefinitions.push_back(value);
2707f76471eSbhatuzdaname     }
27147b0a9b9SAmanda Tang   }
27247b0a9b9SAmanda Tang   if (std::optional<StringRef> extraDef = def.getExtraDefs()) {
27347b0a9b9SAmanda Tang     extraDefinitions.push_back(*extraDef);
27447b0a9b9SAmanda Tang   }
27547b0a9b9SAmanda Tang   FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName());
27647b0a9b9SAmanda Tang   return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
2777f76471eSbhatuzdaname }
2787f76471eSbhatuzdaname 
279ca6bd9cdSMogball void DefGen::emitTopLevelDeclarations() {
280ca6bd9cdSMogball   // Inherit constructors from the attribute or type class.
281ca6bd9cdSMogball   defCls.declare<VisibilityDeclaration>(Visibility::Public);
282ca6bd9cdSMogball   defCls.declare<UsingDeclaration>("Base::Base");
283ca6bd9cdSMogball 
284ca6bd9cdSMogball   // Emit the extra declarations first in case there's a definition in there.
28547b0a9b9SAmanda Tang   std::string extraDecl = formatExtraDeclarations(def);
2867f76471eSbhatuzdaname   std::string extraDef = formatExtraDefinitions(def);
28747b0a9b9SAmanda Tang   defCls.declare<ExtraClassDeclaration>(std::move(extraDecl),
2887f76471eSbhatuzdaname                                         std::move(extraDef));
289ca6bd9cdSMogball }
290ca6bd9cdSMogball 
2913dbac2c0SFehr Mathieu void DefGen::emitName() {
2923dbac2c0SFehr Mathieu   StringRef name;
2933dbac2c0SFehr Mathieu   if (auto *attrDef = dyn_cast<AttrDef>(&def)) {
2943dbac2c0SFehr Mathieu     name = attrDef->getAttrName();
2953dbac2c0SFehr Mathieu   } else {
2963dbac2c0SFehr Mathieu     auto *typeDef = cast<TypeDef>(&def);
2973dbac2c0SFehr Mathieu     name = typeDef->getTypeName();
2983dbac2c0SFehr Mathieu   }
2993dbac2c0SFehr Mathieu   std::string nameDecl =
3003dbac2c0SFehr Mathieu       strfmt("static constexpr ::llvm::StringLiteral name = \"{0}\";\n", name);
3013dbac2c0SFehr Mathieu   defCls.declare<ExtraClassDeclaration>(std::move(nameDecl));
3023dbac2c0SFehr Mathieu }
3033dbac2c0SFehr Mathieu 
30407c157a4SJeremy Kun void DefGen::emitDialectName() {
30507c157a4SJeremy Kun   std::string decl =
30607c157a4SJeremy Kun       strfmt("static constexpr ::llvm::StringLiteral dialectName = \"{0}\";\n",
30707c157a4SJeremy Kun              def.getDialect().getName());
30807c157a4SJeremy Kun   defCls.declare<ExtraClassDeclaration>(std::move(decl));
30907c157a4SJeremy Kun }
31007c157a4SJeremy Kun 
311ca6bd9cdSMogball void DefGen::emitBuilders() {
312ca6bd9cdSMogball   if (!def.skipDefaultBuilders()) {
313ca6bd9cdSMogball     emitDefaultBuilder();
3147359a6b7SMatthias Springer     if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
315ca6bd9cdSMogball       emitCheckedBuilder();
316ca6bd9cdSMogball   }
317ca6bd9cdSMogball   for (auto &builder : def.getBuilders()) {
318ca6bd9cdSMogball     emitCustomBuilder(builder);
3197359a6b7SMatthias Springer     if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
320ca6bd9cdSMogball       emitCheckedCustomBuilder(builder);
321ca6bd9cdSMogball   }
322ca6bd9cdSMogball }
323ca6bd9cdSMogball 
3247359a6b7SMatthias Springer void DefGen::emitVerifierDecl() {
325ca6bd9cdSMogball   defCls.declareStaticMethod(
326db791b27SRamkumar Ramachandra       "::llvm::LogicalResult", "verify",
327ca6bd9cdSMogball       getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
328ca6bd9cdSMogball                          "emitError"}}));
329ca6bd9cdSMogball }
330ca6bd9cdSMogball 
3317359a6b7SMatthias Springer static const char *const patternParameterVerificationCode = R"(
3327359a6b7SMatthias Springer if (!({0})) {
3337359a6b7SMatthias Springer   emitError() << "failed to verify '{1}': {2}";
3347359a6b7SMatthias Springer   return ::mlir::failure();
3357359a6b7SMatthias Springer }
3367359a6b7SMatthias Springer )";
3377359a6b7SMatthias Springer 
3387359a6b7SMatthias Springer void DefGen::emitInvariantsVerifierImpl() {
3397359a6b7SMatthias Springer   SmallVector<MethodParameter> builderParams = getBuilderParams(
3407359a6b7SMatthias Springer       {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
3417359a6b7SMatthias Springer   Method *verifier =
3427359a6b7SMatthias Springer       defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
3437359a6b7SMatthias Springer                        Method::Static, builderParams);
3447359a6b7SMatthias Springer   verifier->body().indent();
3457359a6b7SMatthias Springer 
3467359a6b7SMatthias Springer   // Generate verification for each parameter that is a type constraint.
3477359a6b7SMatthias Springer   for (auto it : llvm::enumerate(def.getParameters())) {
3487359a6b7SMatthias Springer     const AttrOrTypeParameter &param = it.value();
3497359a6b7SMatthias Springer     std::optional<Constraint> constraint = param.getConstraint();
3507359a6b7SMatthias Springer     // No verification needed for parameters that are not type constraints.
3517359a6b7SMatthias Springer     if (!constraint.has_value())
3527359a6b7SMatthias Springer       continue;
3537359a6b7SMatthias Springer     FmtContext ctx;
3547359a6b7SMatthias Springer     // Note: Skip over the first method parameter (`emitError`).
3557359a6b7SMatthias Springer     ctx.withSelf(builderParams[it.index() + 1].getName());
3567359a6b7SMatthias Springer     std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
3577359a6b7SMatthias Springer     verifier->body() << formatv(patternParameterVerificationCode, condition,
3587359a6b7SMatthias Springer                                 param.getName(), constraint->getSummary())
3597359a6b7SMatthias Springer                      << "\n";
3607359a6b7SMatthias Springer   }
3617359a6b7SMatthias Springer   verifier->body() << "return ::mlir::success();";
3627359a6b7SMatthias Springer }
3637359a6b7SMatthias Springer 
3647359a6b7SMatthias Springer void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
3657359a6b7SMatthias Springer   if (!hasImpl && !hasCustomVerifier)
3667359a6b7SMatthias Springer     return;
3677359a6b7SMatthias Springer   defCls.declare<UsingDeclaration>("Base::getChecked");
3687359a6b7SMatthias Springer   SmallVector<MethodParameter> builderParams = getBuilderParams(
3697359a6b7SMatthias Springer       {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
3707359a6b7SMatthias Springer   Method *verifier =
3717359a6b7SMatthias Springer       defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
3727359a6b7SMatthias Springer                        Method::Static, builderParams);
3737359a6b7SMatthias Springer   verifier->body().indent();
3747359a6b7SMatthias Springer 
3757359a6b7SMatthias Springer   auto emitVerifierCall = [&](StringRef name) {
3767359a6b7SMatthias Springer     verifier->body() << strfmt("if (::mlir::failed({0}(", name);
3777359a6b7SMatthias Springer     llvm::interleaveComma(
3787359a6b7SMatthias Springer         llvm::map_range(builderParams,
3797359a6b7SMatthias Springer                         [](auto &param) { return param.getName(); }),
3807359a6b7SMatthias Springer         verifier->body());
3817359a6b7SMatthias Springer     verifier->body() << ")))\n";
3827359a6b7SMatthias Springer     verifier->body() << "  return ::mlir::failure();\n";
3837359a6b7SMatthias Springer   };
3847359a6b7SMatthias Springer 
3857359a6b7SMatthias Springer   if (hasImpl) {
3867359a6b7SMatthias Springer     // Call the verifier that checks the type constraints.
3877359a6b7SMatthias Springer     emitVerifierCall("verifyInvariantsImpl");
3887359a6b7SMatthias Springer   }
3897359a6b7SMatthias Springer   if (hasCustomVerifier) {
3907359a6b7SMatthias Springer     // Call the custom verifier that is provided by the user.
3917359a6b7SMatthias Springer     emitVerifierCall("verify");
3927359a6b7SMatthias Springer   }
3937359a6b7SMatthias Springer   verifier->body() << "return ::mlir::success();";
3947359a6b7SMatthias Springer }
3957359a6b7SMatthias Springer 
396ca6bd9cdSMogball void DefGen::emitParserPrinter() {
397ca6bd9cdSMogball   auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
398ca6bd9cdSMogball       "::llvm::StringLiteral", "getMnemonic");
399ecaad4a8SMogball   mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
40023e3cbe2SRiver Riddle 
401ca6bd9cdSMogball   // Declare the parser and printer, if needed.
4020916d96dSKazu Hirata   bool hasAssemblyFormat = def.getAssemblyFormat().has_value();
40323e3cbe2SRiver Riddle   if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
404ca6bd9cdSMogball     return;
405ca6bd9cdSMogball 
406ca6bd9cdSMogball   // Declare the parser.
407ca6bd9cdSMogball   SmallVector<MethodParameter> parserParams;
40807486395SMogball   parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
409ca6bd9cdSMogball   if (isa<AttrDef>(&def))
41007486395SMogball     parserParams.emplace_back("::mlir::Type", "odsType");
41123e3cbe2SRiver Riddle   auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
41223e3cbe2SRiver Riddle                                   hasAssemblyFormat ? Method::Static
41323e3cbe2SRiver Riddle                                                     : Method::StaticDeclaration,
414ca6bd9cdSMogball                                   std::move(parserParams));
415ca6bd9cdSMogball   // Declare the printer.
41623e3cbe2SRiver Riddle   auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
417ca6bd9cdSMogball   Method *printer =
418ca6bd9cdSMogball       defCls.addMethod("void", "print", props,
41907486395SMogball                        MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
42023e3cbe2SRiver Riddle   // Emit the bodies if we are using the declarative format.
42123e3cbe2SRiver Riddle   if (hasAssemblyFormat)
42223e3cbe2SRiver Riddle     return generateAttrOrTypeFormat(def, parser->body(), printer->body());
423ca6bd9cdSMogball }
424ca6bd9cdSMogball 
425ca6bd9cdSMogball void DefGen::emitAccessors() {
426ca6bd9cdSMogball   for (auto &param : params) {
427ca6bd9cdSMogball     Method *m = defCls.addMethod(
428baca1c1aSMogball         param.getCppAccessorType(), param.getAccessorName(),
429ca6bd9cdSMogball         def.genStorageClass() ? Method::Const : Method::ConstDeclaration);
430ca6bd9cdSMogball     // Generate accessor definitions only if we also generate the storage
431ca6bd9cdSMogball     // class. Otherwise, let the user define the exact accessor definition.
432ca6bd9cdSMogball     if (!def.genStorageClass())
433ca6bd9cdSMogball       continue;
434e1795322SJeff Niu     m->body().indent() << "return getImpl()->" << param.getName() << ";";
435ca6bd9cdSMogball   }
436ca6bd9cdSMogball }
437ca6bd9cdSMogball 
438ca6bd9cdSMogball void DefGen::emitInterfaceMethods() {
439ca6bd9cdSMogball   for (auto &traitDef : def.getTraits())
440ca6bd9cdSMogball     if (auto *trait = dyn_cast<InterfaceTrait>(&traitDef))
441ca6bd9cdSMogball       if (trait->shouldDeclareMethods())
442ca6bd9cdSMogball         emitTraitMethods(*trait);
443ca6bd9cdSMogball }
444ca6bd9cdSMogball 
445ca6bd9cdSMogball //===----------------------------------------------------------------------===//
446ca6bd9cdSMogball // Builder Emission
447ca6bd9cdSMogball 
448ca6bd9cdSMogball SmallVector<MethodParameter>
449ca6bd9cdSMogball DefGen::getBuilderParams(std::initializer_list<MethodParameter> prefix) const {
450ca6bd9cdSMogball   SmallVector<MethodParameter> builderParams;
451ca6bd9cdSMogball   builderParams.append(prefix.begin(), prefix.end());
452ca6bd9cdSMogball   for (auto &param : params)
4531461bd13SMehdi Amini     builderParams.emplace_back(param.getCppType(), param.getName());
454ca6bd9cdSMogball   return builderParams;
455ca6bd9cdSMogball }
456ca6bd9cdSMogball 
457ca6bd9cdSMogball void DefGen::emitDefaultBuilder() {
458ca6bd9cdSMogball   Method *m = defCls.addStaticMethod(
459ca6bd9cdSMogball       def.getCppClassName(), "get",
460ca6bd9cdSMogball       getBuilderParams({{"::mlir::MLIRContext *", "context"}}));
461ca6bd9cdSMogball   MethodBody &body = m->body().indent();
462ca6bd9cdSMogball   auto scope = body.scope("return Base::get(context", ");");
463c730f9a1SKazu Hirata   for (const auto &param : params)
4645fc28e7aSMehdi Amini     body << ", std::move(" << param.getName() << ")";
465ca6bd9cdSMogball }
466ca6bd9cdSMogball 
467ca6bd9cdSMogball void DefGen::emitCheckedBuilder() {
468ca6bd9cdSMogball   Method *m = defCls.addStaticMethod(
469ca6bd9cdSMogball       def.getCppClassName(), "getChecked",
470ca6bd9cdSMogball       getBuilderParams(
471ca6bd9cdSMogball           {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"},
472ca6bd9cdSMogball            {"::mlir::MLIRContext *", "context"}}));
473ca6bd9cdSMogball   MethodBody &body = m->body().indent();
474ca6bd9cdSMogball   auto scope = body.scope("return Base::getChecked(emitError, context", ");");
475c730f9a1SKazu Hirata   for (const auto &param : params)
476c730f9a1SKazu Hirata     body << ", " << param.getName();
477ca6bd9cdSMogball }
478ca6bd9cdSMogball 
479ca6bd9cdSMogball static SmallVector<MethodParameter>
480ca6bd9cdSMogball getCustomBuilderParams(std::initializer_list<MethodParameter> prefix,
481ca6bd9cdSMogball                        const AttrOrTypeBuilder &builder) {
482ca6bd9cdSMogball   auto params = builder.getParameters();
483ca6bd9cdSMogball   SmallVector<MethodParameter> builderParams;
484ca6bd9cdSMogball   builderParams.append(prefix.begin(), prefix.end());
485ca6bd9cdSMogball   if (!builder.hasInferredContextParameter())
486ca6bd9cdSMogball     builderParams.emplace_back("::mlir::MLIRContext *", "context");
487ca6bd9cdSMogball   for (auto &param : params) {
488ca6bd9cdSMogball     builderParams.emplace_back(param.getCppType(), *param.getName(),
489ca6bd9cdSMogball                                param.getDefaultValue());
490ca6bd9cdSMogball   }
491ca6bd9cdSMogball   return builderParams;
492ca6bd9cdSMogball }
493ca6bd9cdSMogball 
494ca6bd9cdSMogball void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) {
495ca6bd9cdSMogball   // Don't emit a body if there isn't one.
496ca6bd9cdSMogball   auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
4977fe2294eSJeff Niu   StringRef returnType = def.getCppClassName();
4983cfe412eSFangrui Song   if (std::optional<StringRef> builderReturnType = builder.getReturnType())
4997fe2294eSJeff Niu     returnType = *builderReturnType;
5007fe2294eSJeff Niu   Method *m = defCls.addMethod(returnType, "get", props,
501ca6bd9cdSMogball                                getCustomBuilderParams({}, builder));
502ca6bd9cdSMogball   if (!builder.getBody())
503ca6bd9cdSMogball     return;
504ca6bd9cdSMogball 
505ca6bd9cdSMogball   // Format the body and emit it.
506ca6bd9cdSMogball   FmtContext ctx;
507ca6bd9cdSMogball   ctx.addSubst("_get", "Base::get");
508ca6bd9cdSMogball   if (!builder.hasInferredContextParameter())
509ca6bd9cdSMogball     ctx.addSubst("_ctxt", "context");
510ca6bd9cdSMogball   std::string bodyStr = tgfmt(*builder.getBody(), &ctx);
511ca6bd9cdSMogball   m->body().indent().getStream().printReindented(bodyStr);
512ca6bd9cdSMogball }
513ca6bd9cdSMogball 
514ca6bd9cdSMogball /// Replace all instances of 'from' to 'to' in `str` and return the new string.
515ca6bd9cdSMogball static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
516ca6bd9cdSMogball   size_t pos = 0;
517ca6bd9cdSMogball   while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
518ca6bd9cdSMogball     str.replace(pos, from.size(), to.data(), to.size());
519ca6bd9cdSMogball   return str;
520ca6bd9cdSMogball }
521ca6bd9cdSMogball 
522ca6bd9cdSMogball void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
523ca6bd9cdSMogball   // Don't emit a body if there isn't one.
524ca6bd9cdSMogball   auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration;
5257fe2294eSJeff Niu   StringRef returnType = def.getCppClassName();
5263cfe412eSFangrui Song   if (std::optional<StringRef> builderReturnType = builder.getReturnType())
5277fe2294eSJeff Niu     returnType = *builderReturnType;
528ca6bd9cdSMogball   Method *m = defCls.addMethod(
5297fe2294eSJeff Niu       returnType, "getChecked", props,
530ca6bd9cdSMogball       getCustomBuilderParams(
531ca6bd9cdSMogball           {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}},
532ca6bd9cdSMogball           builder));
533ca6bd9cdSMogball   if (!builder.getBody())
534ca6bd9cdSMogball     return;
535ca6bd9cdSMogball 
536ca6bd9cdSMogball   // Format the body and emit it. Replace $_get(...) with
537ca6bd9cdSMogball   // Base::getChecked(emitError, ...)
538ca6bd9cdSMogball   FmtContext ctx;
539ca6bd9cdSMogball   if (!builder.hasInferredContextParameter())
540ca6bd9cdSMogball     ctx.addSubst("_ctxt", "context");
541ca6bd9cdSMogball   std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(",
542ca6bd9cdSMogball                                      "Base::getChecked(emitError, ");
543ca6bd9cdSMogball   bodyStr = tgfmt(bodyStr, &ctx);
544ca6bd9cdSMogball   m->body().indent().getStream().printReindented(bodyStr);
545ca6bd9cdSMogball }
546ca6bd9cdSMogball 
547ca6bd9cdSMogball //===----------------------------------------------------------------------===//
548ca6bd9cdSMogball // Interface Method Emission
549ca6bd9cdSMogball 
550ca6bd9cdSMogball void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
551ca6bd9cdSMogball   // Get the set of methods that should always be declared.
552ca6bd9cdSMogball   auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods();
553ca6bd9cdSMogball   StringSet<> alwaysDeclared;
554ca6bd9cdSMogball   alwaysDeclared.insert(alwaysDeclaredMethods.begin(),
555ca6bd9cdSMogball                         alwaysDeclaredMethods.end());
556ca6bd9cdSMogball 
557ca6bd9cdSMogball   Interface iface = trait.getInterface(); // causes strange bugs if elided
558ca6bd9cdSMogball   for (auto &method : iface.getMethods()) {
559ca6bd9cdSMogball     // Don't declare if the method has a body. Or if the method has a default
560ca6bd9cdSMogball     // implementation and the def didn't request that it always be declared.
561ca6bd9cdSMogball     if (method.getBody() || (method.getDefaultImplementation() &&
562ca6bd9cdSMogball                              !alwaysDeclared.count(method.getName())))
563ca6bd9cdSMogball       continue;
564ca6bd9cdSMogball     emitTraitMethod(method);
565ca6bd9cdSMogball   }
566ca6bd9cdSMogball }
567ca6bd9cdSMogball 
568ca6bd9cdSMogball void DefGen::emitTraitMethod(const InterfaceMethod &method) {
569ca6bd9cdSMogball   // All interface methods are declaration-only.
570ca6bd9cdSMogball   auto props =
571ca6bd9cdSMogball       method.isStatic() ? Method::StaticDeclaration : Method::ConstDeclaration;
572ca6bd9cdSMogball   SmallVector<MethodParameter> params;
573ca6bd9cdSMogball   for (auto &param : method.getArguments())
574ca6bd9cdSMogball     params.emplace_back(param.type, param.name);
575ca6bd9cdSMogball   defCls.addMethod(method.getReturnType(), method.getName(), props,
576ca6bd9cdSMogball                    std::move(params));
577ca6bd9cdSMogball }
578ca6bd9cdSMogball 
579ca6bd9cdSMogball //===----------------------------------------------------------------------===//
580ca6bd9cdSMogball // Storage Class Emission
581ca6bd9cdSMogball 
582ca6bd9cdSMogball void DefGen::emitStorageConstructor() {
583ca6bd9cdSMogball   Constructor *ctor =
584ca6bd9cdSMogball       storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
5855fc28e7aSMehdi Amini   for (auto &param : params) {
5865fc28e7aSMehdi Amini     std::string movedValue = ("std::move(" + param.getName() + ")").str();
5875fc28e7aSMehdi Amini     ctor->addMemberInitializer(param.getName(), movedValue);
5885fc28e7aSMehdi Amini   }
589ca6bd9cdSMogball }
590ca6bd9cdSMogball 
591ca6bd9cdSMogball void DefGen::emitKeyType() {
592ca6bd9cdSMogball   std::string keyType("std::tuple<");
593ca6bd9cdSMogball   llvm::raw_string_ostream os(keyType);
594ca6bd9cdSMogball   llvm::interleaveComma(params, os,
595ca6bd9cdSMogball                         [&](auto &param) { os << param.getCppType(); });
596ca6bd9cdSMogball   os << '>';
597ca6bd9cdSMogball   storageCls->declare<UsingDeclaration>("KeyTy", std::move(os.str()));
59838c219b4SRiver Riddle 
59938c219b4SRiver Riddle   // Add a method to construct the key type from the storage.
60038c219b4SRiver Riddle   Method *m = storageCls->addConstMethod<Method::Inline>("KeyTy", "getAsKey");
60138c219b4SRiver Riddle   m->body().indent() << "return KeyTy(";
60238c219b4SRiver Riddle   llvm::interleaveComma(params, m->body().indent(),
60338c219b4SRiver Riddle                         [&](auto &param) { m->body() << param.getName(); });
60438c219b4SRiver Riddle   m->body() << ");";
605ca6bd9cdSMogball }
606ca6bd9cdSMogball 
607ca6bd9cdSMogball void DefGen::emitEquals() {
608ca6bd9cdSMogball   Method *eq = storageCls->addConstMethod<Method::Inline>(
609ca6bd9cdSMogball       "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey"));
610ca6bd9cdSMogball   auto &body = eq->body().indent();
611ca6bd9cdSMogball   auto scope = body.scope("return (", ");");
612ca6bd9cdSMogball   const auto eachFn = [&](auto it) {
613e1795322SJeff Niu     FmtContext ctx({{"_lhs", it.value().getName()},
614ca6bd9cdSMogball                     {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
615761bc83aSMogball     body << tgfmt(it.value().getComparator(), &ctx);
616ca6bd9cdSMogball   };
617ca6bd9cdSMogball   llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
618ca6bd9cdSMogball }
619ca6bd9cdSMogball 
620ca6bd9cdSMogball void DefGen::emitHashKey() {
621ca6bd9cdSMogball   Method *hash = storageCls->addStaticInlineMethod(
622ca6bd9cdSMogball       "::llvm::hash_code", "hashKey",
623ca6bd9cdSMogball       MethodParameter("const KeyTy &", "tblgenKey"));
624ca6bd9cdSMogball   auto &body = hash->body().indent();
625ca6bd9cdSMogball   auto scope = body.scope("return ::llvm::hash_combine(", ");");
626ca6bd9cdSMogball   llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) {
627ca6bd9cdSMogball     body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index());
628ca6bd9cdSMogball   });
629ca6bd9cdSMogball }
630ca6bd9cdSMogball 
631ca6bd9cdSMogball void DefGen::emitConstruct() {
632ca6bd9cdSMogball   Method *construct = storageCls->addMethod<Method::Inline>(
633ca6bd9cdSMogball       strfmt("{0} *", def.getStorageClassName()), "construct",
634ca6bd9cdSMogball       def.hasStorageCustomConstructor() ? Method::StaticDeclaration
635ca6bd9cdSMogball                                         : Method::Static,
636ca6bd9cdSMogball       MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
637ca6bd9cdSMogball                       "allocator"),
6385fc28e7aSMehdi Amini       MethodParameter("KeyTy &&", "tblgenKey"));
639ca6bd9cdSMogball   if (!def.hasStorageCustomConstructor()) {
640ca6bd9cdSMogball     auto &body = construct->body().indent();
64189de9cc8SMehdi Amini     for (const auto &it : llvm::enumerate(params)) {
6425fc28e7aSMehdi Amini       body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
643ca6bd9cdSMogball                       it.value().getName(), it.index());
644ca6bd9cdSMogball     }
645ca6bd9cdSMogball     // Use the parameters' custom allocator code, if provided.
646ca6bd9cdSMogball     FmtContext ctx = FmtContext().addSubst("_allocator", "allocator");
647ca6bd9cdSMogball     for (auto &param : params) {
6483cfe412eSFangrui Song       if (std::optional<StringRef> allocCode = param.getAllocator()) {
649ca6bd9cdSMogball         ctx.withSelf(param.getName()).addSubst("_dst", param.getName());
650ca6bd9cdSMogball         body << tgfmt(*allocCode, &ctx) << '\n';
651ca6bd9cdSMogball       }
652ca6bd9cdSMogball     }
653ca6bd9cdSMogball     auto scope =
654ca6bd9cdSMogball         body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
655ca6bd9cdSMogball                           def.getStorageClassName()),
656ca6bd9cdSMogball                    ");");
6575fc28e7aSMehdi Amini     llvm::interleaveComma(params, body, [&](auto &param) {
6585fc28e7aSMehdi Amini       body << "std::move(" << param.getName() << ")";
6595fc28e7aSMehdi Amini     });
660ca6bd9cdSMogball   }
661ca6bd9cdSMogball }
662ca6bd9cdSMogball 
663ca6bd9cdSMogball void DefGen::emitStorageClass() {
664ca6bd9cdSMogball   // Add the appropriate parent class.
665ca6bd9cdSMogball   storageCls->addParent(strfmt("::mlir::{0}Storage", valueType));
666ca6bd9cdSMogball   // Add the constructor.
667ca6bd9cdSMogball   emitStorageConstructor();
668ca6bd9cdSMogball   // Declare the key type.
669ca6bd9cdSMogball   emitKeyType();
670ca6bd9cdSMogball   // Add the comparison method.
671ca6bd9cdSMogball   emitEquals();
672ca6bd9cdSMogball   // Emit the key hash method.
673ca6bd9cdSMogball   emitHashKey();
674ca6bd9cdSMogball   // Emit the storage constructor. Just declare it if the user wants to define
675ca6bd9cdSMogball   // it themself.
676ca6bd9cdSMogball   emitConstruct();
677ca6bd9cdSMogball   // Emit the storage class members as public, at the very end of the struct.
678ca6bd9cdSMogball   storageCls->finalize();
679ca6bd9cdSMogball   for (auto &param : params)
680ca6bd9cdSMogball     storageCls->declare<Field>(param.getCppType(), param.getName());
681ca6bd9cdSMogball }
682ca6bd9cdSMogball 
68383ef862fSRiver Riddle //===----------------------------------------------------------------------===//
68483ef862fSRiver Riddle // DefGenerator
68583ef862fSRiver Riddle //===----------------------------------------------------------------------===//
68683ef862fSRiver Riddle 
68783ef862fSRiver Riddle namespace {
68883ef862fSRiver Riddle /// This struct is the base generator used when processing tablegen interfaces.
68983ef862fSRiver Riddle class DefGenerator {
69083ef862fSRiver Riddle public:
69183ef862fSRiver Riddle   bool emitDecls(StringRef selectedDialect);
69283ef862fSRiver Riddle   bool emitDefs(StringRef selectedDialect);
69383ef862fSRiver Riddle 
69483ef862fSRiver Riddle protected:
695*bccd37f6SRahul Joshi   DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
6964957518eSMogball                StringRef defType, StringRef valueType, bool isAttrGenerator)
69716df489fSRahul Joshi       : defRecords(defs), os(os), defType(defType), valueType(valueType),
69816df489fSRahul Joshi         isAttrGenerator(isAttrGenerator) {
699308f58ceSJacques Pienaar     // Sort by occurrence in file.
700*bccd37f6SRahul Joshi     llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
701308f58ceSJacques Pienaar       return lhs->getID() < rhs->getID();
702308f58ceSJacques Pienaar     });
703308f58ceSJacques Pienaar   }
70483ef862fSRiver Riddle 
70583ef862fSRiver Riddle   /// Emit the list of def type names.
70683ef862fSRiver Riddle   void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs);
70783ef862fSRiver Riddle   /// Emit the code to dispatch between different defs during parsing/printing.
70883ef862fSRiver Riddle   void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
70983ef862fSRiver Riddle 
71083ef862fSRiver Riddle   /// The set of def records to emit.
711*bccd37f6SRahul Joshi   std::vector<const Record *> defRecords;
712ca6bd9cdSMogball   /// The attribute or type class to emit.
71383ef862fSRiver Riddle   /// The stream to emit to.
71483ef862fSRiver Riddle   raw_ostream &os;
71583ef862fSRiver Riddle   /// The prefix of the tablegen def name, e.g. Attr or Type.
716ca6bd9cdSMogball   StringRef defType;
71783ef862fSRiver Riddle   /// The C++ base value type of the def, e.g. Attribute or Type.
71883ef862fSRiver Riddle   StringRef valueType;
71983ef862fSRiver Riddle   /// Flag indicating if this generator is for Attributes. False if the
72083ef862fSRiver Riddle   /// generator is for types.
72183ef862fSRiver Riddle   bool isAttrGenerator;
72283ef862fSRiver Riddle };
72383ef862fSRiver Riddle 
72483ef862fSRiver Riddle /// A specialized generator for AttrDefs.
72583ef862fSRiver Riddle struct AttrDefGenerator : public DefGenerator {
726*bccd37f6SRahul Joshi   AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
7271d7120c6SRiver Riddle       : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
7284957518eSMogball                      "Attr", "Attribute", /*isAttrGenerator=*/true) {}
72983ef862fSRiver Riddle };
73083ef862fSRiver Riddle /// A specialized generator for TypeDefs.
73183ef862fSRiver Riddle struct TypeDefGenerator : public DefGenerator {
732*bccd37f6SRahul Joshi   TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
7331d7120c6SRiver Riddle       : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
7344957518eSMogball                      "Type", "Type", /*isAttrGenerator=*/false) {}
73583ef862fSRiver Riddle };
736be0a7e9fSMehdi Amini } // namespace
73783ef862fSRiver Riddle 
73883ef862fSRiver Riddle //===----------------------------------------------------------------------===//
73983ef862fSRiver Riddle // GEN: Declarations
74083ef862fSRiver Riddle //===----------------------------------------------------------------------===//
74183ef862fSRiver Riddle 
74283ef862fSRiver Riddle /// Print this above all the other declarations. Contains type declarations used
74383ef862fSRiver Riddle /// later on.
74483ef862fSRiver Riddle static const char *const typeDefDeclHeader = R"(
74583ef862fSRiver Riddle namespace mlir {
746f97e72aaSMehdi Amini class AsmParser;
747f97e72aaSMehdi Amini class AsmPrinter;
748be0a7e9fSMehdi Amini } // namespace mlir
74983ef862fSRiver Riddle )";
75083ef862fSRiver Riddle 
75183ef862fSRiver Riddle bool DefGenerator::emitDecls(StringRef selectedDialect) {
752ca6bd9cdSMogball   emitSourceFileHeader((defType + "Def Declarations").str(), os);
753ca6bd9cdSMogball   IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
75483ef862fSRiver Riddle 
75583ef862fSRiver Riddle   // Output the common "header".
75683ef862fSRiver Riddle   os << typeDefDeclHeader;
75783ef862fSRiver Riddle 
75883ef862fSRiver Riddle   SmallVector<AttrOrTypeDef, 16> defs;
75983ef862fSRiver Riddle   collectAllDefs(selectedDialect, defRecords, defs);
76083ef862fSRiver Riddle   if (defs.empty())
76183ef862fSRiver Riddle     return false;
7624bb0ad23SMehdi Amini   {
76383ef862fSRiver Riddle     NamespaceEmitter nsEmitter(os, defs.front().getDialect());
76483ef862fSRiver Riddle 
76583ef862fSRiver Riddle     // Declare all the def classes first (in case they reference each other).
76683ef862fSRiver Riddle     for (const AttrOrTypeDef &def : defs)
76783ef862fSRiver Riddle       os << "class " << def.getCppClassName() << ";\n";
76883ef862fSRiver Riddle 
76983ef862fSRiver Riddle     // Emit the declarations.
77083ef862fSRiver Riddle     for (const AttrOrTypeDef &def : defs)
771ca6bd9cdSMogball       DefGen(def).emitDecl(os);
7724bb0ad23SMehdi Amini   }
7734bb0ad23SMehdi Amini   // Emit the TypeID explicit specializations to have a single definition for
7744bb0ad23SMehdi Amini   // each of these.
7754bb0ad23SMehdi Amini   for (const AttrOrTypeDef &def : defs)
7764bb0ad23SMehdi Amini     if (!def.getDialect().getCppNamespace().empty())
7775e50dd04SRiver Riddle       os << "MLIR_DECLARE_EXPLICIT_TYPE_ID("
7785e50dd04SRiver Riddle          << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
7795e50dd04SRiver Riddle          << ")\n";
7804bb0ad23SMehdi Amini 
78183ef862fSRiver Riddle   return false;
78283ef862fSRiver Riddle }
78383ef862fSRiver Riddle 
78483ef862fSRiver Riddle //===----------------------------------------------------------------------===//
78583ef862fSRiver Riddle // GEN: Def List
78683ef862fSRiver Riddle //===----------------------------------------------------------------------===//
78783ef862fSRiver Riddle 
78883ef862fSRiver Riddle void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) {
789ca6bd9cdSMogball   IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os);
79083ef862fSRiver Riddle   auto interleaveFn = [&](const AttrOrTypeDef &def) {
79183ef862fSRiver Riddle     os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName();
79283ef862fSRiver Riddle   };
79383ef862fSRiver Riddle   llvm::interleave(defs, os, interleaveFn, ",\n");
79483ef862fSRiver Riddle   os << "\n";
79583ef862fSRiver Riddle }
79683ef862fSRiver Riddle 
79783ef862fSRiver Riddle //===----------------------------------------------------------------------===//
79883ef862fSRiver Riddle // GEN: Definitions
79983ef862fSRiver Riddle //===----------------------------------------------------------------------===//
80083ef862fSRiver Riddle 
801fd6b4041SMehdi Amini /// The code block for default attribute parser/printer dispatch boilerplate.
802fd6b4041SMehdi Amini /// {0}: the dialect fully qualified class name.
8039e0b5533SMathieu Fehr /// {1}: the optional code for the dynamic attribute parser dispatch.
8049e0b5533SMathieu Fehr /// {2}: the optional code for the dynamic attribute printer dispatch.
805fd6b4041SMehdi Amini static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
806fd6b4041SMehdi Amini /// Parse an attribute registered to this dialect.
807fd6b4041SMehdi Amini ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser,
808fd6b4041SMehdi Amini                                       ::mlir::Type type) const {{
809fd6b4041SMehdi Amini   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
810fd6b4041SMehdi Amini   ::llvm::StringRef attrTag;
811fd6b4041SMehdi Amini   {{
812fd6b4041SMehdi Amini     ::mlir::Attribute attr;
813fe4f512bSRiver Riddle     auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
8149750648cSKazu Hirata     if (parseResult.has_value())
815fd6b4041SMehdi Amini       return attr;
816fd6b4041SMehdi Amini   }
8179e0b5533SMathieu Fehr   {1}
818fd6b4041SMehdi Amini   parser.emitError(typeLoc) << "unknown attribute `"
819fd6b4041SMehdi Amini       << attrTag << "` in dialect `" << getNamespace() << "`";
820fd6b4041SMehdi Amini   return {{};
821fd6b4041SMehdi Amini }
822fd6b4041SMehdi Amini /// Print an attribute registered to this dialect.
823fd6b4041SMehdi Amini void {0}::printAttribute(::mlir::Attribute attr,
824fd6b4041SMehdi Amini                          ::mlir::DialectAsmPrinter &printer) const {{
8250a8a5902SMarkus Böck   if (::mlir::succeeded(generatedAttributePrinter(attr, printer)))
826fd6b4041SMehdi Amini     return;
8279e0b5533SMathieu Fehr   {2}
828fd6b4041SMehdi Amini }
829fd6b4041SMehdi Amini )";
830fd6b4041SMehdi Amini 
8319e0b5533SMathieu Fehr /// The code block for dynamic attribute parser dispatch boilerplate.
8329e0b5533SMathieu Fehr static const char *const dialectDynamicAttrParserDispatch = R"(
8339e0b5533SMathieu Fehr   {
8349e0b5533SMathieu Fehr     ::mlir::Attribute genAttr;
8359e0b5533SMathieu Fehr     auto parseResult = parseOptionalDynamicAttr(attrTag, parser, genAttr);
8369750648cSKazu Hirata     if (parseResult.has_value()) {
837c8e6ebd7SKazu Hirata       if (::mlir::succeeded(parseResult.value()))
8389e0b5533SMathieu Fehr         return genAttr;
8399e0b5533SMathieu Fehr       return Attribute();
8409e0b5533SMathieu Fehr     }
8419e0b5533SMathieu Fehr   }
8429e0b5533SMathieu Fehr )";
8439e0b5533SMathieu Fehr 
8449e0b5533SMathieu Fehr /// The code block for dynamic type printer dispatch boilerplate.
8459e0b5533SMathieu Fehr static const char *const dialectDynamicAttrPrinterDispatch = R"(
8469e0b5533SMathieu Fehr   if (::mlir::succeeded(printIfDynamicAttr(attr, printer)))
8479e0b5533SMathieu Fehr     return;
8489e0b5533SMathieu Fehr )";
8499e0b5533SMathieu Fehr 
850c27d85a9SMehdi Amini /// The code block for default type parser/printer dispatch boilerplate.
851c27d85a9SMehdi Amini /// {0}: the dialect fully qualified class name.
8529e0b5533SMathieu Fehr /// {1}: the optional code for the dynamic type parser dispatch.
8539e0b5533SMathieu Fehr /// {2}: the optional code for the dynamic type printer dispatch.
854c27d85a9SMehdi Amini static const char *const dialectDefaultTypePrinterParserDispatch = R"(
855c27d85a9SMehdi Amini /// Parse a type registered to this dialect.
856c27d85a9SMehdi Amini ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
8570a8a5902SMarkus Böck   ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
8580a8a5902SMarkus Böck   ::llvm::StringRef mnemonic;
8590a8a5902SMarkus Böck   ::mlir::Type genType;
860fe4f512bSRiver Riddle   auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
8619750648cSKazu Hirata   if (parseResult.has_value())
862c27d85a9SMehdi Amini     return genType;
8639e0b5533SMathieu Fehr   {1}
864c27d85a9SMehdi Amini   parser.emitError(typeLoc) << "unknown  type `"
865c27d85a9SMehdi Amini       << mnemonic << "` in dialect `" << getNamespace() << "`";
866c27d85a9SMehdi Amini   return {{};
867c27d85a9SMehdi Amini }
868c27d85a9SMehdi Amini /// Print a type registered to this dialect.
869c27d85a9SMehdi Amini void {0}::printType(::mlir::Type type,
870c27d85a9SMehdi Amini                     ::mlir::DialectAsmPrinter &printer) const {{
8710a8a5902SMarkus Böck   if (::mlir::succeeded(generatedTypePrinter(type, printer)))
872c27d85a9SMehdi Amini     return;
8739e0b5533SMathieu Fehr   {2}
874c27d85a9SMehdi Amini }
875c27d85a9SMehdi Amini )";
876c27d85a9SMehdi Amini 
8779e0b5533SMathieu Fehr /// The code block for dynamic type parser dispatch boilerplate.
8789e0b5533SMathieu Fehr static const char *const dialectDynamicTypeParserDispatch = R"(
8799e0b5533SMathieu Fehr   {
8809e0b5533SMathieu Fehr     auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
8813e83426cSMehdi Amini     if (parseResult.has_value()) {
88287a14216SMehdi Amini       if (::mlir::succeeded(parseResult.value()))
8839e0b5533SMathieu Fehr         return genType;
884a6a7a1baSMehdi Amini       return ::mlir::Type();
8859e0b5533SMathieu Fehr     }
8869e0b5533SMathieu Fehr   }
8879e0b5533SMathieu Fehr )";
8889e0b5533SMathieu Fehr 
8899e0b5533SMathieu Fehr /// The code block for dynamic type printer dispatch boilerplate.
8909e0b5533SMathieu Fehr static const char *const dialectDynamicTypePrinterDispatch = R"(
8919e0b5533SMathieu Fehr   if (::mlir::succeeded(printIfDynamicType(type, printer)))
8929e0b5533SMathieu Fehr     return;
8939e0b5533SMathieu Fehr )";
8949e0b5533SMathieu Fehr 
89583ef862fSRiver Riddle /// Emit the dialect printer/parser dispatcher. User's code should call these
89683ef862fSRiver Riddle /// functions from their dialect's print/parse methods.
89783ef862fSRiver Riddle void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
89883ef862fSRiver Riddle   if (llvm::none_of(defs, [](const AttrOrTypeDef &def) {
899064a08cdSKazu Hirata         return def.getMnemonic().has_value();
90083ef862fSRiver Riddle       })) {
90183ef862fSRiver Riddle     return;
90283ef862fSRiver Riddle   }
903ca6bd9cdSMogball   // Declare the parser.
904ca6bd9cdSMogball   SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
905fe4f512bSRiver Riddle                                          {"::llvm::StringRef *", "mnemonic"}};
906ca6bd9cdSMogball   if (isAttrGenerator)
907ca6bd9cdSMogball     params.emplace_back("::mlir::Type", "type");
908ca6bd9cdSMogball   params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
909ca6bd9cdSMogball   Method parse("::mlir::OptionalParseResult",
910ca6bd9cdSMogball                strfmt("generated{0}Parser", valueType), Method::StaticInline,
911ca6bd9cdSMogball                std::move(params));
912ca6bd9cdSMogball   // Declare the printer.
913db791b27SRamkumar Ramachandra   Method printer("::llvm::LogicalResult",
914ca6bd9cdSMogball                  strfmt("generated{0}Printer", valueType), Method::StaticInline,
915ca6bd9cdSMogball                  {{strfmt("::mlir::{0}", valueType), "def"},
916ca6bd9cdSMogball                   {"::mlir::AsmPrinter &", "printer"}});
91783ef862fSRiver Riddle 
918fe4f512bSRiver Riddle   // The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
919fe4f512bSRiver Riddle   // calling the def's parse function.
920fe4f512bSRiver Riddle   parse.body() << "  return "
921fe4f512bSRiver Riddle                   "::mlir::AsmParser::KeywordSwitch<::mlir::"
922fe4f512bSRiver Riddle                   "OptionalParseResult>(parser)\n";
923ca6bd9cdSMogball   const char *const getValueForMnemonic =
924fe4f512bSRiver Riddle       R"(    .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
925ca6bd9cdSMogball       value = {0}::{1};
926ca6bd9cdSMogball       return ::mlir::success(!!value);
927fe4f512bSRiver Riddle     })
928ca6bd9cdSMogball )";
929fe4f512bSRiver Riddle 
93083ef862fSRiver Riddle   // The printer dispatch uses llvm::TypeSwitch to find and call the correct
93183ef862fSRiver Riddle   // printer.
932ca6bd9cdSMogball   printer.body() << "  return ::llvm::TypeSwitch<::mlir::" << valueType
933db791b27SRamkumar Ramachandra                  << ", ::llvm::LogicalResult>(def)";
934ca6bd9cdSMogball   const char *const printValue = R"(    .Case<{0}>([&](auto t) {{
935ca6bd9cdSMogball       printer << {0}::getMnemonic();{1}
936ca6bd9cdSMogball       return ::mlir::success();
937ca6bd9cdSMogball     })
938ca6bd9cdSMogball )";
939ca6bd9cdSMogball   for (auto &def : defs) {
940ca6bd9cdSMogball     if (!def.getMnemonic())
94183ef862fSRiver Riddle       continue;
94223e3cbe2SRiver Riddle     bool hasParserPrinterDecl =
94323e3cbe2SRiver Riddle         def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
944ca6bd9cdSMogball     std::string defClass = strfmt(
945ca6bd9cdSMogball         "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
94623e3cbe2SRiver Riddle 
947ca6bd9cdSMogball     // If the def has no parameters or parser code, invoke a normal `get`.
948ca6bd9cdSMogball     std::string parseOrGet =
94923e3cbe2SRiver Riddle         hasParserPrinterDecl
950ca6bd9cdSMogball             ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
951ca6bd9cdSMogball             : "get(parser.getContext())";
952ca6bd9cdSMogball     parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
95383ef862fSRiver Riddle 
954f30a8a6fSMehdi Amini     // If the def has no parameters and no printer, just print the mnemonic.
955ca6bd9cdSMogball     StringRef printDef = "";
95623e3cbe2SRiver Riddle     if (hasParserPrinterDecl)
957ca6bd9cdSMogball       printDef = "\nt.print(printer);";
958ca6bd9cdSMogball     printer.body() << llvm::formatv(printValue, defClass, printDef);
95983ef862fSRiver Riddle   }
960fe4f512bSRiver Riddle   parse.body() << "    .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
961fe4f512bSRiver Riddle                   "      *mnemonic = keyword;\n"
96255210971SKazu Hirata                   "      return std::nullopt;\n"
963fe4f512bSRiver Riddle                   "    });";
964ca6bd9cdSMogball   printer.body() << "    .Default([](auto) { return ::mlir::failure(); });";
965ca6bd9cdSMogball 
966ca6bd9cdSMogball   raw_indented_ostream indentedOs(os);
967ca6bd9cdSMogball   parse.writeDeclTo(indentedOs);
968ca6bd9cdSMogball   printer.writeDeclTo(indentedOs);
96983ef862fSRiver Riddle }
97083ef862fSRiver Riddle 
97183ef862fSRiver Riddle bool DefGenerator::emitDefs(StringRef selectedDialect) {
972ca6bd9cdSMogball   emitSourceFileHeader((defType + "Def Definitions").str(), os);
97383ef862fSRiver Riddle 
97483ef862fSRiver Riddle   SmallVector<AttrOrTypeDef, 16> defs;
97583ef862fSRiver Riddle   collectAllDefs(selectedDialect, defRecords, defs);
97683ef862fSRiver Riddle   if (defs.empty())
97783ef862fSRiver Riddle     return false;
97883ef862fSRiver Riddle   emitTypeDefList(defs);
97983ef862fSRiver Riddle 
980ca6bd9cdSMogball   IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os);
98183ef862fSRiver Riddle   emitParsePrintDispatch(defs);
9824bb0ad23SMehdi Amini   for (const AttrOrTypeDef &def : defs) {
983ca6bd9cdSMogball     {
984ca6bd9cdSMogball       NamespaceEmitter ns(os, def.getDialect());
985ca6bd9cdSMogball       DefGen gen(def);
986ca6bd9cdSMogball       gen.emitDef(os);
987ca6bd9cdSMogball     }
9884bb0ad23SMehdi Amini     // Emit the TypeID explicit specializations to have a single symbol def.
9894bb0ad23SMehdi Amini     if (!def.getDialect().getCppNamespace().empty())
9905e50dd04SRiver Riddle       os << "MLIR_DEFINE_EXPLICIT_TYPE_ID("
9915e50dd04SRiver Riddle          << def.getDialect().getCppNamespace() << "::" << def.getCppClassName()
9925e50dd04SRiver Riddle          << ")\n";
9934bb0ad23SMehdi Amini   }
99483ef862fSRiver Riddle 
9950a8a5902SMarkus Böck   Dialect firstDialect = defs.front().getDialect();
9964957518eSMogball 
9974957518eSMogball   // Emit the default parser/printer for Attributes if the dialect asked for it.
9984957518eSMogball   if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) {
9990a8a5902SMarkus Böck     NamespaceEmitter nsEmitter(os, firstDialect);
10009e0b5533SMathieu Fehr     if (firstDialect.isExtensible()) {
1001fd6b4041SMehdi Amini       os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
10029e0b5533SMathieu Fehr                           firstDialect.getCppClassName(),
10039e0b5533SMathieu Fehr                           dialectDynamicAttrParserDispatch,
10049e0b5533SMathieu Fehr                           dialectDynamicAttrPrinterDispatch);
10059e0b5533SMathieu Fehr     } else {
10069e0b5533SMathieu Fehr       os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch,
10079e0b5533SMathieu Fehr                           firstDialect.getCppClassName(), "", "");
10089e0b5533SMathieu Fehr     }
10090a8a5902SMarkus Böck   }
1010fd6b4041SMehdi Amini 
1011c27d85a9SMehdi Amini   // Emit the default parser/printer for Types if the dialect asked for it.
10124957518eSMogball   if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) {
10130a8a5902SMarkus Böck     NamespaceEmitter nsEmitter(os, firstDialect);
10149e0b5533SMathieu Fehr     if (firstDialect.isExtensible()) {
1015c27d85a9SMehdi Amini       os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
10169e0b5533SMathieu Fehr                           firstDialect.getCppClassName(),
10179e0b5533SMathieu Fehr                           dialectDynamicTypeParserDispatch,
10189e0b5533SMathieu Fehr                           dialectDynamicTypePrinterDispatch);
10199e0b5533SMathieu Fehr     } else {
10209e0b5533SMathieu Fehr       os << llvm::formatv(dialectDefaultTypePrinterParserDispatch,
10219e0b5533SMathieu Fehr                           firstDialect.getCppClassName(), "", "");
10229e0b5533SMathieu Fehr     }
10230a8a5902SMarkus Böck   }
1024c27d85a9SMehdi Amini 
102583ef862fSRiver Riddle   return false;
102683ef862fSRiver Riddle }
102783ef862fSRiver Riddle 
102883ef862fSRiver Riddle //===----------------------------------------------------------------------===//
1029a3d41879SMatthias Springer // Type Constraints
1030a3d41879SMatthias Springer //===----------------------------------------------------------------------===//
1031a3d41879SMatthias Springer 
1032a3d41879SMatthias Springer /// Find all type constraints for which a C++ function should be generated.
1033a3d41879SMatthias Springer static std::vector<Constraint>
1034*bccd37f6SRahul Joshi getAllTypeConstraints(const RecordKeeper &records) {
1035a3d41879SMatthias Springer   std::vector<Constraint> result;
1036*bccd37f6SRahul Joshi   for (const Record *def :
1037a3d41879SMatthias Springer        records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
1038a3d41879SMatthias Springer     // Ignore constraints defined outside of the top-level file.
1039a3d41879SMatthias Springer     if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
1040a3d41879SMatthias Springer         llvm::SrcMgr.getMainFileID())
1041a3d41879SMatthias Springer       continue;
1042a3d41879SMatthias Springer     Constraint constr(def);
1043a3d41879SMatthias Springer     // Generate C++ function only if "cppFunctionName" is set.
1044a3d41879SMatthias Springer     if (!constr.getCppFunctionName())
1045a3d41879SMatthias Springer       continue;
1046a3d41879SMatthias Springer     result.push_back(constr);
1047a3d41879SMatthias Springer   }
1048a3d41879SMatthias Springer   return result;
1049a3d41879SMatthias Springer }
1050a3d41879SMatthias Springer 
1051*bccd37f6SRahul Joshi static void emitTypeConstraintDecls(const RecordKeeper &records,
1052a3d41879SMatthias Springer                                     raw_ostream &os) {
1053a3d41879SMatthias Springer   static const char *const typeConstraintDecl = R"(
1054a3d41879SMatthias Springer bool {0}(::mlir::Type type);
1055a3d41879SMatthias Springer )";
1056a3d41879SMatthias Springer 
1057a3d41879SMatthias Springer   for (Constraint constr : getAllTypeConstraints(records))
1058a3d41879SMatthias Springer     os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
1059a3d41879SMatthias Springer }
1060a3d41879SMatthias Springer 
1061*bccd37f6SRahul Joshi static void emitTypeConstraintDefs(const RecordKeeper &records,
1062a3d41879SMatthias Springer                                    raw_ostream &os) {
1063a3d41879SMatthias Springer   static const char *const typeConstraintDef = R"(
1064a3d41879SMatthias Springer bool {0}(::mlir::Type type) {
1065a3d41879SMatthias Springer   return ({1});
1066a3d41879SMatthias Springer }
1067a3d41879SMatthias Springer )";
1068a3d41879SMatthias Springer 
1069a3d41879SMatthias Springer   for (Constraint constr : getAllTypeConstraints(records)) {
1070a3d41879SMatthias Springer     FmtContext ctx;
1071a3d41879SMatthias Springer     ctx.withSelf("type");
1072a3d41879SMatthias Springer     std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
1073a3d41879SMatthias Springer     os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
1074a3d41879SMatthias Springer   }
1075a3d41879SMatthias Springer }
1076a3d41879SMatthias Springer 
1077a3d41879SMatthias Springer //===----------------------------------------------------------------------===//
107883ef862fSRiver Riddle // GEN: Registration hooks
107983ef862fSRiver Riddle //===----------------------------------------------------------------------===//
108083ef862fSRiver Riddle 
108183ef862fSRiver Riddle //===----------------------------------------------------------------------===//
108283ef862fSRiver Riddle // AttrDef
108383ef862fSRiver Riddle 
108483ef862fSRiver Riddle static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*");
108583ef862fSRiver Riddle static llvm::cl::opt<std::string>
108683ef862fSRiver Riddle     attrDialect("attrdefs-dialect",
108783ef862fSRiver Riddle                 llvm::cl::desc("Generate attributes for this dialect"),
108883ef862fSRiver Riddle                 llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated);
108983ef862fSRiver Riddle 
109083ef862fSRiver Riddle static mlir::GenRegistration
109183ef862fSRiver Riddle     genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
1092*bccd37f6SRahul Joshi                 [](const RecordKeeper &records, raw_ostream &os) {
109383ef862fSRiver Riddle                   AttrDefGenerator generator(records, os);
109483ef862fSRiver Riddle                   return generator.emitDefs(attrDialect);
109583ef862fSRiver Riddle                 });
109683ef862fSRiver Riddle static mlir::GenRegistration
109783ef862fSRiver Riddle     genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
1098*bccd37f6SRahul Joshi                  [](const RecordKeeper &records, raw_ostream &os) {
109983ef862fSRiver Riddle                    AttrDefGenerator generator(records, os);
110083ef862fSRiver Riddle                    return generator.emitDecls(attrDialect);
110183ef862fSRiver Riddle                  });
110283ef862fSRiver Riddle 
110383ef862fSRiver Riddle //===----------------------------------------------------------------------===//
110483ef862fSRiver Riddle // TypeDef
110583ef862fSRiver Riddle 
110683ef862fSRiver Riddle static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
110783ef862fSRiver Riddle static llvm::cl::opt<std::string>
110883ef862fSRiver Riddle     typeDialect("typedefs-dialect",
110983ef862fSRiver Riddle                 llvm::cl::desc("Generate types for this dialect"),
111083ef862fSRiver Riddle                 llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
111183ef862fSRiver Riddle 
111283ef862fSRiver Riddle static mlir::GenRegistration
111383ef862fSRiver Riddle     genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
1114*bccd37f6SRahul Joshi                 [](const RecordKeeper &records, raw_ostream &os) {
111583ef862fSRiver Riddle                   TypeDefGenerator generator(records, os);
111683ef862fSRiver Riddle                   return generator.emitDefs(typeDialect);
111783ef862fSRiver Riddle                 });
111883ef862fSRiver Riddle static mlir::GenRegistration
111983ef862fSRiver Riddle     genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
1120*bccd37f6SRahul Joshi                  [](const RecordKeeper &records, raw_ostream &os) {
112183ef862fSRiver Riddle                    TypeDefGenerator generator(records, os);
112283ef862fSRiver Riddle                    return generator.emitDecls(typeDialect);
112383ef862fSRiver Riddle                  });
1124a3d41879SMatthias Springer 
1125a3d41879SMatthias Springer static mlir::GenRegistration
1126a3d41879SMatthias Springer     genTypeConstrDefs("gen-type-constraint-defs",
1127a3d41879SMatthias Springer                       "Generate type constraint definitions",
1128*bccd37f6SRahul Joshi                       [](const RecordKeeper &records, raw_ostream &os) {
1129a3d41879SMatthias Springer                         emitTypeConstraintDefs(records, os);
1130a3d41879SMatthias Springer                         return false;
1131a3d41879SMatthias Springer                       });
1132a3d41879SMatthias Springer static mlir::GenRegistration
1133a3d41879SMatthias Springer     genTypeConstrDecls("gen-type-constraint-decls",
1134a3d41879SMatthias Springer                        "Generate type constraint declarations",
1135*bccd37f6SRahul Joshi                        [](const RecordKeeper &records, raw_ostream &os) {
1136a3d41879SMatthias Springer                          emitTypeConstraintDecls(records, os);
1137a3d41879SMatthias Springer                          return false;
1138a3d41879SMatthias Springer                        });
1139