19ad64a5cSRiver Riddle //===- CPPGen.cpp ---------------------------------------------------------===//
29ad64a5cSRiver Riddle //
39ad64a5cSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49ad64a5cSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
59ad64a5cSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69ad64a5cSRiver Riddle //
79ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
89ad64a5cSRiver Riddle //
99ad64a5cSRiver Riddle // This files contains a PDLL generator that outputs C++ code that defines PDLL
109ad64a5cSRiver Riddle // patterns as individual C++ PDLPatternModules for direct use in native code,
119ad64a5cSRiver Riddle // and also defines any native constraints whose bodies were defined in PDLL.
129ad64a5cSRiver Riddle //
139ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
149ad64a5cSRiver Riddle
159ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
169ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
179ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLOps.h"
189ad64a5cSRiver Riddle #include "mlir/IR/BuiltinOps.h"
199ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/AST/Nodes.h"
201c2edb02SRiver Riddle #include "mlir/Tools/PDLL/ODS/Operation.h"
219ad64a5cSRiver Riddle #include "llvm/ADT/SmallString.h"
229ad64a5cSRiver Riddle #include "llvm/ADT/StringExtras.h"
239ad64a5cSRiver Riddle #include "llvm/ADT/StringSet.h"
249ad64a5cSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
259ad64a5cSRiver Riddle #include "llvm/Support/ErrorHandling.h"
269ad64a5cSRiver Riddle #include "llvm/Support/FormatVariadic.h"
27a1fe1f5fSKazu Hirata #include <optional>
289ad64a5cSRiver Riddle
299ad64a5cSRiver Riddle using namespace mlir;
309ad64a5cSRiver Riddle using namespace mlir::pdll;
319ad64a5cSRiver Riddle
329ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
339ad64a5cSRiver Riddle // CodeGen
349ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
359ad64a5cSRiver Riddle
369ad64a5cSRiver Riddle namespace {
379ad64a5cSRiver Riddle class CodeGen {
389ad64a5cSRiver Riddle public:
CodeGen(raw_ostream & os)399ad64a5cSRiver Riddle CodeGen(raw_ostream &os) : os(os) {}
409ad64a5cSRiver Riddle
419ad64a5cSRiver Riddle /// Generate C++ code for the given PDL pattern module.
429ad64a5cSRiver Riddle void generate(const ast::Module &astModule, ModuleOp module);
439ad64a5cSRiver Riddle
449ad64a5cSRiver Riddle private:
459ad64a5cSRiver Riddle void generate(pdl::PatternOp pattern, StringRef patternName,
469ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
479ad64a5cSRiver Riddle
489ad64a5cSRiver Riddle /// Generate C++ code for all user defined constraints and rewrites with
499ad64a5cSRiver Riddle /// native code.
509ad64a5cSRiver Riddle void generateConstraintAndRewrites(const ast::Module &astModule,
519ad64a5cSRiver Riddle ModuleOp module,
529ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
539ad64a5cSRiver Riddle void generate(const ast::UserConstraintDecl *decl,
549ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
559ad64a5cSRiver Riddle void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
561c2edb02SRiver Riddle void generateConstraintOrRewrite(const ast::CallableDecl *decl,
571c2edb02SRiver Riddle bool isConstraint,
589ad64a5cSRiver Riddle StringSet<> &nativeFunctions);
599ad64a5cSRiver Riddle
601c2edb02SRiver Riddle /// Return the native name for the type of the given type.
611c2edb02SRiver Riddle StringRef getNativeTypeName(ast::Type type);
621c2edb02SRiver Riddle
631c2edb02SRiver Riddle /// Return the native name for the type of the given variable decl.
641c2edb02SRiver Riddle StringRef getNativeTypeName(ast::VariableDecl *decl);
651c2edb02SRiver Riddle
669ad64a5cSRiver Riddle /// The stream to output to.
679ad64a5cSRiver Riddle raw_ostream &os;
689ad64a5cSRiver Riddle };
699ad64a5cSRiver Riddle } // namespace
709ad64a5cSRiver Riddle
generate(const ast::Module & astModule,ModuleOp module)719ad64a5cSRiver Riddle void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
729ad64a5cSRiver Riddle SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
739ad64a5cSRiver Riddle StringSet<> nativeFunctions;
749ad64a5cSRiver Riddle
759ad64a5cSRiver Riddle // Generate code for any native functions within the module.
769ad64a5cSRiver Riddle generateConstraintAndRewrites(astModule, module, nativeFunctions);
779ad64a5cSRiver Riddle
789ad64a5cSRiver Riddle os << "namespace {\n";
799ad64a5cSRiver Riddle std::string basePatternName = "GeneratedPDLLPattern";
809ad64a5cSRiver Riddle int patternIndex = 0;
819ad64a5cSRiver Riddle for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
829ad64a5cSRiver Riddle // If the pattern has a name, use that. Otherwise, generate a unique name.
8322426110SRamkumar Ramachandra if (std::optional<StringRef> patternName = pattern.getSymName()) {
849ad64a5cSRiver Riddle patternNames.insert(patternName->str());
859ad64a5cSRiver Riddle } else {
869ad64a5cSRiver Riddle std::string name;
879ad64a5cSRiver Riddle do {
889ad64a5cSRiver Riddle name = (basePatternName + Twine(patternIndex++)).str();
899ad64a5cSRiver Riddle } while (!patternNames.insert(name));
909ad64a5cSRiver Riddle }
919ad64a5cSRiver Riddle
929ad64a5cSRiver Riddle generate(pattern, patternNames.back(), nativeFunctions);
939ad64a5cSRiver Riddle }
949ad64a5cSRiver Riddle os << "} // end namespace\n\n";
959ad64a5cSRiver Riddle
969ad64a5cSRiver Riddle // Emit function to add the generated matchers to the pattern list.
978c66344eSRiver Riddle os << "template <typename... ConfigsT>\n"
988c66344eSRiver Riddle "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
998c66344eSRiver Riddle "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
1009ad64a5cSRiver Riddle for (const auto &name : patternNames)
1018c66344eSRiver Riddle os << " patterns.add<" << name
1028c66344eSRiver Riddle << ">(patterns.getContext(), configs...);\n";
1039ad64a5cSRiver Riddle os << "}\n";
1049ad64a5cSRiver Riddle }
1059ad64a5cSRiver Riddle
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)1069ad64a5cSRiver Riddle void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
1079ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
1089ad64a5cSRiver Riddle const char *patternClassStartStr = R"(
1099ad64a5cSRiver Riddle struct {0} : ::mlir::PDLPatternModule {{
1108c66344eSRiver Riddle template <typename... ConfigsT>
1118c66344eSRiver Riddle {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
1129ad64a5cSRiver Riddle : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
1139ad64a5cSRiver Riddle )";
1149ad64a5cSRiver Riddle os << llvm::formatv(patternClassStartStr, patternName);
1159ad64a5cSRiver Riddle
1169ad64a5cSRiver Riddle os << "R\"mlir(";
1179ad64a5cSRiver Riddle pattern->print(os, OpPrintingFlags().enableDebugInfo());
1188c66344eSRiver Riddle os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
1199ad64a5cSRiver Riddle
1209ad64a5cSRiver Riddle // Register any native functions used within the pattern.
1219ad64a5cSRiver Riddle StringSet<> registeredNativeFunctions;
1229ad64a5cSRiver Riddle auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
1239ad64a5cSRiver Riddle if (!nativeFunctions.count(fnName) ||
1249ad64a5cSRiver Riddle !registeredNativeFunctions.insert(fnName).second)
1259ad64a5cSRiver Riddle return;
1269ad64a5cSRiver Riddle os << " register" << fnType << "Function(\"" << fnName << "\", "
1279ad64a5cSRiver Riddle << fnName << "PDLFn);\n";
1289ad64a5cSRiver Riddle };
1299ad64a5cSRiver Riddle pattern.walk([&](Operation *op) {
1309ad64a5cSRiver Riddle if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131310c3ee4SRiver Riddle checkRegisterNativeFn(constraintOp.getName(), "Constraint");
1329ad64a5cSRiver Riddle else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133310c3ee4SRiver Riddle checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
1349ad64a5cSRiver Riddle });
1359ad64a5cSRiver Riddle os << " }\n};\n\n";
1369ad64a5cSRiver Riddle }
1379ad64a5cSRiver Riddle
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)1389ad64a5cSRiver Riddle void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
1399ad64a5cSRiver Riddle ModuleOp module,
1409ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
1419ad64a5cSRiver Riddle // First check to see which constraints and rewrites are actually referenced
1429ad64a5cSRiver Riddle // in the module.
1439ad64a5cSRiver Riddle StringSet<> usedFns;
1449ad64a5cSRiver Riddle module.walk([&](Operation *op) {
1459ad64a5cSRiver Riddle TypeSwitch<Operation *>(op)
1469ad64a5cSRiver Riddle .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147310c3ee4SRiver Riddle [&](auto op) { usedFns.insert(op.getName()); });
1489ad64a5cSRiver Riddle });
1499ad64a5cSRiver Riddle
1509ad64a5cSRiver Riddle for (const ast::Decl *decl : astModule.getChildren()) {
1519ad64a5cSRiver Riddle TypeSwitch<const ast::Decl *>(decl)
1529ad64a5cSRiver Riddle .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
1539ad64a5cSRiver Riddle [&](const auto *decl) {
1549ad64a5cSRiver Riddle // We only generate code for inline native decls that have been
1559ad64a5cSRiver Riddle // referenced.
1569ad64a5cSRiver Riddle if (decl->getCodeBlock() &&
1579ad64a5cSRiver Riddle usedFns.contains(decl->getName().getName()))
1589ad64a5cSRiver Riddle this->generate(decl, nativeFunctions);
1599ad64a5cSRiver Riddle });
1609ad64a5cSRiver Riddle }
1619ad64a5cSRiver Riddle }
1629ad64a5cSRiver Riddle
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)1639ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserConstraintDecl *decl,
1649ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
1651c2edb02SRiver Riddle return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
1661c2edb02SRiver Riddle /*isConstraint=*/true, nativeFunctions);
1679ad64a5cSRiver Riddle }
1689ad64a5cSRiver Riddle
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)1699ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserRewriteDecl *decl,
1709ad64a5cSRiver Riddle StringSet<> &nativeFunctions) {
1711c2edb02SRiver Riddle return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
1721c2edb02SRiver Riddle /*isConstraint=*/false, nativeFunctions);
1739ad64a5cSRiver Riddle }
1749ad64a5cSRiver Riddle
getNativeTypeName(ast::Type type)1751c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::Type type) {
1769ad64a5cSRiver Riddle return llvm::TypeSwitch<ast::Type, StringRef>(type)
1779ad64a5cSRiver Riddle .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
1781c2edb02SRiver Riddle .Case([&](ast::OperationType opType) -> StringRef {
1791c2edb02SRiver Riddle // Use the derived Op class when available.
1801c2edb02SRiver Riddle if (const auto *odsOp = opType.getODSOperation())
1811c2edb02SRiver Riddle return odsOp->getNativeClassName();
1829ad64a5cSRiver Riddle return "::mlir::Operation *";
1839ad64a5cSRiver Riddle })
1849ad64a5cSRiver Riddle .Case([&](ast::TypeType) { return "::mlir::Type"; })
1859ad64a5cSRiver Riddle .Case([&](ast::ValueType) { return "::mlir::Value"; })
1869ad64a5cSRiver Riddle .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
1879ad64a5cSRiver Riddle .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
1889ad64a5cSRiver Riddle }
1899ad64a5cSRiver Riddle
getNativeTypeName(ast::VariableDecl * decl)1901c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
1911c2edb02SRiver Riddle // Try to extract a type name from the variable's constraints.
1921c2edb02SRiver Riddle for (ast::ConstraintRef &cst : decl->getConstraints()) {
1931c2edb02SRiver Riddle if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
1940a81ace0SKazu Hirata if (std::optional<StringRef> name = userCst->getNativeInputType(0))
1951c2edb02SRiver Riddle return *name;
1961c2edb02SRiver Riddle return getNativeTypeName(userCst->getInputs()[0]);
1971c2edb02SRiver Riddle }
1981c2edb02SRiver Riddle }
1991c2edb02SRiver Riddle
2001c2edb02SRiver Riddle // Otherwise, use the type of the variable.
2011c2edb02SRiver Riddle return getNativeTypeName(decl->getType());
2021c2edb02SRiver Riddle }
2031c2edb02SRiver Riddle
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)2041c2edb02SRiver Riddle void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
2051c2edb02SRiver Riddle bool isConstraint,
2061c2edb02SRiver Riddle StringSet<> &nativeFunctions) {
2071c2edb02SRiver Riddle StringRef name = decl->getName()->getName();
2081c2edb02SRiver Riddle nativeFunctions.insert(name);
2091c2edb02SRiver Riddle
2101c2edb02SRiver Riddle os << "static ";
2111c2edb02SRiver Riddle
2121c2edb02SRiver Riddle // TODO: Work out a proper modeling for "optionality".
2131c2edb02SRiver Riddle
2141c2edb02SRiver Riddle // Emit the result type.
2151c2edb02SRiver Riddle // If this is a constraint, we always return a LogicalResult.
2161c2edb02SRiver Riddle // TODO: This will need to change if we allow Constraints to return values as
2171c2edb02SRiver Riddle // well.
2181c2edb02SRiver Riddle if (isConstraint) {
219*db791b27SRamkumar Ramachandra os << "::llvm::LogicalResult";
2201c2edb02SRiver Riddle } else {
2211c2edb02SRiver Riddle // Otherwise, generate a type based on the results of the callable.
2221c2edb02SRiver Riddle // If the callable has explicit results, use those to build the result.
2231c2edb02SRiver Riddle // Otherwise, use the type of the callable.
2241c2edb02SRiver Riddle ArrayRef<ast::VariableDecl *> results = decl->getResults();
2251c2edb02SRiver Riddle if (results.empty()) {
2261c2edb02SRiver Riddle os << "void";
2271c2edb02SRiver Riddle } else if (results.size() == 1) {
2281c2edb02SRiver Riddle os << getNativeTypeName(results[0]);
2291c2edb02SRiver Riddle } else {
2301c2edb02SRiver Riddle os << "std::tuple<";
2311c2edb02SRiver Riddle llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
2321c2edb02SRiver Riddle os << getNativeTypeName(result);
2331c2edb02SRiver Riddle });
2341c2edb02SRiver Riddle os << ">";
2351c2edb02SRiver Riddle }
2361c2edb02SRiver Riddle }
2371c2edb02SRiver Riddle
2381c2edb02SRiver Riddle os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
2391c2edb02SRiver Riddle if (!decl->getInputs().empty()) {
2401c2edb02SRiver Riddle os << ", ";
2411c2edb02SRiver Riddle llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
2421c2edb02SRiver Riddle os << getNativeTypeName(input) << " " << input->getName().getName();
2431c2edb02SRiver Riddle });
2441c2edb02SRiver Riddle }
2451c2edb02SRiver Riddle os << ") {\n";
2461c2edb02SRiver Riddle os << " " << decl->getCodeBlock()->trim() << "\n}\n\n";
2479ad64a5cSRiver Riddle }
2489ad64a5cSRiver Riddle
2499ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2509ad64a5cSRiver Riddle // CPPGen
2519ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2529ad64a5cSRiver Riddle
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)2539ad64a5cSRiver Riddle void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
2549ad64a5cSRiver Riddle raw_ostream &os) {
2559ad64a5cSRiver Riddle CodeGen codegen(os);
2569ad64a5cSRiver Riddle codegen.generate(astModule, module);
2579ad64a5cSRiver Riddle }
258