xref: /llvm-project/mlir/tools/mlir-tblgen/DialectGen.cpp (revision e768b076e3b7ed38485a29244a0b989076e4b131)
1 //===- DialectGen.cpp - MLIR dialect definitions generator ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DialectGen uses the description of dialects to generate C++ definitions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "DialectGenUtilities.h"
14 #include "mlir/TableGen/Class.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Trait.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 #include "llvm/TableGen/TableGenBackend.h"
28 
29 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
30 
31 using namespace mlir;
32 using namespace mlir::tblgen;
33 using llvm::Record;
34 using llvm::RecordKeeper;
35 
36 static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
37 llvm::cl::opt<std::string>
38     selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
39                     llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
40 
41 /// Utility iterator used for filtering records for a specific dialect.
42 namespace {
43 using DialectFilterIterator =
44     llvm::filter_iterator<ArrayRef<Record *>::iterator,
45                           std::function<bool(const Record *)>>;
46 } // namespace
47 
48 static void populateDiscardableAttributes(
49     Dialect &dialect, const llvm::DagInit *discardableAttrDag,
50     SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
51   for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
52     const llvm::Init *arg = discardableAttrDag->getArg(i);
53 
54     StringRef givenName = discardableAttrDag->getArgNameStr(i);
55     if (givenName.empty())
56       PrintFatalError(dialect.getDef()->getLoc(),
57                       "discardable attributes must be named");
58     discardableAttributes.push_back(
59         {givenName.str(), arg->getAsUnquotedString()});
60   }
61 }
62 
63 /// Given a set of records for a T, filter the ones that correspond to
64 /// the given dialect.
65 template <typename T>
66 static iterator_range<DialectFilterIterator>
67 filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
68   auto filterFn = [&](const Record *record) {
69     return T(record).getDialect() == dialect;
70   };
71   return {DialectFilterIterator(records.begin(), records.end(), filterFn),
72           DialectFilterIterator(records.end(), records.end(), filterFn)};
73 }
74 
75 std::optional<Dialect>
76 tblgen::findDialectToGenerate(ArrayRef<Dialect> dialects) {
77   if (dialects.empty()) {
78     llvm::errs() << "no dialect was found\n";
79     return std::nullopt;
80   }
81 
82   // Select the dialect to gen for.
83   if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0)
84     return dialects.front();
85 
86   if (selectedDialect.getNumOccurrences() == 0) {
87     llvm::errs() << "when more than 1 dialect is present, one must be selected "
88                     "via '-dialect'\n";
89     return std::nullopt;
90   }
91 
92   const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) {
93     return dialect.getName() == selectedDialect;
94   });
95   if (dialectIt == dialects.end()) {
96     llvm::errs() << "selected dialect with '-dialect' does not exist\n";
97     return std::nullopt;
98   }
99   return *dialectIt;
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // GEN: Dialect declarations
104 //===----------------------------------------------------------------------===//
105 
106 /// The code block for the start of a dialect class declaration.
107 ///
108 /// {0}: The name of the dialect class.
109 /// {1}: The dialect namespace.
110 /// {2}: The dialect parent class.
111 static const char *const dialectDeclBeginStr = R"(
112 class {0} : public ::mlir::{2} {
113   explicit {0}(::mlir::MLIRContext *context);
114 
115   void initialize();
116   friend class ::mlir::MLIRContext;
117 public:
118   ~{0}() override;
119   static constexpr ::llvm::StringLiteral getDialectNamespace() {
120     return ::llvm::StringLiteral("{1}");
121   }
122 )";
123 
124 /// Registration for a single dependent dialect: to be inserted in the ctor
125 /// above for each dependent dialect.
126 const char *const dialectRegistrationTemplate =
127     "getContext()->loadDialect<{0}>();";
128 
129 /// The code block for the attribute parser/printer hooks.
130 static const char *const attrParserDecl = R"(
131   /// Parse an attribute registered to this dialect.
132   ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser,
133                                    ::mlir::Type type) const override;
134 
135   /// Print an attribute registered to this dialect.
136   void printAttribute(::mlir::Attribute attr,
137                       ::mlir::DialectAsmPrinter &os) const override;
138 )";
139 
140 /// The code block for the type parser/printer hooks.
141 static const char *const typeParserDecl = R"(
142   /// Parse a type registered to this dialect.
143   ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
144 
145   /// Print a type registered to this dialect.
146   void printType(::mlir::Type type,
147                  ::mlir::DialectAsmPrinter &os) const override;
148 )";
149 
150 /// The code block for the canonicalization pattern registration hook.
151 static const char *const canonicalizerDecl = R"(
152   /// Register canonicalization patterns.
153   void getCanonicalizationPatterns(
154       ::mlir::RewritePatternSet &results) const override;
155 )";
156 
157 /// The code block for the constant materializer hook.
158 static const char *const constantMaterializerDecl = R"(
159   /// Materialize a single constant operation from a given attribute value with
160   /// the desired resultant type.
161   ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
162                                          ::mlir::Attribute value,
163                                          ::mlir::Type type,
164                                          ::mlir::Location loc) override;
165 )";
166 
167 /// The code block for the operation attribute verifier hook.
168 static const char *const opAttrVerifierDecl = R"(
169     /// Provides a hook for verifying dialect attributes attached to the given
170     /// op.
171     ::llvm::LogicalResult verifyOperationAttribute(
172         ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override;
173 )";
174 
175 /// The code block for the region argument attribute verifier hook.
176 static const char *const regionArgAttrVerifierDecl = R"(
177     /// Provides a hook for verifying dialect attributes attached to the given
178     /// op's region argument.
179     ::llvm::LogicalResult verifyRegionArgAttribute(
180         ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex,
181         ::mlir::NamedAttribute attribute) override;
182 )";
183 
184 /// The code block for the region result attribute verifier hook.
185 static const char *const regionResultAttrVerifierDecl = R"(
186     /// Provides a hook for verifying dialect attributes attached to the given
187     /// op's region result.
188     ::llvm::LogicalResult verifyRegionResultAttribute(
189         ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex,
190         ::mlir::NamedAttribute attribute) override;
191 )";
192 
193 /// The code block for the op interface fallback hook.
194 static const char *const operationInterfaceFallbackDecl = R"(
195     /// Provides a hook for op interface.
196     void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID,
197                                       mlir::OperationName opName) override;
198 )";
199 
200 /// The code block for the discardable attribute helper.
201 static const char *const discardableAttrHelperDecl = R"(
202     /// Helper to manage the discardable attribute `{1}`.
203     class {0}AttrHelper {{
204       ::mlir::StringAttr name;
205     public:
206       static constexpr ::llvm::StringLiteral getNameStr() {{
207         return "{4}.{1}";
208       }
209       constexpr ::mlir::StringAttr getName() {{
210         return name;
211       }
212 
213       {0}AttrHelper(::mlir::MLIRContext *ctx)
214         : name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
215 
216      {2} getAttr(::mlir::Operation *op) {{
217        return op->getAttrOfType<{2}>(name);
218      }
219      void setAttr(::mlir::Operation *op, {2} val) {{
220        op->setAttr(name, val);
221      }
222      bool isAttrPresent(::mlir::Operation *op) {{
223        return op->hasAttrOfType<{2}>(name);
224      }
225      void removeAttr(::mlir::Operation *op) {{
226        assert(op->hasAttrOfType<{2}>(name));
227        op->removeAttr(name);
228      }
229    };
230    {0}AttrHelper get{0}AttrHelper() {
231      return {3}AttrName;
232    }
233  private:
234    {0}AttrHelper {3}AttrName;
235  public:
236 )";
237 
238 /// Generate the declaration for the given dialect class.
239 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
240   // Emit all nested namespaces.
241   {
242     NamespaceEmitter nsEmitter(os, dialect);
243 
244     // Emit the start of the decl.
245     std::string cppName = dialect.getCppClassName();
246     StringRef superClassName =
247         dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
248     os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
249                         superClassName);
250 
251     // If the dialect requested the default attribute printer and parser, emit
252     // the declarations for the hooks.
253     if (dialect.useDefaultAttributePrinterParser())
254       os << attrParserDecl;
255     // If the dialect requested the default type printer and parser, emit the
256     // delcarations for the hooks.
257     if (dialect.useDefaultTypePrinterParser())
258       os << typeParserDecl;
259 
260     // Add the decls for the various features of the dialect.
261     if (dialect.hasCanonicalizer())
262       os << canonicalizerDecl;
263     if (dialect.hasConstantMaterializer())
264       os << constantMaterializerDecl;
265     if (dialect.hasOperationAttrVerify())
266       os << opAttrVerifierDecl;
267     if (dialect.hasRegionArgAttrVerify())
268       os << regionArgAttrVerifierDecl;
269     if (dialect.hasRegionResultAttrVerify())
270       os << regionResultAttrVerifierDecl;
271     if (dialect.hasOperationInterfaceFallback())
272       os << operationInterfaceFallbackDecl;
273 
274     const llvm::DagInit *discardableAttrDag =
275         dialect.getDiscardableAttributes();
276     SmallVector<std::pair<std::string, std::string>> discardableAttributes;
277     populateDiscardableAttributes(dialect, discardableAttrDag,
278                                   discardableAttributes);
279 
280     for (const auto &attrPair : discardableAttributes) {
281       std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
282           attrPair.first, /*capitalizeFirst=*/true);
283       std::string camelName = llvm::convertToCamelFromSnakeCase(
284           attrPair.first, /*capitalizeFirst=*/false);
285       os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
286                           attrPair.first, attrPair.second, camelName,
287                           dialect.getName());
288     }
289 
290     if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
291       os << *extraDecl;
292 
293     // End the dialect decl.
294     os << "};\n";
295   }
296   if (!dialect.getCppNamespace().empty())
297     os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
298        << "::" << dialect.getCppClassName() << ")\n";
299 }
300 
301 static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) {
302   emitSourceFileHeader("Dialect Declarations", os, records);
303 
304   auto dialectDefs = records.getAllDerivedDefinitions("Dialect");
305   if (dialectDefs.empty())
306     return false;
307 
308   SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
309   std::optional<Dialect> dialect = findDialectToGenerate(dialects);
310   if (!dialect)
311     return true;
312   emitDialectDecl(*dialect, os);
313   return false;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // GEN: Dialect definitions
318 //===----------------------------------------------------------------------===//
319 
320 /// The code block to generate a dialect constructor definition.
321 ///
322 /// {0}: The name of the dialect class.
323 /// {1}: Initialization code that is emitted in the ctor body before calling
324 ///      initialize(), such as dependent dialect registration.
325 /// {2}: The dialect parent class.
326 /// {3}: Extra members to initialize
327 static const char *const dialectConstructorStr = R"(
328 {0}::{0}(::mlir::MLIRContext *context)
329     : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
330     {3}
331      {{
332   {1}
333   initialize();
334 }
335 )";
336 
337 /// The code block to generate a default destructor definition.
338 ///
339 /// {0}: The name of the dialect class.
340 static const char *const dialectDestructorStr = R"(
341 {0}::~{0}() = default;
342 
343 )";
344 
345 static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
346                            raw_ostream &os) {
347   std::string cppClassName = dialect.getCppClassName();
348 
349   // Emit the TypeID explicit specializations to have a single symbol def.
350   if (!dialect.getCppNamespace().empty())
351     os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace()
352        << "::" << cppClassName << ")\n";
353 
354   // Emit all nested namespaces.
355   NamespaceEmitter nsEmitter(os, dialect);
356 
357   /// Build the list of dependent dialects.
358   std::string dependentDialectRegistrations;
359   {
360     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
361     llvm::interleave(
362         dialect.getDependentDialects(), dialectsOs,
363         [&](StringRef dependentDialect) {
364           dialectsOs << llvm::formatv(dialectRegistrationTemplate,
365                                       dependentDialect);
366         },
367         "\n  ");
368   }
369 
370   // Emit the constructor and destructor.
371   StringRef superClassName =
372       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
373 
374   const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
375   SmallVector<std::pair<std::string, std::string>> discardableAttributes;
376   populateDiscardableAttributes(dialect, discardableAttrDag,
377                                 discardableAttributes);
378   std::string discardableAttributesInit;
379   for (const auto &attrPair : discardableAttributes) {
380     std::string camelName = llvm::convertToCamelFromSnakeCase(
381         attrPair.first, /*capitalizeFirst=*/false);
382     llvm::raw_string_ostream os(discardableAttributesInit);
383     os << ", " << camelName << "AttrName(context)";
384   }
385 
386   os << llvm::formatv(dialectConstructorStr, cppClassName,
387                       dependentDialectRegistrations, superClassName,
388                       discardableAttributesInit);
389   if (!dialect.hasNonDefaultDestructor())
390     os << llvm::formatv(dialectDestructorStr, cppClassName);
391 }
392 
393 static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) {
394   emitSourceFileHeader("Dialect Definitions", os, records);
395 
396   auto dialectDefs = records.getAllDerivedDefinitions("Dialect");
397   if (dialectDefs.empty())
398     return false;
399 
400   SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
401   std::optional<Dialect> dialect = findDialectToGenerate(dialects);
402   if (!dialect)
403     return true;
404   emitDialectDef(*dialect, records, os);
405   return false;
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // GEN: Dialect registration hooks
410 //===----------------------------------------------------------------------===//
411 
412 static mlir::GenRegistration
413     genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
414                     [](const RecordKeeper &records, raw_ostream &os) {
415                       return emitDialectDecls(records, os);
416                     });
417 
418 static mlir::GenRegistration
419     genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
420                    [](const RecordKeeper &records, raw_ostream &os) {
421                      return emitDialectDefs(records, os);
422                    });
423