xref: /llvm-project/mlir/tools/mlir-tblgen/PassGen.cpp (revision e813750354bbc08551cf23ff559a54b4a9ea1f29)
1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PassGen uses the description of passes to generate base classes for passes
10 // and command line registration.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 using llvm::formatv;
25 using llvm::RecordKeeper;
26 
27 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
28 static llvm::cl::opt<std::string>
29     groupName("name", llvm::cl::desc("The name of this group of passes"),
30               llvm::cl::cat(passGenCat));
31 
32 /// Extract the list of passes from the TableGen records.
33 static std::vector<Pass> getPasses(const RecordKeeper &records) {
34   std::vector<Pass> passes;
35 
36   for (const auto *def : records.getAllDerivedDefinitions("PassBase"))
37     passes.emplace_back(def);
38 
39   return passes;
40 }
41 
42 const char *const passHeader = R"(
43 //===----------------------------------------------------------------------===//
44 // {0}
45 //===----------------------------------------------------------------------===//
46 )";
47 
48 //===----------------------------------------------------------------------===//
49 // GEN: Pass registration generation
50 //===----------------------------------------------------------------------===//
51 
52 /// The code snippet used to generate a pass registration.
53 ///
54 /// {0}: The def name of the pass record.
55 /// {1}: The pass constructor call.
56 const char *const passRegistrationCode = R"(
57 //===----------------------------------------------------------------------===//
58 // {0} Registration
59 //===----------------------------------------------------------------------===//
60 
61 inline void register{0}() {{
62   ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
63     return {1};
64   });
65 }
66 
67 // Old registration code, kept for temporary backwards compatibility.
68 inline void register{0}Pass() {{
69   ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{
70     return {1};
71   });
72 }
73 )";
74 
75 /// The code snippet used to generate a function to register all passes in a
76 /// group.
77 ///
78 /// {0}: The name of the pass group.
79 const char *const passGroupRegistrationCode = R"(
80 //===----------------------------------------------------------------------===//
81 // {0} Registration
82 //===----------------------------------------------------------------------===//
83 
84 inline void register{0}Passes() {{
85 )";
86 
87 /// Emits the definition of the struct to be used to control the pass options.
88 static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
89   StringRef passName = pass.getDef()->getName();
90   ArrayRef<PassOption> options = pass.getOptions();
91 
92   // Emit the struct only if the pass has at least one option.
93   if (options.empty())
94     return;
95 
96   os << formatv("struct {0}Options {{\n", passName);
97 
98   for (const PassOption &opt : options) {
99     std::string type = opt.getType().str();
100 
101     if (opt.isListOption())
102       type = "::llvm::SmallVector<" + type + ">";
103 
104     os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
105 
106     if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
107       os << " = " << defaultVal;
108 
109     os << ";\n";
110   }
111 
112   os << "};\n";
113 }
114 
115 static std::string getPassDeclVarName(const Pass &pass) {
116   return "GEN_PASS_DECL_" + pass.getDef()->getName().upper();
117 }
118 
119 /// Emit the code to be included in the public header of the pass.
120 static void emitPassDecls(const Pass &pass, raw_ostream &os) {
121   StringRef passName = pass.getDef()->getName();
122   std::string enableVarName = getPassDeclVarName(pass);
123 
124   os << "#ifdef " << enableVarName << "\n";
125   emitPassOptionsStruct(pass, os);
126 
127   if (StringRef constructor = pass.getConstructor(); constructor.empty()) {
128     // Default constructor declaration.
129     os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n";
130 
131     // Declaration of the constructor with options.
132     if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
133       os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
134                     "{0}Options options);\n",
135                     passName);
136   }
137 
138   os << "#undef " << enableVarName << "\n";
139   os << "#endif // " << enableVarName << "\n";
140 }
141 
142 /// Emit the code for registering each of the given passes with the global
143 /// PassRegistry.
144 static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
145   os << "#ifdef GEN_PASS_REGISTRATION\n";
146 
147   for (const Pass &pass : passes) {
148     std::string constructorCall;
149     if (StringRef constructor = pass.getConstructor(); !constructor.empty())
150       constructorCall = constructor.str();
151     else
152       constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
153 
154     os << formatv(passRegistrationCode, pass.getDef()->getName(),
155                   constructorCall);
156   }
157 
158   os << formatv(passGroupRegistrationCode, groupName);
159 
160   for (const Pass &pass : passes)
161     os << "  register" << pass.getDef()->getName() << "();\n";
162 
163   os << "}\n";
164   os << "#undef GEN_PASS_REGISTRATION\n";
165   os << "#endif // GEN_PASS_REGISTRATION\n";
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // GEN: Pass base class generation
170 //===----------------------------------------------------------------------===//
171 
172 /// The code snippet used to generate the start of a pass base class.
173 ///
174 /// {0}: The def name of the pass record.
175 /// {1}: The base class for the pass.
176 /// {2): The command line argument for the pass.
177 /// {3}: The summary for the pass.
178 /// {4}: The dependent dialects registration.
179 const char *const baseClassBegin = R"(
180 template <typename DerivedT>
181 class {0}Base : public {1} {
182 public:
183   using Base = {0}Base;
184 
185   {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
186   {0}Base(const {0}Base &other) : {1}(other) {{}
187   {0}Base& operator=(const {0}Base &) = delete;
188   {0}Base({0}Base &&) = delete;
189   {0}Base& operator=({0}Base &&) = delete;
190   ~{0}Base() = default;
191 
192   /// Returns the command-line argument attached to this pass.
193   static constexpr ::llvm::StringLiteral getArgumentName() {
194     return ::llvm::StringLiteral("{2}");
195   }
196   ::llvm::StringRef getArgument() const override { return "{2}"; }
197 
198   ::llvm::StringRef getDescription() const override { return "{3}"; }
199 
200   /// Returns the derived pass name.
201   static constexpr ::llvm::StringLiteral getPassName() {
202     return ::llvm::StringLiteral("{0}");
203   }
204   ::llvm::StringRef getName() const override { return "{0}"; }
205 
206   /// Support isa/dyn_cast functionality for the derived pass class.
207   static bool classof(const ::mlir::Pass *pass) {{
208     return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
209   }
210 
211   /// A clone method to create a copy of this pass.
212   std::unique_ptr<::mlir::Pass> clonePass() const override {{
213     return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
214   }
215 
216   /// Return the dialect that must be loaded in the context before this pass.
217   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
218     {4}
219   }
220 
221   /// Explicitly declare the TypeID for this class. We declare an explicit private
222   /// instantiation because Pass classes should only be visible by the current
223   /// library.
224   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
225 
226 )";
227 
228 /// Registration for a single dependent dialect, to be inserted for each
229 /// dependent dialect in the `getDependentDialects` above.
230 const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
231 
232 const char *const friendDefaultConstructorDeclTemplate = R"(
233 namespace impl {{
234   std::unique_ptr<::mlir::Pass> create{0}();
235 } // namespace impl
236 )";
237 
238 const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"(
239 namespace impl {{
240   std::unique_ptr<::mlir::Pass> create{0}({0}Options options);
241 } // namespace impl
242 )";
243 
244 const char *const friendDefaultConstructorDefTemplate = R"(
245   friend std::unique_ptr<::mlir::Pass> create{0}() {{
246     return std::make_unique<DerivedT>();
247   }
248 )";
249 
250 const char *const friendDefaultConstructorWithOptionsDefTemplate = R"(
251   friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
252     return std::make_unique<DerivedT>(std::move(options));
253   }
254 )";
255 
256 const char *const defaultConstructorDefTemplate = R"(
257 std::unique_ptr<::mlir::Pass> create{0}() {{
258   return impl::create{0}();
259 }
260 )";
261 
262 const char *const defaultConstructorWithOptionsDefTemplate = R"(
263 std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{
264   return impl::create{0}(std::move(options));
265 }
266 )";
267 
268 /// Emit the declarations for each of the pass options.
269 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
270   for (const PassOption &opt : pass.getOptions()) {
271     os.indent(2) << "::mlir::Pass::"
272                  << (opt.isListOption() ? "ListOption" : "Option");
273 
274     os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
275                   opt.getType(), opt.getCppVariableName(), opt.getArgument(),
276                   opt.getDescription());
277     if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
278       os << ", ::llvm::cl::init(" << defaultVal << ")";
279     if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
280       os << ", " << *additionalFlags;
281     os << "};\n";
282   }
283 }
284 
285 /// Emit the declarations for each of the pass statistics.
286 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
287   for (const PassStatistic &stat : pass.getStatistics()) {
288     os << formatv("  ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
289                   stat.getCppVariableName(), stat.getName(),
290                   stat.getDescription());
291   }
292 }
293 
294 /// Emit the code to be used in the implementation of the pass.
295 static void emitPassDefs(const Pass &pass, raw_ostream &os) {
296   StringRef passName = pass.getDef()->getName();
297   std::string enableVarName = "GEN_PASS_DEF_" + passName.upper();
298   bool emitDefaultConstructors = pass.getConstructor().empty();
299   bool emitDefaultConstructorWithOptions = !pass.getOptions().empty();
300 
301   os << "#ifdef " << enableVarName << "\n";
302 
303   if (emitDefaultConstructors) {
304     os << formatv(friendDefaultConstructorDeclTemplate, passName);
305 
306     if (emitDefaultConstructorWithOptions)
307       os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName);
308   }
309 
310   std::string dependentDialectRegistrations;
311   {
312     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
313     llvm::interleave(
314         pass.getDependentDialects(), dialectsOs,
315         [&](StringRef dependentDialect) {
316           dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
317         },
318         "\n    ");
319   }
320 
321   os << "namespace impl {\n";
322   os << formatv(baseClassBegin, passName, pass.getBaseClass(),
323                 pass.getArgument(), pass.getSummary(),
324                 dependentDialectRegistrations);
325 
326   if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
327     os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
328                             passName);
329 
330     for (const PassOption &opt : pass.getOptions())
331       os.indent(4) << formatv("{0} = std::move(options.{0});\n",
332                               opt.getCppVariableName());
333 
334     os.indent(2) << "}\n";
335   }
336 
337   // Protected content
338   os << "protected:\n";
339   emitPassOptionDecls(pass, os);
340   emitPassStatisticDecls(pass, os);
341 
342   // Private content
343   os << "private:\n";
344 
345   if (emitDefaultConstructors) {
346     os << formatv(friendDefaultConstructorDefTemplate, passName);
347 
348     if (!pass.getOptions().empty())
349       os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName);
350   }
351 
352   os << "};\n";
353   os << "} // namespace impl\n";
354 
355   if (emitDefaultConstructors) {
356     os << formatv(defaultConstructorDefTemplate, passName);
357 
358     if (emitDefaultConstructorWithOptions)
359       os << formatv(defaultConstructorWithOptionsDefTemplate, passName);
360   }
361 
362   os << "#undef " << enableVarName << "\n";
363   os << "#endif // " << enableVarName << "\n";
364 }
365 
366 static void emitPass(const Pass &pass, raw_ostream &os) {
367   StringRef passName = pass.getDef()->getName();
368   os << formatv(passHeader, passName);
369 
370   emitPassDecls(pass, os);
371   emitPassDefs(pass, os);
372 }
373 
374 // TODO: Drop old pass declarations.
375 // The old pass base class is being kept until all the passes have switched to
376 // the new decls/defs design.
377 const char *const oldPassDeclBegin = R"(
378 template <typename DerivedT>
379 class {0}Base : public {1} {
380 public:
381   using Base = {0}Base;
382 
383   {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
384   {0}Base(const {0}Base &other) : {1}(other) {{}
385   {0}Base& operator=(const {0}Base &) = delete;
386   {0}Base({0}Base &&) = delete;
387   {0}Base& operator=({0}Base &&) = delete;
388   ~{0}Base() = default;
389 
390   /// Returns the command-line argument attached to this pass.
391   static constexpr ::llvm::StringLiteral getArgumentName() {
392     return ::llvm::StringLiteral("{2}");
393   }
394   ::llvm::StringRef getArgument() const override { return "{2}"; }
395 
396   ::llvm::StringRef getDescription() const override { return "{3}"; }
397 
398   /// Returns the derived pass name.
399   static constexpr ::llvm::StringLiteral getPassName() {
400     return ::llvm::StringLiteral("{0}");
401   }
402   ::llvm::StringRef getName() const override { return "{0}"; }
403 
404   /// Support isa/dyn_cast functionality for the derived pass class.
405   static bool classof(const ::mlir::Pass *pass) {{
406     return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
407   }
408 
409   /// A clone method to create a copy of this pass.
410   std::unique_ptr<::mlir::Pass> clonePass() const override {{
411     return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
412   }
413 
414   /// Register the dialects that must be loaded in the context before this pass.
415   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
416     {4}
417   }
418 
419   /// Explicitly declare the TypeID for this class. We declare an explicit private
420   /// instantiation because Pass classes should only be visible by the current
421   /// library.
422   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>)
423 
424 protected:
425 )";
426 
427 // TODO: Drop old pass declarations.
428 /// Emit a backward-compatible declaration of the pass base class.
429 static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
430   StringRef defName = pass.getDef()->getName();
431   std::string dependentDialectRegistrations;
432   {
433     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
434     llvm::interleave(
435         pass.getDependentDialects(), dialectsOs,
436         [&](StringRef dependentDialect) {
437           dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
438         },
439         "\n    ");
440   }
441   os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
442                 pass.getArgument(), pass.getSummary(),
443                 dependentDialectRegistrations);
444   emitPassOptionDecls(pass, os);
445   emitPassStatisticDecls(pass, os);
446   os << "};\n";
447 }
448 
449 static void emitPasses(const RecordKeeper &records, raw_ostream &os) {
450   std::vector<Pass> passes = getPasses(records);
451   os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
452 
453   os << "\n";
454   os << "#ifdef GEN_PASS_DECL\n";
455   os << "// Generate declarations for all passes.\n";
456   for (const Pass &pass : passes)
457     os << "#define " << getPassDeclVarName(pass) << "\n";
458   os << "#undef GEN_PASS_DECL\n";
459   os << "#endif // GEN_PASS_DECL\n";
460 
461   for (const Pass &pass : passes)
462     emitPass(pass, os);
463 
464   emitRegistrations(passes, os);
465 
466   // TODO: Drop old pass declarations.
467   // Emit the old code until all the passes have switched to the new design.
468   os << "// Deprecated. Please use the new per-pass macros.\n";
469   os << "#ifdef GEN_PASS_CLASSES\n";
470   for (const Pass &pass : passes)
471     emitOldPassDecl(pass, os);
472   os << "#undef GEN_PASS_CLASSES\n";
473   os << "#endif // GEN_PASS_CLASSES\n";
474 }
475 
476 static mlir::GenRegistration
477     genPassDecls("gen-pass-decls", "Generate pass declarations",
478                  [](const RecordKeeper &records, raw_ostream &os) {
479                    emitPasses(records, os);
480                    return false;
481                  });
482