xref: /llvm-project/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
19ad64a5cSRiver Riddle //===- CPPGen.cpp ---------------------------------------------------------===//
29ad64a5cSRiver Riddle //
39ad64a5cSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49ad64a5cSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
59ad64a5cSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69ad64a5cSRiver Riddle //
79ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
89ad64a5cSRiver Riddle //
99ad64a5cSRiver Riddle // This files contains a PDLL generator that outputs C++ code that defines PDLL
109ad64a5cSRiver Riddle // patterns as individual C++ PDLPatternModules for direct use in native code,
119ad64a5cSRiver Riddle // and also defines any native constraints whose bodies were defined in PDLL.
129ad64a5cSRiver Riddle //
139ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
149ad64a5cSRiver Riddle 
159ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
169ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
179ad64a5cSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLOps.h"
189ad64a5cSRiver Riddle #include "mlir/IR/BuiltinOps.h"
199ad64a5cSRiver Riddle #include "mlir/Tools/PDLL/AST/Nodes.h"
201c2edb02SRiver Riddle #include "mlir/Tools/PDLL/ODS/Operation.h"
219ad64a5cSRiver Riddle #include "llvm/ADT/SmallString.h"
229ad64a5cSRiver Riddle #include "llvm/ADT/StringExtras.h"
239ad64a5cSRiver Riddle #include "llvm/ADT/StringSet.h"
249ad64a5cSRiver Riddle #include "llvm/ADT/TypeSwitch.h"
259ad64a5cSRiver Riddle #include "llvm/Support/ErrorHandling.h"
269ad64a5cSRiver Riddle #include "llvm/Support/FormatVariadic.h"
27a1fe1f5fSKazu Hirata #include <optional>
289ad64a5cSRiver Riddle 
299ad64a5cSRiver Riddle using namespace mlir;
309ad64a5cSRiver Riddle using namespace mlir::pdll;
319ad64a5cSRiver Riddle 
329ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
339ad64a5cSRiver Riddle // CodeGen
349ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
359ad64a5cSRiver Riddle 
369ad64a5cSRiver Riddle namespace {
379ad64a5cSRiver Riddle class CodeGen {
389ad64a5cSRiver Riddle public:
CodeGen(raw_ostream & os)399ad64a5cSRiver Riddle   CodeGen(raw_ostream &os) : os(os) {}
409ad64a5cSRiver Riddle 
419ad64a5cSRiver Riddle   /// Generate C++ code for the given PDL pattern module.
429ad64a5cSRiver Riddle   void generate(const ast::Module &astModule, ModuleOp module);
439ad64a5cSRiver Riddle 
449ad64a5cSRiver Riddle private:
459ad64a5cSRiver Riddle   void generate(pdl::PatternOp pattern, StringRef patternName,
469ad64a5cSRiver Riddle                 StringSet<> &nativeFunctions);
479ad64a5cSRiver Riddle 
489ad64a5cSRiver Riddle   /// Generate C++ code for all user defined constraints and rewrites with
499ad64a5cSRiver Riddle   /// native code.
509ad64a5cSRiver Riddle   void generateConstraintAndRewrites(const ast::Module &astModule,
519ad64a5cSRiver Riddle                                      ModuleOp module,
529ad64a5cSRiver Riddle                                      StringSet<> &nativeFunctions);
539ad64a5cSRiver Riddle   void generate(const ast::UserConstraintDecl *decl,
549ad64a5cSRiver Riddle                 StringSet<> &nativeFunctions);
559ad64a5cSRiver Riddle   void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
561c2edb02SRiver Riddle   void generateConstraintOrRewrite(const ast::CallableDecl *decl,
571c2edb02SRiver Riddle                                    bool isConstraint,
589ad64a5cSRiver Riddle                                    StringSet<> &nativeFunctions);
599ad64a5cSRiver Riddle 
601c2edb02SRiver Riddle   /// Return the native name for the type of the given type.
611c2edb02SRiver Riddle   StringRef getNativeTypeName(ast::Type type);
621c2edb02SRiver Riddle 
631c2edb02SRiver Riddle   /// Return the native name for the type of the given variable decl.
641c2edb02SRiver Riddle   StringRef getNativeTypeName(ast::VariableDecl *decl);
651c2edb02SRiver Riddle 
669ad64a5cSRiver Riddle   /// The stream to output to.
679ad64a5cSRiver Riddle   raw_ostream &os;
689ad64a5cSRiver Riddle };
699ad64a5cSRiver Riddle } // namespace
709ad64a5cSRiver Riddle 
generate(const ast::Module & astModule,ModuleOp module)719ad64a5cSRiver Riddle void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
729ad64a5cSRiver Riddle   SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
739ad64a5cSRiver Riddle   StringSet<> nativeFunctions;
749ad64a5cSRiver Riddle 
759ad64a5cSRiver Riddle   // Generate code for any native functions within the module.
769ad64a5cSRiver Riddle   generateConstraintAndRewrites(astModule, module, nativeFunctions);
779ad64a5cSRiver Riddle 
789ad64a5cSRiver Riddle   os << "namespace {\n";
799ad64a5cSRiver Riddle   std::string basePatternName = "GeneratedPDLLPattern";
809ad64a5cSRiver Riddle   int patternIndex = 0;
819ad64a5cSRiver Riddle   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
829ad64a5cSRiver Riddle     // If the pattern has a name, use that. Otherwise, generate a unique name.
8322426110SRamkumar Ramachandra     if (std::optional<StringRef> patternName = pattern.getSymName()) {
849ad64a5cSRiver Riddle       patternNames.insert(patternName->str());
859ad64a5cSRiver Riddle     } else {
869ad64a5cSRiver Riddle       std::string name;
879ad64a5cSRiver Riddle       do {
889ad64a5cSRiver Riddle         name = (basePatternName + Twine(patternIndex++)).str();
899ad64a5cSRiver Riddle       } while (!patternNames.insert(name));
909ad64a5cSRiver Riddle     }
919ad64a5cSRiver Riddle 
929ad64a5cSRiver Riddle     generate(pattern, patternNames.back(), nativeFunctions);
939ad64a5cSRiver Riddle   }
949ad64a5cSRiver Riddle   os << "} // end namespace\n\n";
959ad64a5cSRiver Riddle 
969ad64a5cSRiver Riddle   // Emit function to add the generated matchers to the pattern list.
978c66344eSRiver Riddle   os << "template <typename... ConfigsT>\n"
988c66344eSRiver Riddle         "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
998c66344eSRiver Riddle         "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
1009ad64a5cSRiver Riddle   for (const auto &name : patternNames)
1018c66344eSRiver Riddle     os << "  patterns.add<" << name
1028c66344eSRiver Riddle        << ">(patterns.getContext(), configs...);\n";
1039ad64a5cSRiver Riddle   os << "}\n";
1049ad64a5cSRiver Riddle }
1059ad64a5cSRiver Riddle 
generate(pdl::PatternOp pattern,StringRef patternName,StringSet<> & nativeFunctions)1069ad64a5cSRiver Riddle void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
1079ad64a5cSRiver Riddle                        StringSet<> &nativeFunctions) {
1089ad64a5cSRiver Riddle   const char *patternClassStartStr = R"(
1099ad64a5cSRiver Riddle struct {0} : ::mlir::PDLPatternModule {{
1108c66344eSRiver Riddle   template <typename... ConfigsT>
1118c66344eSRiver Riddle   {0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
1129ad64a5cSRiver Riddle     : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
1139ad64a5cSRiver Riddle )";
1149ad64a5cSRiver Riddle   os << llvm::formatv(patternClassStartStr, patternName);
1159ad64a5cSRiver Riddle 
1169ad64a5cSRiver Riddle   os << "R\"mlir(";
1179ad64a5cSRiver Riddle   pattern->print(os, OpPrintingFlags().enableDebugInfo());
1188c66344eSRiver Riddle   os << "\n    )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
1199ad64a5cSRiver Riddle 
1209ad64a5cSRiver Riddle   // Register any native functions used within the pattern.
1219ad64a5cSRiver Riddle   StringSet<> registeredNativeFunctions;
1229ad64a5cSRiver Riddle   auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
1239ad64a5cSRiver Riddle     if (!nativeFunctions.count(fnName) ||
1249ad64a5cSRiver Riddle         !registeredNativeFunctions.insert(fnName).second)
1259ad64a5cSRiver Riddle       return;
1269ad64a5cSRiver Riddle     os << "    register" << fnType << "Function(\"" << fnName << "\", "
1279ad64a5cSRiver Riddle        << fnName << "PDLFn);\n";
1289ad64a5cSRiver Riddle   };
1299ad64a5cSRiver Riddle   pattern.walk([&](Operation *op) {
1309ad64a5cSRiver Riddle     if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
131310c3ee4SRiver Riddle       checkRegisterNativeFn(constraintOp.getName(), "Constraint");
1329ad64a5cSRiver Riddle     else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
133310c3ee4SRiver Riddle       checkRegisterNativeFn(rewriteOp.getName(), "Rewrite");
1349ad64a5cSRiver Riddle   });
1359ad64a5cSRiver Riddle   os << "  }\n};\n\n";
1369ad64a5cSRiver Riddle }
1379ad64a5cSRiver Riddle 
generateConstraintAndRewrites(const ast::Module & astModule,ModuleOp module,StringSet<> & nativeFunctions)1389ad64a5cSRiver Riddle void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
1399ad64a5cSRiver Riddle                                             ModuleOp module,
1409ad64a5cSRiver Riddle                                             StringSet<> &nativeFunctions) {
1419ad64a5cSRiver Riddle   // First check to see which constraints and rewrites are actually referenced
1429ad64a5cSRiver Riddle   // in the module.
1439ad64a5cSRiver Riddle   StringSet<> usedFns;
1449ad64a5cSRiver Riddle   module.walk([&](Operation *op) {
1459ad64a5cSRiver Riddle     TypeSwitch<Operation *>(op)
1469ad64a5cSRiver Riddle         .Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
147310c3ee4SRiver Riddle             [&](auto op) { usedFns.insert(op.getName()); });
1489ad64a5cSRiver Riddle   });
1499ad64a5cSRiver Riddle 
1509ad64a5cSRiver Riddle   for (const ast::Decl *decl : astModule.getChildren()) {
1519ad64a5cSRiver Riddle     TypeSwitch<const ast::Decl *>(decl)
1529ad64a5cSRiver Riddle         .Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
1539ad64a5cSRiver Riddle             [&](const auto *decl) {
1549ad64a5cSRiver Riddle               // We only generate code for inline native decls that have been
1559ad64a5cSRiver Riddle               // referenced.
1569ad64a5cSRiver Riddle               if (decl->getCodeBlock() &&
1579ad64a5cSRiver Riddle                   usedFns.contains(decl->getName().getName()))
1589ad64a5cSRiver Riddle                 this->generate(decl, nativeFunctions);
1599ad64a5cSRiver Riddle             });
1609ad64a5cSRiver Riddle   }
1619ad64a5cSRiver Riddle }
1629ad64a5cSRiver Riddle 
generate(const ast::UserConstraintDecl * decl,StringSet<> & nativeFunctions)1639ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserConstraintDecl *decl,
1649ad64a5cSRiver Riddle                        StringSet<> &nativeFunctions) {
1651c2edb02SRiver Riddle   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
1661c2edb02SRiver Riddle                                      /*isConstraint=*/true, nativeFunctions);
1679ad64a5cSRiver Riddle }
1689ad64a5cSRiver Riddle 
generate(const ast::UserRewriteDecl * decl,StringSet<> & nativeFunctions)1699ad64a5cSRiver Riddle void CodeGen::generate(const ast::UserRewriteDecl *decl,
1709ad64a5cSRiver Riddle                        StringSet<> &nativeFunctions) {
1711c2edb02SRiver Riddle   return generateConstraintOrRewrite(cast<ast::CallableDecl>(decl),
1721c2edb02SRiver Riddle                                      /*isConstraint=*/false, nativeFunctions);
1739ad64a5cSRiver Riddle }
1749ad64a5cSRiver Riddle 
getNativeTypeName(ast::Type type)1751c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::Type type) {
1769ad64a5cSRiver Riddle   return llvm::TypeSwitch<ast::Type, StringRef>(type)
1779ad64a5cSRiver Riddle       .Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
1781c2edb02SRiver Riddle       .Case([&](ast::OperationType opType) -> StringRef {
1791c2edb02SRiver Riddle         // Use the derived Op class when available.
1801c2edb02SRiver Riddle         if (const auto *odsOp = opType.getODSOperation())
1811c2edb02SRiver Riddle           return odsOp->getNativeClassName();
1829ad64a5cSRiver Riddle         return "::mlir::Operation *";
1839ad64a5cSRiver Riddle       })
1849ad64a5cSRiver Riddle       .Case([&](ast::TypeType) { return "::mlir::Type"; })
1859ad64a5cSRiver Riddle       .Case([&](ast::ValueType) { return "::mlir::Value"; })
1869ad64a5cSRiver Riddle       .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
1879ad64a5cSRiver Riddle       .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
1889ad64a5cSRiver Riddle }
1899ad64a5cSRiver Riddle 
getNativeTypeName(ast::VariableDecl * decl)1901c2edb02SRiver Riddle StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) {
1911c2edb02SRiver Riddle   // Try to extract a type name from the variable's constraints.
1921c2edb02SRiver Riddle   for (ast::ConstraintRef &cst : decl->getConstraints()) {
1931c2edb02SRiver Riddle     if (auto *userCst = dyn_cast<ast::UserConstraintDecl>(cst.constraint)) {
1940a81ace0SKazu Hirata       if (std::optional<StringRef> name = userCst->getNativeInputType(0))
1951c2edb02SRiver Riddle         return *name;
1961c2edb02SRiver Riddle       return getNativeTypeName(userCst->getInputs()[0]);
1971c2edb02SRiver Riddle     }
1981c2edb02SRiver Riddle   }
1991c2edb02SRiver Riddle 
2001c2edb02SRiver Riddle   // Otherwise, use the type of the variable.
2011c2edb02SRiver Riddle   return getNativeTypeName(decl->getType());
2021c2edb02SRiver Riddle }
2031c2edb02SRiver Riddle 
generateConstraintOrRewrite(const ast::CallableDecl * decl,bool isConstraint,StringSet<> & nativeFunctions)2041c2edb02SRiver Riddle void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl,
2051c2edb02SRiver Riddle                                           bool isConstraint,
2061c2edb02SRiver Riddle                                           StringSet<> &nativeFunctions) {
2071c2edb02SRiver Riddle   StringRef name = decl->getName()->getName();
2081c2edb02SRiver Riddle   nativeFunctions.insert(name);
2091c2edb02SRiver Riddle 
2101c2edb02SRiver Riddle   os << "static ";
2111c2edb02SRiver Riddle 
2121c2edb02SRiver Riddle   // TODO: Work out a proper modeling for "optionality".
2131c2edb02SRiver Riddle 
2141c2edb02SRiver Riddle   // Emit the result type.
2151c2edb02SRiver Riddle   // If this is a constraint, we always return a LogicalResult.
2161c2edb02SRiver Riddle   // TODO: This will need to change if we allow Constraints to return values as
2171c2edb02SRiver Riddle   // well.
2181c2edb02SRiver Riddle   if (isConstraint) {
219*db791b27SRamkumar Ramachandra     os << "::llvm::LogicalResult";
2201c2edb02SRiver Riddle   } else {
2211c2edb02SRiver Riddle     // Otherwise, generate a type based on the results of the callable.
2221c2edb02SRiver Riddle     // If the callable has explicit results, use those to build the result.
2231c2edb02SRiver Riddle     // Otherwise, use the type of the callable.
2241c2edb02SRiver Riddle     ArrayRef<ast::VariableDecl *> results = decl->getResults();
2251c2edb02SRiver Riddle     if (results.empty()) {
2261c2edb02SRiver Riddle       os << "void";
2271c2edb02SRiver Riddle     } else if (results.size() == 1) {
2281c2edb02SRiver Riddle       os << getNativeTypeName(results[0]);
2291c2edb02SRiver Riddle     } else {
2301c2edb02SRiver Riddle       os << "std::tuple<";
2311c2edb02SRiver Riddle       llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) {
2321c2edb02SRiver Riddle         os << getNativeTypeName(result);
2331c2edb02SRiver Riddle       });
2341c2edb02SRiver Riddle       os << ">";
2351c2edb02SRiver Riddle     }
2361c2edb02SRiver Riddle   }
2371c2edb02SRiver Riddle 
2381c2edb02SRiver Riddle   os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter";
2391c2edb02SRiver Riddle   if (!decl->getInputs().empty()) {
2401c2edb02SRiver Riddle     os << ", ";
2411c2edb02SRiver Riddle     llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) {
2421c2edb02SRiver Riddle       os << getNativeTypeName(input) << " " << input->getName().getName();
2431c2edb02SRiver Riddle     });
2441c2edb02SRiver Riddle   }
2451c2edb02SRiver Riddle   os << ") {\n";
2461c2edb02SRiver Riddle   os << "  " << decl->getCodeBlock()->trim() << "\n}\n\n";
2479ad64a5cSRiver Riddle }
2489ad64a5cSRiver Riddle 
2499ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2509ad64a5cSRiver Riddle // CPPGen
2519ad64a5cSRiver Riddle //===----------------------------------------------------------------------===//
2529ad64a5cSRiver Riddle 
codegenPDLLToCPP(const ast::Module & astModule,ModuleOp module,raw_ostream & os)2539ad64a5cSRiver Riddle void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
2549ad64a5cSRiver Riddle                                   raw_ostream &os) {
2559ad64a5cSRiver Riddle   CodeGen codegen(os);
2569ad64a5cSRiver Riddle   codegen.generate(astModule, module);
2579ad64a5cSRiver Riddle }
258