//===- DialectGen.cpp - MLIR dialect definitions generator ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // DialectGen uses the description of dialects to generate C++ definitions. // //===----------------------------------------------------------------------===// #include "DialectGenUtilities.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #define DEBUG_TYPE "mlir-tblgen-opdefgen" using namespace mlir; using namespace mlir::tblgen; using llvm::Record; using llvm::RecordKeeper; static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); llvm::cl::opt selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); /// Utility iterator used for filtering records for a specific dialect. namespace { using DialectFilterIterator = llvm::filter_iterator::iterator, std::function>; } // namespace static void populateDiscardableAttributes( Dialect &dialect, const llvm::DagInit *discardableAttrDag, SmallVector> &discardableAttributes) { for (int i : llvm::seq(0, discardableAttrDag->getNumArgs())) { const llvm::Init *arg = discardableAttrDag->getArg(i); StringRef givenName = discardableAttrDag->getArgNameStr(i); if (givenName.empty()) PrintFatalError(dialect.getDef()->getLoc(), "discardable attributes must be named"); discardableAttributes.push_back( {givenName.str(), arg->getAsUnquotedString()}); } } /// Given a set of records for a T, filter the ones that correspond to /// the given dialect. template static iterator_range filterForDialect(ArrayRef records, Dialect &dialect) { auto filterFn = [&](const Record *record) { return T(record).getDialect() == dialect; }; return {DialectFilterIterator(records.begin(), records.end(), filterFn), DialectFilterIterator(records.end(), records.end(), filterFn)}; } std::optional tblgen::findDialectToGenerate(ArrayRef dialects) { if (dialects.empty()) { llvm::errs() << "no dialect was found\n"; return std::nullopt; } // Select the dialect to gen for. if (dialects.size() == 1 && selectedDialect.getNumOccurrences() == 0) return dialects.front(); if (selectedDialect.getNumOccurrences() == 0) { llvm::errs() << "when more than 1 dialect is present, one must be selected " "via '-dialect'\n"; return std::nullopt; } const auto *dialectIt = llvm::find_if(dialects, [](const Dialect &dialect) { return dialect.getName() == selectedDialect; }); if (dialectIt == dialects.end()) { llvm::errs() << "selected dialect with '-dialect' does not exist\n"; return std::nullopt; } return *dialectIt; } //===----------------------------------------------------------------------===// // GEN: Dialect declarations //===----------------------------------------------------------------------===// /// The code block for the start of a dialect class declaration. /// /// {0}: The name of the dialect class. /// {1}: The dialect namespace. /// {2}: The dialect parent class. static const char *const dialectDeclBeginStr = R"( class {0} : public ::mlir::{2} { explicit {0}(::mlir::MLIRContext *context); void initialize(); friend class ::mlir::MLIRContext; public: ~{0}() override; static constexpr ::llvm::StringLiteral getDialectNamespace() { return ::llvm::StringLiteral("{1}"); } )"; /// Registration for a single dependent dialect: to be inserted in the ctor /// above for each dependent dialect. const char *const dialectRegistrationTemplate = "getContext()->loadDialect<{0}>();"; /// The code block for the attribute parser/printer hooks. static const char *const attrParserDecl = R"( /// Parse an attribute registered to this dialect. ::mlir::Attribute parseAttribute(::mlir::DialectAsmParser &parser, ::mlir::Type type) const override; /// Print an attribute registered to this dialect. void printAttribute(::mlir::Attribute attr, ::mlir::DialectAsmPrinter &os) const override; )"; /// The code block for the type parser/printer hooks. static const char *const typeParserDecl = R"( /// Parse a type registered to this dialect. ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; /// Print a type registered to this dialect. void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &os) const override; )"; /// The code block for the canonicalization pattern registration hook. static const char *const canonicalizerDecl = R"( /// Register canonicalization patterns. void getCanonicalizationPatterns( ::mlir::RewritePatternSet &results) const override; )"; /// The code block for the constant materializer hook. static const char *const constantMaterializerDecl = R"( /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. ::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder, ::mlir::Attribute value, ::mlir::Type type, ::mlir::Location loc) override; )"; /// The code block for the operation attribute verifier hook. static const char *const opAttrVerifierDecl = R"( /// Provides a hook for verifying dialect attributes attached to the given /// op. ::llvm::LogicalResult verifyOperationAttribute( ::mlir::Operation *op, ::mlir::NamedAttribute attribute) override; )"; /// The code block for the region argument attribute verifier hook. static const char *const regionArgAttrVerifierDecl = R"( /// Provides a hook for verifying dialect attributes attached to the given /// op's region argument. ::llvm::LogicalResult verifyRegionArgAttribute( ::mlir::Operation *op, unsigned regionIndex, unsigned argIndex, ::mlir::NamedAttribute attribute) override; )"; /// The code block for the region result attribute verifier hook. static const char *const regionResultAttrVerifierDecl = R"( /// Provides a hook for verifying dialect attributes attached to the given /// op's region result. ::llvm::LogicalResult verifyRegionResultAttribute( ::mlir::Operation *op, unsigned regionIndex, unsigned resultIndex, ::mlir::NamedAttribute attribute) override; )"; /// The code block for the op interface fallback hook. static const char *const operationInterfaceFallbackDecl = R"( /// Provides a hook for op interface. void *getRegisteredInterfaceForOp(mlir::TypeID interfaceID, mlir::OperationName opName) override; )"; /// The code block for the discardable attribute helper. static const char *const discardableAttrHelperDecl = R"( /// Helper to manage the discardable attribute `{1}`. class {0}AttrHelper {{ ::mlir::StringAttr name; public: static constexpr ::llvm::StringLiteral getNameStr() {{ return "{4}.{1}"; } constexpr ::mlir::StringAttr getName() {{ return name; } {0}AttrHelper(::mlir::MLIRContext *ctx) : name(::mlir::StringAttr::get(ctx, getNameStr())) {{} {2} getAttr(::mlir::Operation *op) {{ return op->getAttrOfType<{2}>(name); } void setAttr(::mlir::Operation *op, {2} val) {{ op->setAttr(name, val); } bool isAttrPresent(::mlir::Operation *op) {{ return op->hasAttrOfType<{2}>(name); } void removeAttr(::mlir::Operation *op) {{ assert(op->hasAttrOfType<{2}>(name)); op->removeAttr(name); } }; {0}AttrHelper get{0}AttrHelper() { return {3}AttrName; } private: {0}AttrHelper {3}AttrName; public: )"; /// Generate the declaration for the given dialect class. static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { // Emit all nested namespaces. { NamespaceEmitter nsEmitter(os, dialect); // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); StringRef superClassName = dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), superClassName); // If the dialect requested the default attribute printer and parser, emit // the declarations for the hooks. if (dialect.useDefaultAttributePrinterParser()) os << attrParserDecl; // If the dialect requested the default type printer and parser, emit the // delcarations for the hooks. if (dialect.useDefaultTypePrinterParser()) os << typeParserDecl; // Add the decls for the various features of the dialect. if (dialect.hasCanonicalizer()) os << canonicalizerDecl; if (dialect.hasConstantMaterializer()) os << constantMaterializerDecl; if (dialect.hasOperationAttrVerify()) os << opAttrVerifierDecl; if (dialect.hasRegionArgAttrVerify()) os << regionArgAttrVerifierDecl; if (dialect.hasRegionResultAttrVerify()) os << regionResultAttrVerifierDecl; if (dialect.hasOperationInterfaceFallback()) os << operationInterfaceFallbackDecl; const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); SmallVector> discardableAttributes; populateDiscardableAttributes(dialect, discardableAttrDag, discardableAttributes); for (const auto &attrPair : discardableAttributes) { std::string camelNameUpper = llvm::convertToCamelFromSnakeCase( attrPair.first, /*capitalizeFirst=*/true); std::string camelName = llvm::convertToCamelFromSnakeCase( attrPair.first, /*capitalizeFirst=*/false); os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper, attrPair.first, attrPair.second, camelName, dialect.getName()); } if (std::optional extraDecl = dialect.getExtraClassDeclaration()) os << *extraDecl; // End the dialect decl. os << "};\n"; } if (!dialect.getCppNamespace().empty()) os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() << "::" << dialect.getCppClassName() << ")\n"; } static bool emitDialectDecls(const RecordKeeper &records, raw_ostream &os) { emitSourceFileHeader("Dialect Declarations", os, records); auto dialectDefs = records.getAllDerivedDefinitions("Dialect"); if (dialectDefs.empty()) return false; SmallVector dialects(dialectDefs.begin(), dialectDefs.end()); std::optional dialect = findDialectToGenerate(dialects); if (!dialect) return true; emitDialectDecl(*dialect, os); return false; } //===----------------------------------------------------------------------===// // GEN: Dialect definitions //===----------------------------------------------------------------------===// /// The code block to generate a dialect constructor definition. /// /// {0}: The name of the dialect class. /// {1}: Initialization code that is emitted in the ctor body before calling /// initialize(), such as dependent dialect registration. /// {2}: The dialect parent class. /// {3}: Extra members to initialize static const char *const dialectConstructorStr = R"( {0}::{0}(::mlir::MLIRContext *context) : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {3} {{ {1} initialize(); } )"; /// The code block to generate a default destructor definition. /// /// {0}: The name of the dialect class. static const char *const dialectDestructorStr = R"( {0}::~{0}() = default; )"; static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, raw_ostream &os) { std::string cppClassName = dialect.getCppClassName(); // Emit the TypeID explicit specializations to have a single symbol def. if (!dialect.getCppNamespace().empty()) os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() << "::" << cppClassName << ")\n"; // Emit all nested namespaces. NamespaceEmitter nsEmitter(os, dialect); /// Build the list of dependent dialects. std::string dependentDialectRegistrations; { llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); llvm::interleave( dialect.getDependentDialects(), dialectsOs, [&](StringRef dependentDialect) { dialectsOs << llvm::formatv(dialectRegistrationTemplate, dependentDialect); }, "\n "); } // Emit the constructor and destructor. StringRef superClassName = dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes(); SmallVector> discardableAttributes; populateDiscardableAttributes(dialect, discardableAttrDag, discardableAttributes); std::string discardableAttributesInit; for (const auto &attrPair : discardableAttributes) { std::string camelName = llvm::convertToCamelFromSnakeCase( attrPair.first, /*capitalizeFirst=*/false); llvm::raw_string_ostream os(discardableAttributesInit); os << ", " << camelName << "AttrName(context)"; } os << llvm::formatv(dialectConstructorStr, cppClassName, dependentDialectRegistrations, superClassName, discardableAttributesInit); if (!dialect.hasNonDefaultDestructor()) os << llvm::formatv(dialectDestructorStr, cppClassName); } static bool emitDialectDefs(const RecordKeeper &records, raw_ostream &os) { emitSourceFileHeader("Dialect Definitions", os, records); auto dialectDefs = records.getAllDerivedDefinitions("Dialect"); if (dialectDefs.empty()) return false; SmallVector dialects(dialectDefs.begin(), dialectDefs.end()); std::optional dialect = findDialectToGenerate(dialects); if (!dialect) return true; emitDialectDef(*dialect, records, os); return false; } //===----------------------------------------------------------------------===// // GEN: Dialect registration hooks //===----------------------------------------------------------------------===// static mlir::GenRegistration genDialectDecls("gen-dialect-decls", "Generate dialect declarations", [](const RecordKeeper &records, raw_ostream &os) { return emitDialectDecls(records, os); }); static mlir::GenRegistration genDialectDefs("gen-dialect-defs", "Generate dialect definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitDialectDefs(records, os); });