1 //===- Pass.cpp - MLIR pass registration generator ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // PassGen uses the description of passes to generate base classes for passes 10 // and command line registration. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Pass.h" 16 #include "llvm/ADT/StringExtras.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 using llvm::formatv; 25 using llvm::RecordKeeper; 26 27 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); 28 static llvm::cl::opt<std::string> 29 groupName("name", llvm::cl::desc("The name of this group of passes"), 30 llvm::cl::cat(passGenCat)); 31 32 /// Extract the list of passes from the TableGen records. 33 static std::vector<Pass> getPasses(const RecordKeeper &records) { 34 std::vector<Pass> passes; 35 36 for (const auto *def : records.getAllDerivedDefinitions("PassBase")) 37 passes.emplace_back(def); 38 39 return passes; 40 } 41 42 const char *const passHeader = R"( 43 //===----------------------------------------------------------------------===// 44 // {0} 45 //===----------------------------------------------------------------------===// 46 )"; 47 48 //===----------------------------------------------------------------------===// 49 // GEN: Pass registration generation 50 //===----------------------------------------------------------------------===// 51 52 /// The code snippet used to generate a pass registration. 53 /// 54 /// {0}: The def name of the pass record. 55 /// {1}: The pass constructor call. 56 const char *const passRegistrationCode = R"( 57 //===----------------------------------------------------------------------===// 58 // {0} Registration 59 //===----------------------------------------------------------------------===// 60 61 inline void register{0}() {{ 62 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ 63 return {1}; 64 }); 65 } 66 67 // Old registration code, kept for temporary backwards compatibility. 68 inline void register{0}Pass() {{ 69 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ 70 return {1}; 71 }); 72 } 73 )"; 74 75 /// The code snippet used to generate a function to register all passes in a 76 /// group. 77 /// 78 /// {0}: The name of the pass group. 79 const char *const passGroupRegistrationCode = R"( 80 //===----------------------------------------------------------------------===// 81 // {0} Registration 82 //===----------------------------------------------------------------------===// 83 84 inline void register{0}Passes() {{ 85 )"; 86 87 /// Emits the definition of the struct to be used to control the pass options. 88 static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { 89 StringRef passName = pass.getDef()->getName(); 90 ArrayRef<PassOption> options = pass.getOptions(); 91 92 // Emit the struct only if the pass has at least one option. 93 if (options.empty()) 94 return; 95 96 os << formatv("struct {0}Options {{\n", passName); 97 98 for (const PassOption &opt : options) { 99 std::string type = opt.getType().str(); 100 101 if (opt.isListOption()) 102 type = "::llvm::SmallVector<" + type + ">"; 103 104 os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName()); 105 106 if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) 107 os << " = " << defaultVal; 108 109 os << ";\n"; 110 } 111 112 os << "};\n"; 113 } 114 115 static std::string getPassDeclVarName(const Pass &pass) { 116 return "GEN_PASS_DECL_" + pass.getDef()->getName().upper(); 117 } 118 119 /// Emit the code to be included in the public header of the pass. 120 static void emitPassDecls(const Pass &pass, raw_ostream &os) { 121 StringRef passName = pass.getDef()->getName(); 122 std::string enableVarName = getPassDeclVarName(pass); 123 124 os << "#ifdef " << enableVarName << "\n"; 125 emitPassOptionsStruct(pass, os); 126 127 if (StringRef constructor = pass.getConstructor(); constructor.empty()) { 128 // Default constructor declaration. 129 os << "std::unique_ptr<::mlir::Pass> create" << passName << "();\n"; 130 131 // Declaration of the constructor with options. 132 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) 133 os << formatv("std::unique_ptr<::mlir::Pass> create{0}(" 134 "{0}Options options);\n", 135 passName); 136 } 137 138 os << "#undef " << enableVarName << "\n"; 139 os << "#endif // " << enableVarName << "\n"; 140 } 141 142 /// Emit the code for registering each of the given passes with the global 143 /// PassRegistry. 144 static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) { 145 os << "#ifdef GEN_PASS_REGISTRATION\n"; 146 147 for (const Pass &pass : passes) { 148 std::string constructorCall; 149 if (StringRef constructor = pass.getConstructor(); !constructor.empty()) 150 constructorCall = constructor.str(); 151 else 152 constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); 153 154 os << formatv(passRegistrationCode, pass.getDef()->getName(), 155 constructorCall); 156 } 157 158 os << formatv(passGroupRegistrationCode, groupName); 159 160 for (const Pass &pass : passes) 161 os << " register" << pass.getDef()->getName() << "();\n"; 162 163 os << "}\n"; 164 os << "#undef GEN_PASS_REGISTRATION\n"; 165 os << "#endif // GEN_PASS_REGISTRATION\n"; 166 } 167 168 //===----------------------------------------------------------------------===// 169 // GEN: Pass base class generation 170 //===----------------------------------------------------------------------===// 171 172 /// The code snippet used to generate the start of a pass base class. 173 /// 174 /// {0}: The def name of the pass record. 175 /// {1}: The base class for the pass. 176 /// {2): The command line argument for the pass. 177 /// {3}: The summary for the pass. 178 /// {4}: The dependent dialects registration. 179 const char *const baseClassBegin = R"( 180 template <typename DerivedT> 181 class {0}Base : public {1} { 182 public: 183 using Base = {0}Base; 184 185 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} 186 {0}Base(const {0}Base &other) : {1}(other) {{} 187 {0}Base& operator=(const {0}Base &) = delete; 188 {0}Base({0}Base &&) = delete; 189 {0}Base& operator=({0}Base &&) = delete; 190 ~{0}Base() = default; 191 192 /// Returns the command-line argument attached to this pass. 193 static constexpr ::llvm::StringLiteral getArgumentName() { 194 return ::llvm::StringLiteral("{2}"); 195 } 196 ::llvm::StringRef getArgument() const override { return "{2}"; } 197 198 ::llvm::StringRef getDescription() const override { return "{3}"; } 199 200 /// Returns the derived pass name. 201 static constexpr ::llvm::StringLiteral getPassName() { 202 return ::llvm::StringLiteral("{0}"); 203 } 204 ::llvm::StringRef getName() const override { return "{0}"; } 205 206 /// Support isa/dyn_cast functionality for the derived pass class. 207 static bool classof(const ::mlir::Pass *pass) {{ 208 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); 209 } 210 211 /// A clone method to create a copy of this pass. 212 std::unique_ptr<::mlir::Pass> clonePass() const override {{ 213 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); 214 } 215 216 /// Return the dialect that must be loaded in the context before this pass. 217 void getDependentDialects(::mlir::DialectRegistry ®istry) const override { 218 {4} 219 } 220 221 /// Explicitly declare the TypeID for this class. We declare an explicit private 222 /// instantiation because Pass classes should only be visible by the current 223 /// library. 224 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) 225 226 )"; 227 228 /// Registration for a single dependent dialect, to be inserted for each 229 /// dependent dialect in the `getDependentDialects` above. 230 const char *const dialectRegistrationTemplate = "registry.insert<{0}>();"; 231 232 const char *const friendDefaultConstructorDeclTemplate = R"( 233 namespace impl {{ 234 std::unique_ptr<::mlir::Pass> create{0}(); 235 } // namespace impl 236 )"; 237 238 const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( 239 namespace impl {{ 240 std::unique_ptr<::mlir::Pass> create{0}({0}Options options); 241 } // namespace impl 242 )"; 243 244 const char *const friendDefaultConstructorDefTemplate = R"( 245 friend std::unique_ptr<::mlir::Pass> create{0}() {{ 246 return std::make_unique<DerivedT>(); 247 } 248 )"; 249 250 const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( 251 friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ 252 return std::make_unique<DerivedT>(std::move(options)); 253 } 254 )"; 255 256 const char *const defaultConstructorDefTemplate = R"( 257 std::unique_ptr<::mlir::Pass> create{0}() {{ 258 return impl::create{0}(); 259 } 260 )"; 261 262 const char *const defaultConstructorWithOptionsDefTemplate = R"( 263 std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ 264 return impl::create{0}(std::move(options)); 265 } 266 )"; 267 268 /// Emit the declarations for each of the pass options. 269 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { 270 for (const PassOption &opt : pass.getOptions()) { 271 os.indent(2) << "::mlir::Pass::" 272 << (opt.isListOption() ? "ListOption" : "Option"); 273 274 os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))", 275 opt.getType(), opt.getCppVariableName(), opt.getArgument(), 276 opt.getDescription()); 277 if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) 278 os << ", ::llvm::cl::init(" << defaultVal << ")"; 279 if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags()) 280 os << ", " << *additionalFlags; 281 os << "};\n"; 282 } 283 } 284 285 /// Emit the declarations for each of the pass statistics. 286 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { 287 for (const PassStatistic &stat : pass.getStatistics()) { 288 os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", 289 stat.getCppVariableName(), stat.getName(), 290 stat.getDescription()); 291 } 292 } 293 294 /// Emit the code to be used in the implementation of the pass. 295 static void emitPassDefs(const Pass &pass, raw_ostream &os) { 296 StringRef passName = pass.getDef()->getName(); 297 std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); 298 bool emitDefaultConstructors = pass.getConstructor().empty(); 299 bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); 300 301 os << "#ifdef " << enableVarName << "\n"; 302 303 if (emitDefaultConstructors) { 304 os << formatv(friendDefaultConstructorDeclTemplate, passName); 305 306 if (emitDefaultConstructorWithOptions) 307 os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName); 308 } 309 310 std::string dependentDialectRegistrations; 311 { 312 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); 313 llvm::interleave( 314 pass.getDependentDialects(), dialectsOs, 315 [&](StringRef dependentDialect) { 316 dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); 317 }, 318 "\n "); 319 } 320 321 os << "namespace impl {\n"; 322 os << formatv(baseClassBegin, passName, pass.getBaseClass(), 323 pass.getArgument(), pass.getSummary(), 324 dependentDialectRegistrations); 325 326 if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) { 327 os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n", 328 passName); 329 330 for (const PassOption &opt : pass.getOptions()) 331 os.indent(4) << formatv("{0} = std::move(options.{0});\n", 332 opt.getCppVariableName()); 333 334 os.indent(2) << "}\n"; 335 } 336 337 // Protected content 338 os << "protected:\n"; 339 emitPassOptionDecls(pass, os); 340 emitPassStatisticDecls(pass, os); 341 342 // Private content 343 os << "private:\n"; 344 345 if (emitDefaultConstructors) { 346 os << formatv(friendDefaultConstructorDefTemplate, passName); 347 348 if (!pass.getOptions().empty()) 349 os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName); 350 } 351 352 os << "};\n"; 353 os << "} // namespace impl\n"; 354 355 if (emitDefaultConstructors) { 356 os << formatv(defaultConstructorDefTemplate, passName); 357 358 if (emitDefaultConstructorWithOptions) 359 os << formatv(defaultConstructorWithOptionsDefTemplate, passName); 360 } 361 362 os << "#undef " << enableVarName << "\n"; 363 os << "#endif // " << enableVarName << "\n"; 364 } 365 366 static void emitPass(const Pass &pass, raw_ostream &os) { 367 StringRef passName = pass.getDef()->getName(); 368 os << formatv(passHeader, passName); 369 370 emitPassDecls(pass, os); 371 emitPassDefs(pass, os); 372 } 373 374 // TODO: Drop old pass declarations. 375 // The old pass base class is being kept until all the passes have switched to 376 // the new decls/defs design. 377 const char *const oldPassDeclBegin = R"( 378 template <typename DerivedT> 379 class {0}Base : public {1} { 380 public: 381 using Base = {0}Base; 382 383 {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{} 384 {0}Base(const {0}Base &other) : {1}(other) {{} 385 {0}Base& operator=(const {0}Base &) = delete; 386 {0}Base({0}Base &&) = delete; 387 {0}Base& operator=({0}Base &&) = delete; 388 ~{0}Base() = default; 389 390 /// Returns the command-line argument attached to this pass. 391 static constexpr ::llvm::StringLiteral getArgumentName() { 392 return ::llvm::StringLiteral("{2}"); 393 } 394 ::llvm::StringRef getArgument() const override { return "{2}"; } 395 396 ::llvm::StringRef getDescription() const override { return "{3}"; } 397 398 /// Returns the derived pass name. 399 static constexpr ::llvm::StringLiteral getPassName() { 400 return ::llvm::StringLiteral("{0}"); 401 } 402 ::llvm::StringRef getName() const override { return "{0}"; } 403 404 /// Support isa/dyn_cast functionality for the derived pass class. 405 static bool classof(const ::mlir::Pass *pass) {{ 406 return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>(); 407 } 408 409 /// A clone method to create a copy of this pass. 410 std::unique_ptr<::mlir::Pass> clonePass() const override {{ 411 return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this)); 412 } 413 414 /// Register the dialects that must be loaded in the context before this pass. 415 void getDependentDialects(::mlir::DialectRegistry ®istry) const override { 416 {4} 417 } 418 419 /// Explicitly declare the TypeID for this class. We declare an explicit private 420 /// instantiation because Pass classes should only be visible by the current 421 /// library. 422 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base<DerivedT>) 423 424 protected: 425 )"; 426 427 // TODO: Drop old pass declarations. 428 /// Emit a backward-compatible declaration of the pass base class. 429 static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { 430 StringRef defName = pass.getDef()->getName(); 431 std::string dependentDialectRegistrations; 432 { 433 llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); 434 llvm::interleave( 435 pass.getDependentDialects(), dialectsOs, 436 [&](StringRef dependentDialect) { 437 dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect); 438 }, 439 "\n "); 440 } 441 os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(), 442 pass.getArgument(), pass.getSummary(), 443 dependentDialectRegistrations); 444 emitPassOptionDecls(pass, os); 445 emitPassStatisticDecls(pass, os); 446 os << "};\n"; 447 } 448 449 static void emitPasses(const RecordKeeper &records, raw_ostream &os) { 450 std::vector<Pass> passes = getPasses(records); 451 os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; 452 453 os << "\n"; 454 os << "#ifdef GEN_PASS_DECL\n"; 455 os << "// Generate declarations for all passes.\n"; 456 for (const Pass &pass : passes) 457 os << "#define " << getPassDeclVarName(pass) << "\n"; 458 os << "#undef GEN_PASS_DECL\n"; 459 os << "#endif // GEN_PASS_DECL\n"; 460 461 for (const Pass &pass : passes) 462 emitPass(pass, os); 463 464 emitRegistrations(passes, os); 465 466 // TODO: Drop old pass declarations. 467 // Emit the old code until all the passes have switched to the new design. 468 os << "// Deprecated. Please use the new per-pass macros.\n"; 469 os << "#ifdef GEN_PASS_CLASSES\n"; 470 for (const Pass &pass : passes) 471 emitOldPassDecl(pass, os); 472 os << "#undef GEN_PASS_CLASSES\n"; 473 os << "#endif // GEN_PASS_CLASSES\n"; 474 } 475 476 static mlir::GenRegistration 477 genPassDecls("gen-pass-decls", "Generate pass declarations", 478 [](const RecordKeeper &records, raw_ostream &os) { 479 emitPasses(records, os); 480 return false; 481 }); 482