xref: /llvm-project/mlir/include/mlir/IR/DialectRegistry.h (revision 6ce44266fc2d06dfcbefd8146279473ccada52ca)
1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines functionality for registring and extending dialects.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_DIALECTREGISTRY_H
14 #define MLIR_IR_DIALECTREGISTRY_H
15 
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/MapVector.h"
20 
21 #include <map>
22 #include <tuple>
23 
24 namespace mlir {
25 class Dialect;
26 
27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
28 using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
29 using DynamicDialectPopulationFunction =
30     std::function<void(MLIRContext *, DynamicDialect *)>;
31 
32 //===----------------------------------------------------------------------===//
33 // DialectExtension
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents an opaque dialect extension. It contains a set of
37 /// required dialects and an application function. The required dialects control
38 /// when the extension is applied, i.e. the extension is applied when all
39 /// required dialects are loaded. The application function can be used to attach
40 /// additional functionality to attributes, dialects, operations, types, etc.,
41 /// and may also load additional necessary dialects.
42 class DialectExtensionBase {
43 public:
44   virtual ~DialectExtensionBase();
45 
46   /// Return the dialects that our required by this extension to be loaded
47   /// before applying. If empty then the extension is invoked for every loaded
48   /// dialect indepently.
49   ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50 
51   /// Apply this extension to the given context and the required dialects.
52   virtual void apply(MLIRContext *context,
53                      MutableArrayRef<Dialect *> dialects) const = 0;
54 
55   /// Return a copy of this extension.
56   virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57 
58 protected:
59   /// Initialize the extension with a set of required dialects.
60   /// If the list is empty, the extension is invoked for every loaded dialect
61   /// independently.
62   DialectExtensionBase(ArrayRef<StringRef> dialectNames)
63       : dialectNames(dialectNames) {}
64 
65 private:
66   /// The names of the dialects affected by this extension.
67   SmallVector<StringRef> dialectNames;
68 };
69 
70 /// This class represents a dialect extension anchored on the given set of
71 /// dialects. When all of the specified dialects have been loaded, the
72 /// application function of this extension will be executed.
73 template <typename DerivedT, typename... DialectsT>
74 class DialectExtension : public DialectExtensionBase {
75 public:
76   /// Applies this extension to the given context and set of required dialects.
77   virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78 
79   /// Return a copy of this extension.
80   std::unique_ptr<DialectExtensionBase> clone() const final {
81     return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82   }
83 
84 protected:
85   DialectExtension()
86       : DialectExtensionBase(
87             ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88 
89   /// Override the base apply method to allow providing the exact dialect types.
90   void apply(MLIRContext *context,
91              MutableArrayRef<Dialect *> dialects) const final {
92     unsigned dialectIdx = 0;
93     auto derivedDialects = std::tuple<DialectsT *...>{
94         static_cast<DialectsT *>(dialects[dialectIdx++])...};
95     std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96                derivedDialects);
97   }
98 };
99 
100 namespace dialect_extension_detail {
101 
102 /// Checks if the given interface, which is attempting to be used, is a
103 /// promised interface of this dialect that has yet to be implemented. If so,
104 /// emits a fatal error.
105 void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
106                                            TypeID interfaceRequestorID,
107                                            TypeID interfaceID,
108                                            StringRef interfaceName);
109 
110 /// Checks if the given interface, which is attempting to be attached, is a
111 /// promised interface of this dialect that has yet to be implemented. If so,
112 /// the promised interface is marked as resolved.
113 void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
114                                                 TypeID interfaceRequestorID,
115                                                 TypeID interfaceID);
116 
117 /// Checks if a promise has been made for the interface/requestor pair.
118 bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119                           TypeID interfaceID);
120 
121 /// Checks if a promise has been made for the interface/requestor pair.
122 template <typename ConcreteT, typename InterfaceT>
123 bool hasPromisedInterface(Dialect &dialect) {
124   return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125                               InterfaceT::getInterfaceID());
126 }
127 
128 } // namespace dialect_extension_detail
129 
130 //===----------------------------------------------------------------------===//
131 // DialectRegistry
132 //===----------------------------------------------------------------------===//
133 
134 /// The DialectRegistry maps a dialect namespace to a constructor for the
135 /// matching dialect. This allows for decoupling the list of dialects
136 /// "available" from the dialects loaded in the Context. The parser in
137 /// particular will lazily load dialects in the Context as operations are
138 /// encountered.
139 class DialectRegistry {
140   using MapTy =
141       std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>,
142                std::less<>>;
143 
144 public:
145   explicit DialectRegistry();
146 
147   template <typename ConcreteDialect>
148   void insert() {
149     insert(TypeID::get<ConcreteDialect>(),
150            ConcreteDialect::getDialectNamespace(),
151            static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
152              // Just allocate the dialect, the context
153              // takes ownership of it.
154              return ctx->getOrLoadDialect<ConcreteDialect>();
155            })));
156   }
157 
158   template <typename ConcreteDialect, typename OtherDialect,
159             typename... MoreDialects>
160   void insert() {
161     insert<ConcreteDialect>();
162     insert<OtherDialect, MoreDialects...>();
163   }
164 
165   /// Add a new dialect constructor to the registry. The constructor must be
166   /// calling MLIRContext::getOrLoadDialect in order for the context to take
167   /// ownership of the dialect and for delayed interface registration to happen.
168   void insert(TypeID typeID, StringRef name,
169               const DialectAllocatorFunction &ctor);
170 
171   /// Add a new dynamic dialect constructor in the registry. The constructor
172   /// provides as argument the created dynamic dialect, and is expected to
173   /// register the dialect types, attributes, and ops, using the
174   /// methods defined in ExtensibleDialect such as registerDynamicOperation.
175   void insertDynamic(StringRef name,
176                      const DynamicDialectPopulationFunction &ctor);
177 
178   /// Return an allocation function for constructing the dialect identified
179   /// by its namespace, or nullptr if the namespace is not in this registry.
180   DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
181 
182   // Register all dialects available in the current registry with the registry
183   // in the provided context.
184   void appendTo(DialectRegistry &destination) const {
185     for (const auto &nameAndRegistrationIt : registry)
186       destination.insert(nameAndRegistrationIt.second.first,
187                          nameAndRegistrationIt.first,
188                          nameAndRegistrationIt.second.second);
189     // Merge the extensions.
190     for (const auto &extension : extensions)
191       destination.extensions.try_emplace(extension.first,
192                                          extension.second->clone());
193   }
194 
195   /// Return the names of dialects known to this registry.
196   auto getDialectNames() const {
197     return llvm::map_range(
198         registry,
199         [](const MapTy::value_type &item) -> StringRef { return item.first; });
200   }
201 
202   /// Apply any held extensions that require the given dialect. Users are not
203   /// expected to call this directly.
204   void applyExtensions(Dialect *dialect) const;
205 
206   /// Apply any applicable extensions to the given context. Users are not
207   /// expected to call this directly.
208   void applyExtensions(MLIRContext *ctx) const;
209 
210   /// Add the given extension to the registry.
211   bool addExtension(TypeID extensionID,
212                     std::unique_ptr<DialectExtensionBase> extension) {
213     return extensions.try_emplace(extensionID, std::move(extension)).second;
214   }
215 
216   /// Add the given extensions to the registry.
217   template <typename... ExtensionsT>
218   void addExtensions() {
219     (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()),
220      ...);
221   }
222 
223   /// Add an extension function that requires the given dialects.
224   /// Note: This bare functor overload is provided in addition to the
225   /// std::function variant to enable dialect type deduction, e.g.:
226   ///  registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) {
227   ///  ... })
228   ///
229   /// is equivalent to:
230   ///  registry.addExtension<MyDialect>(
231   ///     [](MLIRContext *ctx, MyDialect *dialect){ ... }
232   ///  )
233   template <typename... DialectsT>
234   bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
235     using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...);
236 
237     struct Extension : public DialectExtension<Extension, DialectsT...> {
238       Extension(const Extension &) = default;
239       Extension(ExtensionFnT extensionFn)
240           : DialectExtension<Extension, DialectsT...>(),
241             extensionFn(extensionFn) {}
242       ~Extension() override = default;
243 
244       void apply(MLIRContext *context, DialectsT *...dialects) const final {
245         extensionFn(context, dialects...);
246       }
247       ExtensionFnT extensionFn;
248     };
249     return addExtension(TypeID::getFromOpaquePointer(
250                             reinterpret_cast<const void *>(extensionFn)),
251                         std::make_unique<Extension>(extensionFn));
252   }
253 
254   /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
255   /// contains all of the components of this registry.
256   bool isSubsetOf(const DialectRegistry &rhs) const;
257 
258 private:
259   MapTy registry;
260   llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions;
261 };
262 
263 } // namespace mlir
264 
265 #endif // MLIR_IR_DIALECTREGISTRY_H
266