xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision 6ce44266fc2d06dfcbefd8146279473ccada52ca)
19eedf6adSChris Lattner //===- Dialect.cpp - Dialect implementation -------------------------------===//
29eedf6adSChris Lattner //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69eedf6adSChris Lattner //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
89eedf6adSChris Lattner 
99eedf6adSChris Lattner #include "mlir/IR/Dialect.h"
10d7e89121SAlex Zinenko #include "mlir/IR/BuiltinDialect.h"
11ff6e7cf5SRiver Riddle #include "mlir/IR/Diagnostics.h"
12445cc3f6SRiver Riddle #include "mlir/IR/DialectImplementation.h"
1392a7b108SRiver Riddle #include "mlir/IR/DialectInterface.h"
1484cc1865SNikhil Kalra #include "mlir/IR/DialectRegistry.h"
15ba8424a2SMathieu Fehr #include "mlir/IR/ExtensibleDialect.h"
169eedf6adSChris Lattner #include "mlir/IR/MLIRContext.h"
1792a7b108SRiver Riddle #include "mlir/IR/Operation.h"
1884cc1865SNikhil Kalra #include "mlir/Support/TypeID.h"
19b72e13c2SGeoffrey Martin-Noble #include "llvm/ADT/MapVector.h"
2084cc1865SNikhil Kalra #include "llvm/ADT/SetOperations.h"
2184cc1865SNikhil Kalra #include "llvm/ADT/SmallVector.h"
2284cc1865SNikhil Kalra #include "llvm/ADT/SmallVectorExtras.h"
23b4f033f6SRiver Riddle #include "llvm/ADT/Twine.h"
2434ea608aSAlex Zinenko #include "llvm/Support/Debug.h"
259eedf6adSChris Lattner #include "llvm/Support/ManagedStatic.h"
263b3e11daSRiver Riddle #include "llvm/Support/Regex.h"
2784cc1865SNikhil Kalra #include <memory>
2892a7b108SRiver Riddle 
2934ea608aSAlex Zinenko #define DEBUG_TYPE "dialect"
3034ea608aSAlex Zinenko 
319eedf6adSChris Lattner using namespace mlir;
3292a7b108SRiver Riddle using namespace detail;
3392a7b108SRiver Riddle 
343da51522SAlex Zinenko //===----------------------------------------------------------------------===//
3592a7b108SRiver Riddle // Dialect
3692a7b108SRiver Riddle //===----------------------------------------------------------------------===//
3792a7b108SRiver Riddle 
38575b22b5SMehdi Amini Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
39575b22b5SMehdi Amini     : name(name), dialectID(id), context(context) {
40dfc58742SRiver Riddle   assert(isValidNamespace(name) && "invalid dialect namespace");
419eedf6adSChris Lattner }
42f8f723cfSFeng Liu 
43e5639b3fSMehdi Amini Dialect::~Dialect() = default;
44b4f033f6SRiver Riddle 
4554cd6a7eSRiver Riddle /// Verify an attribute from this dialect on the argument at 'argIndex' for
46136ccd49SRiver Riddle /// the region at 'regionIndex' on the given operation. Returns failure if
47136ccd49SRiver Riddle /// the verification failed, success otherwise. This hook may optionally be
48136ccd49SRiver Riddle /// invoked from any operation containing a region.
49136ccd49SRiver Riddle LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
5054cd6a7eSRiver Riddle                                                 NamedAttribute) {
5154cd6a7eSRiver Riddle   return success();
5254cd6a7eSRiver Riddle }
5354cd6a7eSRiver Riddle 
549c9a7e92SSean Silva /// Verify an attribute from this dialect on the result at 'resultIndex' for
559c9a7e92SSean Silva /// the region at 'regionIndex' on the given operation. Returns failure if
569c9a7e92SSean Silva /// the verification failed, success otherwise. This hook may optionally be
579c9a7e92SSean Silva /// invoked from any operation containing a region.
589c9a7e92SSean Silva LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
599c9a7e92SSean Silva                                                    unsigned, NamedAttribute) {
609c9a7e92SSean Silva   return success();
619c9a7e92SSean Silva }
629c9a7e92SSean Silva 
638d5bd823SRiver Riddle /// Parse an attribute registered to this dialect.
642ba4d802SRiver Riddle Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
652ba4d802SRiver Riddle   parser.emitError(parser.getNameLoc())
662ba4d802SRiver Riddle       << "dialect '" << getNamespace()
678d5bd823SRiver Riddle       << "' provides no attribute parsing hook";
688d5bd823SRiver Riddle   return Attribute();
698d5bd823SRiver Riddle }
708d5bd823SRiver Riddle 
71b4f033f6SRiver Riddle /// Parse a type registered to this dialect.
722ba4d802SRiver Riddle Type Dialect::parseType(DialectAsmParser &parser) const {
73a477fbafSChris Lattner   // If this dialect allows unknown types, then represent this with OpaqueType.
74a477fbafSChris Lattner   if (allowsUnknownTypes()) {
75195730a6SRiver Riddle     StringAttr ns = StringAttr::get(getContext(), getNamespace());
76e6260ad0SRiver Riddle     return OpaqueType::get(ns, parser.getFullSymbolSpec());
77a477fbafSChris Lattner   }
78a477fbafSChris Lattner 
792ba4d802SRiver Riddle   parser.emitError(parser.getNameLoc())
802ba4d802SRiver Riddle       << "dialect '" << getNamespace() << "' provides no type parsing hook";
81b4f033f6SRiver Riddle   return Type();
82b4f033f6SRiver Riddle }
833b3e11daSRiver Riddle 
8422426110SRamkumar Ramachandra std::optional<Dialect::ParseOpHook>
85a0c776fcSMehdi Amini Dialect::getParseOperationHook(StringRef opName) const {
861a36588eSKazu Hirata   return std::nullopt;
87a0c776fcSMehdi Amini }
88a0c776fcSMehdi Amini 
89fd87963eSMehdi Amini llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
90fd87963eSMehdi Amini Dialect::getOperationPrinter(Operation *op) const {
91a0c776fcSMehdi Amini   assert(op->getDialect() == this &&
92a0c776fcSMehdi Amini          "Dialect hook invoked on non-dialect owned operation");
93fd87963eSMehdi Amini   return nullptr;
94a0c776fcSMehdi Amini }
95a0c776fcSMehdi Amini 
963b3e11daSRiver Riddle /// Utility function that returns if the given string is a valid dialect
97be0a7e9fSMehdi Amini /// namespace
983b3e11daSRiver Riddle bool Dialect::isValidNamespace(StringRef str) {
993b3e11daSRiver Riddle   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
1003b3e11daSRiver Riddle   return dialectNameRegex.match(str);
1013b3e11daSRiver Riddle }
10292a7b108SRiver Riddle 
10392a7b108SRiver Riddle /// Register a set of dialect interfaces with this dialect instance.
10492a7b108SRiver Riddle void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
105a5ef51d7SRiver Riddle   // Handle the case where the models resolve a promised interface.
106d0e6fd99SFabian Mora   handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
107a5ef51d7SRiver Riddle 
10892a7b108SRiver Riddle   auto it = registeredInterfaces.try_emplace(interface->getID(),
10992a7b108SRiver Riddle                                              std::move(interface));
11092a7b108SRiver Riddle   (void)it;
11177eee579SRiver Riddle   LLVM_DEBUG({
11277eee579SRiver Riddle     if (!it.second) {
11377eee579SRiver Riddle       llvm::dbgs() << "[" DEBUG_TYPE
11477eee579SRiver Riddle                       "] repeated interface registration for dialect "
11577eee579SRiver Riddle                    << getNamespace();
11677eee579SRiver Riddle     }
11777eee579SRiver Riddle   });
11892a7b108SRiver Riddle }
11992a7b108SRiver Riddle 
12092a7b108SRiver Riddle //===----------------------------------------------------------------------===//
12192a7b108SRiver Riddle // Dialect Interface
12292a7b108SRiver Riddle //===----------------------------------------------------------------------===//
12392a7b108SRiver Riddle 
124e5639b3fSMehdi Amini DialectInterface::~DialectInterface() = default;
12592a7b108SRiver Riddle 
12602c2ecb9SRiver Riddle MLIRContext *DialectInterface::getContext() const {
12702c2ecb9SRiver Riddle   return dialect->getContext();
12802c2ecb9SRiver Riddle }
12902c2ecb9SRiver Riddle 
13092a7b108SRiver Riddle DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
13191ae1f5dSMehdi Amini     MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
132f9dc2b70SMehdi Amini   for (auto *dialect : ctx->getLoadedDialects()) {
13391ae1f5dSMehdi Amini #ifndef NDEBUG
134d0e6fd99SFabian Mora     dialect->handleUseOfUndefinedPromisedInterface(
135d0e6fd99SFabian Mora         dialect->getTypeID(), interfaceKind, interfaceName);
13691ae1f5dSMehdi Amini #endif
1377e1af594SRiver Riddle     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
1385e17730cSRiver Riddle       interfaces.insert(interface);
1397e1af594SRiver Riddle       orderedInterfaces.push_back(interface);
1407e1af594SRiver Riddle     }
1417e1af594SRiver Riddle   }
14292a7b108SRiver Riddle }
14392a7b108SRiver Riddle 
144e5639b3fSMehdi Amini DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
14592a7b108SRiver Riddle 
14692a7b108SRiver Riddle /// Get the interface for the dialect of given operation, or null if one
14792a7b108SRiver Riddle /// is not registered.
14892a7b108SRiver Riddle const DialectInterface *
14992a7b108SRiver Riddle DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
1505e17730cSRiver Riddle   return getInterfaceFor(op->getDialect());
15192a7b108SRiver Riddle }
15277eee579SRiver Riddle 
15377eee579SRiver Riddle //===----------------------------------------------------------------------===//
15477eee579SRiver Riddle // DialectExtension
15577eee579SRiver Riddle //===----------------------------------------------------------------------===//
15677eee579SRiver Riddle 
15777eee579SRiver Riddle DialectExtensionBase::~DialectExtensionBase() = default;
15877eee579SRiver Riddle 
159a5ef51d7SRiver Riddle void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
160d0e6fd99SFabian Mora     Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
161d0e6fd99SFabian Mora     StringRef interfaceName) {
162d0e6fd99SFabian Mora   dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
163d0e6fd99SFabian Mora                                                 interfaceID, interfaceName);
164a5ef51d7SRiver Riddle }
165a5ef51d7SRiver Riddle 
166a5ef51d7SRiver Riddle void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
167d0e6fd99SFabian Mora     Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
168d0e6fd99SFabian Mora   dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
169d0e6fd99SFabian Mora                                                      interfaceID);
170d0e6fd99SFabian Mora }
171d0e6fd99SFabian Mora 
172d0e6fd99SFabian Mora bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
173d0e6fd99SFabian Mora                                                     TypeID interfaceRequestorID,
174d0e6fd99SFabian Mora                                                     TypeID interfaceID) {
175d0e6fd99SFabian Mora   return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
176a5ef51d7SRiver Riddle }
177a5ef51d7SRiver Riddle 
17877eee579SRiver Riddle //===----------------------------------------------------------------------===//
17977eee579SRiver Riddle // DialectRegistry
18077eee579SRiver Riddle //===----------------------------------------------------------------------===//
18177eee579SRiver Riddle 
18284cc1865SNikhil Kalra namespace {
18384cc1865SNikhil Kalra template <typename Fn>
18484cc1865SNikhil Kalra void applyExtensionsFn(
18584cc1865SNikhil Kalra     Fn &&applyExtension,
18684cc1865SNikhil Kalra     const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
18784cc1865SNikhil Kalra         &extensions) {
18884cc1865SNikhil Kalra   // Note: Additional extensions may be added while applying an extension.
18984cc1865SNikhil Kalra   // The iterators will be invalidated if extensions are added so we'll keep
19084cc1865SNikhil Kalra   // a copy of the extensions for ourselves.
19184cc1865SNikhil Kalra 
19284cc1865SNikhil Kalra   const auto extractExtension =
19384cc1865SNikhil Kalra       [](const auto &entry) -> DialectExtensionBase * {
19484cc1865SNikhil Kalra     return entry.second.get();
19584cc1865SNikhil Kalra   };
19684cc1865SNikhil Kalra 
19784cc1865SNikhil Kalra   auto startIt = extensions.begin(), endIt = extensions.end();
19884cc1865SNikhil Kalra   size_t count = 0;
19984cc1865SNikhil Kalra   while (startIt != endIt) {
20084cc1865SNikhil Kalra     count += endIt - startIt;
20184cc1865SNikhil Kalra 
20284cc1865SNikhil Kalra     // Grab the subset of extensions we'll apply in this iteration.
20384cc1865SNikhil Kalra     const auto subset =
20484cc1865SNikhil Kalra         llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
20584cc1865SNikhil Kalra 
20684cc1865SNikhil Kalra     for (const auto *ext : subset)
20784cc1865SNikhil Kalra       applyExtension(*ext);
20884cc1865SNikhil Kalra 
20984cc1865SNikhil Kalra     // Book-keep for the next iteration.
21084cc1865SNikhil Kalra     startIt = extensions.begin() + count;
21184cc1865SNikhil Kalra     endIt = extensions.end();
21284cc1865SNikhil Kalra   }
21384cc1865SNikhil Kalra }
21484cc1865SNikhil Kalra } // namespace
21584cc1865SNikhil Kalra 
21677eee579SRiver Riddle DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
21777eee579SRiver Riddle 
21877eee579SRiver Riddle DialectAllocatorFunctionRef
21977eee579SRiver Riddle DialectRegistry::getDialectAllocator(StringRef name) const {
220*6ce44266SKazu Hirata   auto it = registry.find(name);
22177eee579SRiver Riddle   if (it == registry.end())
22277eee579SRiver Riddle     return nullptr;
22377eee579SRiver Riddle   return it->second.second;
22477eee579SRiver Riddle }
22577eee579SRiver Riddle 
22677eee579SRiver Riddle void DialectRegistry::insert(TypeID typeID, StringRef name,
22777eee579SRiver Riddle                              const DialectAllocatorFunction &ctor) {
22877eee579SRiver Riddle   auto inserted = registry.insert(
22977eee579SRiver Riddle       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
23077eee579SRiver Riddle   if (!inserted.second && inserted.first->second.first != typeID) {
23177eee579SRiver Riddle     llvm::report_fatal_error(
23277eee579SRiver Riddle         "Trying to register different dialects for the same namespace: " +
23377eee579SRiver Riddle         name);
23477eee579SRiver Riddle   }
23577eee579SRiver Riddle }
23677eee579SRiver Riddle 
237ba8424a2SMathieu Fehr void DialectRegistry::insertDynamic(
238ba8424a2SMathieu Fehr     StringRef name, const DynamicDialectPopulationFunction &ctor) {
239ba8424a2SMathieu Fehr   // This TypeID marks dynamic dialects. We cannot give a TypeID for the
240ba8424a2SMathieu Fehr   // dialect yet, since the TypeID of a dynamic dialect is defined at its
241ba8424a2SMathieu Fehr   // construction.
242ba8424a2SMathieu Fehr   TypeID typeID = TypeID::get<void>();
243ba8424a2SMathieu Fehr 
244ba8424a2SMathieu Fehr   // Create the dialect, and then call ctor, which allocates its components.
245ba8424a2SMathieu Fehr   auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
246ba8424a2SMathieu Fehr     auto *dynDialect = ctx->getOrLoadDynamicDialect(
247ba8424a2SMathieu Fehr         nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
248ba8424a2SMathieu Fehr     assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
249ba8424a2SMathieu Fehr     return dynDialect;
250ba8424a2SMathieu Fehr   };
251ba8424a2SMathieu Fehr 
252ba8424a2SMathieu Fehr   insert(typeID, name, constructor);
253ba8424a2SMathieu Fehr }
254ba8424a2SMathieu Fehr 
25577eee579SRiver Riddle void DialectRegistry::applyExtensions(Dialect *dialect) const {
25677eee579SRiver Riddle   MLIRContext *ctx = dialect->getContext();
25777eee579SRiver Riddle   StringRef dialectName = dialect->getNamespace();
25877eee579SRiver Riddle 
25977eee579SRiver Riddle   // Functor used to try to apply the given extension.
26077eee579SRiver Riddle   auto applyExtension = [&](const DialectExtensionBase &extension) {
26177eee579SRiver Riddle     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
2624529797aSMehdi Amini     // An empty set is equivalent to always invoke.
2634529797aSMehdi Amini     if (dialectNames.empty()) {
2644529797aSMehdi Amini       extension.apply(ctx, dialect);
2654529797aSMehdi Amini       return;
2664529797aSMehdi Amini     }
26777eee579SRiver Riddle 
26877eee579SRiver Riddle     // Handle the simple case of a single dialect name. In this case, the
26977eee579SRiver Riddle     // required dialect should be the current dialect.
27077eee579SRiver Riddle     if (dialectNames.size() == 1) {
27177eee579SRiver Riddle       if (dialectNames.front() == dialectName)
27277eee579SRiver Riddle         extension.apply(ctx, dialect);
27377eee579SRiver Riddle       return;
27477eee579SRiver Riddle     }
27577eee579SRiver Riddle 
27677eee579SRiver Riddle     // Otherwise, check to see if this extension requires this dialect.
27777eee579SRiver Riddle     const StringRef *nameIt = llvm::find(dialectNames, dialectName);
27877eee579SRiver Riddle     if (nameIt == dialectNames.end())
27977eee579SRiver Riddle       return;
28077eee579SRiver Riddle 
28177eee579SRiver Riddle     // If it does, ensure that all of the other required dialects have been
28277eee579SRiver Riddle     // loaded.
28377eee579SRiver Riddle     SmallVector<Dialect *> requiredDialects;
28477eee579SRiver Riddle     requiredDialects.reserve(dialectNames.size());
28577eee579SRiver Riddle     for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
28677eee579SRiver Riddle          ++it) {
28777eee579SRiver Riddle       // The current dialect is known to be loaded.
28877eee579SRiver Riddle       if (it == nameIt) {
28977eee579SRiver Riddle         requiredDialects.push_back(dialect);
29077eee579SRiver Riddle         continue;
29177eee579SRiver Riddle       }
29277eee579SRiver Riddle       // Otherwise, check if it is loaded.
29377eee579SRiver Riddle       Dialect *loadedDialect = ctx->getLoadedDialect(*it);
29477eee579SRiver Riddle       if (!loadedDialect)
29577eee579SRiver Riddle         return;
29677eee579SRiver Riddle       requiredDialects.push_back(loadedDialect);
29777eee579SRiver Riddle     }
29877eee579SRiver Riddle     extension.apply(ctx, requiredDialects);
29977eee579SRiver Riddle   };
30077eee579SRiver Riddle 
30184cc1865SNikhil Kalra   applyExtensionsFn(applyExtension, extensions);
30277eee579SRiver Riddle }
30377eee579SRiver Riddle 
30477eee579SRiver Riddle void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
30577eee579SRiver Riddle   // Functor used to try to apply the given extension.
30677eee579SRiver Riddle   auto applyExtension = [&](const DialectExtensionBase &extension) {
30777eee579SRiver Riddle     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
3084529797aSMehdi Amini     if (dialectNames.empty()) {
3094529797aSMehdi Amini       auto loadedDialects = ctx->getLoadedDialects();
3104529797aSMehdi Amini       extension.apply(ctx, loadedDialects);
3114529797aSMehdi Amini       return;
3124529797aSMehdi Amini     }
31377eee579SRiver Riddle 
31477eee579SRiver Riddle     // Check to see if all of the dialects for this extension are loaded.
31577eee579SRiver Riddle     SmallVector<Dialect *> requiredDialects;
31677eee579SRiver Riddle     requiredDialects.reserve(dialectNames.size());
31777eee579SRiver Riddle     for (StringRef dialectName : dialectNames) {
31877eee579SRiver Riddle       Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
31977eee579SRiver Riddle       if (!loadedDialect)
32077eee579SRiver Riddle         return;
32177eee579SRiver Riddle       requiredDialects.push_back(loadedDialect);
32277eee579SRiver Riddle     }
32377eee579SRiver Riddle     extension.apply(ctx, requiredDialects);
32477eee579SRiver Riddle   };
32577eee579SRiver Riddle 
32684cc1865SNikhil Kalra   applyExtensionsFn(applyExtension, extensions);
32777eee579SRiver Riddle }
3280f304ef0SRiver Riddle 
3290f304ef0SRiver Riddle bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
33084cc1865SNikhil Kalra   // Check that all extension keys are present in 'rhs'.
33184cc1865SNikhil Kalra   const auto hasExtension = [&](const auto &key) {
33284cc1865SNikhil Kalra     return rhs.extensions.contains(key);
33384cc1865SNikhil Kalra   };
33484cc1865SNikhil Kalra   if (!llvm::all_of(make_first_range(extensions), hasExtension))
3350f304ef0SRiver Riddle     return false;
33684cc1865SNikhil Kalra 
3370f304ef0SRiver Riddle   // Check that the current dialects fully overlap with the dialects in 'rhs'.
3380f304ef0SRiver Riddle   return llvm::all_of(
3390f304ef0SRiver Riddle       registry, [&](const auto &it) { return rhs.registry.count(it.first); });
3400f304ef0SRiver Riddle }
341