1 //===- Pass.cpp - MLIR pass registration generator ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // PassCAPIGen uses the description of passes to generate C API for the passes. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/GenInfo.h" 14 #include "mlir/TableGen/Pass.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/Support/CommandLine.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/Record.h" 20 21 using namespace mlir; 22 using namespace mlir::tblgen; 23 using llvm::formatv; 24 using llvm::RecordKeeper; 25 26 static llvm::cl::OptionCategory 27 passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl"); 28 static llvm::cl::opt<std::string> 29 groupName("prefix", 30 llvm::cl::desc("The prefix to use for this group of passes. The " 31 "form will be mlirCreate<prefix><passname>, the " 32 "prefix can avoid conflicts across libraries."), 33 llvm::cl::cat(passGenCat)); 34 35 const char *const passDecl = R"( 36 /* Create {0} Pass. */ 37 MLIR_CAPI_EXPORTED MlirPass mlirCreate{0}{1}(void); 38 MLIR_CAPI_EXPORTED void mlirRegister{0}{1}(void); 39 40 )"; 41 42 const char *const fileHeader = R"( 43 /* Autogenerated by mlir-tblgen; don't manually edit. */ 44 45 #include "mlir-c/Pass.h" 46 47 #ifdef __cplusplus 48 extern "C" { 49 #endif 50 51 )"; 52 53 const char *const fileFooter = R"( 54 55 #ifdef __cplusplus 56 } 57 #endif 58 )"; 59 60 /// Emit TODO 61 static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) { 62 os << fileHeader; 63 os << "// Registration for the entire group\n"; 64 os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName 65 << "Passes(void);\n\n"; 66 for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { 67 Pass pass(def); 68 StringRef defName = pass.getDef()->getName(); 69 os << formatv(passDecl, groupName, defName); 70 } 71 os << fileFooter; 72 return false; 73 } 74 75 const char *const passCreateDef = R"( 76 MlirPass mlirCreate{0}{1}(void) { 77 return wrap({2}.release()); 78 } 79 void mlirRegister{0}{1}(void) { 80 register{1}(); 81 } 82 83 )"; 84 85 /// {0}: The name of the pass group. 86 const char *const passGroupRegistrationCode = R"( 87 //===----------------------------------------------------------------------===// 88 // {0} Group Registration 89 //===----------------------------------------------------------------------===// 90 91 void mlirRegister{0}Passes(void) {{ 92 register{0}Passes(); 93 } 94 )"; 95 96 static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) { 97 os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; 98 os << formatv(passGroupRegistrationCode, groupName); 99 100 for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { 101 Pass pass(def); 102 StringRef defName = pass.getDef()->getName(); 103 104 std::string constructorCall; 105 if (StringRef constructor = pass.getConstructor(); !constructor.empty()) 106 constructorCall = constructor.str(); 107 else 108 constructorCall = formatv("create{0}()", pass.getDef()->getName()).str(); 109 110 os << formatv(passCreateDef, groupName, defName, constructorCall); 111 } 112 return false; 113 } 114 115 static mlir::GenRegistration genCAPIHeader("gen-pass-capi-header", 116 "Generate pass C API header", 117 &emitCAPIHeader); 118 119 static mlir::GenRegistration genCAPIImpl("gen-pass-capi-impl", 120 "Generate pass C API implementation", 121 &emitCAPIImpl); 122