1 //===- CPPGen.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This files contains a PDLL generator that outputs C++ code that defines PDLL
10 // patterns as individual C++ PDLPatternModules for direct use in native code,
11 // and also defines any native constraints whose bodies were defined in PDLL.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
16 #include "mlir/Dialect/PDL/IR/PDL.h"
17 #include "mlir/Dialect/PDL/IR/PDLOps.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Tools/PDLL/AST/Nodes.h"
20 #include "mlir/Tools/PDLL/ODS/Operation.h"
21 #include "llvm/ADT/SmallString.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include <optional>
28
29 using namespace mlir;
30 using namespace mlir::pdll;
31
32 //===----------------------------------------------------------------------===//
33 // CodeGen
34 //===----------------------------------------------------------------------===//
35
36 namespace {
37 class CodeGen {
38 public:
CodeGen(raw_ostream & os)39 CodeGen(raw_ostream &os) : os(os) {}
40
41 /// Generate C++ code for the given PDL pattern module.
42 void generate(const ast::Module &astModule, ModuleOp module);
43
44 private:
45 void generate(pdl::PatternOp pattern, StringRef patternName,
46 StringSet<> &nativeFunctions);
47
48 /// Generate C++ code for all user defined constraints and rewrites with
49 /// native code.
50 void generateConstraintAndRewrites(const ast::Module &astModule,
51 ModuleOp module,
52 StringSet<> &nativeFunctions);
53 void generate(const ast::UserConstraintDecl *decl,
54 StringSet<> &nativeFunctions);
55 void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
56 void generateConstraintOrRewrite(const ast::CallableDecl *decl,
57 bool isConstraint,
58 StringSet<> &nativeFunctions);
59
60 /// Return the native name for the type of the given type.
61 StringRef getNativeTypeName(ast::Type type);
62
63 /// Return the native name for the type of the given variable decl.
64 StringRef getNativeTypeName(ast::VariableDecl *decl);
65
66 /// The stream to output to.
67 raw_ostream &os;
68 };
69 } // namespace
70
generate(const ast::Module & astModule,ModuleOp module)71 void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
72 SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
73 StringSet<> nativeFunctions;
74
75 // Generate code for any native functions within the module.
76 generateConstraintAndRewrites(astModule, module, nativeFunctions);
77
78 os << "namespace {\n";
79 std::string basePatternName = "GeneratedPDLLPattern";
80 int patternIndex = 0;
81 for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
82 // If the pattern has a name, use that. Otherwise, generate a unique name.
83 if (std::optional<StringRef> patternName = pattern.getSymName()) {
84 patternNames.insert(patternName->str());
85 } else {
86 std::string name;
87 do {
88 name = (basePatternName + Twine(patternIndex++)).str();
89 } while (!patternNames.insert(name));
90 }
91
92 generate(pattern, patternNames.back(), nativeFunctions);
93 }
94 os << "} // end namespace\n\n";
95
96 // Emit function to add the generated matchers to the pattern list.
97 os << "template <typename... ConfigsT>\n"
98 "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
99 "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
100 for (const auto &name : patternNames)
101 os << " patterns.add<" << name
102 << ">(patterns.getContext(), configs...);\n";
103 os << "}\n";
104 }
105
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)106 void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
107 StringSet<> &nativeFunctions) {
108 const char *patternClassStartStr = R"(
109 struct {0} : ::mlir::PDLPatternModule {{
110 template <typename... ConfigsT>
111 {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
112 : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
113 )";
114 os << llvm::formatv(patternClassStartStr, patternName);
115
116 os << "R\"mlir(";
117 pattern->print(os, OpPrintingFlags().enableDebugInfo());
118 os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
119
120 // Register any native functions used within the pattern.
121 StringSet<> registeredNativeFunctions;
122 auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
123 if (!nativeFunctions.count(fnName) ||
124 !registeredNativeFunctions.insert(fnName).second)
125 return;
126 os << " register" << fnType << "Function(\"" << fnName << "\", "
127 << fnName << "PDLFn);\n";
128 };
129 pattern.walk([&](Operation *op) {
130 if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131 checkRegisterNativeFn(constraintOp.getName(), "Constraint");
132 else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133 checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
134 });
135 os << " }\n};\n\n";
136 }
137
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)138 void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
139 ModuleOp module,
140 StringSet<> &nativeFunctions) {
141 // First check to see which constraints and rewrites are actually referenced
142 // in the module.
143 StringSet<> usedFns;
144 module.walk([&](Operation *op) {
145 TypeSwitch<Operation *>(op)
146 .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147 [&](auto op) { usedFns.insert(op.getName()); });
148 });
149
150 for (const ast::Decl *decl : astModule.getChildren()) {
151 TypeSwitch<const ast::Decl *>(decl)
152 .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
153 [&](const auto *decl) {
154 // We only generate code for inline native decls that have been
155 // referenced.
156 if (decl->getCodeBlock() &&
157 usedFns.contains(decl->getName().getName()))
158 this->generate(decl, nativeFunctions);
159 });
160 }
161 }
162
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)163 void CodeGen::generate(const ast::UserConstraintDecl *decl,
164 StringSet<> &nativeFunctions) {
165 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
166 /*isConstraint=*/true, nativeFunctions);
167 }
168
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)169 void CodeGen::generate(const ast::UserRewriteDecl *decl,
170 StringSet<> &nativeFunctions) {
171 return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
172 /*isConstraint=*/false, nativeFunctions);
173 }
174
getNativeTypeName(ast::Type type)175 StringRef CodeGen::getNativeTypeName(ast::Type type) {
176 return llvm::TypeSwitch<ast::Type, StringRef>(type)
177 .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
178 .Case([&](ast::OperationType opType) -> StringRef {
179 // Use the derived Op class when available.
180 if (const auto *odsOp = opType.getODSOperation())
181 return odsOp->getNativeClassName();
182 return "::mlir::Operation *";
183 })
184 .Case([&](ast::TypeType) { return "::mlir::Type"; })
185 .Case([&](ast::ValueType) { return "::mlir::Value"; })
186 .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
187 .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
188 }
189
getNativeTypeName(ast::VariableDecl * decl)190 StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
191 // Try to extract a type name from the variable's constraints.
192 for (ast::ConstraintRef &cst : decl->getConstraints()) {
193 if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
194 if (std::optional<StringRef> name = userCst->getNativeInputType(0))
195 return *name;
196 return getNativeTypeName(userCst->getInputs()[0]);
197 }
198 }
199
200 // Otherwise, use the type of the variable.
201 return getNativeTypeName(decl->getType());
202 }
203
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)204 void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
205 bool isConstraint,
206 StringSet<> &nativeFunctions) {
207 StringRef name = decl->getName()->getName();
208 nativeFunctions.insert(name);
209
210 os << "static ";
211
212 // TODO: Work out a proper modeling for "optionality".
213
214 // Emit the result type.
215 // If this is a constraint, we always return a LogicalResult.
216 // TODO: This will need to change if we allow Constraints to return values as
217 // well.
218 if (isConstraint) {
219 os << "::llvm::LogicalResult";
220 } else {
221 // Otherwise, generate a type based on the results of the callable.
222 // If the callable has explicit results, use those to build the result.
223 // Otherwise, use the type of the callable.
224 ArrayRef<ast::VariableDecl *> results = decl->getResults();
225 if (results.empty()) {
226 os << "void";
227 } else if (results.size() == 1) {
228 os << getNativeTypeName(results[0]);
229 } else {
230 os << "std::tuple<";
231 llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
232 os << getNativeTypeName(result);
233 });
234 os << ">";
235 }
236 }
237
238 os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
239 if (!decl->getInputs().empty()) {
240 os << ", ";
241 llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
242 os << getNativeTypeName(input) << " " << input->getName().getName();
243 });
244 }
245 os << ") {\n";
246 os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
247 }
248
249 //===----------------------------------------------------------------------===//
250 // CPPGen
251 //===----------------------------------------------------------------------===//
252
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)253 void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
254 raw_ostream &os) {
255 CodeGen codegen(os);
256 codegen.generate(astModule, module);
257 }
258