xref: /llvm-project/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLOps.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/ODS/Operation.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <optional>
28 
29 using namespace mlir;
30 using namespace mlir::pdll;
31 
32 //===----------------------------------------------------------------------===//
33 // CodeGen
34 //===----------------------------------------------------------------------===//
35 
36 namespace {
37 class CodeGen {
38 public:
CodeGen(raw_ostream & os)39   CodeGen(raw_ostream &os) : os(os) {}
40 
41   /// Generate C++ code for the given PDL pattern module.
42   void generate(const ast::Module &astModule, ModuleOp module);
43 
44 private:
45   void generate(pdl::PatternOp pattern, StringRef patternName,
46                 StringSet<> &nativeFunctions);
47 
48   /// Generate C++ code for all user defined constraints and rewrites with
49   /// native code.
50   void generateConstraintAndRewrites(const ast::Module &astModule,
51                                      ModuleOp module,
52                                      StringSet<> &nativeFunctions);
53   void generate(const ast::UserConstraintDecl *decl,
54                 StringSet<> &nativeFunctions);
55   void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
56   void generateConstraintOrRewrite(const ast::CallableDecl *decl,
57                                    bool isConstraint,
58                                    StringSet<> &nativeFunctions);
59 
60   /// Return the native name for the type of the given type.
61   StringRef getNativeTypeName(ast::Type type);
62 
63   /// Return the native name for the type of the given variable decl.
64   StringRef getNativeTypeName(ast::VariableDecl *decl);
65 
66   /// The stream to output to.
67   raw_ostream &os;
68 };
69 } // namespace
70 
generate(const ast::Module & astModule,ModuleOp module)71 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
72   SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
73   StringSet<> nativeFunctions;
74 
75   // Generate code for any native functions within the module.
76   generateConstraintAndRewrites(astModule, module, nativeFunctions);
77 
78   os << "namespace {\n";
79   std::string basePatternName = "GeneratedPDLLPattern";
80   int patternIndex = 0;
81   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
82     // If the pattern has a name, use that. Otherwise, generate a unique name.
83     if (std::optional<StringRef> patternName = pattern.getSymName()) {
84       patternNames.insert(patternName->str());
85     } else {
86       std::string name;
87       do {
88         name = (basePatternName + Twine(patternIndex++)).str();
89       } while (!patternNames.insert(name));
90     }
91 
92     generate(pattern, patternNames.back(), nativeFunctions);
93   }
94   os << "} // end namespace\n\n";
95 
96   // Emit function to add the generated matchers to the pattern list.
97   os << "template <typename... ConfigsT>\n"
98         "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
99         "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
100   for (const auto &name : patternNames)
101     os << "  patterns.add<" << name
102        << ">(patterns.getContext(), configs...);\n";
103   os << "}\n";
104 }
105 
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)106 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
107                        StringSet<> &nativeFunctions) {
108   const char *patternClassStartStr = R"(
109 struct {0} : ::mlir::PDLPatternModule {{
110   template <typename... ConfigsT>
111   {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
112     : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
113 )";
114   os << llvm::formatv(patternClassStartStr, patternName);
115 
116   os << "R\"mlir(";
117   pattern->print(os, OpPrintingFlags().enableDebugInfo());
118   os << "\n    )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
119 
120   // Register any native functions used within the pattern.
121   StringSet<> registeredNativeFunctions;
122   auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
123     if (!nativeFunctions.count(fnName) ||
124         !registeredNativeFunctions.insert(fnName).second)
125       return;
126     os << "    register" << fnType << "Function(\"" << fnName << "\", "
127        << fnName << "PDLFn);\n";
128   };
129   pattern.walk([&](Operation *op) {
130     if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131       checkRegisterNativeFn(constraintOp.getName(), "Constraint");
132     else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133       checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
134   });
135   os << "  }\n};\n\n";
136 }
137 
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)138 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
139                                             ModuleOp module,
140                                             StringSet<> &nativeFunctions) {
141   // First check to see which constraints and rewrites are actually referenced
142   // in the module.
143   StringSet<> usedFns;
144   module.walk([&](Operation *op) {
145     TypeSwitch<Operation *>(op)
146         .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147             [&](auto op) { usedFns.insert(op.getName()); });
148   });
149 
150   for (const ast::Decl *decl : astModule.getChildren()) {
151     TypeSwitch<const ast::Decl *>(decl)
152         .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
153             [&](const auto *decl) {
154               // We only generate code for inline native decls that have been
155               // referenced.
156               if (decl->getCodeBlock() &&
157                   usedFns.contains(decl->getName().getName()))
158                 this->generate(decl, nativeFunctions);
159             });
160   }
161 }
162 
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)163 void CodeGen::generate(const ast::UserConstraintDecl *decl,
164                        StringSet<> &nativeFunctions) {
165   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
166                                      /*isConstraint=*/true, nativeFunctions);
167 }
168 
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)169 void CodeGen::generate(const ast::UserRewriteDecl *decl,
170                        StringSet<> &nativeFunctions) {
171   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
172                                      /*isConstraint=*/false, nativeFunctions);
173 }
174 
getNativeTypeName(ast::Type type)175 StringRef CodeGen::getNativeTypeName(ast::Type type) {
176   return llvm::TypeSwitch<ast::Type, StringRef>(type)
177       .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
178       .Case([&](ast::OperationType opType) -> StringRef {
179         // Use the derived Op class when available.
180         if (const auto *odsOp = opType.getODSOperation())
181           return odsOp->getNativeClassName();
182         return "::mlir::Operation *";
183       })
184       .Case([&](ast::TypeType) { return "::mlir::Type"; })
185       .Case([&](ast::ValueType) { return "::mlir::Value"; })
186       .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
187       .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
188 }
189 
getNativeTypeName(ast::VariableDecl * decl)190 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
191   // Try to extract a type name from the variable's constraints.
192   for (ast::ConstraintRef &cst : decl->getConstraints()) {
193     if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
194       if (std::optional<StringRef> name = userCst->getNativeInputType(0))
195         return *name;
196       return getNativeTypeName(userCst->getInputs()[0]);
197     }
198   }
199 
200   // Otherwise, use the type of the variable.
201   return getNativeTypeName(decl->getType());
202 }
203 
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)204 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
205                                           bool isConstraint,
206                                           StringSet<> &nativeFunctions) {
207   StringRef name = decl->getName()->getName();
208   nativeFunctions.insert(name);
209 
210   os << "static ";
211 
212   // TODO: Work out a proper modeling for "optionality".
213 
214   // Emit the result type.
215   // If this is a constraint, we always return a LogicalResult.
216   // TODO: This will need to change if we allow Constraints to return values as
217   // well.
218   if (isConstraint) {
219     os << "::llvm::LogicalResult";
220   } else {
221     // Otherwise, generate a type based on the results of the callable.
222     // If the callable has explicit results, use those to build the result.
223     // Otherwise, use the type of the callable.
224     ArrayRef<ast::VariableDecl *> results = decl->getResults();
225     if (results.empty()) {
226       os << "void";
227     } else if (results.size() == 1) {
228       os << getNativeTypeName(results[0]);
229     } else {
230       os << "std::tuple<";
231       llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
232         os << getNativeTypeName(result);
233       });
234       os << ">";
235     }
236   }
237 
238   os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
239   if (!decl->getInputs().empty()) {
240     os << ", ";
241     llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
242       os << getNativeTypeName(input) << " " << input->getName().getName();
243     });
244   }
245   os << ") {\n";
246   os << "  " << decl->getCodeBlock()->trim() << "\n}\n\n";
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // CPPGen
251 //===----------------------------------------------------------------------===//
252 
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)253 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
254                                   raw_ostream &os) {
255   CodeGen codegen(os);
256   codegen.generate(astModule, module);
257 }
258