1 //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines functionality for registring and extending dialects. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_DIALECTREGISTRY_H 14 #define MLIR_IR_DIALECTREGISTRY_H 15 16 #include "mlir/IR/MLIRContext.h" 17 #include "mlir/Support/TypeID.h" 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/MapVector.h" 20 21 #include <map> 22 #include <tuple> 23 24 namespace mlir { 25 class Dialect; 26 27 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>; 28 using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>; 29 using DynamicDialectPopulationFunction = 30 std::function<void(MLIRContext *, DynamicDialect *)>; 31 32 //===----------------------------------------------------------------------===// 33 // DialectExtension 34 //===----------------------------------------------------------------------===// 35 36 /// This class represents an opaque dialect extension. It contains a set of 37 /// required dialects and an application function. The required dialects control 38 /// when the extension is applied, i.e. the extension is applied when all 39 /// required dialects are loaded. The application function can be used to attach 40 /// additional functionality to attributes, dialects, operations, types, etc., 41 /// and may also load additional necessary dialects. 42 class DialectExtensionBase { 43 public: 44 virtual ~DialectExtensionBase(); 45 46 /// Return the dialects that our required by this extension to be loaded 47 /// before applying. If empty then the extension is invoked for every loaded 48 /// dialect indepently. 49 ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; } 50 51 /// Apply this extension to the given context and the required dialects. 52 virtual void apply(MLIRContext *context, 53 MutableArrayRef<Dialect *> dialects) const = 0; 54 55 /// Return a copy of this extension. 56 virtual std::unique_ptr<DialectExtensionBase> clone() const = 0; 57 58 protected: 59 /// Initialize the extension with a set of required dialects. 60 /// If the list is empty, the extension is invoked for every loaded dialect 61 /// independently. 62 DialectExtensionBase(ArrayRef<StringRef> dialectNames) 63 : dialectNames(dialectNames) {} 64 65 private: 66 /// The names of the dialects affected by this extension. 67 SmallVector<StringRef> dialectNames; 68 }; 69 70 /// This class represents a dialect extension anchored on the given set of 71 /// dialects. When all of the specified dialects have been loaded, the 72 /// application function of this extension will be executed. 73 template <typename DerivedT, typename... DialectsT> 74 class DialectExtension : public DialectExtensionBase { 75 public: 76 /// Applies this extension to the given context and set of required dialects. 77 virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0; 78 79 /// Return a copy of this extension. 80 std::unique_ptr<DialectExtensionBase> clone() const final { 81 return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this)); 82 } 83 84 protected: 85 DialectExtension() 86 : DialectExtensionBase( 87 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {} 88 89 /// Override the base apply method to allow providing the exact dialect types. 90 void apply(MLIRContext *context, 91 MutableArrayRef<Dialect *> dialects) const final { 92 unsigned dialectIdx = 0; 93 auto derivedDialects = std::tuple<DialectsT *...>{ 94 static_cast<DialectsT *>(dialects[dialectIdx++])...}; 95 std::apply([&](DialectsT *...dialect) { apply(context, dialect...); }, 96 derivedDialects); 97 } 98 }; 99 100 namespace dialect_extension_detail { 101 102 /// Checks if the given interface, which is attempting to be used, is a 103 /// promised interface of this dialect that has yet to be implemented. If so, 104 /// emits a fatal error. 105 void handleUseOfUndefinedPromisedInterface(Dialect &dialect, 106 TypeID interfaceRequestorID, 107 TypeID interfaceID, 108 StringRef interfaceName); 109 110 /// Checks if the given interface, which is attempting to be attached, is a 111 /// promised interface of this dialect that has yet to be implemented. If so, 112 /// the promised interface is marked as resolved. 113 void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, 114 TypeID interfaceRequestorID, 115 TypeID interfaceID); 116 117 /// Checks if a promise has been made for the interface/requestor pair. 118 bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, 119 TypeID interfaceID); 120 121 /// Checks if a promise has been made for the interface/requestor pair. 122 template <typename ConcreteT, typename InterfaceT> 123 bool hasPromisedInterface(Dialect &dialect) { 124 return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(), 125 InterfaceT::getInterfaceID()); 126 } 127 128 } // namespace dialect_extension_detail 129 130 //===----------------------------------------------------------------------===// 131 // DialectRegistry 132 //===----------------------------------------------------------------------===// 133 134 /// The DialectRegistry maps a dialect namespace to a constructor for the 135 /// matching dialect. This allows for decoupling the list of dialects 136 /// "available" from the dialects loaded in the Context. The parser in 137 /// particular will lazily load dialects in the Context as operations are 138 /// encountered. 139 class DialectRegistry { 140 using MapTy = 141 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>, 142 std::less<>>; 143 144 public: 145 explicit DialectRegistry(); 146 147 template <typename ConcreteDialect> 148 void insert() { 149 insert(TypeID::get<ConcreteDialect>(), 150 ConcreteDialect::getDialectNamespace(), 151 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) { 152 // Just allocate the dialect, the context 153 // takes ownership of it. 154 return ctx->getOrLoadDialect<ConcreteDialect>(); 155 }))); 156 } 157 158 template <typename ConcreteDialect, typename OtherDialect, 159 typename... MoreDialects> 160 void insert() { 161 insert<ConcreteDialect>(); 162 insert<OtherDialect, MoreDialects...>(); 163 } 164 165 /// Add a new dialect constructor to the registry. The constructor must be 166 /// calling MLIRContext::getOrLoadDialect in order for the context to take 167 /// ownership of the dialect and for delayed interface registration to happen. 168 void insert(TypeID typeID, StringRef name, 169 const DialectAllocatorFunction &ctor); 170 171 /// Add a new dynamic dialect constructor in the registry. The constructor 172 /// provides as argument the created dynamic dialect, and is expected to 173 /// register the dialect types, attributes, and ops, using the 174 /// methods defined in ExtensibleDialect such as registerDynamicOperation. 175 void insertDynamic(StringRef name, 176 const DynamicDialectPopulationFunction &ctor); 177 178 /// Return an allocation function for constructing the dialect identified 179 /// by its namespace, or nullptr if the namespace is not in this registry. 180 DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; 181 182 // Register all dialects available in the current registry with the registry 183 // in the provided context. 184 void appendTo(DialectRegistry &destination) const { 185 for (const auto &nameAndRegistrationIt : registry) 186 destination.insert(nameAndRegistrationIt.second.first, 187 nameAndRegistrationIt.first, 188 nameAndRegistrationIt.second.second); 189 // Merge the extensions. 190 for (const auto &extension : extensions) 191 destination.extensions.try_emplace(extension.first, 192 extension.second->clone()); 193 } 194 195 /// Return the names of dialects known to this registry. 196 auto getDialectNames() const { 197 return llvm::map_range( 198 registry, 199 [](const MapTy::value_type &item) -> StringRef { return item.first; }); 200 } 201 202 /// Apply any held extensions that require the given dialect. Users are not 203 /// expected to call this directly. 204 void applyExtensions(Dialect *dialect) const; 205 206 /// Apply any applicable extensions to the given context. Users are not 207 /// expected to call this directly. 208 void applyExtensions(MLIRContext *ctx) const; 209 210 /// Add the given extension to the registry. 211 bool addExtension(TypeID extensionID, 212 std::unique_ptr<DialectExtensionBase> extension) { 213 return extensions.try_emplace(extensionID, std::move(extension)).second; 214 } 215 216 /// Add the given extensions to the registry. 217 template <typename... ExtensionsT> 218 void addExtensions() { 219 (addExtension(TypeID::get<ExtensionsT>(), std::make_unique<ExtensionsT>()), 220 ...); 221 } 222 223 /// Add an extension function that requires the given dialects. 224 /// Note: This bare functor overload is provided in addition to the 225 /// std::function variant to enable dialect type deduction, e.g.: 226 /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { 227 /// ... }) 228 /// 229 /// is equivalent to: 230 /// registry.addExtension<MyDialect>( 231 /// [](MLIRContext *ctx, MyDialect *dialect){ ... } 232 /// ) 233 template <typename... DialectsT> 234 bool addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { 235 using ExtensionFnT = void (*)(MLIRContext *, DialectsT *...); 236 237 struct Extension : public DialectExtension<Extension, DialectsT...> { 238 Extension(const Extension &) = default; 239 Extension(ExtensionFnT extensionFn) 240 : DialectExtension<Extension, DialectsT...>(), 241 extensionFn(extensionFn) {} 242 ~Extension() override = default; 243 244 void apply(MLIRContext *context, DialectsT *...dialects) const final { 245 extensionFn(context, dialects...); 246 } 247 ExtensionFnT extensionFn; 248 }; 249 return addExtension(TypeID::getFromOpaquePointer( 250 reinterpret_cast<const void *>(extensionFn)), 251 std::make_unique<Extension>(extensionFn)); 252 } 253 254 /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' 255 /// contains all of the components of this registry. 256 bool isSubsetOf(const DialectRegistry &rhs) const; 257 258 private: 259 MapTy registry; 260 llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>> extensions; 261 }; 262 263 } // namespace mlir 264 265 #endif // MLIR_IR_DIALECTREGISTRY_H 266