xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision 012b148dd9402c6ad58805701863d31fb506caff)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/ExtensibleDialect.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/Operation.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/Regex.h"
22 
23 #define DEBUG_TYPE "dialect"
24 
25 using namespace mlir;
26 using namespace detail;
27 
28 //===----------------------------------------------------------------------===//
29 // Dialect
30 //===----------------------------------------------------------------------===//
31 
32 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
33     : name(name), dialectID(id), context(context) {
34   assert(isValidNamespace(name) && "invalid dialect namespace");
35 }
36 
37 Dialect::~Dialect() = default;
38 
39 /// Verify an attribute from this dialect on the argument at 'argIndex' for
40 /// the region at 'regionIndex' on the given operation. Returns failure if
41 /// the verification failed, success otherwise. This hook may optionally be
42 /// invoked from any operation containing a region.
43 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
44                                                 NamedAttribute) {
45   return success();
46 }
47 
48 /// Verify an attribute from this dialect on the result at 'resultIndex' for
49 /// the region at 'regionIndex' on the given operation. Returns failure if
50 /// the verification failed, success otherwise. This hook may optionally be
51 /// invoked from any operation containing a region.
52 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
53                                                    unsigned, NamedAttribute) {
54   return success();
55 }
56 
57 /// Parse an attribute registered to this dialect.
58 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
59   parser.emitError(parser.getNameLoc())
60       << "dialect '" << getNamespace()
61       << "' provides no attribute parsing hook";
62   return Attribute();
63 }
64 
65 /// Parse a type registered to this dialect.
66 Type Dialect::parseType(DialectAsmParser &parser) const {
67   // If this dialect allows unknown types, then represent this with OpaqueType.
68   if (allowsUnknownTypes()) {
69     StringAttr ns = StringAttr::get(getContext(), getNamespace());
70     return OpaqueType::get(ns, parser.getFullSymbolSpec());
71   }
72 
73   parser.emitError(parser.getNameLoc())
74       << "dialect '" << getNamespace() << "' provides no type parsing hook";
75   return Type();
76 }
77 
78 std::optional<Dialect::ParseOpHook>
79 Dialect::getParseOperationHook(StringRef opName) const {
80   return std::nullopt;
81 }
82 
83 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
84 Dialect::getOperationPrinter(Operation *op) const {
85   assert(op->getDialect() == this &&
86          "Dialect hook invoked on non-dialect owned operation");
87   return nullptr;
88 }
89 
90 /// Utility function that returns if the given string is a valid dialect
91 /// namespace
92 bool Dialect::isValidNamespace(StringRef str) {
93   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
94   return dialectNameRegex.match(str);
95 }
96 
97 /// Register a set of dialect interfaces with this dialect instance.
98 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
99   // Handle the case where the models resolve a promised interface.
100   handleAdditionOfUndefinedPromisedInterface(interface->getID());
101 
102   auto it = registeredInterfaces.try_emplace(interface->getID(),
103                                              std::move(interface));
104   (void)it;
105   LLVM_DEBUG({
106     if (!it.second) {
107       llvm::dbgs() << "[" DEBUG_TYPE
108                       "] repeated interface registration for dialect "
109                    << getNamespace();
110     }
111   });
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Dialect Interface
116 //===----------------------------------------------------------------------===//
117 
118 DialectInterface::~DialectInterface() = default;
119 
120 MLIRContext *DialectInterface::getContext() const {
121   return dialect->getContext();
122 }
123 
124 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
125     MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
126   for (auto *dialect : ctx->getLoadedDialects()) {
127 #ifndef NDEBUG
128     dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
129                                                    interfaceName);
130 #endif
131     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
132       interfaces.insert(interface);
133       orderedInterfaces.push_back(interface);
134     }
135   }
136 }
137 
138 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
139 
140 /// Get the interface for the dialect of given operation, or null if one
141 /// is not registered.
142 const DialectInterface *
143 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
144   return getInterfaceFor(op->getDialect());
145 }
146 
147 //===----------------------------------------------------------------------===//
148 // DialectExtension
149 //===----------------------------------------------------------------------===//
150 
151 DialectExtensionBase::~DialectExtensionBase() = default;
152 
153 void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
154     Dialect &dialect, TypeID interfaceID, StringRef interfaceName) {
155   dialect.handleUseOfUndefinedPromisedInterface(interfaceID, interfaceName);
156 }
157 
158 void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
159     Dialect &dialect, TypeID interfaceID) {
160   dialect.handleAdditionOfUndefinedPromisedInterface(interfaceID);
161 }
162 
163 //===----------------------------------------------------------------------===//
164 // DialectRegistry
165 //===----------------------------------------------------------------------===//
166 
167 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
168 
169 DialectAllocatorFunctionRef
170 DialectRegistry::getDialectAllocator(StringRef name) const {
171   auto it = registry.find(name.str());
172   if (it == registry.end())
173     return nullptr;
174   return it->second.second;
175 }
176 
177 void DialectRegistry::insert(TypeID typeID, StringRef name,
178                              const DialectAllocatorFunction &ctor) {
179   auto inserted = registry.insert(
180       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
181   if (!inserted.second && inserted.first->second.first != typeID) {
182     llvm::report_fatal_error(
183         "Trying to register different dialects for the same namespace: " +
184         name);
185   }
186 }
187 
188 void DialectRegistry::insertDynamic(
189     StringRef name, const DynamicDialectPopulationFunction &ctor) {
190   // This TypeID marks dynamic dialects. We cannot give a TypeID for the
191   // dialect yet, since the TypeID of a dynamic dialect is defined at its
192   // construction.
193   TypeID typeID = TypeID::get<void>();
194 
195   // Create the dialect, and then call ctor, which allocates its components.
196   auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
197     auto *dynDialect = ctx->getOrLoadDynamicDialect(
198         nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
199     assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
200     return dynDialect;
201   };
202 
203   insert(typeID, name, constructor);
204 }
205 
206 void DialectRegistry::applyExtensions(Dialect *dialect) const {
207   MLIRContext *ctx = dialect->getContext();
208   StringRef dialectName = dialect->getNamespace();
209 
210   // Functor used to try to apply the given extension.
211   auto applyExtension = [&](const DialectExtensionBase &extension) {
212     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
213     // An empty set is equivalent to always invoke.
214     if (dialectNames.empty()) {
215       extension.apply(ctx, dialect);
216       return;
217     }
218 
219     // Handle the simple case of a single dialect name. In this case, the
220     // required dialect should be the current dialect.
221     if (dialectNames.size() == 1) {
222       if (dialectNames.front() == dialectName)
223         extension.apply(ctx, dialect);
224       return;
225     }
226 
227     // Otherwise, check to see if this extension requires this dialect.
228     const StringRef *nameIt = llvm::find(dialectNames, dialectName);
229     if (nameIt == dialectNames.end())
230       return;
231 
232     // If it does, ensure that all of the other required dialects have been
233     // loaded.
234     SmallVector<Dialect *> requiredDialects;
235     requiredDialects.reserve(dialectNames.size());
236     for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
237          ++it) {
238       // The current dialect is known to be loaded.
239       if (it == nameIt) {
240         requiredDialects.push_back(dialect);
241         continue;
242       }
243       // Otherwise, check if it is loaded.
244       Dialect *loadedDialect = ctx->getLoadedDialect(*it);
245       if (!loadedDialect)
246         return;
247       requiredDialects.push_back(loadedDialect);
248     }
249     extension.apply(ctx, requiredDialects);
250   };
251 
252   // Note: Additional extensions may be added while applying an extension.
253   for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
254     applyExtension(*extensions[i]);
255 }
256 
257 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
258   // Functor used to try to apply the given extension.
259   auto applyExtension = [&](const DialectExtensionBase &extension) {
260     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
261     if (dialectNames.empty()) {
262       auto loadedDialects = ctx->getLoadedDialects();
263       extension.apply(ctx, loadedDialects);
264       return;
265     }
266 
267     // Check to see if all of the dialects for this extension are loaded.
268     SmallVector<Dialect *> requiredDialects;
269     requiredDialects.reserve(dialectNames.size());
270     for (StringRef dialectName : dialectNames) {
271       Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
272       if (!loadedDialect)
273         return;
274       requiredDialects.push_back(loadedDialect);
275     }
276     extension.apply(ctx, requiredDialects);
277   };
278 
279   // Note: Additional extensions may be added while applying an extension.
280   for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
281     applyExtension(*extensions[i]);
282 }
283 
284 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
285   // Treat any extensions conservatively.
286   if (!extensions.empty())
287     return false;
288   // Check that the current dialects fully overlap with the dialects in 'rhs'.
289   return llvm::all_of(
290       registry, [&](const auto &it) { return rhs.registry.count(it.first); });
291 }
292