xref: /llvm-project/mlir/lib/IR/Dialect.cpp (revision 6ce44266fc2d06dfcbefd8146279473ccada52ca)
1 //===- Dialect.cpp - Dialect implementation -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Dialect.h"
10 #include "mlir/IR/BuiltinDialect.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/DialectInterface.h"
14 #include "mlir/IR/DialectRegistry.h"
15 #include "mlir/IR/ExtensibleDialect.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/TypeID.h"
19 #include "llvm/ADT/MapVector.h"
20 #include "llvm/ADT/SetOperations.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/SmallVectorExtras.h"
23 #include "llvm/ADT/Twine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ManagedStatic.h"
26 #include "llvm/Support/Regex.h"
27 #include <memory>
28 
29 #define DEBUG_TYPE "dialect"
30 
31 using namespace mlir;
32 using namespace detail;
33 
34 //===----------------------------------------------------------------------===//
35 // Dialect
36 //===----------------------------------------------------------------------===//
37 
38 Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
39     : name(name), dialectID(id), context(context) {
40   assert(isValidNamespace(name) && "invalid dialect namespace");
41 }
42 
43 Dialect::~Dialect() = default;
44 
45 /// Verify an attribute from this dialect on the argument at 'argIndex' for
46 /// the region at 'regionIndex' on the given operation. Returns failure if
47 /// the verification failed, success otherwise. This hook may optionally be
48 /// invoked from any operation containing a region.
49 LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
50                                                 NamedAttribute) {
51   return success();
52 }
53 
54 /// Verify an attribute from this dialect on the result at 'resultIndex' for
55 /// the region at 'regionIndex' on the given operation. Returns failure if
56 /// the verification failed, success otherwise. This hook may optionally be
57 /// invoked from any operation containing a region.
58 LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
59                                                    unsigned, NamedAttribute) {
60   return success();
61 }
62 
63 /// Parse an attribute registered to this dialect.
64 Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
65   parser.emitError(parser.getNameLoc())
66       << "dialect '" << getNamespace()
67       << "' provides no attribute parsing hook";
68   return Attribute();
69 }
70 
71 /// Parse a type registered to this dialect.
72 Type Dialect::parseType(DialectAsmParser &parser) const {
73   // If this dialect allows unknown types, then represent this with OpaqueType.
74   if (allowsUnknownTypes()) {
75     StringAttr ns = StringAttr::get(getContext(), getNamespace());
76     return OpaqueType::get(ns, parser.getFullSymbolSpec());
77   }
78 
79   parser.emitError(parser.getNameLoc())
80       << "dialect '" << getNamespace() << "' provides no type parsing hook";
81   return Type();
82 }
83 
84 std::optional<Dialect::ParseOpHook>
85 Dialect::getParseOperationHook(StringRef opName) const {
86   return std::nullopt;
87 }
88 
89 llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
90 Dialect::getOperationPrinter(Operation *op) const {
91   assert(op->getDialect() == this &&
92          "Dialect hook invoked on non-dialect owned operation");
93   return nullptr;
94 }
95 
96 /// Utility function that returns if the given string is a valid dialect
97 /// namespace
98 bool Dialect::isValidNamespace(StringRef str) {
99   llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
100   return dialectNameRegex.match(str);
101 }
102 
103 /// Register a set of dialect interfaces with this dialect instance.
104 void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
105   // Handle the case where the models resolve a promised interface.
106   handleAdditionOfUndefinedPromisedInterface(getTypeID(), interface->getID());
107 
108   auto it = registeredInterfaces.try_emplace(interface->getID(),
109                                              std::move(interface));
110   (void)it;
111   LLVM_DEBUG({
112     if (!it.second) {
113       llvm::dbgs() << "[" DEBUG_TYPE
114                       "] repeated interface registration for dialect "
115                    << getNamespace();
116     }
117   });
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Dialect Interface
122 //===----------------------------------------------------------------------===//
123 
124 DialectInterface::~DialectInterface() = default;
125 
126 MLIRContext *DialectInterface::getContext() const {
127   return dialect->getContext();
128 }
129 
130 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
131     MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
132   for (auto *dialect : ctx->getLoadedDialects()) {
133 #ifndef NDEBUG
134     dialect->handleUseOfUndefinedPromisedInterface(
135         dialect->getTypeID(), interfaceKind, interfaceName);
136 #endif
137     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
138       interfaces.insert(interface);
139       orderedInterfaces.push_back(interface);
140     }
141   }
142 }
143 
144 DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
145 
146 /// Get the interface for the dialect of given operation, or null if one
147 /// is not registered.
148 const DialectInterface *
149 DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
150   return getInterfaceFor(op->getDialect());
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // DialectExtension
155 //===----------------------------------------------------------------------===//
156 
157 DialectExtensionBase::~DialectExtensionBase() = default;
158 
159 void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
160     Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
161     StringRef interfaceName) {
162   dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
163                                                 interfaceID, interfaceName);
164 }
165 
166 void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
167     Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
168   dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
169                                                      interfaceID);
170 }
171 
172 bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
173                                                     TypeID interfaceRequestorID,
174                                                     TypeID interfaceID) {
175   return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // DialectRegistry
180 //===----------------------------------------------------------------------===//
181 
182 namespace {
183 template <typename Fn>
184 void applyExtensionsFn(
185     Fn &&applyExtension,
186     const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
187         &extensions) {
188   // Note: Additional extensions may be added while applying an extension.
189   // The iterators will be invalidated if extensions are added so we'll keep
190   // a copy of the extensions for ourselves.
191 
192   const auto extractExtension =
193       [](const auto &entry) -> DialectExtensionBase * {
194     return entry.second.get();
195   };
196 
197   auto startIt = extensions.begin(), endIt = extensions.end();
198   size_t count = 0;
199   while (startIt != endIt) {
200     count += endIt - startIt;
201 
202     // Grab the subset of extensions we'll apply in this iteration.
203     const auto subset =
204         llvm::map_to_vector(llvm::make_range(startIt, endIt), extractExtension);
205 
206     for (const auto *ext : subset)
207       applyExtension(*ext);
208 
209     // Book-keep for the next iteration.
210     startIt = extensions.begin() + count;
211     endIt = extensions.end();
212   }
213 }
214 } // namespace
215 
216 DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
217 
218 DialectAllocatorFunctionRef
219 DialectRegistry::getDialectAllocator(StringRef name) const {
220   auto it = registry.find(name);
221   if (it == registry.end())
222     return nullptr;
223   return it->second.second;
224 }
225 
226 void DialectRegistry::insert(TypeID typeID, StringRef name,
227                              const DialectAllocatorFunction &ctor) {
228   auto inserted = registry.insert(
229       std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
230   if (!inserted.second && inserted.first->second.first != typeID) {
231     llvm::report_fatal_error(
232         "Trying to register different dialects for the same namespace: " +
233         name);
234   }
235 }
236 
237 void DialectRegistry::insertDynamic(
238     StringRef name, const DynamicDialectPopulationFunction &ctor) {
239   // This TypeID marks dynamic dialects. We cannot give a TypeID for the
240   // dialect yet, since the TypeID of a dynamic dialect is defined at its
241   // construction.
242   TypeID typeID = TypeID::get<void>();
243 
244   // Create the dialect, and then call ctor, which allocates its components.
245   auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
246     auto *dynDialect = ctx->getOrLoadDynamicDialect(
247         nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
248     assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
249     return dynDialect;
250   };
251 
252   insert(typeID, name, constructor);
253 }
254 
255 void DialectRegistry::applyExtensions(Dialect *dialect) const {
256   MLIRContext *ctx = dialect->getContext();
257   StringRef dialectName = dialect->getNamespace();
258 
259   // Functor used to try to apply the given extension.
260   auto applyExtension = [&](const DialectExtensionBase &extension) {
261     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
262     // An empty set is equivalent to always invoke.
263     if (dialectNames.empty()) {
264       extension.apply(ctx, dialect);
265       return;
266     }
267 
268     // Handle the simple case of a single dialect name. In this case, the
269     // required dialect should be the current dialect.
270     if (dialectNames.size() == 1) {
271       if (dialectNames.front() == dialectName)
272         extension.apply(ctx, dialect);
273       return;
274     }
275 
276     // Otherwise, check to see if this extension requires this dialect.
277     const StringRef *nameIt = llvm::find(dialectNames, dialectName);
278     if (nameIt == dialectNames.end())
279       return;
280 
281     // If it does, ensure that all of the other required dialects have been
282     // loaded.
283     SmallVector<Dialect *> requiredDialects;
284     requiredDialects.reserve(dialectNames.size());
285     for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
286          ++it) {
287       // The current dialect is known to be loaded.
288       if (it == nameIt) {
289         requiredDialects.push_back(dialect);
290         continue;
291       }
292       // Otherwise, check if it is loaded.
293       Dialect *loadedDialect = ctx->getLoadedDialect(*it);
294       if (!loadedDialect)
295         return;
296       requiredDialects.push_back(loadedDialect);
297     }
298     extension.apply(ctx, requiredDialects);
299   };
300 
301   applyExtensionsFn(applyExtension, extensions);
302 }
303 
304 void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
305   // Functor used to try to apply the given extension.
306   auto applyExtension = [&](const DialectExtensionBase &extension) {
307     ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
308     if (dialectNames.empty()) {
309       auto loadedDialects = ctx->getLoadedDialects();
310       extension.apply(ctx, loadedDialects);
311       return;
312     }
313 
314     // Check to see if all of the dialects for this extension are loaded.
315     SmallVector<Dialect *> requiredDialects;
316     requiredDialects.reserve(dialectNames.size());
317     for (StringRef dialectName : dialectNames) {
318       Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
319       if (!loadedDialect)
320         return;
321       requiredDialects.push_back(loadedDialect);
322     }
323     extension.apply(ctx, requiredDialects);
324   };
325 
326   applyExtensionsFn(applyExtension, extensions);
327 }
328 
329 bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
330   // Check that all extension keys are present in 'rhs'.
331   const auto hasExtension = [&](const auto &key) {
332     return rhs.extensions.contains(key);
333   };
334   if (!llvm::all_of(make_first_range(extensions), hasExtension))
335     return false;
336 
337   // Check that the current dialects fully overlap with the dialects in 'rhs'.
338   return llvm::all_of(
339       registry, [&](const auto &it) { return rhs.registry.count(it.first); });
340 }
341